Skip to content

Instantly share code, notes, and snippets.

@sert121
Last active March 17, 2025 02:19
Show Gist options
  • Save sert121/1aeaf761b632cd15e5323d260f6b12be to your computer and use it in GitHub Desktop.
Save sert121/1aeaf761b632cd15e5323d260f6b12be to your computer and use it in GitHub Desktop.
A simplified implementation of the GraphRouter paper
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import logging
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
torch.manual_seed(42)
np.random.seed(42)
# load distilbert-base-uncased for PLM
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
plm_model = AutoModel.from_pretrained("distilbert-base-uncased")
plm_model.eval()
def sample_dataset(dataset, field, num_samples):
ds = dataset.shuffle(seed=42)
samples = ds.select(range(num_samples))
texts = [sample[field] for sample in samples]
return texts
def plot_losses(train_losses, val_losses, val_accuracies, num_epochs):
# plots for train, val loss and val accuracy
epochs_range = np.arange(1, num_epochs+1)
plt.figure(figsize=(14, 4))
plt.subplot(1, 3, 1)
plt.plot(epochs_range, train_losses, label='Train Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.subplot(1, 3, 2)
plt.plot(epochs_range, val_losses, label='Val Loss', color='orange')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Validation Loss')
plt.legend()
plt.subplot(1, 3, 3)
plt.plot(epochs_range, val_accuracies, label='Val Accuracy', color='green')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Validation Accuracy')
plt.legend()
plt.tight_layout()
plt.savefig('graphrouter_performance_v0.png')
def get_text_embedding(text, tokenizer, model):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
embedding = outputs.last_hidden_state.mean(dim=1).squeeze(0) # [hidden_dim]
return embedding
class GraphRouterGNN(nn.Module):
def __init__(self, proj_dim, num_layers=2, initial_task_emb=None, initial_llm_emb=None):
super(GraphRouterGNN, self).__init__()
self.proj_dim = proj_dim
self.num_layers = num_layers
# initialize task and llm embeddings using PLM-based projections if provided.
if initial_task_emb is not None:
self.task_emb = nn.Parameter(initial_task_emb)
else:
self.task_emb = nn.Parameter(torch.randn(num_tasks, proj_dim))
if initial_llm_emb is not None:
self.llm_emb = nn.Parameter(initial_llm_emb)
else:
self.llm_emb = nn.Parameter(torch.randn(num_llms, proj_dim))
# message passing layers for query nodes.
self.lin_query = nn.Linear(proj_dim, proj_dim)
self.lin_task_to_query = nn.Linear(proj_dim, proj_dim)
self.lin_llm_to_query = nn.Linear(proj_dim, proj_dim)
# update layers for task nodes.
self.lin_task_self = nn.Linear(proj_dim, proj_dim)
self.lin_query_to_task = nn.Linear(proj_dim, proj_dim)
# update layers for llm nodes.
self.lin_llm_self = nn.Linear(proj_dim, proj_dim)
self.lin_query_to_llm = nn.Linear(proj_dim, proj_dim)
# mlp to combine query and corresponding task embeddings.
self.mlp_qt = nn.Sequential(
nn.Linear(proj_dim * 2, proj_dim),
nn.ReLU(),
nn.Linear(proj_dim, proj_dim)
)
def forward(self, query_feats, query_task_ids):
# initialize node features.
tasks = self.task_emb # shape: [num_tasks, proj_dim]
llms = self.llm_emb # shape: [num_llms, proj_dim]
queries = query_feats # shape: [num_queries, proj_dim]
# perform message passing.
for layer in range(self.num_layers):
# --- update query nodes ---
task_msg = tasks[query_task_ids] # For each query, get corresponding task node.
llm_msg = llms.mean(dim=0, keepdim=True).expand_as(queries) # Average message from llm nodes.
queries = torch.relu(self.lin_query(queries) +
self.lin_task_to_query(task_msg) +
self.lin_llm_to_query(llm_msg))
# --- update task nodes ---
new_tasks = []
for t in range(num_tasks):
mask = (query_task_ids == t)
if mask.sum() > 0:
query_mean = queries[mask].mean(dim=0, keepdim=True)
else:
query_mean = torch.zeros(1, self.proj_dim, device=queries.device)
updated_t = torch.relu(self.lin_task_self(tasks[t:t+1]) +
self.lin_query_to_task(query_mean))
new_tasks.append(updated_t)
tasks = torch.cat(new_tasks, dim=0)
# --- update llm nodes ---
query_mean_all = queries.mean(dim=0, keepdim=True)
llms = torch.relu(self.lin_llm_self(llms) +
self.lin_query_to_llm(query_mean_all))
# combining query and corresponding task embeddings.
task_for_query = tasks[query_task_ids]
combined = torch.cat([queries, task_for_query], dim=1)
h_qt = self.mlp_qt(combined)
# computing logits: dot product with llm embeddings.
logits = torch.matmul(h_qt, llms.t())
return logits
if __name__ == "__main__":
num_tasks = 2 # 0: Alpaca, 1: GSM8K
num_llms = 2 # 0: gpt-4o-mini, 1: claude-3-5-haiku
num_samples = 200 # from each dataset
try:
alpaca_dataset = load_dataset("tatsu-lab/alpaca")
logging.info("Loaded Alpaca dataset from HF")
except Exception as e:
logging.error("Error loading Alpaca dataset: %s", e)
raise
try:
gsm8k_dataset = load_dataset("gsm8k", "main")
logging.info("Loaded GSM8K dataset from Hugging Face")
except Exception as e:
logging.error("Error loading GSM8K dataset: %s", e)
raise
alpaca_texts = sample_dataset(alpaca_dataset["train"], "instruction", num_samples)
gsm8k_texts = sample_dataset(gsm8k_dataset["train"], "question", num_samples)
query_texts = alpaca_texts + gsm8k_texts
query_task_ids_list = [0] * num_samples + [1] * num_samples
# descriptive texts for tasks and llm nodes (generated by gpt-4o)
task_texts = [
"Alpaca: An instruction-following dataset requiring coherent responses.",
"GSM8K: A dataset of grade school math problems requiring multi-step reasoning."
]
llm_texts = [
"gpt-4o-mini: A compact model optimized for instruction following tasks.",
"claude-3-5-haiku: A model optimized for mathematical reasoning tasks."
]
# encode tasks and llm nodes using plm
task_emb_list = [get_text_embedding(text, tokenizer, plm_model) for text in task_texts]
llm_emb_list = [get_text_embedding(text, tokenizer, plm_model) for text in llm_texts]
proj_dim = 32
projection = nn.Linear(768, proj_dim) # projection may help as we deal with small dataset + few tasks
# project initial task and llm embeddings.
with torch.no_grad():
task_emb_init = torch.stack(task_emb_list, dim=0)
llm_emb_init = torch.stack(llm_emb_list, dim=0)
task_emb_init_proj = projection(task_emb_init) # shape: [2, proj_dim]
llm_emb_init_proj = projection(llm_emb_init) # shape: [2, proj_dim]
# encode queries using plm, project to lower dimension
query_emb_list = []
for text in query_texts:
emb = get_text_embedding(text, tokenizer, plm_model)
query_emb_list.append(emb)
query_emb = torch.stack(query_emb_list, dim=0)
query_emb = projection(query_emb) # shape: [400, proj_dim]
query_emb = query_emb.detach() # detach to prevent reusing the computation graph
query_task_ids_all = torch.tensor(query_task_ids_list, dtype=torch.long)
# ground truth: route task 0 -> llm 0, task 1 -> llm 1.
query_llm_labels = query_task_ids_all.clone()
# split data into train, validation, and test sets
num_queries = query_emb.size(0)
indices = torch.randperm(num_queries)
train_end = int(0.7 * num_queries)
val_end = int(0.8 * num_queries)
train_idx = indices[:train_end]
val_idx = indices[train_end:val_end]
test_idx = indices[val_end:]
train_queries = query_emb[train_idx]
train_task_ids = query_task_ids_all[train_idx]
train_labels = query_llm_labels[train_idx]
val_queries = query_emb[val_idx]
val_task_ids = query_task_ids_all[val_idx]
val_labels = query_llm_labels[val_idx]
test_queries = query_emb[test_idx]
test_task_ids = query_task_ids_all[test_idx]
test_labels = query_llm_labels[test_idx]
# training and evaluation
num_epochs = 20
model = GraphRouterGNN(proj_dim, num_layers=2,
initial_task_emb=task_emb_init_proj,
initial_llm_emb=llm_emb_init_proj)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
train_losses = []
val_losses = []
val_accuracies = []
logging.info("Starting training...")
for epoch in range(num_epochs):
model.train()
optimizer.zero_grad()
logits = model(train_queries, train_task_ids)
loss = criterion(logits, train_labels)
loss.backward()
optimizer.step()
train_losses.append(loss.item())
model.eval()
with torch.no_grad():
val_logits = model(val_queries, val_task_ids)
loss_val = criterion(val_logits, val_labels)
val_losses.append(loss_val.item())
preds = torch.argmax(val_logits, dim=1)
val_acc = (preds == val_labels).float().mean().item()
val_accuracies.append(val_acc)
logging.info(f"Epoch {epoch+1:02d}: Train Loss = {loss.item():.4f}, Val Loss = {loss_val.item():.4f}, Val Acc = {val_acc:.4f}")
# test evaluation
model.eval()
with torch.no_grad():
test_logits = model(test_queries, test_task_ids)
test_preds = torch.argmax(test_logits, dim=1)
test_acc = (test_preds == test_labels).float().mean().item()
logging.info(f"Test Accuracy on held-out samples: {test_acc:.4f}")
plot_losses(train_losses, val_losses, val_accuracies, num_epochs)
@sert121
Copy link
Author

sert121 commented Mar 17, 2025

image
A preview of the training, val losses.

2025-03-16 22:01:43,840 - INFO - Loaded Alpaca dataset from alpaca_data.json
2025-03-16 22:01:47,027 - INFO - Loaded GSM8K dataset from Hugging Face
2025-03-16 22:01:54,680 - INFO - Starting training...
2025-03-16 22:01:54,739 - INFO - Epoch 01: Train Loss = 0.6932, Val Loss = 0.6926, Val Acc = 0.5500
2025-03-16 22:01:54,743 - INFO - Epoch 02: Train Loss = 0.6929, Val Loss = 0.6922, Val Acc = 0.5500
2025-03-16 22:01:54,747 - INFO - Epoch 03: Train Loss = 0.6926, Val Loss = 0.6918, Val Acc = 0.5500
2025-03-16 22:01:54,750 - INFO - Epoch 04: Train Loss = 0.6923, Val Loss = 0.6914, Val Acc = 0.5500
2025-03-16 22:01:54,754 - INFO - Epoch 05: Train Loss = 0.6920, Val Loss = 0.6909, Val Acc = 0.5500
2025-03-16 22:01:54,758 - INFO - Epoch 06: Train Loss = 0.6915, Val Loss = 0.6903, Val Acc = 0.5500
2025-03-16 22:01:54,762 - INFO - Epoch 07: Train Loss = 0.6910, Val Loss = 0.6896, Val Acc = 0.5500
2025-03-16 22:01:54,766 - INFO - Epoch 08: Train Loss = 0.6904, Val Loss = 0.6888, Val Acc = 0.5500
2025-03-16 22:01:54,769 - INFO - Epoch 09: Train Loss = 0.6896, Val Loss = 0.6878, Val Acc = 0.5500
2025-03-16 22:01:54,773 - INFO - Epoch 10: Train Loss = 0.6887, Val Loss = 0.6867, Val Acc = 0.5500
2025-03-16 22:01:54,777 - INFO - Epoch 11: Train Loss = 0.6876, Val Loss = 0.6853, Val Acc = 0.5500
2025-03-16 22:01:54,781 - INFO - Epoch 12: Train Loss = 0.6862, Val Loss = 0.6836, Val Acc = 0.5500
2025-03-16 22:01:54,785 - INFO - Epoch 13: Train Loss = 0.6846, Val Loss = 0.6816, Val Acc = 0.5500
2025-03-16 22:01:54,789 - INFO - Epoch 14: Train Loss = 0.6827, Val Loss = 0.6792, Val Acc = 0.5500
2025-03-16 22:01:54,792 - INFO - Epoch 15: Train Loss = 0.6804, Val Loss = 0.6764, Val Acc = 0.5500
2025-03-16 22:01:54,796 - INFO - Epoch 16: Train Loss = 0.6775, Val Loss = 0.6732, Val Acc = 0.7250
2025-03-16 22:01:54,800 - INFO - Epoch 17: Train Loss = 0.6743, Val Loss = 0.6693, Val Acc = 1.0000
2025-03-16 22:01:54,804 - INFO - Epoch 18: Train Loss = 0.6703, Val Loss = 0.6647, Val Acc = 1.0000
2025-03-16 22:01:54,808 - INFO - Epoch 19: Train Loss = 0.6656, Val Loss = 0.6593, Val Acc = 1.0000
2025-03-16 22:01:54,812 - INFO - Epoch 20: Train Loss = 0.6601, Val Loss = 0.6529, Val Acc = 1.0000
2025-03-16 22:01:54,815 - INFO - Epoch 21: Train Loss = 0.6536, Val Loss = 0.6454, Val Acc = 1.0000
2025-03-16 22:01:54,819 - INFO - Epoch 22: Train Loss = 0.6460, Val Loss = 0.6369, Val Acc = 1.0000
2025-03-16 22:01:54,823 - INFO - Epoch 23: Train Loss = 0.6370, Val Loss = 0.6270, Val Acc = 1.0000
2025-03-16 22:01:54,827 - INFO - Epoch 24: Train Loss = 0.6267, Val Loss = 0.6154, Val Acc = 1.0000
2025-03-16 22:01:54,831 - INFO - Epoch 25: Train Loss = 0.6147, Val Loss = 0.6020, Val Acc = 1.0000
2025-03-16 22:01:54,834 - INFO - Epoch 26: Train Loss = 0.6009, Val Loss = 0.5866, Val Acc = 1.0000
2025-03-16 22:01:54,838 - INFO - Epoch 27: Train Loss = 0.5853, Val Loss = 0.5687, Val Acc = 1.0000
2025-03-16 22:01:54,842 - INFO - Epoch 28: Train Loss = 0.5670, Val Loss = 0.5489, Val Acc = 1.0000
2025-03-16 22:01:54,846 - INFO - Epoch 29: Train Loss = 0.5472, Val Loss = 0.5264, Val Acc = 1.0000
2025-03-16 22:01:54,850 - INFO - Epoch 30: Train Loss = 0.5242, Val Loss = 0.5010, Val Acc = 1.0000
2025-03-16 22:01:54,850 - INFO - Test Accuracy on held-out samples: 1.0000

@sert121
Copy link
Author

sert121 commented Mar 17, 2025

Some notes:
Holdout set accuracy being high could be explained as we are considering very few samples here.
The model takes relatively few epochs to converge similarly.
Training could continue further than 30 epochs, capped for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment