조현아

update classifier

...@@ -33,12 +33,9 @@ def eval(model_path): ...@@ -33,12 +33,9 @@ def eval(model_path):
33 model.load_state_dict(torch.load(weight_path)) 33 model.load_state_dict(torch.load(weight_path))
34 34
35 print('\n[+] Load dataset') 35 print('\n[+] Load dataset')
36 - test_transform = get_valid_transform(args, model)
37 - #print('\nTEST Transform\n', test_transform)
38 test_dataset = get_dataset(args, 'test') 36 test_dataset = get_dataset(args, 'test')
39 37
40 38
41 -
42 test_loader = iter(get_dataloader(args, test_dataset)) ### 39 test_loader = iter(get_dataloader(args, test_dataset)) ###
43 40
44 print('\n[+] Start testing') 41 print('\n[+] Start testing')
......
...@@ -16,6 +16,8 @@ class BaseNet(nn.Module): ...@@ -16,6 +16,8 @@ class BaseNet(nn.Module):
16 x = self.after(f) 16 x = self.after(f)
17 x = x.reshape(x.size(0), -1) 17 x = x.reshape(x.size(0), -1)
18 x = self.fc(x) 18 x = self.fc(x)
19 +
20 + # output, first
19 return x, f 21 return x, f
20 22
21 """ 23 """
......
...@@ -24,7 +24,7 @@ def train(**kwargs): ...@@ -24,7 +24,7 @@ def train(**kwargs):
24 24
25 print('\n[+] Create log dir') 25 print('\n[+] Create log dir')
26 model_name = get_model_name(args) 26 model_name = get_model_name(args)
27 - log_dir = os.path.join('/content/drive/My Drive/CD2 Project/classify/', model_name) 27 + log_dir = os.path.join('/content/drive/My Drive/CD2 Project/runs/classify/', model_name)
28 os.makedirs(os.path.join(log_dir, 'model')) 28 os.makedirs(os.path.join(log_dir, 'model'))
29 json.dump(kwargs, open(os.path.join(log_dir, 'kwargs.json'), 'w')) 29 json.dump(kwargs, open(os.path.join(log_dir, 'kwargs.json'), 'w'))
30 writer = SummaryWriter(log_dir=log_dir) 30 writer = SummaryWriter(log_dir=log_dir)
...@@ -42,13 +42,11 @@ def train(**kwargs): ...@@ -42,13 +42,11 @@ def train(**kwargs):
42 if args.use_cuda: 42 if args.use_cuda:
43 model = model.cuda() 43 model = model.cuda()
44 criterion = criterion.cuda() 44 criterion = criterion.cuda()
45 - writer.add_graph(model) 45 + #writer.add_graph(model)
46 46
47 print('\n[+] Load dataset') 47 print('\n[+] Load dataset')
48 - transform = get_train_transform(args, model, log_dir) 48 + train_dataset = get_dataset(args, 'train')
49 - val_transform = get_valid_transform(args, model) 49 + valid_dataset = get_dataset(args, 'val')
50 - train_dataset = get_dataset(args, transform, 'train')
51 - valid_dataset = get_dataset(args, val_transform, 'val')
52 train_loader = iter(get_inf_dataloader(args, train_dataset)) 50 train_loader = iter(get_inf_dataloader(args, train_dataset))
53 max_epoch = len(train_dataset) // args.batch_size 51 max_epoch = len(train_dataset) // args.batch_size
54 best_acc = -1 52 best_acc = -1
...@@ -82,6 +80,7 @@ def train(**kwargs): ...@@ -82,6 +80,7 @@ def train(**kwargs):
82 print(' BW Time : {:.3f}ms'.format(_train_res[4]*1000)) 80 print(' BW Time : {:.3f}ms'.format(_train_res[4]*1000))
83 81
84 if step % args.val_step == args.val_step-1: 82 if step % args.val_step == args.val_step-1:
83 + # print("\nstep, args.val_step: ", step, args.val_step)
85 valid_loader = iter(get_dataloader(args, valid_dataset)) 84 valid_loader = iter(get_dataloader(args, valid_dataset))
86 _valid_res = validate(args, model, criterion, valid_loader, step, writer) 85 _valid_res = validate(args, model, criterion, valid_loader, step, writer)
87 print('\n[+] Valid results') 86 print('\n[+] Valid results')
......
1 +import os
2 +import time
3 +import importlib
4 +import collections
5 +import pickle as cp
6 +import glob
7 +import numpy as np
8 +import pandas as pd
9 +
10 +from natsort import natsorted
11 +from PIL import Image
12 +import torch
13 +import torchvision
14 +import torch.nn.functional as F
15 +import torchvision.models as models
16 +import torchvision.transforms as transforms
17 +from torch.utils.data import Subset
18 +from torch.utils.data import Dataset, DataLoader
19 +
20 +from sklearn.model_selection import StratifiedShuffleSplit
21 +from sklearn.model_selection import train_test_split
22 +from sklearn.model_selection import KFold
23 +
24 +from networks import basenet, grayResNet2
25 +
26 +
27 +TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/nonaug+Normal_train/'
28 +TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/train_nonaug_classify_target.csv'
29 +VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/nonaug+Normal_val/'
30 +VAL_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/val_nonaug_classify_target.csv'
31 +TEST_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/nonaug+Normal_test/'
32 +TEST_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/test_nonaug_classify_target.csv'
33 +
34 +current_epoch = 0
35 +
36 +def concat_image_features(image, features, max_features=3):
37 + _, h, w = image.shape
38 + #print("\nfsize: ", features.size()) # (1, 240, 240)
39 + # features.size(0) = 64
40 + #print(features.size(0))
41 + #max_features = min(features.size(0), max_features)
42 +
43 + max_features = features.size(0)
44 + image_feature = image.clone()
45 +
46 + for i in range(max_features):
47 + # features torch.Size([64, 16, 16])
48 +
49 + feature = features[i:i+1]
50 + #torch.Size([1, 16, 16])
51 +
52 + _min, _max = torch.min(feature), torch.max(feature)
53 + feature = (feature - _min) / (_max - _min + 1e-6)
54 + # torch.Size([1, 16, 16])
55 +
56 + feature = torch.cat([feature]*1, 0)
57 + #feature = torch.cat([feature]*3, 0)
58 + # torch.Size([3, 16, 16]) -> [1, 16, 16]
59 +
60 + feature = feature.view(1, 1, feature.size(1), feature.size(2))
61 + #feature = feature.view(1, 3, feature.size(1), feature.size(2))
62 + # torch.Size([1, 3, 16, 16])-> [1, 1, 16, 16]
63 +
64 + feature = F.upsample(feature, size=(h,w), mode="bilinear")
65 + # torch.Size([1, 3, 32, 32])-> [1, 1, 32, 32]
66 +
67 + feature = feature.view(1, h, w) #(3, h, w) input of size 3072
68 + # torch.Size([3, 32, 32])->[1, 32, 32]
69 +
70 + #print("img_feature & feature size:\n", image_feature.size(),"\n", feature.size())
71 + # img_feature & feature size:
72 + # torch.Size([1, 32, 32]) -> [1, 32, 64]
73 + # torch.Size([3, 32, 32] ->[1, 32, 32]
74 +
75 +
76 + image_feature = torch.cat((image_feature, feature), 2) ### dim = 2
77 + #print("\nimg feature size: ", image_feature.size()) #[1, 240, 720]
78 +
79 + return image_feature
80 +
81 +def get_model_name(args):
82 + from datetime import datetime, timedelta, timezone
83 + now = datetime.now(timezone.utc)
84 + tz = timezone(timedelta(hours=9))
85 + now = now.astimezone(tz)
86 + date_time = now.strftime("%B_%d_%H:%M:%S")
87 + model_name = '__'.join([date_time, args.network, str(args.seed)])
88 + return model_name
89 +
90 +
91 +
92 +def dict_to_namedtuple(d):
93 + Args = collections.namedtuple('Args', sorted(d.keys()))
94 +
95 + for k,v in d.items():
96 + if type(v) is dict:
97 + d[k] = dict_to_namedtuple(v)
98 +
99 + elif type(v) is str:
100 + try:
101 + d[k] = eval(v)
102 + except:
103 + d[k] = v
104 +
105 + args = Args(**d)
106 + return args
107 +
108 +
109 +def parse_args(kwargs):
110 + # combine with default args
111 + kwargs['dataset'] = kwargs['dataset'] if 'dataset' in kwargs else 'BraTS'
112 + kwargs['network'] = kwargs['network'] if 'network' in kwargs else 'resnet50'
113 + kwargs['optimizer'] = kwargs['optimizer'] if 'optimizer' in kwargs else 'adam'
114 + kwargs['learning_rate'] = kwargs['learning_rate'] if 'learning_rate' in kwargs else 0.001
115 + kwargs['seed'] = kwargs['seed'] if 'seed' in kwargs else None
116 + kwargs['use_cuda'] = kwargs['use_cuda'] if 'use_cuda' in kwargs else True
117 + kwargs['use_cuda'] = kwargs['use_cuda'] and torch.cuda.is_available()
118 + kwargs['num_workers'] = kwargs['num_workers'] if 'num_workers' in kwargs else 4
119 + kwargs['print_step'] = kwargs['print_step'] if 'print_step' in kwargs else 100
120 + kwargs['val_step'] = kwargs['val_step'] if 'val_step' in kwargs else 100
121 + kwargs['scheduler'] = kwargs['scheduler'] if 'scheduler' in kwargs else 'exp'
122 + kwargs['batch_size'] = kwargs['batch_size'] if 'batch_size' in kwargs else 32
123 + kwargs['start_step'] = kwargs['start_step'] if 'start_step' in kwargs else 0
124 + kwargs['max_step'] = kwargs['max_step'] if 'max_step' in kwargs else 2500
125 + kwargs['augment_path'] = kwargs['augment_path'] if 'augment_path' in kwargs else None
126 +
127 + # to named tuple
128 + args = dict_to_namedtuple(kwargs)
129 + return args, kwargs
130 +
131 +
132 +def select_model(args):
133 + # grayResNet2
134 + resnet_dict = {'resnet18':grayResNet2.resnet18(), 'resnet34':grayResNet2.resnet34(),
135 + 'resnet50':grayResNet2.resnet50(), 'resnet101':grayResNet2.resnet101(), 'resnet152':grayResNet2.resnet152()}
136 +
137 + if args.network in resnet_dict:
138 + backbone = resnet_dict[args.network]
139 + model = basenet.BaseNet(backbone, args)
140 + else:
141 + Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net')
142 + model = Net(args)
143 +
144 + #print(model) # print model architecture
145 + return model
146 +
147 +
148 +def select_optimizer(args, model):
149 + if args.optimizer == 'sgd':
150 + optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=0.0001)
151 + elif args.optimizer == 'rms':
152 + optimizer = torch.optim.RMSprop(model.parameters(), lr=args.learning_rate)
153 + elif args.optimizer == 'adam':
154 + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
155 + else:
156 + raise Exception('Unknown Optimizer')
157 + return optimizer
158 +
159 +
160 +def select_scheduler(args, optimizer):
161 + if not args.scheduler or args.scheduler == 'None':
162 + return None
163 + elif args.scheduler =='clr':
164 + return torch.optim.lr_scheduler.CyclicLR(optimizer, 0.01, 0.015, mode='triangular2', step_size_up=250000, cycle_momentum=False)
165 + elif args.scheduler =='exp':
166 + return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9999283, last_epoch=-1)
167 + else:
168 + raise Exception('Unknown Scheduler')
169 +
170 +
171 +class CustomDataset(Dataset):
172 + def __init__(self, data_path, csv_path):
173 + self.path = data_path
174 + self.imgs = natsorted(os.listdir(data_path))
175 + self.len = len(self.imgs)
176 + self.transform = transforms.Compose([
177 + transforms.Resize([240, 240]),
178 + transforms.ToTensor()
179 + ])
180 +
181 + df = pd.read_csv(csv_path)
182 + targets_list = []
183 +
184 + for fname in self.imgs:
185 + row = df.loc[df['filename'] == fname]
186 + targets_list.append(row.iloc[0, 1])
187 +
188 + self.targets = targets_list
189 +
190 + def __len__(self):
191 + return self.len
192 +
193 + def __getitem__(self, idx):
194 + img_loc = os.path.join(self.path, self.imgs[idx])
195 + targets = self.targets[idx]
196 + image = Image.open(img_loc)
197 + image = self.transform(image)
198 + return image, targets
199 +
200 +
201 +
202 +def get_dataset(args, transform, split='train'):
203 + assert split in ['train', 'val', 'test']
204 +
205 + if split in ['train']:
206 + dataset = CustomDataset(TRAIN_DATASET_PATH, TRAIN_TARGET_PATH)
207 + elif split in ['val']:
208 + dataset = CustomDataset(VAL_DATASET_PATH, VAL_TARGET_PATH)
209 + else : # test
210 + dataset = CustomDataset(TEST_DATASET_PATH, TEST_TARGET_PATH)
211 +
212 +
213 + return dataset
214 +
215 +
216 +def get_dataloader(args, dataset, shuffle=False, pin_memory=True):
217 + data_loader = torch.utils.data.DataLoader(dataset,
218 + batch_size=args.batch_size,
219 + shuffle=shuffle,
220 + num_workers=args.num_workers,
221 + pin_memory=pin_memory)
222 + return data_loader
223 +
224 +
225 +def get_aug_dataloader(args, dataset, shuffle=False, pin_memory=True):
226 + data_loader = torch.utils.data.DataLoader(dataset,
227 + batch_size=args.batch_size,
228 + shuffle=shuffle,
229 + num_workers=args.num_workers,
230 + pin_memory=pin_memory)
231 + return data_loader
232 +
233 +
234 +def get_inf_dataloader(args, dataset):
235 + global current_epoch
236 + data_loader = iter(get_dataloader(args, dataset, shuffle=True))
237 +
238 + while True:
239 + try:
240 + batch = next(data_loader)
241 +
242 + except StopIteration:
243 + current_epoch += 1
244 + data_loader = iter(get_dataloader(args, dataset, shuffle=True))
245 + batch = next(data_loader)
246 +
247 + yield batch
248 +
249 +
250 +
251 +
252 +def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None):
253 + model.train()
254 + images, target = batch
255 +
256 + if device:
257 + images = images.to(device)
258 + target = target.to(device)
259 +
260 + elif args.use_cuda:
261 + images = images.cuda(non_blocking=True)
262 + target = target.cuda(non_blocking=True)
263 +
264 + # compute output
265 + start_t = time.time()
266 + output, first = model(images)
267 + forward_t = time.time() - start_t
268 + loss = criterion(output, target)
269 +
270 + # measure accuracy and record loss
271 + acc1, acc5 = accuracy(output, target, topk=(1, 5))
272 + acc1 /= images.size(0)
273 + acc5 /= images.size(0)
274 +
275 + # compute gradient and do SGD step
276 + optimizer.zero_grad()
277 + start_t = time.time()
278 + loss.backward()
279 + backward_t = time.time() - start_t
280 + optimizer.step()
281 + if scheduler: scheduler.step()
282 +
283 + # if writer and step % args.print_step == 0:
284 + # n_imgs = min(images.size(0), 10)
285 + # tag = 'train/' + str(step)
286 + # for j in range(n_imgs):
287 + # writer.add_image(tag,
288 + # concat_image_features(images[j], first[j]), global_step=step)
289 +
290 + return acc1, acc5, loss, forward_t, backward_t
291 +
292 +
293 +#_acc1, _acc5 = accuracy(output, target, topk=(1, 5))
294 +def accuracy(output, target, topk=(1,)):
295 + """Computes the accuracy over the k top predictions for the specified values of k"""
296 + with torch.no_grad():
297 + maxk = max(topk)
298 + batch_size = target.size(0)
299 +
300 + _, pred = output.topk(maxk, 1, True, True)
301 + pred = pred.t()
302 + correct = pred.eq(target.view(1, -1).expand_as(pred))
303 +
304 +
305 +
306 + res = []
307 + for k in topk:
308 + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
309 + res.append(correct_k)
310 + return res
311 +
312 +def validate(args, model, criterion, valid_loader, step, writer, device=None):
313 + # switch to evaluate mode
314 + model.eval()
315 +
316 + acc1, acc5 = 0, 0
317 + samples = 0
318 + infer_t = 0
319 +
320 + with torch.no_grad():
321 + for i, (images, target) in enumerate(valid_loader):
322 +
323 + start_t = time.time()
324 + if device:
325 + images = images.to(device)
326 + target = target.to(device)
327 +
328 + elif args.use_cuda is not None:
329 + images = images.cuda(non_blocking=True)
330 + target = target.cuda(non_blocking=True)
331 +
332 + # compute output
333 + output, first = model(images)
334 + loss = criterion(output, target)
335 + infer_t += time.time() - start_t
336 +
337 + # measure accuracy and record loss
338 + _acc1, _acc5 = accuracy(output, target, topk=(1, 5))
339 + acc1 += _acc1
340 + acc5 += _acc5
341 + samples += images.size(0)
342 +
343 + acc1 /= samples
344 + acc5 /= samples
345 +
346 + # if writer:
347 + # n_imgs = min(images.size(0), 10)
348 + # for j in range(n_imgs):
349 + # writer.add_image('valid/input_image',
350 + # concat_image_features(images[j], first[j]), global_step=step)
351 +
352 + return acc1, acc5, loss, infer_t