Showing
1 changed file
with
211 additions
and
0 deletions
3DCNN_VGGNet_2DResNet/main.py
0 → 100644
1 | +# -*- coding: utf-8 -*- | ||
2 | +"""mainnn.ipynb | ||
3 | + | ||
4 | +Automatically generated by Colaboratory. | ||
5 | + | ||
6 | +Original file is located at | ||
7 | + https://colab.research.google.com/drive/1tGd53i4_WlVJHCjEzaD2oiL0okf70bOT | ||
8 | +""" | ||
9 | + | ||
10 | +import json | ||
11 | + | ||
12 | +from dataset import PAC2019, PAC20192D, PAC20193D | ||
13 | +from model import Model, VGGBasedModel, VGGBasedModel2D, Model3D | ||
14 | +from model_resnet import ResNet, resnet18, resnset34, resnet50 | ||
15 | + | ||
16 | +import torch | ||
17 | +from torch.autograd import Variable | ||
18 | +import torch.nn as nn | ||
19 | +from torch.utils.data import DataLoader | ||
20 | +import numpy as np | ||
21 | + | ||
22 | +from tqdm import * | ||
23 | + | ||
24 | +import gc | ||
25 | +gc.collect() | ||
26 | +torch.cuda.empty_cache() | ||
27 | + | ||
28 | +def cosine_rampdown(current, rampdown_length): | ||
29 | + """Cosine rampdown from https://arxiv.org/abs/1608.03983""" | ||
30 | + assert 0 <= current <= rampdown_length | ||
31 | + return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1)) | ||
32 | + | ||
33 | +def cosine_lr(current_epoch, num_epochs, initial_lr): | ||
34 | + return initial_lr * cosine_rampdown(current_epoch, num_epochs) | ||
35 | + | ||
36 | +def sigmoid_rampup(current, rampup_length): | ||
37 | + if rampup_length == 0: | ||
38 | + return 1.0 | ||
39 | + else: | ||
40 | + current = np.clip(current, 0.0, rampup_length) | ||
41 | + phase = 1.0 - current / rampup_length | ||
42 | + return float(np.exp(-5.0 * phase * phase)) | ||
43 | + | ||
44 | + | ||
45 | +with open("config.json") as fid: | ||
46 | + ctx = json.load(fid) | ||
47 | + | ||
48 | +if ctx["3d"]: | ||
49 | + train_set = PAC20193D(ctx, set='train') | ||
50 | + val_set = PAC20193D(ctx, set='valid') | ||
51 | + test_set = PAC20193D(ctx, set='test') | ||
52 | + | ||
53 | + model = Model3D | ||
54 | + #model = VGGBasedModel() | ||
55 | + | ||
56 | + optimizer = torch.optim.SGD(model.parameters(), lr=ctx["learning_rate"], | ||
57 | + momentum=0.9, weight_decay=ctx["weight_decay"]) | ||
58 | + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.97) | ||
59 | + | ||
60 | +else: | ||
61 | + train_set = PAC20192D(ctx, set='train') | ||
62 | + val_set = PAC20192D(ctx, set='val') | ||
63 | + test_set = PAC20192D(ctx, set='test') | ||
64 | + | ||
65 | + model = resnet18() | ||
66 | + #model = resnet34() | ||
67 | + #model = resnet50() | ||
68 | + | ||
69 | + optimizer = torch.optim.Adam(model.parameters(), lr=ctx["learning_rate"], | ||
70 | + weight_decay=ctx["weight_decay"]) | ||
71 | + | ||
72 | + | ||
73 | +train_loader = DataLoader(train_set, shuffle=False, drop_last=False, | ||
74 | + num_workers=8, batch_size=ctx["batch_size"]) | ||
75 | +val_loader = DataLoader(val_set, shuffle=False, drop_last=False, | ||
76 | + num_workers=8, batch_size=ctx["batch_size"]) | ||
77 | +test_loader = DataLoader(test_set, shuffle=False, drop_last=False, | ||
78 | + num_workers=8, batch_size=ctx["batch_size"]) | ||
79 | + | ||
80 | +mse_loss = nn.MSELoss() | ||
81 | +mae_loss = nn.L1Loss() | ||
82 | +model.cuda() | ||
83 | + | ||
84 | +best = np.inf | ||
85 | +for e in tqdm(range(1, ctx["epochs"]+1), desc="Epochs"): | ||
86 | + model.train() | ||
87 | + last_50 = [] | ||
88 | + | ||
89 | + if ctx["3d"]: | ||
90 | + scheduler.step() | ||
91 | + tqdm.write('Learning Rate: {:.6f}'.format(scheduler.get_lr()[0])) | ||
92 | + else: | ||
93 | + if e <= ctx["initial_lr_rampup"]: | ||
94 | + lr = ctx["learning_rate"] * sigmoid_rampup(e, ctx["initial_lr_rampup"]) | ||
95 | + else: | ||
96 | + lr = cosine_lr(e-ctx["initial_lr_rampup"], | ||
97 | + ctx["epochs"]-ctx["initial_lr_rampup"], | ||
98 | + ctx["learning_rate"]) | ||
99 | + | ||
100 | + for param_group in optimizer.param_groups: | ||
101 | + tqdm.write("Learning Rate: {:.6f}".format(lr)) | ||
102 | + param_group['lr'] = lr | ||
103 | + | ||
104 | + | ||
105 | + for i, data in enumerate(train_loader): | ||
106 | + if ctx["mixup"]: | ||
107 | + lam = np.random.beta(ctx["mixup_alpha"], ctx["mixup_alpha"]) | ||
108 | + | ||
109 | + length_data = data["input"].size(0)//2 | ||
110 | + data1_x = data["input"][0:length_data] | ||
111 | + data1_y = data["label"][0:length_data] | ||
112 | + data2_x = data["input"][length_data:] | ||
113 | + data2_y = data["label"][length_data:] | ||
114 | + | ||
115 | + data["input"] = lam*data1_x + (1.-lam)*data2_x | ||
116 | + data["label"] = lam*data1_y + (1.-lam)*data2_y | ||
117 | + | ||
118 | + input_image = Variable(data["input"], requires_grad=True).float().cuda() | ||
119 | + if ctx["3d"]: | ||
120 | + input_image = input_image.squeeze(1) | ||
121 | + output = model(input_image) | ||
122 | + label = Variable(data["label"].float()).cuda() | ||
123 | + | ||
124 | + loss = mae_loss(output.squeeze(), label) | ||
125 | + optimizer.zero_grad() | ||
126 | + loss.backward() | ||
127 | + optimizer.step() | ||
128 | + | ||
129 | + last_50.append(loss.data) | ||
130 | + if (i+1) % 50 == 0: | ||
131 | + tqdm.write('Training Loss: %f' % torch.mean(torch.stack(last_50)).item()) | ||
132 | + last_50 = [] | ||
133 | + | ||
134 | + | ||
135 | + # tqdm.write('Validation...') | ||
136 | + model.eval() | ||
137 | + # val_mse_loss = [] | ||
138 | + val_mae_loss = [] | ||
139 | + for i, data in enumerate(val_loader): | ||
140 | + input_image = Variable(data["input"]).float().cuda() | ||
141 | + input_image = input_image.squeeze(1) # | ||
142 | + output = model(input_image) | ||
143 | + label = Variable(data["label"].float()).cuda() | ||
144 | + #print(output) | ||
145 | + #print(label) | ||
146 | + loss = mae_loss(output.squeeze(), label) | ||
147 | + val_mae_loss.append(loss.data) | ||
148 | + | ||
149 | + | ||
150 | + # loss = torch.mean(torch.abs(output.squeeze() - label)) | ||
151 | + # val_mae_loss.append(loss.data) | ||
152 | + | ||
153 | + torch.save(model.state_dict(), ctx["save_path"]) # | ||
154 | + | ||
155 | +# print('Validation Loss (MSE): ', torch.mean(torch.stack(val_mse_loss))) | ||
156 | + tqdm.write('Validation Loss (MAE): %f' % torch.mean(torch.stack(val_mae_loss)).item()) | ||
157 | + | ||
158 | + if torch.mean(torch.stack(val_mae_loss)) < best: | ||
159 | + best = torch.mean(torch.stack(val_mae_loss)) | ||
160 | + tqdm.write('model saved') | ||
161 | + | ||
162 | + print("<<<<< training set >>>>>") | ||
163 | + # tqdm.write('Training...') | ||
164 | + model.eval() | ||
165 | + # training_mse_loss = [] | ||
166 | + train_mae_loss = [] | ||
167 | + for i, data in enumerate(train_loader): | ||
168 | + input_image = Variable(data["input"]).float().cuda() | ||
169 | + if ctx["3d"]: | ||
170 | + input_image = input_image.squeeze(1) | ||
171 | + output = model(input_image) | ||
172 | + label = Variable(data["label"].float()).cuda() | ||
173 | + print(output) | ||
174 | + print(label) | ||
175 | + loss = mae_loss(output.squeeze(), label) | ||
176 | + train_mae_loss.append(loss.data) | ||
177 | + tqdm.write('Training Loss (MAE): %f' % torch.mean(torch.stack(train_mae_loss)).item()) | ||
178 | + | ||
179 | + print("<<<<< validation set >>>>>") | ||
180 | + # tqdm.write('Validation...') | ||
181 | + model.eval() | ||
182 | + # val_mse_loss = [] | ||
183 | + val_mae_loss = [] | ||
184 | + for i, data in enumerate(val_loader): | ||
185 | + input_image = Variable(data["input"]).float().cuda() | ||
186 | + if ctx["3d"]: | ||
187 | + input_image = input_image.squeeze(1) | ||
188 | + output = model(input_image) | ||
189 | + label = Variable(data["label"].float()).cuda() | ||
190 | + print(output) | ||
191 | + print(label) | ||
192 | + loss = mae_loss(output.squeeze(), label) | ||
193 | + val_mae_loss.append(loss.data) | ||
194 | + tqdm.write('Validation Loss (MAE): %f' % torch.mean(torch.stack(val_mae_loss)).item()) | ||
195 | + | ||
196 | + print("<<<<< test set >>>>>") | ||
197 | + # tqdm.write('Test...') | ||
198 | + model.eval() | ||
199 | + # test_mse_loss = [] | ||
200 | + test_mae_loss = [] | ||
201 | + for i, data in enumerate(test_loader): | ||
202 | + input_image = Variable(data["input"]).float().cuda() | ||
203 | + if ctx["3d"]: | ||
204 | + input_image = input_image.squeeze(1) | ||
205 | + output = model(input_image) | ||
206 | + label = Variable(data["label"].float()).cuda() | ||
207 | + print(output) | ||
208 | + print(label) | ||
209 | + loss = mae_loss(output.squeeze(), label) | ||
210 | + test_mae_loss.append(loss.data) | ||
211 | + tqdm.write('Test Loss (MAE): %f' % torch.mean(torch.stack(test_mae_loss)).item()) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment