Created
September 26, 2024 05:32
-
-
Save comeeasy/8757e0e0339a262daae8e031d691782a to your computer and use it in GitHub Desktop.
ViT example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from torchvision import datasets, transforms | |
from torch.utils.data import DataLoader, Subset | |
import random | |
""" | |
1. Load CIFAR10 dataset # | |
""" | |
# Define data transformations | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
# Add additional transforms if needed | |
]) | |
# Load CIFAR10 training and test datasets | |
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) | |
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) | |
# Create data loaders | |
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=24) | |
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=24) | |
""" | |
2. Visualizating a single data | |
""" | |
import matplotlib.pyplot as plt | |
# Get the class names from the dataset | |
classes = train_dataset.classes | |
# Get a single sample from the training dataset | |
image, label = train_dataset[0] | |
class_name = classes[label] | |
# Convert the image tensor to a NumPy array and transpose dimensions for plotting | |
image = image.numpy().transpose((1, 2, 0)) | |
# Display the image with class name | |
plt.imshow(image) | |
plt.title(f'Label: {label}, Class: {class_name}') | |
plt.show() | |
""" | |
3. Define the ViT | |
""" | |
# Define the Patch Embedding module | |
class PatchEmbedding(nn.Module): | |
def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=128): | |
super().__init__() | |
self.embed_dim = embed_dim | |
self.patch_size = patch_size | |
num_patches = (img_size // patch_size) ** 2 | |
self.proj = nn.Conv2d( | |
in_channels, embed_dim, kernel_size=patch_size, stride=patch_size | |
) | |
# is a learnable embedding that represents the entire image after the Transformer encoder | |
# processes the sequence of patch embeddings | |
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim)) | |
# The ViT architecture opts for learnable positional embeddings rather than fixed ones. | |
# 학습 가능한 positional embedding 을 설계하기 위해 초기값으로 random 값을 설정 | |
self.pos_embed = nn.Parameter( | |
torch.randn(1, num_patches + 1, embed_dim) | |
) | |
def forward(self, x): | |
B = x.shape[0] | |
x = self.proj(x) # Shape: [B, embed_dim, H', W'] | |
x = x.flatten(2) # Shape: [B, embed_dim, N] | |
x = x.transpose(1, 2) # Shape: [B, N, embed_dim] | |
cls_tokens = self.cls_token.expand(B, -1, -1) # Shape: [B, 1, embed_dim] | |
assert cls_tokens.shape == (B, 1, self.embed_dim) | |
x = torch.cat((cls_tokens, x), dim=1) # Shape: [B, N+1, embed_dim] | |
x = x + self.pos_embed # Add positional embeddings | |
return x | |
# Define the Transformer Encoder Block | |
class TransformerEncoderBlock(nn.Module): | |
def __init__( | |
self, embed_dim=128, num_heads=4, mlp_ratio=4.0, dropout_rate=0.1 | |
): | |
super().__init__() | |
self.norm1 = nn.LayerNorm(embed_dim) | |
self.attn = nn.MultiheadAttention( | |
embed_dim, num_heads, dropout=dropout_rate | |
) | |
self.norm2 = nn.LayerNorm(embed_dim) | |
self.mlp = nn.Sequential( | |
nn.Linear(embed_dim, int(embed_dim * mlp_ratio)), | |
nn.GELU(), | |
nn.Linear(int(embed_dim * mlp_ratio), embed_dim), | |
) | |
self.dropout = nn.Dropout(dropout_rate) | |
def forward(self, x): | |
x_residual = x | |
x = self.norm1(x) | |
x, _ = self.attn(x.transpose(0, 1), x.transpose(0, 1), x.transpose(0, 1)) | |
x = x.transpose(0, 1) | |
x = x_residual + self.dropout(x) | |
x_residual = x | |
x = self.norm2(x) | |
x = x_residual + self.dropout(self.mlp(x)) | |
return x | |
# Define the Vision Transformer model | |
class VisionTransformer(nn.Module): | |
def __init__( | |
self, | |
img_size=32, | |
patch_size=4, | |
in_channels=3, | |
num_classes=10, | |
embed_dim=128, | |
depth=6, | |
num_heads=4, | |
mlp_ratio=4.0, | |
dropout_rate=0.1, | |
): | |
super().__init__() | |
self.patch_embed = PatchEmbedding( | |
img_size, patch_size, in_channels, embed_dim | |
) | |
self.encoder_layers = nn.ModuleList( | |
[ | |
TransformerEncoderBlock( | |
embed_dim, num_heads, mlp_ratio, dropout_rate | |
) | |
for _ in range(depth) | |
] | |
) | |
self.norm = nn.LayerNorm(embed_dim) | |
self.head = nn.Linear(embed_dim, num_classes) | |
def forward(self, x): | |
x = self.patch_embed(x) | |
for layer in self.encoder_layers: | |
x = layer(x) | |
x = self.norm(x) | |
cls_token = x[:, 0] | |
logits = self.head(cls_token) | |
return logits | |
# Instantiate the model | |
model = VisionTransformer( | |
img_size=32, | |
patch_size=4, | |
in_channels=3, | |
num_classes=10, | |
embed_dim=128, | |
depth=2, | |
num_heads=4, | |
mlp_ratio=4.0, | |
dropout_rate=0.1, | |
) | |
# Move the model to the appropriate device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
""" | |
4. Training with a small data | |
""" | |
### 4.1. preparing a small data | |
# Define data transformations | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
# Add additional transforms if needed | |
]) | |
# Load the full CIFAR10 training dataset | |
full_train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) | |
# Set the desired number of images per class | |
images_per_class = 2 # For 10 classes, total images will be 20 | |
# Create a dictionary to hold indices for each class | |
class_indices = [[] for _ in range(len(full_train_dataset.classes))] | |
# Populate the class_indices list | |
for idx, (_, label) in enumerate(full_train_dataset): | |
class_indices[label].append(idx) | |
# Select images_per_class indices per class | |
subset_indices = [] | |
for class_idx in range(len(full_train_dataset.classes)): | |
indices = class_indices[class_idx] | |
selected_indices = random.sample(indices, images_per_class) | |
subset_indices.extend(selected_indices) | |
# Create the small subset dataset | |
small_train_dataset = Subset(full_train_dataset, subset_indices) | |
# Create a data loader for the small dataset | |
small_train_loader = DataLoader(small_train_dataset, batch_size=4, shuffle=True, num_workers=4) | |
# Optionally, print the size of the small dataset | |
print(f'Total number of images in the small dataset: {len(small_train_dataset)}') | |
### 4.2. Train with a small data | |
# Define the loss function and optimizer | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.Adam(model.parameters(), lr=1e-3) | |
# Training loop parameters | |
num_epochs = 30 # Adjust the number of epochs as needed | |
# Step 4: Training Loop | |
loss_values = [] | |
for epoch in range(num_epochs): | |
model.train() | |
running_loss = 0.0 | |
for images, labels in small_train_loader: | |
images = images.to(device) | |
labels = labels.to(device) | |
# Zero gradients | |
optimizer.zero_grad() | |
# Forward pass | |
outputs = model(images) | |
loss = criterion(outputs, labels) | |
# Backward pass and optimization | |
loss.backward() | |
optimizer.step() | |
running_loss += loss.item() | |
# Calculate average loss for the epoch | |
avg_loss = running_loss / len(train_loader) | |
loss_values.append(avg_loss) | |
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}') | |
print('Training complete.') | |
### 4.3. visualize results | |
# Step 5: Evaluation | |
model.eval() | |
correct = 0 | |
total = 0 | |
with torch.no_grad(): | |
for images, labels in train_loader: | |
images = images.to(device) | |
labels = labels.to(device) | |
outputs = model(images) | |
_, predicted = torch.max(outputs.data, 1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
accuracy = 100 * correct / total | |
print(f'Accuracy on the small training set: {accuracy:.2f}%') | |
# Optional: Visualize Training Loss | |
plt.plot(range(1, num_epochs + 1), loss_values, marker='o') | |
plt.xlabel('Epoch') | |
plt.ylabel('Loss') | |
plt.title('Training Loss over Epochs') | |
plt.show() | |
""" | |
5. Full training | |
""" | |
# Instantiate the model | |
model = VisionTransformer( | |
img_size=32, | |
patch_size=4, | |
in_channels=3, | |
num_classes=10, | |
embed_dim=128, | |
depth=2, | |
num_heads=4, | |
mlp_ratio=4.0, | |
dropout_rate=0.1, | |
) | |
# Move the model to the appropriate device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
# Define the loss function and optimizer | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.Adam(model.parameters(), lr=1e-3) | |
# Training loop parameters | |
num_epochs = 30 # Adjust the number of epochs as needed | |
# Step 4: Training Loop | |
loss_values = [] | |
for epoch in range(num_epochs): | |
model.train() | |
running_loss = 0.0 | |
for images, labels in train_loader: | |
images = images.to(device) | |
labels = labels.to(device) | |
# Zero gradients | |
optimizer.zero_grad() | |
# Forward pass | |
outputs = model(images) | |
loss = criterion(outputs, labels) | |
# Backward pass and optimization | |
loss.backward() | |
optimizer.step() | |
running_loss += loss.item() | |
# Calculate average loss for the epoch | |
avg_loss = running_loss / len(train_loader) | |
loss_values.append(avg_loss) | |
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}') | |
print('Training complete.') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment