Skip to content

Instantly share code, notes, and snippets.

@whit3rabbit
Created May 27, 2024 18:58
Show Gist options
  • Save whit3rabbit/cd4e3dec6b31282a1f0dd813e8c2639f to your computer and use it in GitHub Desktop.
Save whit3rabbit/cd4e3dec6b31282a1f0dd813e8c2639f to your computer and use it in GitHub Desktop.
# !pip install torch transformers scikit-learn umap-learn matplotlib datasets joblib pandas zstandard
'''
Code outline:
1. Loading the pre-trained GPT-2 model and tokenizer.
2. Loading or downloading the dataset and saving it to disk with compression.
3. Collecting and normalizing activations from the middle layer of the model.
4. Training multiple SAEs with different feature sizes and saving them to disk.
5. Finding the feature that responds to the phrase "Golden Gate Bridge" in each SAE.
6. Visualizing the local neighborhood of the target feature using UMAP.
7. Analyzing nearby features for the target feature by computing cosine similarity and retrieving the top similar features.
'''
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from sklearn.decomposition import DictionaryLearning
import numpy as np
import umap
import matplotlib.pyplot as plt
from datasets import load_dataset, Dataset
import joblib
import os
import pandas as pd
import zstandard as zstd
# Function to save dataset to disk with zstd compression
def save_dataset(dataset, file_path):
df = pd.DataFrame(dataset)
with open(file_path, 'wb') as f:
cctx = zstd.ZstdCompressor()
with cctx.stream_writer(f) as compressor:
df.to_csv(compressor, index=False)
# Function to load dataset from disk with zstd compression
def load_dataset_from_disk(file_path):
with open(file_path, 'rb') as f:
dctx = zstd.ZstdDecompressor()
with dctx.stream_reader(f) as decompressor:
df = pd.read_csv(decompressor)
return Dataset.from_pandas(df)
# Load a smaller pre-trained model and tokenizer, e.g., GPT-2
model_name = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Add a padding token to the tokenizer
tokenizer.pad_token = tokenizer.eos_token
# Check if the dataset file exists, otherwise download it
dataset_file = "pile_uncopyrighted_dataset.csv.zst"
if os.path.exists(dataset_file):
print("Loading dataset from disk...")
dataset = load_dataset_from_disk(dataset_file)
else:
try:
print("Downloading dataset...")
dataset = load_dataset("monology/pile-uncopyrighted", split="train[:1000]") # Use a smaller subset
print("Saving dataset to disk...")
save_dataset(dataset, dataset_file)
except Exception as e:
print(f"An error occurred while loading the dataset: {e}")
raise
# Check if activations already exist, otherwise calculate them
activations_file = "activations.pt"
if os.path.exists(activations_file):
activations = torch.load(activations_file)
else:
# Collect activations
activations = []
for text in dataset["text"]:
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
hidden_states = outputs.hidden_states
middle_layer_index = len(hidden_states) // 2
residual_stream = hidden_states[middle_layer_index]
activations.append(residual_stream.reshape(-1, residual_stream.shape[-1]))
activations = torch.cat(activations, dim=0)
# Normalize activations
residual_stream_dim = activations.shape[-1]
activations_norm = torch.sqrt(torch.mean(activations**2, dim=-1, keepdim=True))
activations = torch.sqrt(residual_stream_dim) * activations / activations_norm
torch.save(activations, activations_file)
# Define the feature sizes for different SAEs
feature_sizes = [1024, 2048, 4096]
# Train multiple SAEs with different feature sizes and save them
saes = []
for i, n_components in enumerate(feature_sizes):
model_file = f"sae_{i+1}.joblib"
if os.path.exists(model_file):
print(f"Loading SAE {i+1} from disk...")
sae = joblib.load(model_file)
else:
print(f"Training SAE with {n_components} features...")
sae = DictionaryLearning(n_components=n_components, transform_alpha=5, random_state=0, max_iter=10)
sae.fit(activations.detach().cpu().numpy())
print(f"Saving SAE {i+1} to disk...")
joblib.dump(sae, model_file)
saes.append(sae)
# Find the feature that responds to "Golden Gate Bridge" in each SAE
target_phrase = "Golden Gate Bridge"
for i, sae in enumerate(saes):
features = sae.transform(activations.detach().cpu().numpy())
target_feature_index = None
for j in range(sae.n_components):
feature = features[:, j]
top_activating_indices = feature.argsort()[-5:][::-1]
for example_idx in top_activating_indices:
text_idx = example_idx // activations.shape[1]
if target_phrase.lower() in dataset["text"][text_idx].lower():
target_feature_index = j
break
if target_feature_index is not None:
break
if target_feature_index is None:
print(f"No feature found that responds to '{target_phrase}' in SAE {i+1}")
else:
print(f"Feature {target_feature_index} responds to '{target_phrase}' in SAE {i+1}")
# Visualize the local neighborhood of the target feature using UMAP
umap_embeddings = umap.UMAP(n_neighbors=15, min_dist=0.1, metric='cosine').fit_transform(sae.components_)
# Plot the UMAP visualization
plt.figure(figsize=(10, 10))
plt.scatter(umap_embeddings[:, 0], umap_embeddings[:, 1], s=5, alpha=0.5)
plt.scatter(umap_embeddings[target_feature_index, 0], umap_embeddings[target_feature_index, 1], s=100, color='red', label=f"Feature {target_feature_index}")
plt.gca().set_aspect('equal', 'datalim')
plt.title(f'UMAP Visualization of Learned Features (SAE {i+1})')
plt.legend()
plt.show()
# Analyze nearby features for the target feature
target_feature = sae.components_[target_feature_index]
# Compute cosine similarity between the target feature and all other features
similarity_scores = np.dot(sae.components_, target_feature) / (np.linalg.norm(sae.components_, axis=1) * np.linalg.norm(target_feature))
# Get the indices of the top similar features
top_similar_indices = similarity_scores.argsort()[-5:][::-1]
print(f"Top similar features to Feature {target_feature_index} in SAE {i+1}:")
for idx in top_similar_indices:
print(f"Feature {idx}")
top_activating_indices = features[:, idx].argsort()[-5:][::-1]
for example_idx in top_activating_indices:
text_idx = example_idx // activations.shape[1]
token_idx = example_idx % activations.shape[1]
input_ids = tokenizer(dataset["text"][text_idx], return_tensors="pt")["input_ids"][0]
if token_idx < len(input_ids):
print(tokenizer.decode(input_ids[token_idx]))
print()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment