get_mean_std.py
1.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.utils import save_image
from PIL.ImageOps import grayscale
from PIL import Image
from torchvision.datasets import ImageFolder
class MyDataset(ImageFolder):
def __init__(self, root, trainsform):
super(MyDataset, self).__init__(root, trainsform)
def __getitem__(self, index):
image, label = super(MyDataset, self).__getitem__(index)
return image, label
def get_params(path, resize_size):
my_transform = transforms.Compose([
transforms.Resize((resize_size,resize_size)),
transforms.Grayscale(),
transforms.ToTensor()
])
my_dataset = MyDataset(path, my_transform)
loader = torch.utils.data.DataLoader(
my_dataset,
batch_size=256,
num_workers=8,
shuffle=False
)
mean = 0.
std = 0.
nb_samples = 0.
for i, (data, target) in enumerate(loader):
batch_samples = data.size(0)
data = data.view(batch_samples, data.size(1), -1)
mean += data.mean(2).sum(0)
std += data.std(2).sum(0)
nb_samples += batch_samples
mean /= nb_samples
std /= nb_samples
print(f"mean : {mean} , std : {std}")
return mean, std
"""
my_transform = transforms.Compose([
transforms.Resize((64,64)),
transforms.ToTensor()
])
my_dataset = MyDataset("../data/Third_data/not_binary", my_transform)
loader = torch.utils.data.DataLoader(
my_dataset,
batch_size=256,
num_workers=8,
shuffle=False
)
mean = 0.
std = 0.
nb_samples = 0.
for i, (data, target) in enumerate(loader):
batch_samples = data.size(0)
data = data.view(batch_samples, data.size(1), -1)
mean += data.mean(2).sum(0)
std += data.std(2).sum(0)
nb_samples += batch_samples
mean /= nb_samples
std /= nb_samples
print(f"mean : {mean}, std : {std}")
"""