Created
May 27, 2024 18:58
-
-
Save whit3rabbit/cd4e3dec6b31282a1f0dd813e8c2639f to your computer and use it in GitHub Desktop.
Trying to understand https://transformer-circuits.pub/2024/scaling-monosemanticity/index.html
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# !pip install torch transformers scikit-learn umap-learn matplotlib datasets joblib pandas zstandard | |
''' | |
Code outline: | |
1. Loading the pre-trained GPT-2 model and tokenizer. | |
2. Loading or downloading the dataset and saving it to disk with compression. | |
3. Collecting and normalizing activations from the middle layer of the model. | |
4. Training multiple SAEs with different feature sizes and saving them to disk. | |
5. Finding the feature that responds to the phrase "Golden Gate Bridge" in each SAE. | |
6. Visualizing the local neighborhood of the target feature using UMAP. | |
7. Analyzing nearby features for the target feature by computing cosine similarity and retrieving the top similar features. | |
''' | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from sklearn.decomposition import DictionaryLearning | |
import numpy as np | |
import umap | |
import matplotlib.pyplot as plt | |
from datasets import load_dataset, Dataset | |
import joblib | |
import os | |
import pandas as pd | |
import zstandard as zstd | |
# Function to save dataset to disk with zstd compression | |
def save_dataset(dataset, file_path): | |
df = pd.DataFrame(dataset) | |
with open(file_path, 'wb') as f: | |
cctx = zstd.ZstdCompressor() | |
with cctx.stream_writer(f) as compressor: | |
df.to_csv(compressor, index=False) | |
# Function to load dataset from disk with zstd compression | |
def load_dataset_from_disk(file_path): | |
with open(file_path, 'rb') as f: | |
dctx = zstd.ZstdDecompressor() | |
with dctx.stream_reader(f) as decompressor: | |
df = pd.read_csv(decompressor) | |
return Dataset.from_pandas(df) | |
# Load a smaller pre-trained model and tokenizer, e.g., GPT-2 | |
model_name = "gpt2" | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Add a padding token to the tokenizer | |
tokenizer.pad_token = tokenizer.eos_token | |
# Check if the dataset file exists, otherwise download it | |
dataset_file = "pile_uncopyrighted_dataset.csv.zst" | |
if os.path.exists(dataset_file): | |
print("Loading dataset from disk...") | |
dataset = load_dataset_from_disk(dataset_file) | |
else: | |
try: | |
print("Downloading dataset...") | |
dataset = load_dataset("monology/pile-uncopyrighted", split="train[:1000]") # Use a smaller subset | |
print("Saving dataset to disk...") | |
save_dataset(dataset, dataset_file) | |
except Exception as e: | |
print(f"An error occurred while loading the dataset: {e}") | |
raise | |
# Check if activations already exist, otherwise calculate them | |
activations_file = "activations.pt" | |
if os.path.exists(activations_file): | |
activations = torch.load(activations_file) | |
else: | |
# Collect activations | |
activations = [] | |
for text in dataset["text"]: | |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) | |
with torch.no_grad(): | |
outputs = model(**inputs, output_hidden_states=True) | |
hidden_states = outputs.hidden_states | |
middle_layer_index = len(hidden_states) // 2 | |
residual_stream = hidden_states[middle_layer_index] | |
activations.append(residual_stream.reshape(-1, residual_stream.shape[-1])) | |
activations = torch.cat(activations, dim=0) | |
# Normalize activations | |
residual_stream_dim = activations.shape[-1] | |
activations_norm = torch.sqrt(torch.mean(activations**2, dim=-1, keepdim=True)) | |
activations = torch.sqrt(residual_stream_dim) * activations / activations_norm | |
torch.save(activations, activations_file) | |
# Define the feature sizes for different SAEs | |
feature_sizes = [1024, 2048, 4096] | |
# Train multiple SAEs with different feature sizes and save them | |
saes = [] | |
for i, n_components in enumerate(feature_sizes): | |
model_file = f"sae_{i+1}.joblib" | |
if os.path.exists(model_file): | |
print(f"Loading SAE {i+1} from disk...") | |
sae = joblib.load(model_file) | |
else: | |
print(f"Training SAE with {n_components} features...") | |
sae = DictionaryLearning(n_components=n_components, transform_alpha=5, random_state=0, max_iter=10) | |
sae.fit(activations.detach().cpu().numpy()) | |
print(f"Saving SAE {i+1} to disk...") | |
joblib.dump(sae, model_file) | |
saes.append(sae) | |
# Find the feature that responds to "Golden Gate Bridge" in each SAE | |
target_phrase = "Golden Gate Bridge" | |
for i, sae in enumerate(saes): | |
features = sae.transform(activations.detach().cpu().numpy()) | |
target_feature_index = None | |
for j in range(sae.n_components): | |
feature = features[:, j] | |
top_activating_indices = feature.argsort()[-5:][::-1] | |
for example_idx in top_activating_indices: | |
text_idx = example_idx // activations.shape[1] | |
if target_phrase.lower() in dataset["text"][text_idx].lower(): | |
target_feature_index = j | |
break | |
if target_feature_index is not None: | |
break | |
if target_feature_index is None: | |
print(f"No feature found that responds to '{target_phrase}' in SAE {i+1}") | |
else: | |
print(f"Feature {target_feature_index} responds to '{target_phrase}' in SAE {i+1}") | |
# Visualize the local neighborhood of the target feature using UMAP | |
umap_embeddings = umap.UMAP(n_neighbors=15, min_dist=0.1, metric='cosine').fit_transform(sae.components_) | |
# Plot the UMAP visualization | |
plt.figure(figsize=(10, 10)) | |
plt.scatter(umap_embeddings[:, 0], umap_embeddings[:, 1], s=5, alpha=0.5) | |
plt.scatter(umap_embeddings[target_feature_index, 0], umap_embeddings[target_feature_index, 1], s=100, color='red', label=f"Feature {target_feature_index}") | |
plt.gca().set_aspect('equal', 'datalim') | |
plt.title(f'UMAP Visualization of Learned Features (SAE {i+1})') | |
plt.legend() | |
plt.show() | |
# Analyze nearby features for the target feature | |
target_feature = sae.components_[target_feature_index] | |
# Compute cosine similarity between the target feature and all other features | |
similarity_scores = np.dot(sae.components_, target_feature) / (np.linalg.norm(sae.components_, axis=1) * np.linalg.norm(target_feature)) | |
# Get the indices of the top similar features | |
top_similar_indices = similarity_scores.argsort()[-5:][::-1] | |
print(f"Top similar features to Feature {target_feature_index} in SAE {i+1}:") | |
for idx in top_similar_indices: | |
print(f"Feature {idx}") | |
top_activating_indices = features[:, idx].argsort()[-5:][::-1] | |
for example_idx in top_activating_indices: | |
text_idx = example_idx // activations.shape[1] | |
token_idx = example_idx % activations.shape[1] | |
input_ids = tokenizer(dataset["text"][text_idx], return_tensors="pt")["input_ids"][0] | |
if token_idx < len(input_ids): | |
print(tokenizer.decode(input_ids[token_idx])) | |
print() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment