mnist_m.py
2.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch.utils.data as data
from PIL import Image
import os
import params
from torchvision import transforms
import torch
import torch.utils.data as data_utils
class GetLoader(data.Dataset):
def __init__(self, data_root, data_list, transform=None):
self.root = data_root
self.transform = transform
f = open(data_list, 'r')
data_list = f.readlines()
f.close()
self.n_data = len(data_list)
self.img_paths = []
self.img_labels = []
for data_ in data_list:
self.img_paths.append(data_[:-3])
self.img_labels.append(data_[-2])
def __getitem__(self, item):
img_paths, labels = self.img_paths[item], self.img_labels[item]
imgs = Image.open(os.path.join(self.root, img_paths)).convert('RGB')
if self.transform is not None:
imgs = self.transform(imgs)
labels = int(labels)
return imgs, labels
def __len__(self):
return self.n_data
def get_mnist_m(train,adp=False,size= 0 ):
if train == True:
mode = 'train'
else:
mode = 'test'
train_list = os.path.join(params.mnist_m_dataset_root, 'mnist_m_{}_labels.txt'.format(mode))
# image pre-processing
pre_process = transforms.Compose([
transforms.Resize(params.image_size),
# transforms.Grayscale(3),
transforms.ToTensor(),
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
# transforms.Grayscale(1),
]
)
dataset_target = GetLoader(
data_root=os.path.join(params.mnist_m_dataset_root, 'mnist_m_{}'.format(mode)),
data_list=train_list,
transform=pre_process)
if train:
# perm = torch.randperm(len(dataset_target))
# indices = perm[:10000]
dataset_target,_ = data_utils.random_split(dataset_target, [size,len(dataset_target)-size])
# size = len(dataset_target)
# train, valid = data_utils.random_split(dataset_target,[size-int(size*params.train_val_ratio),int(size*params.train_val_ratio)])
# train_loader = torch.utils.data.DataLoader(
# dataset=train,
# batch_size= params.adp_batch_size if adp else params.batch_size,
# shuffle=True,
# drop_last=True)
# valid_loader = torch.utils.data.DataLoader(
# dataset=valid,
# batch_size= params.adp_batch_size if adp else params.batch_size,
# shuffle=True,
# drop_last=True)
# return train_loader,valid_loader
dataloader = torch.utils.data.DataLoader(
dataset=dataset_target,
batch_size= params.adp_batch_size if adp else params.batch_size,
shuffle=True,
drop_last=True)
return dataloader