Created
February 28, 2025 02:29
-
-
Save FrancescoJo/d7cb55c80f7b6b074b5e5de23652f9d9 to your computer and use it in GitHub Desktop.
Simple MoE implementation using tf/keras
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow.keras import layers, Model | |
import numpy as np | |
class ExpertLayer(layers.Layer): | |
"""Individual expert network.""" | |
def __init__(self, hidden_dim, output_dim, **kwargs): | |
super(ExpertLayer, self).__init__(**kwargs) | |
self.hidden_layer = layers.Dense(hidden_dim, activation='relu') | |
self.output_layer = layers.Dense(output_dim, activation='linear') | |
def call(self, inputs): | |
x = self.hidden_layer(inputs) | |
return self.output_layer(x) | |
class GatingNetwork(layers.Layer): | |
"""Gating network that determines which experts to use for each input.""" | |
def __init__(self, num_experts, **kwargs): | |
super(GatingNetwork, self).__init__(**kwargs) | |
self.num_experts = num_experts | |
self.gate = layers.Dense(num_experts, activation='softmax') | |
def call(self, inputs): | |
return self.gate(inputs) | |
class MixtureOfExperts(Model): | |
"""Mixture of Experts model.""" | |
def __init__(self, num_experts, hidden_dim, output_dim, **kwargs): | |
super(MixtureOfExperts, self).__init__(**kwargs) | |
# Create experts | |
self.experts = [ExpertLayer(hidden_dim, output_dim, name=f'expert_{i}') | |
for i in range(num_experts)] | |
# Create gating network | |
self.gating_network = GatingNetwork(num_experts) | |
def call(self, inputs): | |
# Get expert outputs | |
expert_outputs = tf.stack([expert(inputs) for expert in self.experts], axis=1) | |
# Get gating weights | |
gating_weights = self.gating_network(inputs) | |
# Shape gating weights for proper broadcasting | |
gating_weights = tf.expand_dims(gating_weights, axis=2) | |
# Combine expert outputs using gating weights | |
final_output = tf.reduce_sum(expert_outputs * gating_weights, axis=1) | |
return final_output | |
# Example usage | |
def create_dataset(): | |
"""Create a simple synthetic dataset for demonstration.""" | |
np.random.seed(42) | |
# Generate random input data | |
X = np.random.rand(1000, 10).astype(np.float32) | |
# Generate target data - different rules for different input ranges | |
y = np.zeros((1000, 1), dtype=np.float32) | |
# First expert: sum of first 3 features | |
mask1 = X[:, 0] < 0.3 | |
y[mask1] = np.sum(X[mask1, :3], axis=1, keepdims=True) | |
# Second expert: product of features 4-6 | |
mask2 = (X[:, 0] >= 0.3) & (X[:, 0] < 0.6) | |
y[mask2] = np.prod(X[mask2, 3:6], axis=1, keepdims=True) | |
# Third expert: difference between max and min of features 7-9 | |
mask3 = X[:, 0] >= 0.6 | |
y[mask3] = (np.max(X[mask3, 6:9], axis=1) - np.min(X[mask3, 6:9], axis=1)).reshape(-1, 1) | |
# Add some noise | |
y += np.random.normal(0, 0.1, y.shape).astype(np.float32) | |
# Split into train and test | |
train_size = int(0.8 * len(X)) | |
X_train, X_test = X[:train_size], X[train_size:] | |
y_train, y_test = y[:train_size], y[train_size:] | |
return (X_train, y_train), (X_test, y_test) | |
def train_and_evaluate_moe(): | |
# Create dataset | |
(X_train, y_train), (X_test, y_test) = create_dataset() | |
# Define model parameters | |
num_experts = 3 | |
input_dim = X_train.shape[1] | |
hidden_dim = 16 | |
output_dim = 1 | |
# Create MoE model | |
moe_model = MixtureOfExperts(num_experts, hidden_dim, output_dim) | |
# Compile model | |
moe_model.compile( | |
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), | |
loss='mse' | |
) | |
# Train model | |
history = moe_model.fit( | |
X_train, y_train, | |
epochs=50, | |
batch_size=32, | |
validation_data=(X_test, y_test), | |
verbose=1 | |
) | |
# Evaluate model | |
test_loss = moe_model.evaluate(X_test, y_test) | |
print(f"Test Loss: {test_loss}") | |
# Visualize gating network decisions | |
gating_outputs = moe_model.gating_network(X_test).numpy() | |
print("Gating Network Output Examples:") | |
for i in range(5): | |
print(f"Sample {i}: Expert weights = {gating_outputs[i]}") | |
# Plot training history | |
import matplotlib.pyplot as plt | |
plt.figure(figsize=(12, 4)) | |
plt.subplot(1, 2, 1) | |
plt.plot(history.history['loss']) | |
plt.plot(history.history['val_loss']) | |
plt.title('Model Loss') | |
plt.ylabel('Loss') | |
plt.xlabel('Epoch') | |
plt.legend(['Train', 'Validation'], loc='upper right') | |
plt.grid(True) | |
# Plot expert selection distribution | |
plt.subplot(1, 2, 2) | |
selected_experts = np.argmax(gating_outputs, axis=1) | |
plt.hist(selected_experts, bins=range(num_experts + 1), alpha=0.7) | |
plt.title('Expert Selection Distribution') | |
plt.xlabel('Expert Index') | |
plt.ylabel('Count') | |
plt.xticks(range(num_experts)) | |
plt.grid(True) | |
plt.tight_layout() | |
plt.show() | |
return moe_model | |
if __name__ == "__main__": | |
# Check TensorFlow version | |
print(f"TensorFlow version: {tf.__version__}") | |
# Train and evaluate MoE model | |
model = train_and_evaluate_moe() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment