Skip to content

Instantly share code, notes, and snippets.

import torch
import torch.nn as nn
class MultiQueryAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
import numpy as np
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cpu"
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
dur = {}
class Tensor:
def __init__(self, value, requires_grad=False):
self.value = value # Store the value of the tensor
self.grad = 0 # Gradient initialized to zero
self.requires_grad = requires_grad
self._backward = lambda: None # Function to compute gradient
self._prev = set() # Track previous nodes for backpropagation
def backward(self):
"""Computes gradients using reverse-mode automatic differentiation."""
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
# Load GPT-2 model and tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
model.eval()
def compare_outputs(prompt):
print('Prompt:', prompt)
@danyaljj
danyaljj / sae
Created December 30, 2024 19:31
# Results
## Feature 1: Biden's health and cognitive abilities.
### Stats
F1 score: 0.5
Pearson Correlation: -0.38014296063485287
import numpy as np
import matplotlib.pyplot as plt
# Parameters
T = 35 # Number of trials
beta = 0.5 # Probability parameter for z
alpha_1 = 0.2 # Probability parameter for x if z=0
alpha_2 = 0.6 # Probability parameter for x if z=1
# Simulate z
@danyaljj
danyaljj / sql_agent.py
Created July 19, 2024 21:44
sql_agent.py
def get_response_sql(user_query, chat_history, plot=False):
# Specify the path to the SQLite database
db_path = "metadataDB/output_database.db"
# Connect to the SQLite database
db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
underspecified = classify_underspecified_query(user_query , chat_history)
def get_response_sql(user_query, chat_history, plot=False):
# Specify the path to the SQLite database
db_path = "metadataDB/output_database.db"
# Connect to the SQLite database
db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
underspecified = classify_underspecified_query(user_query , chat_history)
#!/usr/bin/env python
from typing import Iterable
from collections import Counter
import os
import logging
import sys
import json
import click
import datasets
import numpy as np
import json
import argparse
from typing import Optional, Union, Tuple
import torch
torch.manual_seed(0)
from transformers import BertModel, BertTokenizer, PreTrainedModel, BertConfig
from transformers.modeling_outputs import MultipleChoiceModelOutput