Showing
1 changed file
with
67 additions
and
19 deletions
... | @@ -28,8 +28,6 @@ DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_fr | ... | @@ -28,8 +28,6 @@ DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_fr |
28 | TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/' | 28 | TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/' |
29 | VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame/' | 29 | VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame/' |
30 | 30 | ||
31 | -TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame.csv' | ||
32 | -VAL_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame.csv' | ||
33 | 31 | ||
34 | current_epoch = 0 | 32 | current_epoch = 0 |
35 | 33 | ||
... | @@ -65,14 +63,36 @@ def concat_image_features(image, features, max_features=3): | ... | @@ -65,14 +63,36 @@ def concat_image_features(image, features, max_features=3): |
65 | image_feature = image.clone() | 63 | image_feature = image.clone() |
66 | 64 | ||
67 | for i in range(max_features): | 65 | for i in range(max_features): |
66 | + # features torch.Size([64, 16, 16]) | ||
67 | + | ||
68 | feature = features[i:i+1] | 68 | feature = features[i:i+1] |
69 | + #torch.Size([1, 16, 16]) | ||
70 | + | ||
69 | _min, _max = torch.min(feature), torch.max(feature) | 71 | _min, _max = torch.min(feature), torch.max(feature) |
70 | feature = (feature - _min) / (_max - _min + 1e-6) | 72 | feature = (feature - _min) / (_max - _min + 1e-6) |
71 | - feature = torch.cat([feature]*3, 0) | 73 | + # torch.Size([1, 16, 16]) |
72 | - feature = feature.view(1, 3, feature.size(1), feature.size(2)) | 74 | + |
75 | + feature = torch.cat([feature]*1, 0) | ||
76 | + #feature = torch.cat([feature]*3, 0) | ||
77 | + # torch.Size([3, 16, 16]) -> [1, 16, 16] | ||
78 | + | ||
79 | + feature = feature.view(1, 1, feature.size(1), feature.size(2)) | ||
80 | + #feature = feature.view(1, 3, feature.size(1), feature.size(2)) | ||
81 | + # torch.Size([1, 3, 16, 16])-> [1, 1, 16, 16] | ||
82 | + | ||
73 | feature = F.upsample(feature, size=(h,w), mode="bilinear") | 83 | feature = F.upsample(feature, size=(h,w), mode="bilinear") |
74 | - feature = feature.view(3, h, w) | 84 | + # torch.Size([1, 3, 32, 32])-> [1, 1, 32, 32] |
75 | - image_feature = torch.cat((image_feature, feature), 2) | 85 | + |
86 | + feature = feature.view(1, h, w) #(3, h, w) input of size 3072 | ||
87 | + # torch.Size([3, 32, 32])->[1, 32, 32] | ||
88 | + | ||
89 | + print("img_feature & feature size:\n", image_feature.size(),"\n", feature.size()) | ||
90 | + # img_feature & feature size: | ||
91 | + # torch.Size([1, 32, 32]) -> [1, 32, 64] | ||
92 | + # torch.Size([3, 32, 32] ->[1, 32, 32] | ||
93 | + | ||
94 | + | ||
95 | + image_feature = torch.cat((image_feature, feature), 2) ### dim = 2 | ||
76 | 96 | ||
77 | return image_feature | 97 | return image_feature |
78 | 98 | ||
... | @@ -148,7 +168,7 @@ def select_model(args): | ... | @@ -148,7 +168,7 @@ def select_model(args): |
148 | Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net') | 168 | Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net') |
149 | model = Net(args) | 169 | model = Net(args) |
150 | 170 | ||
151 | - print(model) | 171 | + #print(model) # print model architecture |
152 | return model | 172 | return model |
153 | 173 | ||
154 | 174 | ||
... | @@ -197,10 +217,38 @@ class CustomDataset(Dataset): | ... | @@ -197,10 +217,38 @@ class CustomDataset(Dataset): |
197 | targets = self.targets[idx] | 217 | targets = self.targets[idx] |
198 | #img = self.img[idx] | 218 | #img = self.img[idx] |
199 | image = Image.open(img_loc) | 219 | image = Image.open(img_loc) |
220 | + #print("Image:\n", image) | ||
221 | + #print("type of img:\n", type(image)) #<class 'PIL.PngImagePlugin.PngImageFile'> | ||
222 | + #w, h = image.size | ||
223 | + #print(image.size) #(240, 240) | ||
224 | + #image = image.reshape(w, h) | ||
200 | 225 | ||
226 | + # image = np.array(image) * 255 | ||
227 | + # image = image.astype('uint8') | ||
228 | + # image = Image.fromarray(image, mode = 'L') | ||
229 | + | ||
230 | + | ||
201 | if self.transform is not None: | 231 | if self.transform is not None: |
202 | #img = self.transform(img) | 232 | #img = self.transform(img) |
203 | - tensor_image = self.transform(image) | 233 | + # print("\ngetitem image max:\n", np.amax(np.array(image)), np.array(image).shape) |
234 | + #image [0, 255] | ||
235 | + | ||
236 | + tensor_image = self.transform(image) ## | ||
237 | + | ||
238 | + """ | ||
239 | + range [0, 1] -> [0, 255] | ||
240 | + RuntimeError: Input type (torch.cuda.DoubleTensor) and weight type (torch.cuda.FloatTensor) should be the same | ||
241 | + # tensor_image = np.array(tensor_image) * 255 | ||
242 | + # tensor_image = tensor_image.astype('uint8') | ||
243 | + # tensor_image = np.reshape(tensor_image, (32, 32)) | ||
244 | + # tensor_image = Image.fromarray(tensor_image, mode = 'L') | ||
245 | + # tensor_image = np.reshape(tensor_image, (1, 32, 32)) | ||
246 | + # tensor_image = tensor_image.astype('float') | ||
247 | + """ | ||
248 | + | ||
249 | + #print("\ngetitem tensor_image max:\n", np.amax(np.array(tensor_image)), np.array(tensor_image).shape) | ||
250 | + # tensor_image range: [0, 1], shape: (1, 32, 32) | ||
251 | + | ||
204 | #return img, targets | 252 | #return img, targets |
205 | return tensor_image, targets | 253 | return tensor_image, targets |
206 | 254 | ||
... | @@ -273,7 +321,7 @@ def get_inf_dataloader(args, dataset): | ... | @@ -273,7 +321,7 @@ def get_inf_dataloader(args, dataset): |
273 | 321 | ||
274 | def get_train_transform(args, model, log_dir=None): | 322 | def get_train_transform(args, model, log_dir=None): |
275 | if args.fast_auto_augment: | 323 | if args.fast_auto_augment: |
276 | - assert args.dataset == 'BraTS' # TODO: FastAutoAugment for Imagenet | 324 | + #assert args.dataset == 'BraTS' # TODO: FastAutoAugment for Imagenet |
277 | 325 | ||
278 | from fast_auto_augment import fast_auto_augment | 326 | from fast_auto_augment import fast_auto_augment |
279 | if args.augment_path: | 327 | if args.augment_path: |
... | @@ -281,7 +329,7 @@ def get_train_transform(args, model, log_dir=None): | ... | @@ -281,7 +329,7 @@ def get_train_transform(args, model, log_dir=None): |
281 | os.system('cp {} {}'.format( | 329 | os.system('cp {} {}'.format( |
282 | args.augment_path, os.path.join(log_dir, 'augmentation.cp'))) | 330 | args.augment_path, os.path.join(log_dir, 'augmentation.cp'))) |
283 | else: | 331 | else: |
284 | - transform = fast_auto_augment(args, model, K=4, B=1, num_process=4) | 332 | + transform = fast_auto_augment(args, model, K=4, B=1, num_process=4) ## |
285 | if log_dir: | 333 | if log_dir: |
286 | cp.dump(transform, open(os.path.join(log_dir, 'augmentation.cp'), 'wb')) | 334 | cp.dump(transform, open(os.path.join(log_dir, 'augmentation.cp'), 'wb')) |
287 | 335 | ||
... | @@ -302,14 +350,14 @@ def get_train_transform(args, model, log_dir=None): | ... | @@ -302,14 +350,14 @@ def get_train_transform(args, model, log_dir=None): |
302 | transforms.ToTensor() | 350 | transforms.ToTensor() |
303 | ]) | 351 | ]) |
304 | 352 | ||
305 | - elif args.dataset == 'BraTS': | 353 | + # elif args.dataset == 'BraTS': |
306 | - resize_h, resize_w = 256, 256 | 354 | + # resize_h, resize_w = 256, 256 |
307 | - transform = transforms.Compose([ | 355 | + # transform = transforms.Compose([ |
308 | - transforms.Resize([resize_h, resize_w]), | 356 | + # transforms.Resize([resize_h, resize_w]), |
309 | - transforms.RandomCrop(model.img_size), | 357 | + # transforms.RandomCrop(model.img_size), |
310 | - transforms.RandomHorizontalFlip(), | 358 | + # transforms.RandomHorizontalFlip(), |
311 | - transforms.ToTensor() | 359 | + # transforms.ToTensor() |
312 | - ]) | 360 | + # ]) |
313 | else: | 361 | else: |
314 | raise Exception('Unknown Dataset') | 362 | raise Exception('Unknown Dataset') |
315 | 363 | ||
... | @@ -393,7 +441,7 @@ def validate(args, model, criterion, valid_loader, step, writer, device=None): | ... | @@ -393,7 +441,7 @@ def validate(args, model, criterion, valid_loader, step, writer, device=None): |
393 | infer_t = 0 | 441 | infer_t = 0 |
394 | 442 | ||
395 | with torch.no_grad(): | 443 | with torch.no_grad(): |
396 | - for i, (images, target) in enumerate(valid_loader): | 444 | + for i, (images, target) in enumerate(valid_loader): ## |
397 | 445 | ||
398 | start_t = time.time() | 446 | start_t = time.time() |
399 | if device: | 447 | if device: | ... | ... |
-
Please register or login to post a comment