Hyunji

sliced whole validation

1 +import json
2 +
3 +from dataset import PAC2019, PAC20192D
4 +from model import Model, VGGBasedModel, VGGBasedModel2D
5 +from model_resnet import resnet18
6 +
7 +import torch
8 +from torch.autograd import Variable
9 +import torch.nn as nn
10 +from torch.utils.data import DataLoader
11 +import numpy as np
12 +
13 +import medicaltorch.transforms as mt_transforms
14 +import torchvision as tv
15 +import torchvision.utils as vutils
16 +
17 +import matplotlib.pyplot as plt
18 +from collections import defaultdict, Counter
19 +
20 +from tqdm import *
21 +
22 +
23 +with open("config.json") as fid:
24 + ctx = json.load(fid)
25 +
26 +val_set = PAC2019(ctx, set='val', split=0.8)
27 +
28 +val_loader = DataLoader(val_set, shuffle=False, drop_last=False,
29 + num_workers=8, batch_size=1)
30 +
31 +model = resnet18()
32 +model.cuda()
33 +# model.load_state_dict(torch.load('models/lr0.0006_rampup20.pt'))
34 +model.load_state_dict(torch.load('models/2d.pt'))
35 +model.eval()
36 +
37 +portion = 0.8
38 +errors = []
39 +error_per_age = defaultdict(list)
40 +error_per_age_per_slice = defaultdict(lambda: defaultdict(list))
41 +errors_val = []
42 +for i, data in enumerate(tqdm(val_loader)):
43 + gm_image = Variable(data["gm"]).float().cuda()
44 + wm_image = Variable(data["wm"]).float().cuda()
45 + # print(input_image.shape)
46 +
47 +
48 + slices = []
49 + start = int((1.-portion)*gm_image.shape[1])
50 + end = int(portion*gm_image.shape[1])
51 + gm_image = gm_image[0,start:end,:,:]
52 + wm_image = wm_image[0,start:end,:,:]
53 + # print(gm_image.shape)
54 + for slice_idx in range(gm_image.shape[0]):
55 + slice_gm = gm_image[slice_idx,:,:]
56 + slice_gm = slice_gm.unsqueeze(0)
57 + slice_wm = wm_image[slice_idx,:,:]
58 + slice_wm = slice_wm.unsqueeze(0)
59 + slice = torch.cat([slice_gm, slice_wm], dim=0)
60 + # print(slice.shape)
61 + slices.append({
62 + 'image': slice,
63 + 'label': data['label']
64 + })
65 + # print('Slice: ', slice.shape)
66 +
67 + error = []
68 + for idx, slice in enumerate(slices):
69 + age = int(slice['label'].item())
70 + slice['image'] = slice['image'].unsqueeze(0)
71 + # print(slice['image'].shape)
72 + output = model(slice['image'])
73 + # print(output[0], slice['label'])
74 + error.append(np.abs(output[0].item() - slice['label'].item()))
75 + error_per_age_per_slice[idx][age].append(np.abs(output[0].item() - slice['label'].item()))
76 + # print(error)
77 + errors.append(error)
78 + errors_val.append(np.mean(error))
79 + error_per_age[int(slice['label'].item())].append(np.mean(error))
80 +
81 +print('Validation error: ', np.mean(errors_val))
82 +min_slice = 0
83 +# print(error_per_age_per_slice.keys())
84 +max_slice = len(error_per_age_per_slice.keys())
85 +min_age = min(error_per_age_per_slice[0].keys())
86 +max_age = max(error_per_age_per_slice[0].keys())+1
87 +# print('Min/max: ', min_age, max_age)
88 +heatmap = np.zeros((max_age, max_slice))
89 +# print(error_per_age_per_slice.keys())
90 +# print(error_per_age_per_slice[0].keys())
91 +# print(list(sorted(error_per_age_per_slice[0].keys())))
92 +for slice_idx in sorted(error_per_age_per_slice.keys()):
93 + # print('here')
94 + for age in range(0, 75):
95 + # print('age: here')
96 + # print('Slice/Age: %d/%d --> ' % (slice_idx, age), error_per_age_per_slice[slice_idx][age])
97 + mean = np.mean(error_per_age_per_slice[slice_idx][age])
98 + if not np.isnan(mean):
99 + heatmap[age,slice_idx] = mean
100 + # print('mean: ', np.mean(error_per_age_per_slice[slice_idx][age]))
101 +plt.imshow(heatmap, cmap='viridis')
102 +plt.colorbar()
103 +plt.ylabel('Age')
104 +plt.xlabel('Slice')
105 +# plt.grid()
106 +plt.show()
107 +# raise
108 +# print(error_per_age)
109 +sorted_values = []
110 +keys = []
111 +for k in sorted(error_per_age.keys()):
112 + sorted_values.append(error_per_age[k])
113 + keys.append(k)
114 +
115 +fig = plt.figure(1, figsize=(9, 6))
116 +ax = fig.add_subplot(111)
117 +
118 +ax.boxplot(sorted_values)
119 +ax.set_xticklabels(keys)
120 +plt.show()
121 +
122 +
123 +errors = np.array(errors)
124 +# print(errors.shape)
125 +mean_errors = np.mean(errors, axis=0)
126 +# plt.plot(mean_errors)
127 +fig, (ax,ax2) = plt.subplots(nrows=2, sharex=True)
128 +x = np.linspace(0, errors.shape[1])
129 +extent = [x[0]-(x[1]-x[0])/2., x[-1]+(x[1]-x[0])/2.,0,1]
130 +ax.imshow(mean_errors[np.newaxis,:], cmap="viridis", aspect="auto", extent=extent)
131 +# print(mean_errors.shape)
132 +# print(x.shape)
133 +ax2.plot(np.arange(mean_errors.shape[0]),mean_errors)
134 +
135 +plt.ylabel('Mean Absolute Error (MAE)')
136 +plt.xlabel('Slice index')
137 +
138 +plt.show()
139 +# print(mean_errors)