Created
August 14, 2019 20:38
-
-
Save nzw0301/b94b509b5b0fc049501a9b46a1494bc0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# CIFAR-100 | |
import numpy as np | |
import torchvision.transforms as transforms | |
from torch.utils.data import DataLoader | |
from torchvision.datasets import CIFAR100 | |
train_transform = transforms.Compose( | |
[ | |
transforms.ToTensor(), | |
]) | |
train_set = CIFAR100( | |
root='~/data', | |
train=True, | |
download=True, | |
transform=train_transform) | |
train_loader = DataLoader( | |
train_set, | |
batch_size=50_000, | |
shuffle=True, | |
) | |
data = iter(train_loader).next() | |
print( | |
data[0].mean(dim=[0, 2, 3]).numpy(), | |
data[0].std(dim=[0, 2, 3]).numpy() | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Output:
[0.5070779 0.48654884 0.44091937] [0.26733428 0.25643846 0.27615047]
So, I use these values for
Normalize
: