조현아

get 1 random policy

1 +import os
2 +import fire
3 +import json
4 +from pprint import pprint
5 +import pickle
6 +
7 +import torch
8 +import torch.nn as nn
9 +from torch.utils.tensorboard import SummaryWriter
10 +
11 +from utils import *
12 +
13 +# command
14 +# python getAugmented.py --model_path='logs/April_24_21:05:15__resnet50__None/'
15 +
16 +def eval(model_path):
17 + print('\n[+] Parse arguments')
18 + kwargs_path = os.path.join(model_path, 'kwargs.json')
19 + kwargs = json.loads(open(kwargs_path).read())
20 + args, kwargs = parse_args(kwargs)
21 + pprint(args)
22 + device = torch.device('cuda' if args.use_cuda else 'cpu')
23 +
24 + cp_path = os.path.join(model_path, 'augmentation.cp')
25 +
26 + writer = SummaryWriter(log_dir=model_path)
27 +
28 +
29 + print('\n[+] Load transform')
30 + # list
31 + with open(cp_path, 'rb') as f:
32 + aug_transform_list = pickle.load(f)
33 +
34 + augmented_image_list = [torch.Tensor(240,0)] * len(get_dataset(args, None, 'test'))
35 +
36 +
37 + print('\n[+] Load dataset')
38 + for aug_idx, aug_transform in enumerate(aug_transform_list):
39 + dataset = get_dataset(args, aug_transform, 'test')
40 +
41 + loader = iter(get_aug_dataloader(args, dataset))
42 +
43 + for i, (images, target) in enumerate(loader):
44 + images = images.view(240, 240)
45 +
46 + # concat image
47 + augmented_image_list[i] = torch.cat([augmented_image_list[i], images], dim = 1)
48 +
49 + if i % 1000 == 0:
50 + print("\n images size: ", augmented_image_list[i].size()) # [240, 240]
51 +
52 + break
53 + # break
54 +
55 +
56 + # print(augmented_image_list)
57 +
58 +
59 + print('\n[+] Write on tensorboard')
60 + if writer:
61 + for i, data in enumerate(augmented_image_list):
62 + tag = 'img/' + str(i)
63 + writer.add_image(tag, data.view(1, 240, -1), global_step=0)
64 + break
65 +
66 + writer.close()
67 +
68 +
69 + # if writer:
70 + # for j in range():
71 + # tag = 'img/' + str(img_count) + '_' + str(j)
72 + # # writer.add_image(tag,
73 + # # concat_image_features(images[j], first[j]), global_step=step)
74 + # # if j > 0:
75 + # # fore = concat_image_features(fore, images[j])
76 +
77 + # writer.add_image(tag, fore, global_step=0)
78 + # img_count = img_count + 1
79 +
80 + # writer.close()
81 +
82 +if __name__ == '__main__':
83 + fire.Fire(eval)
1 +import os
2 +import fire
3 +import json
4 +from pprint import pprint
5 +import pickle
6 +
7 +import torch
8 +import torch.nn as nn
9 +from torch.utils.tensorboard import SummaryWriter
10 +
11 +from utils import *
12 +
13 +# command
14 +# python getAugmented.py --model_path='logs/April_24_21:05:15__resnet50__None/'
15 +
16 +def eval(model_path):
17 + print('\n[+] Parse arguments')
18 + kwargs_path = os.path.join(model_path, 'kwargs.json')
19 + kwargs = json.loads(open(kwargs_path).read())
20 + args, kwargs = parse_args(kwargs)
21 + pprint(args)
22 + device = torch.device('cuda' if args.use_cuda else 'cpu')
23 +
24 + cp_path = os.path.join(model_path, 'augmentation.cp')
25 +
26 + writer = SummaryWriter(log_dir=model_path)
27 +
28 +
29 + print('\n[+] Load transform')
30 + # list
31 + with open(cp_path, 'rb') as f:
32 + aug_transform_list = pickle.load(f)
33 +
34 + augmented_image_list = [torch.Tensor(240,0)] * len(get_dataset(args, None, 'train'))
35 +
36 +
37 + print('\n[+] Load dataset')
38 + for aug_idx, aug_transform in enumerate(aug_transform_list):
39 + dataset = get_dataset(args, aug_transform, 'train')
40 +
41 + loader = iter(get_aug_dataloader(args, dataset))
42 +
43 + for i, (images, target) in enumerate(loader):
44 + images = images.view(240, 240)
45 +
46 + # concat image
47 + augmented_image_list[i] = torch.cat([augmented_image_list[i], images], dim = 1)
48 +
49 +
50 +
51 +
52 + print('\n[+] Write on tensorboard')
53 + if writer:
54 + for i, data in enumerate(augmented_image_list):
55 + tag = 'img/' + str(i)
56 + writer.add_image(tag, data.view(1, 240, -1), global_step=0)
57 +
58 + writer.close()
59 +
60 +
61 +if __name__ == '__main__':
62 + fire.Fire(eval)
1 +import os
2 +import fire
3 +import json
4 +from pprint import pprint
5 +import pickle
6 +import random
7 +
8 +import torch
9 +import torch.nn as nn
10 +from torchvision.utils import save_image
11 +from torch.utils.tensorboard import SummaryWriter
12 +
13 +from utils import *
14 +
15 +# command
16 +# python getAugmented_saveimg.py --model_path='logs/April_26_00:55:16__resnet50__None/'
17 +
18 +def eval(model_path):
19 + print('\n[+] Parse arguments')
20 + kwargs_path = os.path.join(model_path, 'kwargs.json')
21 + kwargs = json.loads(open(kwargs_path).read())
22 + args, kwargs = parse_args(kwargs)
23 + pprint(args)
24 + device = torch.device('cuda' if args.use_cuda else 'cpu')
25 +
26 + cp_path = os.path.join(model_path, 'augmentation.cp')
27 +
28 + writer = SummaryWriter(log_dir=model_path)
29 +
30 +
31 + print('\n[+] Load transform')
32 + # list to tensor
33 + with open(cp_path, 'rb') as f:
34 + aug_transform_list = pickle.load(f)
35 +
36 + transform = transforms.RandomChoice(aug_transform_list)
37 +
38 +
39 + print('\n[+] Load dataset')
40 +
41 + dataset = get_dataset(args, transform, 'train')
42 + loader = iter(get_aug_dataloader(args, dataset))
43 +
44 +
45 + print('\n[+] Save 1 random policy')
46 + os.makedirs(os.path.join(model_path, 'augmented_imgs'))
47 + save_dir = os.path.join(model_path, 'augmented_imgs')
48 +
49 + for i, (image, target) in enumerate(loader):
50 + image = image.view(240, 240)
51 + # save img
52 + save_image(image, os.path.join(save_dir, 'aug_'+ str(i) + '.png'))
53 +
54 + if(i % 100 == 0):
55 + print("\n saved images: ", i)
56 +
57 + print('\n[+] Finished to save')
58 +
59 +if __name__ == '__main__':
60 + fire.Fire(eval)
61 +
62 +
63 +
This diff is collapsed. Click to expand it.
1 -import os
2 -import time
3 -import importlib
4 -import collections
5 -import pickle as cp
6 -import glob
7 -import numpy as np
8 -import pandas as pd
9 -
10 -from natsort import natsorted
11 -from PIL import Image
12 -import torch
13 -import torchvision
14 -import torch.nn.functional as F
15 -import torchvision.models as models
16 -import torchvision.transforms as transforms
17 -from torch.utils.data import Subset
18 -from torch.utils.data import Dataset, DataLoader
19 -
20 -from sklearn.model_selection import StratifiedShuffleSplit
21 -from sklearn.model_selection import train_test_split
22 -from sklearn.model_selection import KFold
23 -
24 -from networks import *
25 -
26 -
27 -TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/nonaug+Normal_train/'
28 -TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/train_nonaug_classify_target.csv'
29 -# VAL_DATASET_PATH = '../../data/MICCAI_BraTS_2019_Data_Training/ce_valid/'
30 -# VAL_TARGET_PATH = '../../data/MICCAI_BraTS_2019_Data_Training/ce_valid_targets.csv'
31 -
32 -current_epoch = 0
33 -
34 -
35 -def split_dataset(args, dataset, k):
36 - # load dataset
37 - X = list(range(len(dataset)))
38 - Y = dataset.targets
39 -
40 - # split to k-fold
41 - assert len(X) == len(Y)
42 -
43 - def _it_to_list(_it):
44 - return list(zip(*list(_it)))
45 -
46 - sss = StratifiedShuffleSplit(n_splits=k, random_state=args.seed, test_size=0.1)
47 - Dm_indexes, Da_indexes = _it_to_list(sss.split(X, Y))
48 -
49 - return Dm_indexes, Da_indexes
50 -
51 -
52 -
53 -def get_model_name(args):
54 - from datetime import datetime, timedelta, timezone
55 - now = datetime.now(timezone.utc)
56 - tz = timezone(timedelta(hours=9))
57 - now = now.astimezone(tz)
58 - date_time = now.strftime("%B_%d_%H:%M:%S")
59 - model_name = '__'.join([date_time, args.network, str(args.seed)])
60 - return model_name
61 -
62 -
63 -def dict_to_namedtuple(d):
64 - Args = collections.namedtuple('Args', sorted(d.keys()))
65 -
66 - for k,v in d.items():
67 - if type(v) is dict:
68 - d[k] = dict_to_namedtuple(v)
69 -
70 - elif type(v) is str:
71 - try:
72 - d[k] = eval(v)
73 - except:
74 - d[k] = v
75 -
76 - args = Args(**d)
77 - return args
78 -
79 -
80 -def parse_args(kwargs):
81 - # combine with default args
82 - kwargs['dataset'] = kwargs['dataset'] if 'dataset' in kwargs else 'BraTS'
83 - kwargs['network'] = kwargs['network'] if 'network' in kwargs else 'resnet50'
84 - kwargs['optimizer'] = kwargs['optimizer'] if 'optimizer' in kwargs else 'adam'
85 - kwargs['learning_rate'] = kwargs['learning_rate'] if 'learning_rate' in kwargs else 0.0001
86 - kwargs['seed'] = kwargs['seed'] if 'seed' in kwargs else None
87 - kwargs['use_cuda'] = kwargs['use_cuda'] if 'use_cuda' in kwargs else True
88 - kwargs['use_cuda'] = kwargs['use_cuda'] and torch.cuda.is_available()
89 - kwargs['num_workers'] = kwargs['num_workers'] if 'num_workers' in kwargs else 4
90 - kwargs['print_step'] = kwargs['print_step'] if 'print_step' in kwargs else 500
91 - kwargs['val_step'] = kwargs['val_step'] if 'val_step' in kwargs else 500
92 - kwargs['scheduler'] = kwargs['scheduler'] if 'scheduler' in kwargs else 'exp'
93 - kwargs['batch_size'] = kwargs['batch_size'] if 'batch_size' in kwargs else 128
94 - kwargs['start_step'] = kwargs['start_step'] if 'start_step' in kwargs else 0
95 - kwargs['max_step'] = kwargs['max_step'] if 'max_step' in kwargs else 5000
96 - kwargs['augment_path'] = kwargs['augment_path'] if 'augment_path' in kwargs else None
97 -
98 - # to named tuple
99 - args = dict_to_namedtuple(kwargs)
100 - return args, kwargs
101 -
102 -
103 -def select_model(args):
104 - # grayResNet2
105 - resnet_dict = {'resnet18':grayResNet2.resnet18(), 'resnet34':grayResNet2.resnet34(),
106 - 'resnet50':grayResNet2.resnet50(), 'resnet101':grayResNet2.resnet101(), 'resnet152':grayResNet2.resnet152()}
107 -
108 - if args.network in resnet_dict:
109 - backbone = resnet_dict[args.network]
110 - model = basenet.BaseNet(backbone, args)
111 - else:
112 - Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net')
113 - model = Net(args)
114 -
115 - #print(model) # print model architecture
116 - return model
117 -
118 -
119 -def select_optimizer(args, model):
120 - if args.optimizer == 'sgd':
121 - optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=0.0001)
122 - elif args.optimizer == 'rms':
123 - optimizer = torch.optim.RMSprop(model.parameters(), lr=args.learning_rate)
124 - elif args.optimizer == 'adam':
125 - optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
126 - else:
127 - raise Exception('Unknown Optimizer')
128 - return optimizer
129 -
130 -
131 -def select_scheduler(args, optimizer):
132 - if not args.scheduler or args.scheduler == 'None':
133 - return None
134 - elif args.scheduler =='clr':
135 - return torch.optim.lr_scheduler.CyclicLR(optimizer, 0.01, 0.015, mode='triangular2', step_size_up=250000, cycle_momentum=False)
136 - elif args.scheduler =='exp':
137 - return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9999283, last_epoch=-1)
138 - else:
139 - raise Exception('Unknown Scheduler')
140 -
141 -
142 -class CustomDataset(Dataset):
143 - def __init__(self, data_path, csv_path):
144 - self.len = len(self.imgs)
145 - self.path = data_path
146 - self.imgs = natsorted(os.listdir(data_path))
147 -
148 - df = pd.read_csv(csv_path)
149 - targets_list = []
150 -
151 - for fname in self.imgs:
152 - row = df.loc[df['filename'] == fname]
153 - targets_list.append(row.iloc[0, 1])
154 -
155 - self.targets = targets_list
156 -
157 - def __len__(self):
158 - return self.len
159 -
160 - def __getitem__(self, idx):
161 - img_loc = os.path.join(self.path, self.imgs[idx])
162 - targets = self.targets[idx]
163 - image = Image.open(img_loc)
164 - return image, targets
165 -
166 -
167 -
168 -def get_dataset(args, transform, split='train'):
169 - assert split in ['train', 'val', 'test', 'trainval']
170 -
171 - if split in ['train']:
172 - dataset = CustomDataset(TRAIN_DATASET_PATH, TRAIN_TARGET_PATH, transform=transform)
173 - else: #test
174 - dataset = CustomDataset(VAL_DATASET_PATH, VAL_TARGET_PATH, transform=transform)
175 -
176 - return dataset
177 -
178 -
179 -def get_dataloader(args, dataset, shuffle=False, pin_memory=True):
180 - data_loader = torch.utils.data.DataLoader(dataset,
181 - batch_size=args.batch_size,
182 - shuffle=shuffle,
183 - num_workers=args.num_workers,
184 - pin_memory=pin_memory)
185 - return data_loader
186 -
187 -
188 -def get_aug_dataloader(args, dataset, shuffle=False, pin_memory=True):
189 - data_loader = torch.utils.data.DataLoader(dataset,
190 - batch_size=args.batch_size,
191 - shuffle=shuffle,
192 - num_workers=args.num_workers,
193 - pin_memory=pin_memory)
194 - return data_loader
195 -
196 -
197 -def get_inf_dataloader(args, dataset):
198 - global current_epoch
199 - data_loader = iter(get_dataloader(args, dataset, shuffle=True))
200 -
201 - while True:
202 - try:
203 - batch = next(data_loader)
204 -
205 - except StopIteration:
206 - current_epoch += 1
207 - data_loader = iter(get_dataloader(args, dataset, shuffle=True))
208 - batch = next(data_loader)
209 -
210 - yield batch
211 -
212 -
213 -
214 -
215 -def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None):
216 - model.train()
217 - images, target = batch
218 -
219 - if device:
220 - images = images.to(device)
221 - target = target.to(device)
222 -
223 - elif args.use_cuda:
224 - images = images.cuda(non_blocking=True)
225 - target = target.cuda(non_blocking=True)
226 -
227 - # compute output
228 - start_t = time.time()
229 - output, first = model(images)
230 - forward_t = time.time() - start_t
231 - loss = criterion(output, target)
232 -
233 - # measure accuracy and record loss
234 - acc1, acc5 = accuracy(output, target, topk=(1, 5))
235 - acc1 /= images.size(0)
236 - acc5 /= images.size(0)
237 -
238 - # compute gradient and do SGD step
239 - optimizer.zero_grad()
240 - start_t = time.time()
241 - loss.backward()
242 - backward_t = time.time() - start_t
243 - optimizer.step()
244 - if scheduler: scheduler.step()
245 -
246 - if writer and step % args.print_step == 0:
247 - n_imgs = min(images.size(0), 10)
248 - tag = 'train/' + str(step)
249 - for j in range(n_imgs):
250 - writer.add_image(tag,
251 - concat_image_features(images[j], first[j]), global_step=step)
252 -
253 - return acc1, acc5, loss, forward_t, backward_t
254 -
255 -
256 -#_acc1, _acc5 = accuracy(output, target, topk=(1, 5))
257 -def accuracy(output, target, topk=(1,)):
258 - """Computes the accuracy over the k top predictions for the specified values of k"""
259 - with torch.no_grad():
260 - maxk = max(topk)
261 - batch_size = target.size(0)
262 -
263 - _, pred = output.topk(maxk, 1, True, True)
264 - pred = pred.t()
265 - correct = pred.eq(target.view(1, -1).expand_as(pred))
266 -
267 - res = []
268 - for k in topk:
269 - correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
270 - res.append(correct_k)
271 - return res
272 -