조현아

resolved concat size err

...@@ -28,8 +28,6 @@ DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_fr ...@@ -28,8 +28,6 @@ DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_fr
28 TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/' 28 TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/'
29 VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame/' 29 VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame/'
30 30
31 -TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame.csv'
32 -VAL_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame.csv'
33 31
34 current_epoch = 0 32 current_epoch = 0
35 33
...@@ -65,14 +63,36 @@ def concat_image_features(image, features, max_features=3): ...@@ -65,14 +63,36 @@ def concat_image_features(image, features, max_features=3):
65 image_feature = image.clone() 63 image_feature = image.clone()
66 64
67 for i in range(max_features): 65 for i in range(max_features):
66 + # features torch.Size([64, 16, 16])
67 +
68 feature = features[i:i+1] 68 feature = features[i:i+1]
69 + #torch.Size([1, 16, 16])
70 +
69 _min, _max = torch.min(feature), torch.max(feature) 71 _min, _max = torch.min(feature), torch.max(feature)
70 feature = (feature - _min) / (_max - _min + 1e-6) 72 feature = (feature - _min) / (_max - _min + 1e-6)
71 - feature = torch.cat([feature]*3, 0) 73 + # torch.Size([1, 16, 16])
72 - feature = feature.view(1, 3, feature.size(1), feature.size(2)) 74 +
75 + feature = torch.cat([feature]*1, 0)
76 + #feature = torch.cat([feature]*3, 0)
77 + # torch.Size([3, 16, 16]) -> [1, 16, 16]
78 +
79 + feature = feature.view(1, 1, feature.size(1), feature.size(2))
80 + #feature = feature.view(1, 3, feature.size(1), feature.size(2))
81 + # torch.Size([1, 3, 16, 16])-> [1, 1, 16, 16]
82 +
73 feature = F.upsample(feature, size=(h,w), mode="bilinear") 83 feature = F.upsample(feature, size=(h,w), mode="bilinear")
74 - feature = feature.view(3, h, w) 84 + # torch.Size([1, 3, 32, 32])-> [1, 1, 32, 32]
75 - image_feature = torch.cat((image_feature, feature), 2) 85 +
86 + feature = feature.view(1, h, w) #(3, h, w) input of size 3072
87 + # torch.Size([3, 32, 32])->[1, 32, 32]
88 +
89 + print("img_feature & feature size:\n", image_feature.size(),"\n", feature.size())
90 + # img_feature & feature size:
91 + # torch.Size([1, 32, 32]) -> [1, 32, 64]
92 + # torch.Size([3, 32, 32] ->[1, 32, 32]
93 +
94 +
95 + image_feature = torch.cat((image_feature, feature), 2) ### dim = 2
76 96
77 return image_feature 97 return image_feature
78 98
...@@ -148,7 +168,7 @@ def select_model(args): ...@@ -148,7 +168,7 @@ def select_model(args):
148 Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net') 168 Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net')
149 model = Net(args) 169 model = Net(args)
150 170
151 - print(model) 171 + #print(model) # print model architecture
152 return model 172 return model
153 173
154 174
...@@ -197,10 +217,38 @@ class CustomDataset(Dataset): ...@@ -197,10 +217,38 @@ class CustomDataset(Dataset):
197 targets = self.targets[idx] 217 targets = self.targets[idx]
198 #img = self.img[idx] 218 #img = self.img[idx]
199 image = Image.open(img_loc) 219 image = Image.open(img_loc)
220 + #print("Image:\n", image)
221 + #print("type of img:\n", type(image)) #<class 'PIL.PngImagePlugin.PngImageFile'>
222 + #w, h = image.size
223 + #print(image.size) #(240, 240)
224 + #image = image.reshape(w, h)
200 225
226 + # image = np.array(image) * 255
227 + # image = image.astype('uint8')
228 + # image = Image.fromarray(image, mode = 'L')
229 +
230 +
201 if self.transform is not None: 231 if self.transform is not None:
202 #img = self.transform(img) 232 #img = self.transform(img)
203 - tensor_image = self.transform(image) 233 + # print("\ngetitem image max:\n", np.amax(np.array(image)), np.array(image).shape)
234 + #image [0, 255]
235 +
236 + tensor_image = self.transform(image) ##
237 +
238 + """
239 + range [0, 1] -> [0, 255]
240 + RuntimeError: Input type (torch.cuda.DoubleTensor) and weight type (torch.cuda.FloatTensor) should be the same
241 + # tensor_image = np.array(tensor_image) * 255
242 + # tensor_image = tensor_image.astype('uint8')
243 + # tensor_image = np.reshape(tensor_image, (32, 32))
244 + # tensor_image = Image.fromarray(tensor_image, mode = 'L')
245 + # tensor_image = np.reshape(tensor_image, (1, 32, 32))
246 + # tensor_image = tensor_image.astype('float')
247 + """
248 +
249 + #print("\ngetitem tensor_image max:\n", np.amax(np.array(tensor_image)), np.array(tensor_image).shape)
250 + # tensor_image range: [0, 1], shape: (1, 32, 32)
251 +
204 #return img, targets 252 #return img, targets
205 return tensor_image, targets 253 return tensor_image, targets
206 254
...@@ -273,7 +321,7 @@ def get_inf_dataloader(args, dataset): ...@@ -273,7 +321,7 @@ def get_inf_dataloader(args, dataset):
273 321
274 def get_train_transform(args, model, log_dir=None): 322 def get_train_transform(args, model, log_dir=None):
275 if args.fast_auto_augment: 323 if args.fast_auto_augment:
276 - assert args.dataset == 'BraTS' # TODO: FastAutoAugment for Imagenet 324 + #assert args.dataset == 'BraTS' # TODO: FastAutoAugment for Imagenet
277 325
278 from fast_auto_augment import fast_auto_augment 326 from fast_auto_augment import fast_auto_augment
279 if args.augment_path: 327 if args.augment_path:
...@@ -281,7 +329,7 @@ def get_train_transform(args, model, log_dir=None): ...@@ -281,7 +329,7 @@ def get_train_transform(args, model, log_dir=None):
281 os.system('cp {} {}'.format( 329 os.system('cp {} {}'.format(
282 args.augment_path, os.path.join(log_dir, 'augmentation.cp'))) 330 args.augment_path, os.path.join(log_dir, 'augmentation.cp')))
283 else: 331 else:
284 - transform = fast_auto_augment(args, model, K=4, B=1, num_process=4) 332 + transform = fast_auto_augment(args, model, K=4, B=1, num_process=4) ##
285 if log_dir: 333 if log_dir:
286 cp.dump(transform, open(os.path.join(log_dir, 'augmentation.cp'), 'wb')) 334 cp.dump(transform, open(os.path.join(log_dir, 'augmentation.cp'), 'wb'))
287 335
...@@ -302,14 +350,14 @@ def get_train_transform(args, model, log_dir=None): ...@@ -302,14 +350,14 @@ def get_train_transform(args, model, log_dir=None):
302 transforms.ToTensor() 350 transforms.ToTensor()
303 ]) 351 ])
304 352
305 - elif args.dataset == 'BraTS': 353 + # elif args.dataset == 'BraTS':
306 - resize_h, resize_w = 256, 256 354 + # resize_h, resize_w = 256, 256
307 - transform = transforms.Compose([ 355 + # transform = transforms.Compose([
308 - transforms.Resize([resize_h, resize_w]), 356 + # transforms.Resize([resize_h, resize_w]),
309 - transforms.RandomCrop(model.img_size), 357 + # transforms.RandomCrop(model.img_size),
310 - transforms.RandomHorizontalFlip(), 358 + # transforms.RandomHorizontalFlip(),
311 - transforms.ToTensor() 359 + # transforms.ToTensor()
312 - ]) 360 + # ])
313 else: 361 else:
314 raise Exception('Unknown Dataset') 362 raise Exception('Unknown Dataset')
315 363
...@@ -393,7 +441,7 @@ def validate(args, model, criterion, valid_loader, step, writer, device=None): ...@@ -393,7 +441,7 @@ def validate(args, model, criterion, valid_loader, step, writer, device=None):
393 infer_t = 0 441 infer_t = 0
394 442
395 with torch.no_grad(): 443 with torch.no_grad():
396 - for i, (images, target) in enumerate(valid_loader): 444 + for i, (images, target) in enumerate(valid_loader): ##
397 445
398 start_t = time.time() 446 start_t = time.time()
399 if device: 447 if device:
......