Showing
2 changed files
with
28 additions
and
19 deletions
| ... | @@ -33,11 +33,12 @@ def eval(model_path): | ... | @@ -33,11 +33,12 @@ def eval(model_path): |
| 33 | print('\n[+] Load dataset') | 33 | print('\n[+] Load dataset') |
| 34 | test_transform = get_valid_transform(args, model) | 34 | test_transform = get_valid_transform(args, model) |
| 35 | test_dataset = get_dataset(args, test_transform, 'test') | 35 | test_dataset = get_dataset(args, test_transform, 'test') |
| 36 | - test_loader = iter(get_dataloader(args, test_dataset)) | 36 | + print("len(dataset): ", len(test_dataset), type(test_dataset)) # 590 |
| 37 | + | ||
| 38 | + test_loader = iter(get_dataloader(args, test_dataset)) ### | ||
| 37 | 39 | ||
| 38 | print('\n[+] Start testing') | 40 | print('\n[+] Start testing') |
| 39 | - log_dir = os.path.join('/content/drive/My Drive/CD2 Project/runs', model_name) | 41 | + writer = SummaryWriter(log_dir=model_path) |
| 40 | - writer = SummaryWriter(log_dir=log_dir) | ||
| 41 | _test_res = validate(args, model, criterion, test_loader, step=0, writer=writer) | 42 | _test_res = validate(args, model, criterion, test_loader, step=0, writer=writer) |
| 42 | 43 | ||
| 43 | print('\n[+] Valid results') | 44 | print('\n[+] Valid results') | ... | ... |
| ... | @@ -55,7 +55,7 @@ def split_dataset(args, dataset, k): | ... | @@ -55,7 +55,7 @@ def split_dataset(args, dataset, k): |
| 55 | 55 | ||
| 56 | return Dm_indexes, Da_indexes | 56 | return Dm_indexes, Da_indexes |
| 57 | 57 | ||
| 58 | - | 58 | +#(images[j], first[j]), global_step=step) |
| 59 | def concat_image_features(image, features, max_features=3): | 59 | def concat_image_features(image, features, max_features=3): |
| 60 | _, h, w = image.shape | 60 | _, h, w = image.shape |
| 61 | 61 | ||
| ... | @@ -93,6 +93,7 @@ def concat_image_features(image, features, max_features=3): | ... | @@ -93,6 +93,7 @@ def concat_image_features(image, features, max_features=3): |
| 93 | 93 | ||
| 94 | 94 | ||
| 95 | image_feature = torch.cat((image_feature, feature), 2) ### dim = 2 | 95 | image_feature = torch.cat((image_feature, feature), 2) ### dim = 2 |
| 96 | + #print("\nimg feature size: ", image_feature.size()) #[1, 240, 720] | ||
| 96 | 97 | ||
| 97 | return image_feature | 98 | return image_feature |
| 98 | 99 | ||
| ... | @@ -149,10 +150,6 @@ def parse_args(kwargs): | ... | @@ -149,10 +150,6 @@ def parse_args(kwargs): |
| 149 | 150 | ||
| 150 | 151 | ||
| 151 | def select_model(args): | 152 | def select_model(args): |
| 152 | - # resnet_dict = {'ResNet18':grayResNet.ResNet18(), 'ResNet34':grayResNet.ResNet34(), | ||
| 153 | - # 'ResNet50':grayResNet.ResNet50(), 'ResNet101':grayResNet.ResNet101(), 'ResNet152':grayResNet.ResNet152()} | ||
| 154 | - | ||
| 155 | - | ||
| 156 | # grayResNet2 | 153 | # grayResNet2 |
| 157 | resnet_dict = {'resnet18':grayResNet2.resnet18(), 'resnet34':grayResNet2.resnet34(), | 154 | resnet_dict = {'resnet18':grayResNet2.resnet18(), 'resnet34':grayResNet2.resnet34(), |
| 158 | 'resnet50':grayResNet2.resnet50(), 'resnet101':grayResNet2.resnet101(), 'resnet152':grayResNet2.resnet152()} | 155 | 'resnet50':grayResNet2.resnet50(), 'resnet101':grayResNet2.resnet101(), 'resnet152':grayResNet2.resnet152()} |
| ... | @@ -285,7 +282,7 @@ def get_dataset(args, transform, split='train'): | ... | @@ -285,7 +282,7 @@ def get_dataset(args, transform, split='train'): |
| 285 | elif args.dataset == 'BraTS': | 282 | elif args.dataset == 'BraTS': |
| 286 | if split in ['train']: | 283 | if split in ['train']: |
| 287 | dataset = CustomDataset(TRAIN_DATASET_PATH, transform=transform) | 284 | dataset = CustomDataset(TRAIN_DATASET_PATH, transform=transform) |
| 288 | - else: | 285 | + else: #test |
| 289 | dataset = CustomDataset(VAL_DATASET_PATH, transform=transform) | 286 | dataset = CustomDataset(VAL_DATASET_PATH, transform=transform) |
| 290 | 287 | ||
| 291 | 288 | ||
| ... | @@ -382,7 +379,7 @@ def get_valid_transform(args, model): | ... | @@ -382,7 +379,7 @@ def get_valid_transform(args, model): |
| 382 | transforms.ToTensor() | 379 | transforms.ToTensor() |
| 383 | ]) | 380 | ]) |
| 384 | elif args.dataset == 'BraTS': | 381 | elif args.dataset == 'BraTS': |
| 385 | - resize_h, resize_w = 256, 256 | 382 | + resize_h, resize_w = 240, 240 |
| 386 | val_transform = transforms.Compose([ | 383 | val_transform = transforms.Compose([ |
| 387 | transforms.Resize([resize_h, resize_w]), | 384 | transforms.Resize([resize_h, resize_w]), |
| 388 | transforms.ToTensor() | 385 | transforms.ToTensor() |
| ... | @@ -426,13 +423,14 @@ def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer | ... | @@ -426,13 +423,14 @@ def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer |
| 426 | 423 | ||
| 427 | if writer and step % args.print_step == 0: | 424 | if writer and step % args.print_step == 0: |
| 428 | n_imgs = min(images.size(0), 10) | 425 | n_imgs = min(images.size(0), 10) |
| 426 | + tag = 'train/' + str(step) | ||
| 429 | for j in range(n_imgs): | 427 | for j in range(n_imgs): |
| 430 | - writer.add_image('train/input_image', | 428 | + writer.add_image(tag, |
| 431 | concat_image_features(images[j], first[j]), global_step=step) | 429 | concat_image_features(images[j], first[j]), global_step=step) |
| 432 | 430 | ||
| 433 | return acc1, acc5, loss, forward_t, backward_t | 431 | return acc1, acc5, loss, forward_t, backward_t |
| 434 | 432 | ||
| 435 | - | 433 | +# validate(args, model, criterion, test_loader, step=0, writer=writer) |
| 436 | def validate(args, model, criterion, valid_loader, step, writer, device=None): | 434 | def validate(args, model, criterion, valid_loader, step, writer, device=None): |
| 437 | # switch to evaluate mode | 435 | # switch to evaluate mode |
| 438 | model.eval() | 436 | model.eval() |
| ... | @@ -441,19 +439,24 @@ def validate(args, model, criterion, valid_loader, step, writer, device=None): | ... | @@ -441,19 +439,24 @@ def validate(args, model, criterion, valid_loader, step, writer, device=None): |
| 441 | samples = 0 | 439 | samples = 0 |
| 442 | infer_t = 0 | 440 | infer_t = 0 |
| 443 | 441 | ||
| 442 | + img_count = 0 | ||
| 443 | + | ||
| 444 | with torch.no_grad(): | 444 | with torch.no_grad(): |
| 445 | - for i, (images, target) in enumerate(valid_loader): ## | 445 | + for i, (images, target) in enumerate(valid_loader): ## loop [0, 148] |
| 446 | 446 | ||
| 447 | + #print("\n1 images size: ", images.size()) #[4, 1, 240, 240] | ||
| 447 | start_t = time.time() | 448 | start_t = time.time() |
| 448 | if device: | 449 | if device: |
| 449 | images = images.to(device) | 450 | images = images.to(device) |
| 450 | target = target.to(device) | 451 | target = target.to(device) |
| 451 | 452 | ||
| 452 | - elif args.use_cuda is not None: | 453 | + elif args.use_cuda is not None: # |
| 453 | images = images.cuda(non_blocking=True) | 454 | images = images.cuda(non_blocking=True) |
| 454 | target = target.cuda(non_blocking=True) | 455 | target = target.cuda(non_blocking=True) |
| 456 | + #print("\n2 images size: ", images.size()) #[4, 1, 240, 240] | ||
| 455 | 457 | ||
| 456 | # compute output | 458 | # compute output |
| 459 | + # first = nn.Sequential(*list(backbone.children())[:1]) | ||
| 457 | output, first = model(images) | 460 | output, first = model(images) |
| 458 | loss = criterion(output, target) | 461 | loss = criterion(output, target) |
| 459 | infer_t += time.time() - start_t | 462 | infer_t += time.time() - start_t |
| ... | @@ -464,14 +467,19 @@ def validate(args, model, criterion, valid_loader, step, writer, device=None): | ... | @@ -464,14 +467,19 @@ def validate(args, model, criterion, valid_loader, step, writer, device=None): |
| 464 | acc5 += _acc5 | 467 | acc5 += _acc5 |
| 465 | samples += images.size(0) | 468 | samples += images.size(0) |
| 466 | 469 | ||
| 470 | + if writer: | ||
| 471 | + # print("\n3 images.size(0): ", images.size(0)) | ||
| 472 | + n_imgs = min(images.size(0), 10) | ||
| 473 | + for j in range(n_imgs): | ||
| 474 | + tag = 'valid/' + str(img_count) | ||
| 475 | + writer.add_image(tag, | ||
| 476 | + concat_image_features(images[j], first[j]), global_step=step) | ||
| 477 | + img_count = img_count + 1 | ||
| 478 | + | ||
| 467 | acc1 /= samples | 479 | acc1 /= samples |
| 468 | acc5 /= samples | 480 | acc5 /= samples |
| 469 | 481 | ||
| 470 | - if writer: | 482 | + |
| 471 | - n_imgs = min(images.size(0), 10) | ||
| 472 | - for j in range(n_imgs): | ||
| 473 | - writer.add_image('valid/input_image', | ||
| 474 | - concat_image_features(images[j], first[j]), global_step=step) | ||
| 475 | 483 | ||
| 476 | return acc1, acc5, loss, infer_t | 484 | return acc1, acc5, loss, infer_t |
| 477 | 485 | ... | ... |
-
Please register or login to post a comment