get_mean_std.py 1.98 KB
import os
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.utils import save_image
from PIL.ImageOps import grayscale
from PIL import Image
from torchvision.datasets import ImageFolder

class MyDataset(ImageFolder):
    def __init__(self, root, trainsform):
        super(MyDataset, self).__init__(root, trainsform)
        
    def __getitem__(self, index):
        image, label = super(MyDataset, self).__getitem__(index)
        return image, label
 


def get_params(path, resize_size):
    my_transform = transforms.Compose([
    transforms.Resize((resize_size,resize_size)),
    transforms.Grayscale(),
    transforms.ToTensor()
    ])

    my_dataset = MyDataset(path, my_transform)

    loader = torch.utils.data.DataLoader(
        my_dataset,
        batch_size=256,
        num_workers=8,
        shuffle=False
    )

    mean = 0.
    std = 0.
    nb_samples = 0.
    for i, (data, target) in enumerate(loader):
        batch_samples = data.size(0)
        data = data.view(batch_samples, data.size(1), -1)
        mean += data.mean(2).sum(0)
        std += data.std(2).sum(0)
        nb_samples += batch_samples

    mean /= nb_samples
    std /= nb_samples
    print(f"mean : {mean} , std : {std}")
    return mean, std

"""
my_transform = transforms.Compose([
transforms.Resize((64,64)),
transforms.ToTensor()
])

my_dataset = MyDataset("../data/Third_data/not_binary", my_transform)

loader = torch.utils.data.DataLoader(
    my_dataset,
    batch_size=256,
    num_workers=8,
    shuffle=False
)

mean = 0.
std = 0.
nb_samples = 0.
for i, (data, target) in enumerate(loader):
    batch_samples = data.size(0)
    data = data.view(batch_samples, data.size(1), -1)
    mean += data.mean(2).sum(0)
    std += data.std(2).sum(0)
    nb_samples += batch_samples

mean /= nb_samples
std /= nb_samples

print(f"mean : {mean}, std : {std}")
"""