Showing
1 changed file
with
139 additions
and
0 deletions
1 | +import json | ||
2 | + | ||
3 | +from dataset import PAC2019, PAC20192D | ||
4 | +from model import Model, VGGBasedModel, VGGBasedModel2D | ||
5 | +from model_resnet import resnet18 | ||
6 | + | ||
7 | +import torch | ||
8 | +from torch.autograd import Variable | ||
9 | +import torch.nn as nn | ||
10 | +from torch.utils.data import DataLoader | ||
11 | +import numpy as np | ||
12 | + | ||
13 | +import medicaltorch.transforms as mt_transforms | ||
14 | +import torchvision as tv | ||
15 | +import torchvision.utils as vutils | ||
16 | + | ||
17 | +import matplotlib.pyplot as plt | ||
18 | +from collections import defaultdict, Counter | ||
19 | + | ||
20 | +from tqdm import * | ||
21 | + | ||
22 | + | ||
23 | +with open("config.json") as fid: | ||
24 | + ctx = json.load(fid) | ||
25 | + | ||
26 | +val_set = PAC2019(ctx, set='val', split=0.8) | ||
27 | + | ||
28 | +val_loader = DataLoader(val_set, shuffle=False, drop_last=False, | ||
29 | + num_workers=8, batch_size=1) | ||
30 | + | ||
31 | +model = resnet18() | ||
32 | +model.cuda() | ||
33 | +# model.load_state_dict(torch.load('models/lr0.0006_rampup20.pt')) | ||
34 | +model.load_state_dict(torch.load('models/2d.pt')) | ||
35 | +model.eval() | ||
36 | + | ||
37 | +portion = 0.8 | ||
38 | +errors = [] | ||
39 | +error_per_age = defaultdict(list) | ||
40 | +error_per_age_per_slice = defaultdict(lambda: defaultdict(list)) | ||
41 | +errors_val = [] | ||
42 | +for i, data in enumerate(tqdm(val_loader)): | ||
43 | + gm_image = Variable(data["gm"]).float().cuda() | ||
44 | + wm_image = Variable(data["wm"]).float().cuda() | ||
45 | + # print(input_image.shape) | ||
46 | + | ||
47 | + | ||
48 | + slices = [] | ||
49 | + start = int((1.-portion)*gm_image.shape[1]) | ||
50 | + end = int(portion*gm_image.shape[1]) | ||
51 | + gm_image = gm_image[0,start:end,:,:] | ||
52 | + wm_image = wm_image[0,start:end,:,:] | ||
53 | + # print(gm_image.shape) | ||
54 | + for slice_idx in range(gm_image.shape[0]): | ||
55 | + slice_gm = gm_image[slice_idx,:,:] | ||
56 | + slice_gm = slice_gm.unsqueeze(0) | ||
57 | + slice_wm = wm_image[slice_idx,:,:] | ||
58 | + slice_wm = slice_wm.unsqueeze(0) | ||
59 | + slice = torch.cat([slice_gm, slice_wm], dim=0) | ||
60 | + # print(slice.shape) | ||
61 | + slices.append({ | ||
62 | + 'image': slice, | ||
63 | + 'label': data['label'] | ||
64 | + }) | ||
65 | + # print('Slice: ', slice.shape) | ||
66 | + | ||
67 | + error = [] | ||
68 | + for idx, slice in enumerate(slices): | ||
69 | + age = int(slice['label'].item()) | ||
70 | + slice['image'] = slice['image'].unsqueeze(0) | ||
71 | + # print(slice['image'].shape) | ||
72 | + output = model(slice['image']) | ||
73 | + # print(output[0], slice['label']) | ||
74 | + error.append(np.abs(output[0].item() - slice['label'].item())) | ||
75 | + error_per_age_per_slice[idx][age].append(np.abs(output[0].item() - slice['label'].item())) | ||
76 | + # print(error) | ||
77 | + errors.append(error) | ||
78 | + errors_val.append(np.mean(error)) | ||
79 | + error_per_age[int(slice['label'].item())].append(np.mean(error)) | ||
80 | + | ||
81 | +print('Validation error: ', np.mean(errors_val)) | ||
82 | +min_slice = 0 | ||
83 | +# print(error_per_age_per_slice.keys()) | ||
84 | +max_slice = len(error_per_age_per_slice.keys()) | ||
85 | +min_age = min(error_per_age_per_slice[0].keys()) | ||
86 | +max_age = max(error_per_age_per_slice[0].keys())+1 | ||
87 | +# print('Min/max: ', min_age, max_age) | ||
88 | +heatmap = np.zeros((max_age, max_slice)) | ||
89 | +# print(error_per_age_per_slice.keys()) | ||
90 | +# print(error_per_age_per_slice[0].keys()) | ||
91 | +# print(list(sorted(error_per_age_per_slice[0].keys()))) | ||
92 | +for slice_idx in sorted(error_per_age_per_slice.keys()): | ||
93 | + # print('here') | ||
94 | + for age in range(0, 75): | ||
95 | + # print('age: here') | ||
96 | + # print('Slice/Age: %d/%d --> ' % (slice_idx, age), error_per_age_per_slice[slice_idx][age]) | ||
97 | + mean = np.mean(error_per_age_per_slice[slice_idx][age]) | ||
98 | + if not np.isnan(mean): | ||
99 | + heatmap[age,slice_idx] = mean | ||
100 | + # print('mean: ', np.mean(error_per_age_per_slice[slice_idx][age])) | ||
101 | +plt.imshow(heatmap, cmap='viridis') | ||
102 | +plt.colorbar() | ||
103 | +plt.ylabel('Age') | ||
104 | +plt.xlabel('Slice') | ||
105 | +# plt.grid() | ||
106 | +plt.show() | ||
107 | +# raise | ||
108 | +# print(error_per_age) | ||
109 | +sorted_values = [] | ||
110 | +keys = [] | ||
111 | +for k in sorted(error_per_age.keys()): | ||
112 | + sorted_values.append(error_per_age[k]) | ||
113 | + keys.append(k) | ||
114 | + | ||
115 | +fig = plt.figure(1, figsize=(9, 6)) | ||
116 | +ax = fig.add_subplot(111) | ||
117 | + | ||
118 | +ax.boxplot(sorted_values) | ||
119 | +ax.set_xticklabels(keys) | ||
120 | +plt.show() | ||
121 | + | ||
122 | + | ||
123 | +errors = np.array(errors) | ||
124 | +# print(errors.shape) | ||
125 | +mean_errors = np.mean(errors, axis=0) | ||
126 | +# plt.plot(mean_errors) | ||
127 | +fig, (ax,ax2) = plt.subplots(nrows=2, sharex=True) | ||
128 | +x = np.linspace(0, errors.shape[1]) | ||
129 | +extent = [x[0]-(x[1]-x[0])/2., x[-1]+(x[1]-x[0])/2.,0,1] | ||
130 | +ax.imshow(mean_errors[np.newaxis,:], cmap="viridis", aspect="auto", extent=extent) | ||
131 | +# print(mean_errors.shape) | ||
132 | +# print(x.shape) | ||
133 | +ax2.plot(np.arange(mean_errors.shape[0]),mean_errors) | ||
134 | + | ||
135 | +plt.ylabel('Mean Absolute Error (MAE)') | ||
136 | +plt.xlabel('Slice index') | ||
137 | + | ||
138 | +plt.show() | ||
139 | +# print(mean_errors) |
-
Please register or login to post a comment