Skip to content

Instantly share code, notes, and snippets.

@ruvnet
Last active March 11, 2025 00:25
Show Gist options
  • Save ruvnet/56807c220f4d80a82b6e0e8b276f631b to your computer and use it in GitHub Desktop.
Save ruvnet/56807c220f4d80a82b6e0e8b276f631b to your computer and use it in GitHub Desktop.
Diffusion-Based Coding Model with PyTorch
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Diffusion-Based Coding Model with PyTorch\n",
"\n",
"Welcome to the Diffusion-Based Coding Model notebook. This notebook provides a comprehensive, step-by-step guide to building a diffusion-based coding model from scratch using PyTorch. Our approach blends state-of-the-art ideas from diffusion models with code generation tasks. The notebook covers:\n",
"\n",
"- **Introduction**: Learn about the project, its features, and benefits.\n",
"- **Requirements & Installation**: How to set up your environment.\n",
"- **Data Collection & Preprocessing**: Fetching code snippets from GitHub, tokenizing, and augmenting data.\n",
"- **Model Architecture**: A baseline diffusion model with plans to extend to iterative denoising.\n",
"- **Training & Evaluation**: How to train, validate, and evaluate the model.\n",
"- **Inference**: Generate code snippets using the trained model.\n",
"- **Model Saving & Error Handling**: Best practices for saving models and robust logging.\n",
"\n",
"## Features\n",
"\n",
"- **Comprehensive Pipeline**: Data collection, preprocessing, augmentation, training, evaluation, and deployment.\n",
"- **Diffusion Model Foundations**: While simplified here, the structure is designed to be extended with iterative denoising steps typical of diffusion models.\n",
"- **Robust Data Handling**: Incorporates code tokenization, augmentation (insertion, deletion, swapping) to enhance model robustness.\n",
"- **Flexible Architecture**: A baseline LSTM-based model which can be replaced or extended with Transformer-based denoising architectures for diffusion.\n",
"\n",
"## Benefits\n",
"\n",
"- **Faster Inference Potential**: Diffusion models offer opportunities for parallel generation and iterative refinement, promising much faster token generation compared to autoregressive models.\n",
"- **Improved Global Consistency**: The iterative refinement process can help maintain code consistency across longer sequences.\n",
"- **Scalability**: Designed to be extended to distributed and large-scale training setups, which is critical for real-world coding assistants.\n",
"\n",
"Let's get started!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Requirements & Installation\n",
"\n",
"Before running this notebook, install the required packages. You can install all dependencies by running the cell below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"trusted": true
},
"outputs": [],
"source": [
"!pip install torch torchvision torchtext requests beautifulsoup4 scikit-learn pandas numpy"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data Collection\n",
"Collect code snippets and their corresponding prompts from GitHub repositories. In this example, we fetch code from a specified repository URL."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import requests\n",
"from bs4 import BeautifulSoup\n",
"\n",
"def fetch_code_snippets(repo_url):\n",
" response = requests.get(repo_url)\n",
" if response.status_code == 200:\n",
" soup = BeautifulSoup(response.content, 'html.parser')\n",
" code_snippets = []\n",
" for code_tag in soup.find_all('code'):\n",
" code_snippet = code_tag.get_text()\n",
" if code_snippet.strip():\n",
" code_snippets.append(code_snippet.strip())\n",
" return code_snippets\n",
" else:\n",
" raise Exception(f'Failed to fetch code snippets from {repo_url}')\n",
"\n",
"repo_url = 'https://github.com/example/repo'\n",
"code_snippets = fetch_code_snippets(repo_url)\n",
"print(f'Fetched {len(code_snippets)} code snippets.')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data Preprocessing\n",
"Tokenize code snippets and convert them into sequences of tokens using a basic tokenizer. In a real-world scenario, you would choose a tokenizer that understands code syntax better."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from torchtext.data.utils import get_tokenizer\n",
"\n",
"def tokenize_code(code_snippet):\n",
" tokenizer = get_tokenizer('basic_english')\n",
" tokens = tokenizer(code_snippet)\n",
" return tokens\n",
"\n",
"def preprocess_data(code_snippets):\n",
" tokenized_snippets = [tokenize_code(snippet) for snippet in code_snippets]\n",
" all_tokens = [token for snippet in tokenized_snippets for token in snippet]\n",
" unique_tokens = list(set(all_tokens))\n",
" token_to_idx = {token: idx for idx, token in enumerate(unique_tokens)}\n",
" idx_to_token = {idx: token for token, idx in token_to_idx.items()}\n",
" return tokenized_snippets, token_to_idx, idx_to_token\n",
"\n",
"tokenized_snippets, token_to_idx, idx_to_token = preprocess_data(code_snippets)\n",
"print(f'Total unique tokens: {len(token_to_idx)}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data Augmentation\n",
"Enhance the training data to make the model more robust. Here, we apply random insertion, deletion, and swapping of tokens."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"def augment_data(tokenized_snippets, token_to_idx):\n",
" augmented_snippets = []\n",
" for snippet in tokenized_snippets:\n",
" # Random Insertion\n",
" if np.random.rand() < 0.1:\n",
" insertion_index = np.random.randint(0, len(snippet) + 1)\n",
" inserted_token = np.random.choice(list(token_to_idx.keys()))\n",
" snippet.insert(insertion_index, inserted_token)\n",
" # Random Deletion\n",
" if np.random.rand() < 0.1 and len(snippet) > 1:\n",
" deletion_index = np.random.randint(0, len(snippet))\n",
" del snippet[deletion_index]\n",
" # Random Swap\n",
" if np.random.rand() < 0.1 and len(snippet) > 1:\n",
" swap_index1, swap_index2 = np.random.choice(len(snippet), 2, replace=False)\n",
" snippet[swap_index1], snippet[swap_index2] = snippet[swap_index2], snippet[swap_index1]\n",
" augmented_snippets.append(snippet)\n",
" return augmented_snippets\n",
"\n",
"augmented_snippets = augment_data(tokenized_snippets, token_to_idx)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model Architecture\n",
"Define the diffusion model architecture using PyTorch. In a full diffusion model, you would include a noise scheduler and iterative denoising steps. Here, we illustrate a baseline model (using an LSTM) that can serve as a foundation for adding diffusion-specific components."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"\n",
"class DiffusionModel(nn.Module):\n",
" def __init__(self, vocab_size, embedding_dim, hidden_dim):\n",
" super(DiffusionModel, self).__init__()\n",
" self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
" # In a full diffusion model, replace the LSTM with a Transformer-based denoiser that accepts time-step conditioning\n",
" self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)\n",
" self.fc = nn.Linear(hidden_dim, vocab_size)\n",
"\n",
" def forward(self, x):\n",
" x = self.embedding(x)\n",
" x, _ = self.lstm(x)\n",
" x = self.fc(x)\n",
" return x\n",
"\n",
"vocab_size = len(token_to_idx)\n",
"embedding_dim = 256\n",
"hidden_dim = 512\n",
"learning_rate = 0.001\n",
"\n",
"model = DiffusionModel(vocab_size, embedding_dim, hidden_dim)\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.Adam(model.parameters(), lr=learning_rate)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data Preparation\n",
"Prepare the data for training by converting tokens to indices and creating batches using a custom Dataset class."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import Dataset, DataLoader\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"class CodeDataset(Dataset):\n",
" def __init__(self, tokenized_snippets, token_to_idx, seq_length=100):\n",
" self.tokenized_snippets = tokenized_snippets\n",
" self.token_to_idx = token_to_idx\n",
" self.seq_length = seq_length\n",
" def __len__(self):\n",
" return len(self.tokenized_snippets)\n",
" def __getitem__(self, idx):\n",
" snippet = self.tokenized_snippets[idx]\n",
" input_seq = snippet[:self.seq_length]\n",
" target_seq = snippet[1:self.seq_length+1]\n",
" input_seq = [self.token_to_idx[token] for token in input_seq]\n",
" target_seq = [self.token_to_idx[token] for token in target_seq]\n",
" return torch.tensor(input_seq), torch.tensor(target_seq)\n",
"\n",
"seq_length = 100\n",
"batch_size = 32\n",
"\n",
"dataset = CodeDataset(augmented_snippets, token_to_idx, seq_length)\n",
"train_dataset, val_dataset = train_test_split(dataset, test_size=0.2, random_state=42)\n",
"train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
"val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training\n",
"Train the model using the prepared data. In a full diffusion model, you would incorporate a noise schedule and iterative denoising. Here, we demonstrate a standard training loop for a baseline model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"num_epochs = 10\n",
"scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)\n",
"\n",
"for epoch in range(num_epochs):\n",
" model.train()\n",
" total_loss = 0\n",
" for inputs, targets in train_loader:\n",
" optimizer.zero_grad()\n",
" outputs = model(inputs)\n",
" loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))\n",
" loss.backward()\n",
" optimizer.step()\n",
" total_loss += loss.item()\n",
" scheduler.step()\n",
" avg_loss = total_loss / len(train_loader)\n",
" print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluation\n",
"Evaluate the model on the validation set."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.eval()\n",
"total_loss = 0\n",
"with torch.no_grad():\n",
" for inputs, targets in val_loader:\n",
" outputs = model(inputs)\n",
" loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))\n",
" total_loss += loss.item()\n",
"\n",
"val_loss = total_loss / len(val_loader)\n",
"print(f'Validation Loss: {val_loss}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Inference\n",
"Generate code snippets using the trained model. This function takes a prompt and iteratively predicts the next token."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def generate_code(model, prompt, max_length=100):\n",
" model.eval()\n",
" input_seq = [token_to_idx[token] for token in tokenize_code(prompt)]\n",
" input_seq = torch.tensor(input_seq).unsqueeze(0)\n",
" generated_code = []\n",
" with torch.no_grad():\n",
" for _ in range(max_length):\n",
" outputs = model(input_seq)\n",
" _, predicted = torch.max(outputs[:, -1, :], 1)\n",
" generated_code.append(idx_to_token[predicted.item()])\n",
" input_seq = torch.cat([input_seq, predicted.unsqueeze(0)], dim=1)\n",
" return ' '.join(generated_code)\n",
"\n",
"prompt = 'def hello_world():'\n",
"generated_code = generate_code(model, prompt)\n",
"print(generated_code)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model Saving and Loading\n",
"Save and load the trained model for later use."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def save_model(model, path):\n",
" torch.save(model.state_dict(), path)\n",
"\n",
"def load_model(model, path):\n",
" model.load_state_dict(torch.load(path))\n",
" model.eval()\n",
"\n",
"model_path = 'diffusion_model.pth'\n",
"save_model(model, model_path)\n",
"load_model(model, model_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Error Handling and Logging\n",
"Robust error handling and logging are essential for tracking training and inference processes. The cell below demonstrates how to integrate logging into the training loop."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import logging\n",
"\n",
"logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')\n",
"\n",
"try:\n",
" for epoch in range(num_epochs):\n",
" model.train()\n",
" total_loss = 0\n",
" for inputs, targets in train_loader:\n",
" optimizer.zero_grad()\n",
" outputs = model(inputs)\n",
" loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))\n",
" loss.backward()\n",
" optimizer.step()\n",
" total_loss += loss.item()\n",
" scheduler.step()\n",
" avg_loss = total_loss / len(train_loader)\n",
" logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss}')\n",
"except Exception as e:\n",
" logging.error(f'An error occurred: {e}')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.8.5",
"file_extension": ".py",
"mimetype": "text/x-python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment