Last active
March 17, 2025 02:19
-
-
Save sert121/1aeaf761b632cd15e5323d260f6b12be to your computer and use it in GitHub Desktop.
A simplified implementation of the GraphRouter paper
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import logging | |
from datasets import load_dataset | |
from transformers import AutoTokenizer, AutoModel | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
torch.manual_seed(42) | |
np.random.seed(42) | |
# load distilbert-base-uncased for PLM | |
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") | |
plm_model = AutoModel.from_pretrained("distilbert-base-uncased") | |
plm_model.eval() | |
def sample_dataset(dataset, field, num_samples): | |
ds = dataset.shuffle(seed=42) | |
samples = ds.select(range(num_samples)) | |
texts = [sample[field] for sample in samples] | |
return texts | |
def plot_losses(train_losses, val_losses, val_accuracies, num_epochs): | |
# plots for train, val loss and val accuracy | |
epochs_range = np.arange(1, num_epochs+1) | |
plt.figure(figsize=(14, 4)) | |
plt.subplot(1, 3, 1) | |
plt.plot(epochs_range, train_losses, label='Train Loss') | |
plt.xlabel('Epoch') | |
plt.ylabel('Loss') | |
plt.title('Training Loss') | |
plt.legend() | |
plt.subplot(1, 3, 2) | |
plt.plot(epochs_range, val_losses, label='Val Loss', color='orange') | |
plt.xlabel('Epoch') | |
plt.ylabel('Loss') | |
plt.title('Validation Loss') | |
plt.legend() | |
plt.subplot(1, 3, 3) | |
plt.plot(epochs_range, val_accuracies, label='Val Accuracy', color='green') | |
plt.xlabel('Epoch') | |
plt.ylabel('Accuracy') | |
plt.title('Validation Accuracy') | |
plt.legend() | |
plt.tight_layout() | |
plt.savefig('graphrouter_performance_v0.png') | |
def get_text_embedding(text, tokenizer, model): | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
embedding = outputs.last_hidden_state.mean(dim=1).squeeze(0) # [hidden_dim] | |
return embedding | |
class GraphRouterGNN(nn.Module): | |
def __init__(self, proj_dim, num_layers=2, initial_task_emb=None, initial_llm_emb=None): | |
super(GraphRouterGNN, self).__init__() | |
self.proj_dim = proj_dim | |
self.num_layers = num_layers | |
# initialize task and llm embeddings using PLM-based projections if provided. | |
if initial_task_emb is not None: | |
self.task_emb = nn.Parameter(initial_task_emb) | |
else: | |
self.task_emb = nn.Parameter(torch.randn(num_tasks, proj_dim)) | |
if initial_llm_emb is not None: | |
self.llm_emb = nn.Parameter(initial_llm_emb) | |
else: | |
self.llm_emb = nn.Parameter(torch.randn(num_llms, proj_dim)) | |
# message passing layers for query nodes. | |
self.lin_query = nn.Linear(proj_dim, proj_dim) | |
self.lin_task_to_query = nn.Linear(proj_dim, proj_dim) | |
self.lin_llm_to_query = nn.Linear(proj_dim, proj_dim) | |
# update layers for task nodes. | |
self.lin_task_self = nn.Linear(proj_dim, proj_dim) | |
self.lin_query_to_task = nn.Linear(proj_dim, proj_dim) | |
# update layers for llm nodes. | |
self.lin_llm_self = nn.Linear(proj_dim, proj_dim) | |
self.lin_query_to_llm = nn.Linear(proj_dim, proj_dim) | |
# mlp to combine query and corresponding task embeddings. | |
self.mlp_qt = nn.Sequential( | |
nn.Linear(proj_dim * 2, proj_dim), | |
nn.ReLU(), | |
nn.Linear(proj_dim, proj_dim) | |
) | |
def forward(self, query_feats, query_task_ids): | |
# initialize node features. | |
tasks = self.task_emb # shape: [num_tasks, proj_dim] | |
llms = self.llm_emb # shape: [num_llms, proj_dim] | |
queries = query_feats # shape: [num_queries, proj_dim] | |
# perform message passing. | |
for layer in range(self.num_layers): | |
# --- update query nodes --- | |
task_msg = tasks[query_task_ids] # For each query, get corresponding task node. | |
llm_msg = llms.mean(dim=0, keepdim=True).expand_as(queries) # Average message from llm nodes. | |
queries = torch.relu(self.lin_query(queries) + | |
self.lin_task_to_query(task_msg) + | |
self.lin_llm_to_query(llm_msg)) | |
# --- update task nodes --- | |
new_tasks = [] | |
for t in range(num_tasks): | |
mask = (query_task_ids == t) | |
if mask.sum() > 0: | |
query_mean = queries[mask].mean(dim=0, keepdim=True) | |
else: | |
query_mean = torch.zeros(1, self.proj_dim, device=queries.device) | |
updated_t = torch.relu(self.lin_task_self(tasks[t:t+1]) + | |
self.lin_query_to_task(query_mean)) | |
new_tasks.append(updated_t) | |
tasks = torch.cat(new_tasks, dim=0) | |
# --- update llm nodes --- | |
query_mean_all = queries.mean(dim=0, keepdim=True) | |
llms = torch.relu(self.lin_llm_self(llms) + | |
self.lin_query_to_llm(query_mean_all)) | |
# combining query and corresponding task embeddings. | |
task_for_query = tasks[query_task_ids] | |
combined = torch.cat([queries, task_for_query], dim=1) | |
h_qt = self.mlp_qt(combined) | |
# computing logits: dot product with llm embeddings. | |
logits = torch.matmul(h_qt, llms.t()) | |
return logits | |
if __name__ == "__main__": | |
num_tasks = 2 # 0: Alpaca, 1: GSM8K | |
num_llms = 2 # 0: gpt-4o-mini, 1: claude-3-5-haiku | |
num_samples = 200 # from each dataset | |
try: | |
alpaca_dataset = load_dataset("tatsu-lab/alpaca") | |
logging.info("Loaded Alpaca dataset from HF") | |
except Exception as e: | |
logging.error("Error loading Alpaca dataset: %s", e) | |
raise | |
try: | |
gsm8k_dataset = load_dataset("gsm8k", "main") | |
logging.info("Loaded GSM8K dataset from Hugging Face") | |
except Exception as e: | |
logging.error("Error loading GSM8K dataset: %s", e) | |
raise | |
alpaca_texts = sample_dataset(alpaca_dataset["train"], "instruction", num_samples) | |
gsm8k_texts = sample_dataset(gsm8k_dataset["train"], "question", num_samples) | |
query_texts = alpaca_texts + gsm8k_texts | |
query_task_ids_list = [0] * num_samples + [1] * num_samples | |
# descriptive texts for tasks and llm nodes (generated by gpt-4o) | |
task_texts = [ | |
"Alpaca: An instruction-following dataset requiring coherent responses.", | |
"GSM8K: A dataset of grade school math problems requiring multi-step reasoning." | |
] | |
llm_texts = [ | |
"gpt-4o-mini: A compact model optimized for instruction following tasks.", | |
"claude-3-5-haiku: A model optimized for mathematical reasoning tasks." | |
] | |
# encode tasks and llm nodes using plm | |
task_emb_list = [get_text_embedding(text, tokenizer, plm_model) for text in task_texts] | |
llm_emb_list = [get_text_embedding(text, tokenizer, plm_model) for text in llm_texts] | |
proj_dim = 32 | |
projection = nn.Linear(768, proj_dim) # projection may help as we deal with small dataset + few tasks | |
# project initial task and llm embeddings. | |
with torch.no_grad(): | |
task_emb_init = torch.stack(task_emb_list, dim=0) | |
llm_emb_init = torch.stack(llm_emb_list, dim=0) | |
task_emb_init_proj = projection(task_emb_init) # shape: [2, proj_dim] | |
llm_emb_init_proj = projection(llm_emb_init) # shape: [2, proj_dim] | |
# encode queries using plm, project to lower dimension | |
query_emb_list = [] | |
for text in query_texts: | |
emb = get_text_embedding(text, tokenizer, plm_model) | |
query_emb_list.append(emb) | |
query_emb = torch.stack(query_emb_list, dim=0) | |
query_emb = projection(query_emb) # shape: [400, proj_dim] | |
query_emb = query_emb.detach() # detach to prevent reusing the computation graph | |
query_task_ids_all = torch.tensor(query_task_ids_list, dtype=torch.long) | |
# ground truth: route task 0 -> llm 0, task 1 -> llm 1. | |
query_llm_labels = query_task_ids_all.clone() | |
# split data into train, validation, and test sets | |
num_queries = query_emb.size(0) | |
indices = torch.randperm(num_queries) | |
train_end = int(0.7 * num_queries) | |
val_end = int(0.8 * num_queries) | |
train_idx = indices[:train_end] | |
val_idx = indices[train_end:val_end] | |
test_idx = indices[val_end:] | |
train_queries = query_emb[train_idx] | |
train_task_ids = query_task_ids_all[train_idx] | |
train_labels = query_llm_labels[train_idx] | |
val_queries = query_emb[val_idx] | |
val_task_ids = query_task_ids_all[val_idx] | |
val_labels = query_llm_labels[val_idx] | |
test_queries = query_emb[test_idx] | |
test_task_ids = query_task_ids_all[test_idx] | |
test_labels = query_llm_labels[test_idx] | |
# training and evaluation | |
num_epochs = 20 | |
model = GraphRouterGNN(proj_dim, num_layers=2, | |
initial_task_emb=task_emb_init_proj, | |
initial_llm_emb=llm_emb_init_proj) | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.Adam(model.parameters(), lr=1e-3) | |
train_losses = [] | |
val_losses = [] | |
val_accuracies = [] | |
logging.info("Starting training...") | |
for epoch in range(num_epochs): | |
model.train() | |
optimizer.zero_grad() | |
logits = model(train_queries, train_task_ids) | |
loss = criterion(logits, train_labels) | |
loss.backward() | |
optimizer.step() | |
train_losses.append(loss.item()) | |
model.eval() | |
with torch.no_grad(): | |
val_logits = model(val_queries, val_task_ids) | |
loss_val = criterion(val_logits, val_labels) | |
val_losses.append(loss_val.item()) | |
preds = torch.argmax(val_logits, dim=1) | |
val_acc = (preds == val_labels).float().mean().item() | |
val_accuracies.append(val_acc) | |
logging.info(f"Epoch {epoch+1:02d}: Train Loss = {loss.item():.4f}, Val Loss = {loss_val.item():.4f}, Val Acc = {val_acc:.4f}") | |
# test evaluation | |
model.eval() | |
with torch.no_grad(): | |
test_logits = model(test_queries, test_task_ids) | |
test_preds = torch.argmax(test_logits, dim=1) | |
test_acc = (test_preds == test_labels).float().mean().item() | |
logging.info(f"Test Accuracy on held-out samples: {test_acc:.4f}") | |
plot_losses(train_losses, val_losses, val_accuracies, num_epochs) |
Some notes:
Holdout set accuracy being high could be explained as we are considering very few samples here.
The model takes relatively few epochs to converge similarly.
Training could continue further than 30 epochs, capped for now.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
A preview of the training, val losses.