Showing
1 changed file
with
67 additions
and
19 deletions
| ... | @@ -28,8 +28,6 @@ DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_fr | ... | @@ -28,8 +28,6 @@ DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_fr |
| 28 | TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/' | 28 | TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/' |
| 29 | VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame/' | 29 | VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame/' |
| 30 | 30 | ||
| 31 | -TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame.csv' | ||
| 32 | -VAL_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame.csv' | ||
| 33 | 31 | ||
| 34 | current_epoch = 0 | 32 | current_epoch = 0 |
| 35 | 33 | ||
| ... | @@ -65,14 +63,36 @@ def concat_image_features(image, features, max_features=3): | ... | @@ -65,14 +63,36 @@ def concat_image_features(image, features, max_features=3): |
| 65 | image_feature = image.clone() | 63 | image_feature = image.clone() |
| 66 | 64 | ||
| 67 | for i in range(max_features): | 65 | for i in range(max_features): |
| 66 | + # features torch.Size([64, 16, 16]) | ||
| 67 | + | ||
| 68 | feature = features[i:i+1] | 68 | feature = features[i:i+1] |
| 69 | + #torch.Size([1, 16, 16]) | ||
| 70 | + | ||
| 69 | _min, _max = torch.min(feature), torch.max(feature) | 71 | _min, _max = torch.min(feature), torch.max(feature) |
| 70 | feature = (feature - _min) / (_max - _min + 1e-6) | 72 | feature = (feature - _min) / (_max - _min + 1e-6) |
| 71 | - feature = torch.cat([feature]*3, 0) | 73 | + # torch.Size([1, 16, 16]) |
| 72 | - feature = feature.view(1, 3, feature.size(1), feature.size(2)) | 74 | + |
| 75 | + feature = torch.cat([feature]*1, 0) | ||
| 76 | + #feature = torch.cat([feature]*3, 0) | ||
| 77 | + # torch.Size([3, 16, 16]) -> [1, 16, 16] | ||
| 78 | + | ||
| 79 | + feature = feature.view(1, 1, feature.size(1), feature.size(2)) | ||
| 80 | + #feature = feature.view(1, 3, feature.size(1), feature.size(2)) | ||
| 81 | + # torch.Size([1, 3, 16, 16])-> [1, 1, 16, 16] | ||
| 82 | + | ||
| 73 | feature = F.upsample(feature, size=(h,w), mode="bilinear") | 83 | feature = F.upsample(feature, size=(h,w), mode="bilinear") |
| 74 | - feature = feature.view(3, h, w) | 84 | + # torch.Size([1, 3, 32, 32])-> [1, 1, 32, 32] |
| 75 | - image_feature = torch.cat((image_feature, feature), 2) | 85 | + |
| 86 | + feature = feature.view(1, h, w) #(3, h, w) input of size 3072 | ||
| 87 | + # torch.Size([3, 32, 32])->[1, 32, 32] | ||
| 88 | + | ||
| 89 | + print("img_feature & feature size:\n", image_feature.size(),"\n", feature.size()) | ||
| 90 | + # img_feature & feature size: | ||
| 91 | + # torch.Size([1, 32, 32]) -> [1, 32, 64] | ||
| 92 | + # torch.Size([3, 32, 32] ->[1, 32, 32] | ||
| 93 | + | ||
| 94 | + | ||
| 95 | + image_feature = torch.cat((image_feature, feature), 2) ### dim = 2 | ||
| 76 | 96 | ||
| 77 | return image_feature | 97 | return image_feature |
| 78 | 98 | ||
| ... | @@ -148,7 +168,7 @@ def select_model(args): | ... | @@ -148,7 +168,7 @@ def select_model(args): |
| 148 | Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net') | 168 | Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net') |
| 149 | model = Net(args) | 169 | model = Net(args) |
| 150 | 170 | ||
| 151 | - print(model) | 171 | + #print(model) # print model architecture |
| 152 | return model | 172 | return model |
| 153 | 173 | ||
| 154 | 174 | ||
| ... | @@ -197,10 +217,38 @@ class CustomDataset(Dataset): | ... | @@ -197,10 +217,38 @@ class CustomDataset(Dataset): |
| 197 | targets = self.targets[idx] | 217 | targets = self.targets[idx] |
| 198 | #img = self.img[idx] | 218 | #img = self.img[idx] |
| 199 | image = Image.open(img_loc) | 219 | image = Image.open(img_loc) |
| 220 | + #print("Image:\n", image) | ||
| 221 | + #print("type of img:\n", type(image)) #<class 'PIL.PngImagePlugin.PngImageFile'> | ||
| 222 | + #w, h = image.size | ||
| 223 | + #print(image.size) #(240, 240) | ||
| 224 | + #image = image.reshape(w, h) | ||
| 200 | 225 | ||
| 226 | + # image = np.array(image) * 255 | ||
| 227 | + # image = image.astype('uint8') | ||
| 228 | + # image = Image.fromarray(image, mode = 'L') | ||
| 229 | + | ||
| 230 | + | ||
| 201 | if self.transform is not None: | 231 | if self.transform is not None: |
| 202 | #img = self.transform(img) | 232 | #img = self.transform(img) |
| 203 | - tensor_image = self.transform(image) | 233 | + # print("\ngetitem image max:\n", np.amax(np.array(image)), np.array(image).shape) |
| 234 | + #image [0, 255] | ||
| 235 | + | ||
| 236 | + tensor_image = self.transform(image) ## | ||
| 237 | + | ||
| 238 | + """ | ||
| 239 | + range [0, 1] -> [0, 255] | ||
| 240 | + RuntimeError: Input type (torch.cuda.DoubleTensor) and weight type (torch.cuda.FloatTensor) should be the same | ||
| 241 | + # tensor_image = np.array(tensor_image) * 255 | ||
| 242 | + # tensor_image = tensor_image.astype('uint8') | ||
| 243 | + # tensor_image = np.reshape(tensor_image, (32, 32)) | ||
| 244 | + # tensor_image = Image.fromarray(tensor_image, mode = 'L') | ||
| 245 | + # tensor_image = np.reshape(tensor_image, (1, 32, 32)) | ||
| 246 | + # tensor_image = tensor_image.astype('float') | ||
| 247 | + """ | ||
| 248 | + | ||
| 249 | + #print("\ngetitem tensor_image max:\n", np.amax(np.array(tensor_image)), np.array(tensor_image).shape) | ||
| 250 | + # tensor_image range: [0, 1], shape: (1, 32, 32) | ||
| 251 | + | ||
| 204 | #return img, targets | 252 | #return img, targets |
| 205 | return tensor_image, targets | 253 | return tensor_image, targets |
| 206 | 254 | ||
| ... | @@ -273,7 +321,7 @@ def get_inf_dataloader(args, dataset): | ... | @@ -273,7 +321,7 @@ def get_inf_dataloader(args, dataset): |
| 273 | 321 | ||
| 274 | def get_train_transform(args, model, log_dir=None): | 322 | def get_train_transform(args, model, log_dir=None): |
| 275 | if args.fast_auto_augment: | 323 | if args.fast_auto_augment: |
| 276 | - assert args.dataset == 'BraTS' # TODO: FastAutoAugment for Imagenet | 324 | + #assert args.dataset == 'BraTS' # TODO: FastAutoAugment for Imagenet |
| 277 | 325 | ||
| 278 | from fast_auto_augment import fast_auto_augment | 326 | from fast_auto_augment import fast_auto_augment |
| 279 | if args.augment_path: | 327 | if args.augment_path: |
| ... | @@ -281,7 +329,7 @@ def get_train_transform(args, model, log_dir=None): | ... | @@ -281,7 +329,7 @@ def get_train_transform(args, model, log_dir=None): |
| 281 | os.system('cp {} {}'.format( | 329 | os.system('cp {} {}'.format( |
| 282 | args.augment_path, os.path.join(log_dir, 'augmentation.cp'))) | 330 | args.augment_path, os.path.join(log_dir, 'augmentation.cp'))) |
| 283 | else: | 331 | else: |
| 284 | - transform = fast_auto_augment(args, model, K=4, B=1, num_process=4) | 332 | + transform = fast_auto_augment(args, model, K=4, B=1, num_process=4) ## |
| 285 | if log_dir: | 333 | if log_dir: |
| 286 | cp.dump(transform, open(os.path.join(log_dir, 'augmentation.cp'), 'wb')) | 334 | cp.dump(transform, open(os.path.join(log_dir, 'augmentation.cp'), 'wb')) |
| 287 | 335 | ||
| ... | @@ -302,14 +350,14 @@ def get_train_transform(args, model, log_dir=None): | ... | @@ -302,14 +350,14 @@ def get_train_transform(args, model, log_dir=None): |
| 302 | transforms.ToTensor() | 350 | transforms.ToTensor() |
| 303 | ]) | 351 | ]) |
| 304 | 352 | ||
| 305 | - elif args.dataset == 'BraTS': | 353 | + # elif args.dataset == 'BraTS': |
| 306 | - resize_h, resize_w = 256, 256 | 354 | + # resize_h, resize_w = 256, 256 |
| 307 | - transform = transforms.Compose([ | 355 | + # transform = transforms.Compose([ |
| 308 | - transforms.Resize([resize_h, resize_w]), | 356 | + # transforms.Resize([resize_h, resize_w]), |
| 309 | - transforms.RandomCrop(model.img_size), | 357 | + # transforms.RandomCrop(model.img_size), |
| 310 | - transforms.RandomHorizontalFlip(), | 358 | + # transforms.RandomHorizontalFlip(), |
| 311 | - transforms.ToTensor() | 359 | + # transforms.ToTensor() |
| 312 | - ]) | 360 | + # ]) |
| 313 | else: | 361 | else: |
| 314 | raise Exception('Unknown Dataset') | 362 | raise Exception('Unknown Dataset') |
| 315 | 363 | ||
| ... | @@ -393,7 +441,7 @@ def validate(args, model, criterion, valid_loader, step, writer, device=None): | ... | @@ -393,7 +441,7 @@ def validate(args, model, criterion, valid_loader, step, writer, device=None): |
| 393 | infer_t = 0 | 441 | infer_t = 0 |
| 394 | 442 | ||
| 395 | with torch.no_grad(): | 443 | with torch.no_grad(): |
| 396 | - for i, (images, target) in enumerate(valid_loader): | 444 | + for i, (images, target) in enumerate(valid_loader): ## |
| 397 | 445 | ||
| 398 | start_t = time.time() | 446 | start_t = time.time() |
| 399 | if device: | 447 | if device: | ... | ... |
-
Please register or login to post a comment