Showing
4 changed files
with
359 additions
and
9 deletions
... | @@ -33,11 +33,8 @@ def eval(model_path): | ... | @@ -33,11 +33,8 @@ def eval(model_path): |
33 | model.load_state_dict(torch.load(weight_path)) | 33 | model.load_state_dict(torch.load(weight_path)) |
34 | 34 | ||
35 | print('\n[+] Load dataset') | 35 | print('\n[+] Load dataset') |
36 | - test_transform = get_valid_transform(args, model) | ||
37 | - #print('\nTEST Transform\n', test_transform) | ||
38 | test_dataset = get_dataset(args, 'test') | 36 | test_dataset = get_dataset(args, 'test') |
39 | 37 | ||
40 | - | ||
41 | 38 | ||
42 | test_loader = iter(get_dataloader(args, test_dataset)) ### | 39 | test_loader = iter(get_dataloader(args, test_dataset)) ### |
43 | 40 | ... | ... |
... | @@ -16,6 +16,8 @@ class BaseNet(nn.Module): | ... | @@ -16,6 +16,8 @@ class BaseNet(nn.Module): |
16 | x = self.after(f) | 16 | x = self.after(f) |
17 | x = x.reshape(x.size(0), -1) | 17 | x = x.reshape(x.size(0), -1) |
18 | x = self.fc(x) | 18 | x = self.fc(x) |
19 | + | ||
20 | + # output, first | ||
19 | return x, f | 21 | return x, f |
20 | 22 | ||
21 | """ | 23 | """ | ... | ... |
... | @@ -24,7 +24,7 @@ def train(**kwargs): | ... | @@ -24,7 +24,7 @@ def train(**kwargs): |
24 | 24 | ||
25 | print('\n[+] Create log dir') | 25 | print('\n[+] Create log dir') |
26 | model_name = get_model_name(args) | 26 | model_name = get_model_name(args) |
27 | - log_dir = os.path.join('/content/drive/My Drive/CD2 Project/classify/', model_name) | 27 | + log_dir = os.path.join('/content/drive/My Drive/CD2 Project/runs/classify/', model_name) |
28 | os.makedirs(os.path.join(log_dir, 'model')) | 28 | os.makedirs(os.path.join(log_dir, 'model')) |
29 | json.dump(kwargs, open(os.path.join(log_dir, 'kwargs.json'), 'w')) | 29 | json.dump(kwargs, open(os.path.join(log_dir, 'kwargs.json'), 'w')) |
30 | writer = SummaryWriter(log_dir=log_dir) | 30 | writer = SummaryWriter(log_dir=log_dir) |
... | @@ -42,13 +42,11 @@ def train(**kwargs): | ... | @@ -42,13 +42,11 @@ def train(**kwargs): |
42 | if args.use_cuda: | 42 | if args.use_cuda: |
43 | model = model.cuda() | 43 | model = model.cuda() |
44 | criterion = criterion.cuda() | 44 | criterion = criterion.cuda() |
45 | - writer.add_graph(model) | 45 | + #writer.add_graph(model) |
46 | 46 | ||
47 | print('\n[+] Load dataset') | 47 | print('\n[+] Load dataset') |
48 | - transform = get_train_transform(args, model, log_dir) | 48 | + train_dataset = get_dataset(args, 'train') |
49 | - val_transform = get_valid_transform(args, model) | 49 | + valid_dataset = get_dataset(args, 'val') |
50 | - train_dataset = get_dataset(args, transform, 'train') | ||
51 | - valid_dataset = get_dataset(args, val_transform, 'val') | ||
52 | train_loader = iter(get_inf_dataloader(args, train_dataset)) | 50 | train_loader = iter(get_inf_dataloader(args, train_dataset)) |
53 | max_epoch = len(train_dataset) // args.batch_size | 51 | max_epoch = len(train_dataset) // args.batch_size |
54 | best_acc = -1 | 52 | best_acc = -1 |
... | @@ -82,6 +80,7 @@ def train(**kwargs): | ... | @@ -82,6 +80,7 @@ def train(**kwargs): |
82 | print(' BW Time : {:.3f}ms'.format(_train_res[4]*1000)) | 80 | print(' BW Time : {:.3f}ms'.format(_train_res[4]*1000)) |
83 | 81 | ||
84 | if step % args.val_step == args.val_step-1: | 82 | if step % args.val_step == args.val_step-1: |
83 | + # print("\nstep, args.val_step: ", step, args.val_step) | ||
85 | valid_loader = iter(get_dataloader(args, valid_dataset)) | 84 | valid_loader = iter(get_dataloader(args, valid_dataset)) |
86 | _valid_res = validate(args, model, criterion, valid_loader, step, writer) | 85 | _valid_res = validate(args, model, criterion, valid_loader, step, writer) |
87 | print('\n[+] Valid results') | 86 | print('\n[+] Valid results') | ... | ... |
code/classifier/utils.py
0 → 100644
1 | +import os | ||
2 | +import time | ||
3 | +import importlib | ||
4 | +import collections | ||
5 | +import pickle as cp | ||
6 | +import glob | ||
7 | +import numpy as np | ||
8 | +import pandas as pd | ||
9 | + | ||
10 | +from natsort import natsorted | ||
11 | +from PIL import Image | ||
12 | +import torch | ||
13 | +import torchvision | ||
14 | +import torch.nn.functional as F | ||
15 | +import torchvision.models as models | ||
16 | +import torchvision.transforms as transforms | ||
17 | +from torch.utils.data import Subset | ||
18 | +from torch.utils.data import Dataset, DataLoader | ||
19 | + | ||
20 | +from sklearn.model_selection import StratifiedShuffleSplit | ||
21 | +from sklearn.model_selection import train_test_split | ||
22 | +from sklearn.model_selection import KFold | ||
23 | + | ||
24 | +from networks import basenet, grayResNet2 | ||
25 | + | ||
26 | + | ||
27 | +TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/nonaug+Normal_train/' | ||
28 | +TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/train_nonaug_classify_target.csv' | ||
29 | +VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/nonaug+Normal_val/' | ||
30 | +VAL_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/val_nonaug_classify_target.csv' | ||
31 | +TEST_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/nonaug+Normal_test/' | ||
32 | +TEST_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/test_nonaug_classify_target.csv' | ||
33 | + | ||
34 | +current_epoch = 0 | ||
35 | + | ||
36 | +def concat_image_features(image, features, max_features=3): | ||
37 | + _, h, w = image.shape | ||
38 | + #print("\nfsize: ", features.size()) # (1, 240, 240) | ||
39 | + # features.size(0) = 64 | ||
40 | + #print(features.size(0)) | ||
41 | + #max_features = min(features.size(0), max_features) | ||
42 | + | ||
43 | + max_features = features.size(0) | ||
44 | + image_feature = image.clone() | ||
45 | + | ||
46 | + for i in range(max_features): | ||
47 | + # features torch.Size([64, 16, 16]) | ||
48 | + | ||
49 | + feature = features[i:i+1] | ||
50 | + #torch.Size([1, 16, 16]) | ||
51 | + | ||
52 | + _min, _max = torch.min(feature), torch.max(feature) | ||
53 | + feature = (feature - _min) / (_max - _min + 1e-6) | ||
54 | + # torch.Size([1, 16, 16]) | ||
55 | + | ||
56 | + feature = torch.cat([feature]*1, 0) | ||
57 | + #feature = torch.cat([feature]*3, 0) | ||
58 | + # torch.Size([3, 16, 16]) -> [1, 16, 16] | ||
59 | + | ||
60 | + feature = feature.view(1, 1, feature.size(1), feature.size(2)) | ||
61 | + #feature = feature.view(1, 3, feature.size(1), feature.size(2)) | ||
62 | + # torch.Size([1, 3, 16, 16])-> [1, 1, 16, 16] | ||
63 | + | ||
64 | + feature = F.upsample(feature, size=(h,w), mode="bilinear") | ||
65 | + # torch.Size([1, 3, 32, 32])-> [1, 1, 32, 32] | ||
66 | + | ||
67 | + feature = feature.view(1, h, w) #(3, h, w) input of size 3072 | ||
68 | + # torch.Size([3, 32, 32])->[1, 32, 32] | ||
69 | + | ||
70 | + #print("img_feature & feature size:\n", image_feature.size(),"\n", feature.size()) | ||
71 | + # img_feature & feature size: | ||
72 | + # torch.Size([1, 32, 32]) -> [1, 32, 64] | ||
73 | + # torch.Size([3, 32, 32] ->[1, 32, 32] | ||
74 | + | ||
75 | + | ||
76 | + image_feature = torch.cat((image_feature, feature), 2) ### dim = 2 | ||
77 | + #print("\nimg feature size: ", image_feature.size()) #[1, 240, 720] | ||
78 | + | ||
79 | + return image_feature | ||
80 | + | ||
81 | +def get_model_name(args): | ||
82 | + from datetime import datetime, timedelta, timezone | ||
83 | + now = datetime.now(timezone.utc) | ||
84 | + tz = timezone(timedelta(hours=9)) | ||
85 | + now = now.astimezone(tz) | ||
86 | + date_time = now.strftime("%B_%d_%H:%M:%S") | ||
87 | + model_name = '__'.join([date_time, args.network, str(args.seed)]) | ||
88 | + return model_name | ||
89 | + | ||
90 | + | ||
91 | + | ||
92 | +def dict_to_namedtuple(d): | ||
93 | + Args = collections.namedtuple('Args', sorted(d.keys())) | ||
94 | + | ||
95 | + for k,v in d.items(): | ||
96 | + if type(v) is dict: | ||
97 | + d[k] = dict_to_namedtuple(v) | ||
98 | + | ||
99 | + elif type(v) is str: | ||
100 | + try: | ||
101 | + d[k] = eval(v) | ||
102 | + except: | ||
103 | + d[k] = v | ||
104 | + | ||
105 | + args = Args(**d) | ||
106 | + return args | ||
107 | + | ||
108 | + | ||
109 | +def parse_args(kwargs): | ||
110 | + # combine with default args | ||
111 | + kwargs['dataset'] = kwargs['dataset'] if 'dataset' in kwargs else 'BraTS' | ||
112 | + kwargs['network'] = kwargs['network'] if 'network' in kwargs else 'resnet50' | ||
113 | + kwargs['optimizer'] = kwargs['optimizer'] if 'optimizer' in kwargs else 'adam' | ||
114 | + kwargs['learning_rate'] = kwargs['learning_rate'] if 'learning_rate' in kwargs else 0.001 | ||
115 | + kwargs['seed'] = kwargs['seed'] if 'seed' in kwargs else None | ||
116 | + kwargs['use_cuda'] = kwargs['use_cuda'] if 'use_cuda' in kwargs else True | ||
117 | + kwargs['use_cuda'] = kwargs['use_cuda'] and torch.cuda.is_available() | ||
118 | + kwargs['num_workers'] = kwargs['num_workers'] if 'num_workers' in kwargs else 4 | ||
119 | + kwargs['print_step'] = kwargs['print_step'] if 'print_step' in kwargs else 100 | ||
120 | + kwargs['val_step'] = kwargs['val_step'] if 'val_step' in kwargs else 100 | ||
121 | + kwargs['scheduler'] = kwargs['scheduler'] if 'scheduler' in kwargs else 'exp' | ||
122 | + kwargs['batch_size'] = kwargs['batch_size'] if 'batch_size' in kwargs else 32 | ||
123 | + kwargs['start_step'] = kwargs['start_step'] if 'start_step' in kwargs else 0 | ||
124 | + kwargs['max_step'] = kwargs['max_step'] if 'max_step' in kwargs else 2500 | ||
125 | + kwargs['augment_path'] = kwargs['augment_path'] if 'augment_path' in kwargs else None | ||
126 | + | ||
127 | + # to named tuple | ||
128 | + args = dict_to_namedtuple(kwargs) | ||
129 | + return args, kwargs | ||
130 | + | ||
131 | + | ||
132 | +def select_model(args): | ||
133 | + # grayResNet2 | ||
134 | + resnet_dict = {'resnet18':grayResNet2.resnet18(), 'resnet34':grayResNet2.resnet34(), | ||
135 | + 'resnet50':grayResNet2.resnet50(), 'resnet101':grayResNet2.resnet101(), 'resnet152':grayResNet2.resnet152()} | ||
136 | + | ||
137 | + if args.network in resnet_dict: | ||
138 | + backbone = resnet_dict[args.network] | ||
139 | + model = basenet.BaseNet(backbone, args) | ||
140 | + else: | ||
141 | + Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net') | ||
142 | + model = Net(args) | ||
143 | + | ||
144 | + #print(model) # print model architecture | ||
145 | + return model | ||
146 | + | ||
147 | + | ||
148 | +def select_optimizer(args, model): | ||
149 | + if args.optimizer == 'sgd': | ||
150 | + optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=0.0001) | ||
151 | + elif args.optimizer == 'rms': | ||
152 | + optimizer = torch.optim.RMSprop(model.parameters(), lr=args.learning_rate) | ||
153 | + elif args.optimizer == 'adam': | ||
154 | + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) | ||
155 | + else: | ||
156 | + raise Exception('Unknown Optimizer') | ||
157 | + return optimizer | ||
158 | + | ||
159 | + | ||
160 | +def select_scheduler(args, optimizer): | ||
161 | + if not args.scheduler or args.scheduler == 'None': | ||
162 | + return None | ||
163 | + elif args.scheduler =='clr': | ||
164 | + return torch.optim.lr_scheduler.CyclicLR(optimizer, 0.01, 0.015, mode='triangular2', step_size_up=250000, cycle_momentum=False) | ||
165 | + elif args.scheduler =='exp': | ||
166 | + return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9999283, last_epoch=-1) | ||
167 | + else: | ||
168 | + raise Exception('Unknown Scheduler') | ||
169 | + | ||
170 | + | ||
171 | +class CustomDataset(Dataset): | ||
172 | + def __init__(self, data_path, csv_path): | ||
173 | + self.path = data_path | ||
174 | + self.imgs = natsorted(os.listdir(data_path)) | ||
175 | + self.len = len(self.imgs) | ||
176 | + self.transform = transforms.Compose([ | ||
177 | + transforms.Resize([240, 240]), | ||
178 | + transforms.ToTensor() | ||
179 | + ]) | ||
180 | + | ||
181 | + df = pd.read_csv(csv_path) | ||
182 | + targets_list = [] | ||
183 | + | ||
184 | + for fname in self.imgs: | ||
185 | + row = df.loc[df['filename'] == fname] | ||
186 | + targets_list.append(row.iloc[0, 1]) | ||
187 | + | ||
188 | + self.targets = targets_list | ||
189 | + | ||
190 | + def __len__(self): | ||
191 | + return self.len | ||
192 | + | ||
193 | + def __getitem__(self, idx): | ||
194 | + img_loc = os.path.join(self.path, self.imgs[idx]) | ||
195 | + targets = self.targets[idx] | ||
196 | + image = Image.open(img_loc) | ||
197 | + image = self.transform(image) | ||
198 | + return image, targets | ||
199 | + | ||
200 | + | ||
201 | + | ||
202 | +def get_dataset(args, transform, split='train'): | ||
203 | + assert split in ['train', 'val', 'test'] | ||
204 | + | ||
205 | + if split in ['train']: | ||
206 | + dataset = CustomDataset(TRAIN_DATASET_PATH, TRAIN_TARGET_PATH) | ||
207 | + elif split in ['val']: | ||
208 | + dataset = CustomDataset(VAL_DATASET_PATH, VAL_TARGET_PATH) | ||
209 | + else : # test | ||
210 | + dataset = CustomDataset(TEST_DATASET_PATH, TEST_TARGET_PATH) | ||
211 | + | ||
212 | + | ||
213 | + return dataset | ||
214 | + | ||
215 | + | ||
216 | +def get_dataloader(args, dataset, shuffle=False, pin_memory=True): | ||
217 | + data_loader = torch.utils.data.DataLoader(dataset, | ||
218 | + batch_size=args.batch_size, | ||
219 | + shuffle=shuffle, | ||
220 | + num_workers=args.num_workers, | ||
221 | + pin_memory=pin_memory) | ||
222 | + return data_loader | ||
223 | + | ||
224 | + | ||
225 | +def get_aug_dataloader(args, dataset, shuffle=False, pin_memory=True): | ||
226 | + data_loader = torch.utils.data.DataLoader(dataset, | ||
227 | + batch_size=args.batch_size, | ||
228 | + shuffle=shuffle, | ||
229 | + num_workers=args.num_workers, | ||
230 | + pin_memory=pin_memory) | ||
231 | + return data_loader | ||
232 | + | ||
233 | + | ||
234 | +def get_inf_dataloader(args, dataset): | ||
235 | + global current_epoch | ||
236 | + data_loader = iter(get_dataloader(args, dataset, shuffle=True)) | ||
237 | + | ||
238 | + while True: | ||
239 | + try: | ||
240 | + batch = next(data_loader) | ||
241 | + | ||
242 | + except StopIteration: | ||
243 | + current_epoch += 1 | ||
244 | + data_loader = iter(get_dataloader(args, dataset, shuffle=True)) | ||
245 | + batch = next(data_loader) | ||
246 | + | ||
247 | + yield batch | ||
248 | + | ||
249 | + | ||
250 | + | ||
251 | + | ||
252 | +def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None): | ||
253 | + model.train() | ||
254 | + images, target = batch | ||
255 | + | ||
256 | + if device: | ||
257 | + images = images.to(device) | ||
258 | + target = target.to(device) | ||
259 | + | ||
260 | + elif args.use_cuda: | ||
261 | + images = images.cuda(non_blocking=True) | ||
262 | + target = target.cuda(non_blocking=True) | ||
263 | + | ||
264 | + # compute output | ||
265 | + start_t = time.time() | ||
266 | + output, first = model(images) | ||
267 | + forward_t = time.time() - start_t | ||
268 | + loss = criterion(output, target) | ||
269 | + | ||
270 | + # measure accuracy and record loss | ||
271 | + acc1, acc5 = accuracy(output, target, topk=(1, 5)) | ||
272 | + acc1 /= images.size(0) | ||
273 | + acc5 /= images.size(0) | ||
274 | + | ||
275 | + # compute gradient and do SGD step | ||
276 | + optimizer.zero_grad() | ||
277 | + start_t = time.time() | ||
278 | + loss.backward() | ||
279 | + backward_t = time.time() - start_t | ||
280 | + optimizer.step() | ||
281 | + if scheduler: scheduler.step() | ||
282 | + | ||
283 | + # if writer and step % args.print_step == 0: | ||
284 | + # n_imgs = min(images.size(0), 10) | ||
285 | + # tag = 'train/' + str(step) | ||
286 | + # for j in range(n_imgs): | ||
287 | + # writer.add_image(tag, | ||
288 | + # concat_image_features(images[j], first[j]), global_step=step) | ||
289 | + | ||
290 | + return acc1, acc5, loss, forward_t, backward_t | ||
291 | + | ||
292 | + | ||
293 | +#_acc1, _acc5 = accuracy(output, target, topk=(1, 5)) | ||
294 | +def accuracy(output, target, topk=(1,)): | ||
295 | + """Computes the accuracy over the k top predictions for the specified values of k""" | ||
296 | + with torch.no_grad(): | ||
297 | + maxk = max(topk) | ||
298 | + batch_size = target.size(0) | ||
299 | + | ||
300 | + _, pred = output.topk(maxk, 1, True, True) | ||
301 | + pred = pred.t() | ||
302 | + correct = pred.eq(target.view(1, -1).expand_as(pred)) | ||
303 | + | ||
304 | + | ||
305 | + | ||
306 | + res = [] | ||
307 | + for k in topk: | ||
308 | + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | ||
309 | + res.append(correct_k) | ||
310 | + return res | ||
311 | + | ||
312 | +def validate(args, model, criterion, valid_loader, step, writer, device=None): | ||
313 | + # switch to evaluate mode | ||
314 | + model.eval() | ||
315 | + | ||
316 | + acc1, acc5 = 0, 0 | ||
317 | + samples = 0 | ||
318 | + infer_t = 0 | ||
319 | + | ||
320 | + with torch.no_grad(): | ||
321 | + for i, (images, target) in enumerate(valid_loader): | ||
322 | + | ||
323 | + start_t = time.time() | ||
324 | + if device: | ||
325 | + images = images.to(device) | ||
326 | + target = target.to(device) | ||
327 | + | ||
328 | + elif args.use_cuda is not None: | ||
329 | + images = images.cuda(non_blocking=True) | ||
330 | + target = target.cuda(non_blocking=True) | ||
331 | + | ||
332 | + # compute output | ||
333 | + output, first = model(images) | ||
334 | + loss = criterion(output, target) | ||
335 | + infer_t += time.time() - start_t | ||
336 | + | ||
337 | + # measure accuracy and record loss | ||
338 | + _acc1, _acc5 = accuracy(output, target, topk=(1, 5)) | ||
339 | + acc1 += _acc1 | ||
340 | + acc5 += _acc5 | ||
341 | + samples += images.size(0) | ||
342 | + | ||
343 | + acc1 /= samples | ||
344 | + acc5 /= samples | ||
345 | + | ||
346 | + # if writer: | ||
347 | + # n_imgs = min(images.size(0), 10) | ||
348 | + # for j in range(n_imgs): | ||
349 | + # writer.add_image('valid/input_image', | ||
350 | + # concat_image_features(images[j], first[j]), global_step=step) | ||
351 | + | ||
352 | + return acc1, acc5, loss, infer_t |
-
Please register or login to post a comment