Showing
5 changed files
with
208 additions
and
272 deletions
code/FAA2_VM/getAugmented_1.py
0 → 100644
1 | +import os | ||
2 | +import fire | ||
3 | +import json | ||
4 | +from pprint import pprint | ||
5 | +import pickle | ||
6 | + | ||
7 | +import torch | ||
8 | +import torch.nn as nn | ||
9 | +from torch.utils.tensorboard import SummaryWriter | ||
10 | + | ||
11 | +from utils import * | ||
12 | + | ||
13 | +# command | ||
14 | +# python getAugmented.py --model_path='logs/April_24_21:05:15__resnet50__None/' | ||
15 | + | ||
16 | +def eval(model_path): | ||
17 | + print('\n[+] Parse arguments') | ||
18 | + kwargs_path = os.path.join(model_path, 'kwargs.json') | ||
19 | + kwargs = json.loads(open(kwargs_path).read()) | ||
20 | + args, kwargs = parse_args(kwargs) | ||
21 | + pprint(args) | ||
22 | + device = torch.device('cuda' if args.use_cuda else 'cpu') | ||
23 | + | ||
24 | + cp_path = os.path.join(model_path, 'augmentation.cp') | ||
25 | + | ||
26 | + writer = SummaryWriter(log_dir=model_path) | ||
27 | + | ||
28 | + | ||
29 | + print('\n[+] Load transform') | ||
30 | + # list | ||
31 | + with open(cp_path, 'rb') as f: | ||
32 | + aug_transform_list = pickle.load(f) | ||
33 | + | ||
34 | + augmented_image_list = [torch.Tensor(240,0)] * len(get_dataset(args, None, 'test')) | ||
35 | + | ||
36 | + | ||
37 | + print('\n[+] Load dataset') | ||
38 | + for aug_idx, aug_transform in enumerate(aug_transform_list): | ||
39 | + dataset = get_dataset(args, aug_transform, 'test') | ||
40 | + | ||
41 | + loader = iter(get_aug_dataloader(args, dataset)) | ||
42 | + | ||
43 | + for i, (images, target) in enumerate(loader): | ||
44 | + images = images.view(240, 240) | ||
45 | + | ||
46 | + # concat image | ||
47 | + augmented_image_list[i] = torch.cat([augmented_image_list[i], images], dim = 1) | ||
48 | + | ||
49 | + if i % 1000 == 0: | ||
50 | + print("\n images size: ", augmented_image_list[i].size()) # [240, 240] | ||
51 | + | ||
52 | + break | ||
53 | + # break | ||
54 | + | ||
55 | + | ||
56 | + # print(augmented_image_list) | ||
57 | + | ||
58 | + | ||
59 | + print('\n[+] Write on tensorboard') | ||
60 | + if writer: | ||
61 | + for i, data in enumerate(augmented_image_list): | ||
62 | + tag = 'img/' + str(i) | ||
63 | + writer.add_image(tag, data.view(1, 240, -1), global_step=0) | ||
64 | + break | ||
65 | + | ||
66 | + writer.close() | ||
67 | + | ||
68 | + | ||
69 | + # if writer: | ||
70 | + # for j in range(): | ||
71 | + # tag = 'img/' + str(img_count) + '_' + str(j) | ||
72 | + # # writer.add_image(tag, | ||
73 | + # # concat_image_features(images[j], first[j]), global_step=step) | ||
74 | + # # if j > 0: | ||
75 | + # # fore = concat_image_features(fore, images[j]) | ||
76 | + | ||
77 | + # writer.add_image(tag, fore, global_step=0) | ||
78 | + # img_count = img_count + 1 | ||
79 | + | ||
80 | + # writer.close() | ||
81 | + | ||
82 | +if __name__ == '__main__': | ||
83 | + fire.Fire(eval) |
code/FAA2_VM/getAugmented_all.py
0 → 100644
1 | +import os | ||
2 | +import fire | ||
3 | +import json | ||
4 | +from pprint import pprint | ||
5 | +import pickle | ||
6 | + | ||
7 | +import torch | ||
8 | +import torch.nn as nn | ||
9 | +from torch.utils.tensorboard import SummaryWriter | ||
10 | + | ||
11 | +from utils import * | ||
12 | + | ||
13 | +# command | ||
14 | +# python getAugmented.py --model_path='logs/April_24_21:05:15__resnet50__None/' | ||
15 | + | ||
16 | +def eval(model_path): | ||
17 | + print('\n[+] Parse arguments') | ||
18 | + kwargs_path = os.path.join(model_path, 'kwargs.json') | ||
19 | + kwargs = json.loads(open(kwargs_path).read()) | ||
20 | + args, kwargs = parse_args(kwargs) | ||
21 | + pprint(args) | ||
22 | + device = torch.device('cuda' if args.use_cuda else 'cpu') | ||
23 | + | ||
24 | + cp_path = os.path.join(model_path, 'augmentation.cp') | ||
25 | + | ||
26 | + writer = SummaryWriter(log_dir=model_path) | ||
27 | + | ||
28 | + | ||
29 | + print('\n[+] Load transform') | ||
30 | + # list | ||
31 | + with open(cp_path, 'rb') as f: | ||
32 | + aug_transform_list = pickle.load(f) | ||
33 | + | ||
34 | + augmented_image_list = [torch.Tensor(240,0)] * len(get_dataset(args, None, 'train')) | ||
35 | + | ||
36 | + | ||
37 | + print('\n[+] Load dataset') | ||
38 | + for aug_idx, aug_transform in enumerate(aug_transform_list): | ||
39 | + dataset = get_dataset(args, aug_transform, 'train') | ||
40 | + | ||
41 | + loader = iter(get_aug_dataloader(args, dataset)) | ||
42 | + | ||
43 | + for i, (images, target) in enumerate(loader): | ||
44 | + images = images.view(240, 240) | ||
45 | + | ||
46 | + # concat image | ||
47 | + augmented_image_list[i] = torch.cat([augmented_image_list[i], images], dim = 1) | ||
48 | + | ||
49 | + | ||
50 | + | ||
51 | + | ||
52 | + print('\n[+] Write on tensorboard') | ||
53 | + if writer: | ||
54 | + for i, data in enumerate(augmented_image_list): | ||
55 | + tag = 'img/' + str(i) | ||
56 | + writer.add_image(tag, data.view(1, 240, -1), global_step=0) | ||
57 | + | ||
58 | + writer.close() | ||
59 | + | ||
60 | + | ||
61 | +if __name__ == '__main__': | ||
62 | + fire.Fire(eval) |
code/FAA2_VM/getAugmented_saveimg.py
0 → 100644
1 | +import os | ||
2 | +import fire | ||
3 | +import json | ||
4 | +from pprint import pprint | ||
5 | +import pickle | ||
6 | +import random | ||
7 | + | ||
8 | +import torch | ||
9 | +import torch.nn as nn | ||
10 | +from torchvision.utils import save_image | ||
11 | +from torch.utils.tensorboard import SummaryWriter | ||
12 | + | ||
13 | +from utils import * | ||
14 | + | ||
15 | +# command | ||
16 | +# python getAugmented_saveimg.py --model_path='logs/April_26_00:55:16__resnet50__None/' | ||
17 | + | ||
18 | +def eval(model_path): | ||
19 | + print('\n[+] Parse arguments') | ||
20 | + kwargs_path = os.path.join(model_path, 'kwargs.json') | ||
21 | + kwargs = json.loads(open(kwargs_path).read()) | ||
22 | + args, kwargs = parse_args(kwargs) | ||
23 | + pprint(args) | ||
24 | + device = torch.device('cuda' if args.use_cuda else 'cpu') | ||
25 | + | ||
26 | + cp_path = os.path.join(model_path, 'augmentation.cp') | ||
27 | + | ||
28 | + writer = SummaryWriter(log_dir=model_path) | ||
29 | + | ||
30 | + | ||
31 | + print('\n[+] Load transform') | ||
32 | + # list to tensor | ||
33 | + with open(cp_path, 'rb') as f: | ||
34 | + aug_transform_list = pickle.load(f) | ||
35 | + | ||
36 | + transform = transforms.RandomChoice(aug_transform_list) | ||
37 | + | ||
38 | + | ||
39 | + print('\n[+] Load dataset') | ||
40 | + | ||
41 | + dataset = get_dataset(args, transform, 'train') | ||
42 | + loader = iter(get_aug_dataloader(args, dataset)) | ||
43 | + | ||
44 | + | ||
45 | + print('\n[+] Save 1 random policy') | ||
46 | + os.makedirs(os.path.join(model_path, 'augmented_imgs')) | ||
47 | + save_dir = os.path.join(model_path, 'augmented_imgs') | ||
48 | + | ||
49 | + for i, (image, target) in enumerate(loader): | ||
50 | + image = image.view(240, 240) | ||
51 | + # save img | ||
52 | + save_image(image, os.path.join(save_dir, 'aug_'+ str(i) + '.png')) | ||
53 | + | ||
54 | + if(i % 100 == 0): | ||
55 | + print("\n saved images: ", i) | ||
56 | + | ||
57 | + print('\n[+] Finished to save') | ||
58 | + | ||
59 | +if __name__ == '__main__': | ||
60 | + fire.Fire(eval) | ||
61 | + | ||
62 | + | ||
63 | + |
code/classifier/classify_normal_lesion.ipynb
0 → 100644
This diff is collapsed. Click to expand it.
code/classifier/utils/util.py
deleted
100644 → 0
1 | -import os | ||
2 | -import time | ||
3 | -import importlib | ||
4 | -import collections | ||
5 | -import pickle as cp | ||
6 | -import glob | ||
7 | -import numpy as np | ||
8 | -import pandas as pd | ||
9 | - | ||
10 | -from natsort import natsorted | ||
11 | -from PIL import Image | ||
12 | -import torch | ||
13 | -import torchvision | ||
14 | -import torch.nn.functional as F | ||
15 | -import torchvision.models as models | ||
16 | -import torchvision.transforms as transforms | ||
17 | -from torch.utils.data import Subset | ||
18 | -from torch.utils.data import Dataset, DataLoader | ||
19 | - | ||
20 | -from sklearn.model_selection import StratifiedShuffleSplit | ||
21 | -from sklearn.model_selection import train_test_split | ||
22 | -from sklearn.model_selection import KFold | ||
23 | - | ||
24 | -from networks import * | ||
25 | - | ||
26 | - | ||
27 | -TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/nonaug+Normal_train/' | ||
28 | -TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/train_nonaug_classify_target.csv' | ||
29 | -# VAL_DATASET_PATH = '../../data/MICCAI_BraTS_2019_Data_Training/ce_valid/' | ||
30 | -# VAL_TARGET_PATH = '../../data/MICCAI_BraTS_2019_Data_Training/ce_valid_targets.csv' | ||
31 | - | ||
32 | -current_epoch = 0 | ||
33 | - | ||
34 | - | ||
35 | -def split_dataset(args, dataset, k): | ||
36 | - # load dataset | ||
37 | - X = list(range(len(dataset))) | ||
38 | - Y = dataset.targets | ||
39 | - | ||
40 | - # split to k-fold | ||
41 | - assert len(X) == len(Y) | ||
42 | - | ||
43 | - def _it_to_list(_it): | ||
44 | - return list(zip(*list(_it))) | ||
45 | - | ||
46 | - sss = StratifiedShuffleSplit(n_splits=k, random_state=args.seed, test_size=0.1) | ||
47 | - Dm_indexes, Da_indexes = _it_to_list(sss.split(X, Y)) | ||
48 | - | ||
49 | - return Dm_indexes, Da_indexes | ||
50 | - | ||
51 | - | ||
52 | - | ||
53 | -def get_model_name(args): | ||
54 | - from datetime import datetime, timedelta, timezone | ||
55 | - now = datetime.now(timezone.utc) | ||
56 | - tz = timezone(timedelta(hours=9)) | ||
57 | - now = now.astimezone(tz) | ||
58 | - date_time = now.strftime("%B_%d_%H:%M:%S") | ||
59 | - model_name = '__'.join([date_time, args.network, str(args.seed)]) | ||
60 | - return model_name | ||
61 | - | ||
62 | - | ||
63 | -def dict_to_namedtuple(d): | ||
64 | - Args = collections.namedtuple('Args', sorted(d.keys())) | ||
65 | - | ||
66 | - for k,v in d.items(): | ||
67 | - if type(v) is dict: | ||
68 | - d[k] = dict_to_namedtuple(v) | ||
69 | - | ||
70 | - elif type(v) is str: | ||
71 | - try: | ||
72 | - d[k] = eval(v) | ||
73 | - except: | ||
74 | - d[k] = v | ||
75 | - | ||
76 | - args = Args(**d) | ||
77 | - return args | ||
78 | - | ||
79 | - | ||
80 | -def parse_args(kwargs): | ||
81 | - # combine with default args | ||
82 | - kwargs['dataset'] = kwargs['dataset'] if 'dataset' in kwargs else 'BraTS' | ||
83 | - kwargs['network'] = kwargs['network'] if 'network' in kwargs else 'resnet50' | ||
84 | - kwargs['optimizer'] = kwargs['optimizer'] if 'optimizer' in kwargs else 'adam' | ||
85 | - kwargs['learning_rate'] = kwargs['learning_rate'] if 'learning_rate' in kwargs else 0.0001 | ||
86 | - kwargs['seed'] = kwargs['seed'] if 'seed' in kwargs else None | ||
87 | - kwargs['use_cuda'] = kwargs['use_cuda'] if 'use_cuda' in kwargs else True | ||
88 | - kwargs['use_cuda'] = kwargs['use_cuda'] and torch.cuda.is_available() | ||
89 | - kwargs['num_workers'] = kwargs['num_workers'] if 'num_workers' in kwargs else 4 | ||
90 | - kwargs['print_step'] = kwargs['print_step'] if 'print_step' in kwargs else 500 | ||
91 | - kwargs['val_step'] = kwargs['val_step'] if 'val_step' in kwargs else 500 | ||
92 | - kwargs['scheduler'] = kwargs['scheduler'] if 'scheduler' in kwargs else 'exp' | ||
93 | - kwargs['batch_size'] = kwargs['batch_size'] if 'batch_size' in kwargs else 128 | ||
94 | - kwargs['start_step'] = kwargs['start_step'] if 'start_step' in kwargs else 0 | ||
95 | - kwargs['max_step'] = kwargs['max_step'] if 'max_step' in kwargs else 5000 | ||
96 | - kwargs['augment_path'] = kwargs['augment_path'] if 'augment_path' in kwargs else None | ||
97 | - | ||
98 | - # to named tuple | ||
99 | - args = dict_to_namedtuple(kwargs) | ||
100 | - return args, kwargs | ||
101 | - | ||
102 | - | ||
103 | -def select_model(args): | ||
104 | - # grayResNet2 | ||
105 | - resnet_dict = {'resnet18':grayResNet2.resnet18(), 'resnet34':grayResNet2.resnet34(), | ||
106 | - 'resnet50':grayResNet2.resnet50(), 'resnet101':grayResNet2.resnet101(), 'resnet152':grayResNet2.resnet152()} | ||
107 | - | ||
108 | - if args.network in resnet_dict: | ||
109 | - backbone = resnet_dict[args.network] | ||
110 | - model = basenet.BaseNet(backbone, args) | ||
111 | - else: | ||
112 | - Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net') | ||
113 | - model = Net(args) | ||
114 | - | ||
115 | - #print(model) # print model architecture | ||
116 | - return model | ||
117 | - | ||
118 | - | ||
119 | -def select_optimizer(args, model): | ||
120 | - if args.optimizer == 'sgd': | ||
121 | - optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=0.0001) | ||
122 | - elif args.optimizer == 'rms': | ||
123 | - optimizer = torch.optim.RMSprop(model.parameters(), lr=args.learning_rate) | ||
124 | - elif args.optimizer == 'adam': | ||
125 | - optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) | ||
126 | - else: | ||
127 | - raise Exception('Unknown Optimizer') | ||
128 | - return optimizer | ||
129 | - | ||
130 | - | ||
131 | -def select_scheduler(args, optimizer): | ||
132 | - if not args.scheduler or args.scheduler == 'None': | ||
133 | - return None | ||
134 | - elif args.scheduler =='clr': | ||
135 | - return torch.optim.lr_scheduler.CyclicLR(optimizer, 0.01, 0.015, mode='triangular2', step_size_up=250000, cycle_momentum=False) | ||
136 | - elif args.scheduler =='exp': | ||
137 | - return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9999283, last_epoch=-1) | ||
138 | - else: | ||
139 | - raise Exception('Unknown Scheduler') | ||
140 | - | ||
141 | - | ||
142 | -class CustomDataset(Dataset): | ||
143 | - def __init__(self, data_path, csv_path): | ||
144 | - self.len = len(self.imgs) | ||
145 | - self.path = data_path | ||
146 | - self.imgs = natsorted(os.listdir(data_path)) | ||
147 | - | ||
148 | - df = pd.read_csv(csv_path) | ||
149 | - targets_list = [] | ||
150 | - | ||
151 | - for fname in self.imgs: | ||
152 | - row = df.loc[df['filename'] == fname] | ||
153 | - targets_list.append(row.iloc[0, 1]) | ||
154 | - | ||
155 | - self.targets = targets_list | ||
156 | - | ||
157 | - def __len__(self): | ||
158 | - return self.len | ||
159 | - | ||
160 | - def __getitem__(self, idx): | ||
161 | - img_loc = os.path.join(self.path, self.imgs[idx]) | ||
162 | - targets = self.targets[idx] | ||
163 | - image = Image.open(img_loc) | ||
164 | - return image, targets | ||
165 | - | ||
166 | - | ||
167 | - | ||
168 | -def get_dataset(args, transform, split='train'): | ||
169 | - assert split in ['train', 'val', 'test', 'trainval'] | ||
170 | - | ||
171 | - if split in ['train']: | ||
172 | - dataset = CustomDataset(TRAIN_DATASET_PATH, TRAIN_TARGET_PATH, transform=transform) | ||
173 | - else: #test | ||
174 | - dataset = CustomDataset(VAL_DATASET_PATH, VAL_TARGET_PATH, transform=transform) | ||
175 | - | ||
176 | - return dataset | ||
177 | - | ||
178 | - | ||
179 | -def get_dataloader(args, dataset, shuffle=False, pin_memory=True): | ||
180 | - data_loader = torch.utils.data.DataLoader(dataset, | ||
181 | - batch_size=args.batch_size, | ||
182 | - shuffle=shuffle, | ||
183 | - num_workers=args.num_workers, | ||
184 | - pin_memory=pin_memory) | ||
185 | - return data_loader | ||
186 | - | ||
187 | - | ||
188 | -def get_aug_dataloader(args, dataset, shuffle=False, pin_memory=True): | ||
189 | - data_loader = torch.utils.data.DataLoader(dataset, | ||
190 | - batch_size=args.batch_size, | ||
191 | - shuffle=shuffle, | ||
192 | - num_workers=args.num_workers, | ||
193 | - pin_memory=pin_memory) | ||
194 | - return data_loader | ||
195 | - | ||
196 | - | ||
197 | -def get_inf_dataloader(args, dataset): | ||
198 | - global current_epoch | ||
199 | - data_loader = iter(get_dataloader(args, dataset, shuffle=True)) | ||
200 | - | ||
201 | - while True: | ||
202 | - try: | ||
203 | - batch = next(data_loader) | ||
204 | - | ||
205 | - except StopIteration: | ||
206 | - current_epoch += 1 | ||
207 | - data_loader = iter(get_dataloader(args, dataset, shuffle=True)) | ||
208 | - batch = next(data_loader) | ||
209 | - | ||
210 | - yield batch | ||
211 | - | ||
212 | - | ||
213 | - | ||
214 | - | ||
215 | -def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None): | ||
216 | - model.train() | ||
217 | - images, target = batch | ||
218 | - | ||
219 | - if device: | ||
220 | - images = images.to(device) | ||
221 | - target = target.to(device) | ||
222 | - | ||
223 | - elif args.use_cuda: | ||
224 | - images = images.cuda(non_blocking=True) | ||
225 | - target = target.cuda(non_blocking=True) | ||
226 | - | ||
227 | - # compute output | ||
228 | - start_t = time.time() | ||
229 | - output, first = model(images) | ||
230 | - forward_t = time.time() - start_t | ||
231 | - loss = criterion(output, target) | ||
232 | - | ||
233 | - # measure accuracy and record loss | ||
234 | - acc1, acc5 = accuracy(output, target, topk=(1, 5)) | ||
235 | - acc1 /= images.size(0) | ||
236 | - acc5 /= images.size(0) | ||
237 | - | ||
238 | - # compute gradient and do SGD step | ||
239 | - optimizer.zero_grad() | ||
240 | - start_t = time.time() | ||
241 | - loss.backward() | ||
242 | - backward_t = time.time() - start_t | ||
243 | - optimizer.step() | ||
244 | - if scheduler: scheduler.step() | ||
245 | - | ||
246 | - if writer and step % args.print_step == 0: | ||
247 | - n_imgs = min(images.size(0), 10) | ||
248 | - tag = 'train/' + str(step) | ||
249 | - for j in range(n_imgs): | ||
250 | - writer.add_image(tag, | ||
251 | - concat_image_features(images[j], first[j]), global_step=step) | ||
252 | - | ||
253 | - return acc1, acc5, loss, forward_t, backward_t | ||
254 | - | ||
255 | - | ||
256 | -#_acc1, _acc5 = accuracy(output, target, topk=(1, 5)) | ||
257 | -def accuracy(output, target, topk=(1,)): | ||
258 | - """Computes the accuracy over the k top predictions for the specified values of k""" | ||
259 | - with torch.no_grad(): | ||
260 | - maxk = max(topk) | ||
261 | - batch_size = target.size(0) | ||
262 | - | ||
263 | - _, pred = output.topk(maxk, 1, True, True) | ||
264 | - pred = pred.t() | ||
265 | - correct = pred.eq(target.view(1, -1).expand_as(pred)) | ||
266 | - | ||
267 | - res = [] | ||
268 | - for k in topk: | ||
269 | - correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | ||
270 | - res.append(correct_k) | ||
271 | - return res | ||
272 | - |
-
Please register or login to post a comment