Skip to content

Instantly share code, notes, and snippets.

@FrancescoJo
Created February 28, 2025 02:29
Show Gist options
  • Save FrancescoJo/d7cb55c80f7b6b074b5e5de23652f9d9 to your computer and use it in GitHub Desktop.
Save FrancescoJo/d7cb55c80f7b6b074b5e5de23652f9d9 to your computer and use it in GitHub Desktop.
Simple MoE implementation using tf/keras
import tensorflow as tf
from tensorflow.keras import layers, Model
import numpy as np
class ExpertLayer(layers.Layer):
"""Individual expert network."""
def __init__(self, hidden_dim, output_dim, **kwargs):
super(ExpertLayer, self).__init__(**kwargs)
self.hidden_layer = layers.Dense(hidden_dim, activation='relu')
self.output_layer = layers.Dense(output_dim, activation='linear')
def call(self, inputs):
x = self.hidden_layer(inputs)
return self.output_layer(x)
class GatingNetwork(layers.Layer):
"""Gating network that determines which experts to use for each input."""
def __init__(self, num_experts, **kwargs):
super(GatingNetwork, self).__init__(**kwargs)
self.num_experts = num_experts
self.gate = layers.Dense(num_experts, activation='softmax')
def call(self, inputs):
return self.gate(inputs)
class MixtureOfExperts(Model):
"""Mixture of Experts model."""
def __init__(self, num_experts, hidden_dim, output_dim, **kwargs):
super(MixtureOfExperts, self).__init__(**kwargs)
# Create experts
self.experts = [ExpertLayer(hidden_dim, output_dim, name=f'expert_{i}')
for i in range(num_experts)]
# Create gating network
self.gating_network = GatingNetwork(num_experts)
def call(self, inputs):
# Get expert outputs
expert_outputs = tf.stack([expert(inputs) for expert in self.experts], axis=1)
# Get gating weights
gating_weights = self.gating_network(inputs)
# Shape gating weights for proper broadcasting
gating_weights = tf.expand_dims(gating_weights, axis=2)
# Combine expert outputs using gating weights
final_output = tf.reduce_sum(expert_outputs * gating_weights, axis=1)
return final_output
# Example usage
def create_dataset():
"""Create a simple synthetic dataset for demonstration."""
np.random.seed(42)
# Generate random input data
X = np.random.rand(1000, 10).astype(np.float32)
# Generate target data - different rules for different input ranges
y = np.zeros((1000, 1), dtype=np.float32)
# First expert: sum of first 3 features
mask1 = X[:, 0] < 0.3
y[mask1] = np.sum(X[mask1, :3], axis=1, keepdims=True)
# Second expert: product of features 4-6
mask2 = (X[:, 0] >= 0.3) & (X[:, 0] < 0.6)
y[mask2] = np.prod(X[mask2, 3:6], axis=1, keepdims=True)
# Third expert: difference between max and min of features 7-9
mask3 = X[:, 0] >= 0.6
y[mask3] = (np.max(X[mask3, 6:9], axis=1) - np.min(X[mask3, 6:9], axis=1)).reshape(-1, 1)
# Add some noise
y += np.random.normal(0, 0.1, y.shape).astype(np.float32)
# Split into train and test
train_size = int(0.8 * len(X))
X_train, X_test = X[:train_size], X[train_size:]
y_train, y_test = y[:train_size], y[train_size:]
return (X_train, y_train), (X_test, y_test)
def train_and_evaluate_moe():
# Create dataset
(X_train, y_train), (X_test, y_test) = create_dataset()
# Define model parameters
num_experts = 3
input_dim = X_train.shape[1]
hidden_dim = 16
output_dim = 1
# Create MoE model
moe_model = MixtureOfExperts(num_experts, hidden_dim, output_dim)
# Compile model
moe_model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss='mse'
)
# Train model
history = moe_model.fit(
X_train, y_train,
epochs=50,
batch_size=32,
validation_data=(X_test, y_test),
verbose=1
)
# Evaluate model
test_loss = moe_model.evaluate(X_test, y_test)
print(f"Test Loss: {test_loss}")
# Visualize gating network decisions
gating_outputs = moe_model.gating_network(X_test).numpy()
print("Gating Network Output Examples:")
for i in range(5):
print(f"Sample {i}: Expert weights = {gating_outputs[i]}")
# Plot training history
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper right')
plt.grid(True)
# Plot expert selection distribution
plt.subplot(1, 2, 2)
selected_experts = np.argmax(gating_outputs, axis=1)
plt.hist(selected_experts, bins=range(num_experts + 1), alpha=0.7)
plt.title('Expert Selection Distribution')
plt.xlabel('Expert Index')
plt.ylabel('Count')
plt.xticks(range(num_experts))
plt.grid(True)
plt.tight_layout()
plt.show()
return moe_model
if __name__ == "__main__":
# Check TensorFlow version
print(f"TensorFlow version: {tf.__version__}")
# Train and evaluate MoE model
model = train_and_evaluate_moe()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment