Skip to content

Instantly share code, notes, and snippets.

@comeeasy
Created September 26, 2024 05:32
Show Gist options
  • Save comeeasy/8757e0e0339a262daae8e031d691782a to your computer and use it in GitHub Desktop.
Save comeeasy/8757e0e0339a262daae8e031d691782a to your computer and use it in GitHub Desktop.
ViT example
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import random
"""
1. Load CIFAR10 dataset #
"""
# Define data transformations
transform = transforms.Compose([
transforms.ToTensor(),
# Add additional transforms if needed
])
# Load CIFAR10 training and test datasets
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=24)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=24)
"""
2. Visualizating a single data
"""
import matplotlib.pyplot as plt
# Get the class names from the dataset
classes = train_dataset.classes
# Get a single sample from the training dataset
image, label = train_dataset[0]
class_name = classes[label]
# Convert the image tensor to a NumPy array and transpose dimensions for plotting
image = image.numpy().transpose((1, 2, 0))
# Display the image with class name
plt.imshow(image)
plt.title(f'Label: {label}, Class: {class_name}')
plt.show()
"""
3. Define the ViT
"""
# Define the Patch Embedding module
class PatchEmbedding(nn.Module):
def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=128):
super().__init__()
self.embed_dim = embed_dim
self.patch_size = patch_size
num_patches = (img_size // patch_size) ** 2
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=patch_size, stride=patch_size
)
# is a learnable embedding that represents the entire image after the Transformer encoder
# processes the sequence of patch embeddings
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
# The ViT architecture opts for learnable positional embeddings rather than fixed ones.
# 학습 가능한 positional embedding 을 설계하기 위해 초기값으로 random 값을 설정
self.pos_embed = nn.Parameter(
torch.randn(1, num_patches + 1, embed_dim)
)
def forward(self, x):
B = x.shape[0]
x = self.proj(x) # Shape: [B, embed_dim, H', W']
x = x.flatten(2) # Shape: [B, embed_dim, N]
x = x.transpose(1, 2) # Shape: [B, N, embed_dim]
cls_tokens = self.cls_token.expand(B, -1, -1) # Shape: [B, 1, embed_dim]
assert cls_tokens.shape == (B, 1, self.embed_dim)
x = torch.cat((cls_tokens, x), dim=1) # Shape: [B, N+1, embed_dim]
x = x + self.pos_embed # Add positional embeddings
return x
# Define the Transformer Encoder Block
class TransformerEncoderBlock(nn.Module):
def __init__(
self, embed_dim=128, num_heads=4, mlp_ratio=4.0, dropout_rate=0.1
):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(
embed_dim, num_heads, dropout=dropout_rate
)
self.norm2 = nn.LayerNorm(embed_dim)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
nn.GELU(),
nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
)
self.dropout = nn.Dropout(dropout_rate)
def forward(self, x):
x_residual = x
x = self.norm1(x)
x, _ = self.attn(x.transpose(0, 1), x.transpose(0, 1), x.transpose(0, 1))
x = x.transpose(0, 1)
x = x_residual + self.dropout(x)
x_residual = x
x = self.norm2(x)
x = x_residual + self.dropout(self.mlp(x))
return x
# Define the Vision Transformer model
class VisionTransformer(nn.Module):
def __init__(
self,
img_size=32,
patch_size=4,
in_channels=3,
num_classes=10,
embed_dim=128,
depth=6,
num_heads=4,
mlp_ratio=4.0,
dropout_rate=0.1,
):
super().__init__()
self.patch_embed = PatchEmbedding(
img_size, patch_size, in_channels, embed_dim
)
self.encoder_layers = nn.ModuleList(
[
TransformerEncoderBlock(
embed_dim, num_heads, mlp_ratio, dropout_rate
)
for _ in range(depth)
]
)
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_classes)
def forward(self, x):
x = self.patch_embed(x)
for layer in self.encoder_layers:
x = layer(x)
x = self.norm(x)
cls_token = x[:, 0]
logits = self.head(cls_token)
return logits
# Instantiate the model
model = VisionTransformer(
img_size=32,
patch_size=4,
in_channels=3,
num_classes=10,
embed_dim=128,
depth=2,
num_heads=4,
mlp_ratio=4.0,
dropout_rate=0.1,
)
# Move the model to the appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
"""
4. Training with a small data
"""
### 4.1. preparing a small data
# Define data transformations
transform = transforms.Compose([
transforms.ToTensor(),
# Add additional transforms if needed
])
# Load the full CIFAR10 training dataset
full_train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
# Set the desired number of images per class
images_per_class = 2 # For 10 classes, total images will be 20
# Create a dictionary to hold indices for each class
class_indices = [[] for _ in range(len(full_train_dataset.classes))]
# Populate the class_indices list
for idx, (_, label) in enumerate(full_train_dataset):
class_indices[label].append(idx)
# Select images_per_class indices per class
subset_indices = []
for class_idx in range(len(full_train_dataset.classes)):
indices = class_indices[class_idx]
selected_indices = random.sample(indices, images_per_class)
subset_indices.extend(selected_indices)
# Create the small subset dataset
small_train_dataset = Subset(full_train_dataset, subset_indices)
# Create a data loader for the small dataset
small_train_loader = DataLoader(small_train_dataset, batch_size=4, shuffle=True, num_workers=4)
# Optionally, print the size of the small dataset
print(f'Total number of images in the small dataset: {len(small_train_dataset)}')
### 4.2. Train with a small data
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# Training loop parameters
num_epochs = 30 # Adjust the number of epochs as needed
# Step 4: Training Loop
loss_values = []
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for images, labels in small_train_loader:
images = images.to(device)
labels = labels.to(device)
# Zero gradients
optimizer.zero_grad()
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward pass and optimization
loss.backward()
optimizer.step()
running_loss += loss.item()
# Calculate average loss for the epoch
avg_loss = running_loss / len(train_loader)
loss_values.append(avg_loss)
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')
print('Training complete.')
### 4.3. visualize results
# Step 5: Evaluation
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in train_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f'Accuracy on the small training set: {accuracy:.2f}%')
# Optional: Visualize Training Loss
plt.plot(range(1, num_epochs + 1), loss_values, marker='o')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss over Epochs')
plt.show()
"""
5. Full training
"""
# Instantiate the model
model = VisionTransformer(
img_size=32,
patch_size=4,
in_channels=3,
num_classes=10,
embed_dim=128,
depth=2,
num_heads=4,
mlp_ratio=4.0,
dropout_rate=0.1,
)
# Move the model to the appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# Training loop parameters
num_epochs = 30 # Adjust the number of epochs as needed
# Step 4: Training Loop
loss_values = []
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for images, labels in train_loader:
images = images.to(device)
labels = labels.to(device)
# Zero gradients
optimizer.zero_grad()
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward pass and optimization
loss.backward()
optimizer.step()
running_loss += loss.item()
# Calculate average loss for the epoch
avg_loss = running_loss / len(train_loader)
loss_values.append(avg_loss)
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')
print('Training complete.')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment