Hyunji

main code

1 +# -*- coding: utf-8 -*-
2 +"""mainnn.ipynb
3 +
4 +Automatically generated by Colaboratory.
5 +
6 +Original file is located at
7 + https://colab.research.google.com/drive/1tGd53i4_WlVJHCjEzaD2oiL0okf70bOT
8 +"""
9 +
10 +import json
11 +
12 +from dataset import PAC2019, PAC20192D, PAC20193D
13 +from model import Model, VGGBasedModel, VGGBasedModel2D, Model3D
14 +from model_resnet import ResNet, resnet18, resnset34, resnet50
15 +
16 +import torch
17 +from torch.autograd import Variable
18 +import torch.nn as nn
19 +from torch.utils.data import DataLoader
20 +import numpy as np
21 +
22 +from tqdm import *
23 +
24 +import gc
25 +gc.collect()
26 +torch.cuda.empty_cache()
27 +
28 +def cosine_rampdown(current, rampdown_length):
29 + """Cosine rampdown from https://arxiv.org/abs/1608.03983"""
30 + assert 0 <= current <= rampdown_length
31 + return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1))
32 +
33 +def cosine_lr(current_epoch, num_epochs, initial_lr):
34 + return initial_lr * cosine_rampdown(current_epoch, num_epochs)
35 +
36 +def sigmoid_rampup(current, rampup_length):
37 + if rampup_length == 0:
38 + return 1.0
39 + else:
40 + current = np.clip(current, 0.0, rampup_length)
41 + phase = 1.0 - current / rampup_length
42 + return float(np.exp(-5.0 * phase * phase))
43 +
44 +
45 +with open("config.json") as fid:
46 + ctx = json.load(fid)
47 +
48 +if ctx["3d"]:
49 + train_set = PAC20193D(ctx, set='train')
50 + val_set = PAC20193D(ctx, set='valid')
51 + test_set = PAC20193D(ctx, set='test')
52 +
53 + model = Model3D
54 + #model = VGGBasedModel()
55 +
56 + optimizer = torch.optim.SGD(model.parameters(), lr=ctx["learning_rate"],
57 + momentum=0.9, weight_decay=ctx["weight_decay"])
58 + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.97)
59 +
60 +else:
61 + train_set = PAC20192D(ctx, set='train')
62 + val_set = PAC20192D(ctx, set='val')
63 + test_set = PAC20192D(ctx, set='test')
64 +
65 + model = resnet18()
66 + #model = resnet34()
67 + #model = resnet50()
68 +
69 + optimizer = torch.optim.Adam(model.parameters(), lr=ctx["learning_rate"],
70 + weight_decay=ctx["weight_decay"])
71 +
72 +
73 +train_loader = DataLoader(train_set, shuffle=False, drop_last=False,
74 + num_workers=8, batch_size=ctx["batch_size"])
75 +val_loader = DataLoader(val_set, shuffle=False, drop_last=False,
76 + num_workers=8, batch_size=ctx["batch_size"])
77 +test_loader = DataLoader(test_set, shuffle=False, drop_last=False,
78 + num_workers=8, batch_size=ctx["batch_size"])
79 +
80 +mse_loss = nn.MSELoss()
81 +mae_loss = nn.L1Loss()
82 +model.cuda()
83 +
84 +best = np.inf
85 +for e in tqdm(range(1, ctx["epochs"]+1), desc="Epochs"):
86 + model.train()
87 + last_50 = []
88 +
89 + if ctx["3d"]:
90 + scheduler.step()
91 + tqdm.write('Learning Rate: {:.6f}'.format(scheduler.get_lr()[0]))
92 + else:
93 + if e <= ctx["initial_lr_rampup"]:
94 + lr = ctx["learning_rate"] * sigmoid_rampup(e, ctx["initial_lr_rampup"])
95 + else:
96 + lr = cosine_lr(e-ctx["initial_lr_rampup"],
97 + ctx["epochs"]-ctx["initial_lr_rampup"],
98 + ctx["learning_rate"])
99 +
100 + for param_group in optimizer.param_groups:
101 + tqdm.write("Learning Rate: {:.6f}".format(lr))
102 + param_group['lr'] = lr
103 +
104 +
105 + for i, data in enumerate(train_loader):
106 + if ctx["mixup"]:
107 + lam = np.random.beta(ctx["mixup_alpha"], ctx["mixup_alpha"])
108 +
109 + length_data = data["input"].size(0)//2
110 + data1_x = data["input"][0:length_data]
111 + data1_y = data["label"][0:length_data]
112 + data2_x = data["input"][length_data:]
113 + data2_y = data["label"][length_data:]
114 +
115 + data["input"] = lam*data1_x + (1.-lam)*data2_x
116 + data["label"] = lam*data1_y + (1.-lam)*data2_y
117 +
118 + input_image = Variable(data["input"], requires_grad=True).float().cuda()
119 + if ctx["3d"]:
120 + input_image = input_image.squeeze(1)
121 + output = model(input_image)
122 + label = Variable(data["label"].float()).cuda()
123 +
124 + loss = mae_loss(output.squeeze(), label)
125 + optimizer.zero_grad()
126 + loss.backward()
127 + optimizer.step()
128 +
129 + last_50.append(loss.data)
130 + if (i+1) % 50 == 0:
131 + tqdm.write('Training Loss: %f' % torch.mean(torch.stack(last_50)).item())
132 + last_50 = []
133 +
134 +
135 + # tqdm.write('Validation...')
136 + model.eval()
137 + # val_mse_loss = []
138 + val_mae_loss = []
139 + for i, data in enumerate(val_loader):
140 + input_image = Variable(data["input"]).float().cuda()
141 + input_image = input_image.squeeze(1) #
142 + output = model(input_image)
143 + label = Variable(data["label"].float()).cuda()
144 + #print(output)
145 + #print(label)
146 + loss = mae_loss(output.squeeze(), label)
147 + val_mae_loss.append(loss.data)
148 +
149 +
150 + # loss = torch.mean(torch.abs(output.squeeze() - label))
151 + # val_mae_loss.append(loss.data)
152 +
153 + torch.save(model.state_dict(), ctx["save_path"]) #
154 +
155 +# print('Validation Loss (MSE): ', torch.mean(torch.stack(val_mse_loss)))
156 + tqdm.write('Validation Loss (MAE): %f' % torch.mean(torch.stack(val_mae_loss)).item())
157 +
158 + if torch.mean(torch.stack(val_mae_loss)) < best:
159 + best = torch.mean(torch.stack(val_mae_loss))
160 + tqdm.write('model saved')
161 +
162 + print("<<<<< training set >>>>>")
163 + # tqdm.write('Training...')
164 + model.eval()
165 + # training_mse_loss = []
166 + train_mae_loss = []
167 + for i, data in enumerate(train_loader):
168 + input_image = Variable(data["input"]).float().cuda()
169 + if ctx["3d"]:
170 + input_image = input_image.squeeze(1)
171 + output = model(input_image)
172 + label = Variable(data["label"].float()).cuda()
173 + print(output)
174 + print(label)
175 + loss = mae_loss(output.squeeze(), label)
176 + train_mae_loss.append(loss.data)
177 + tqdm.write('Training Loss (MAE): %f' % torch.mean(torch.stack(train_mae_loss)).item())
178 +
179 + print("<<<<< validation set >>>>>")
180 + # tqdm.write('Validation...')
181 + model.eval()
182 + # val_mse_loss = []
183 + val_mae_loss = []
184 + for i, data in enumerate(val_loader):
185 + input_image = Variable(data["input"]).float().cuda()
186 + if ctx["3d"]:
187 + input_image = input_image.squeeze(1)
188 + output = model(input_image)
189 + label = Variable(data["label"].float()).cuda()
190 + print(output)
191 + print(label)
192 + loss = mae_loss(output.squeeze(), label)
193 + val_mae_loss.append(loss.data)
194 + tqdm.write('Validation Loss (MAE): %f' % torch.mean(torch.stack(val_mae_loss)).item())
195 +
196 + print("<<<<< test set >>>>>")
197 + # tqdm.write('Test...')
198 + model.eval()
199 + # test_mse_loss = []
200 + test_mae_loss = []
201 + for i, data in enumerate(test_loader):
202 + input_image = Variable(data["input"]).float().cuda()
203 + if ctx["3d"]:
204 + input_image = input_image.squeeze(1)
205 + output = model(input_image)
206 + label = Variable(data["label"].float()).cuda()
207 + print(output)
208 + print(label)
209 + loss = mae_loss(output.squeeze(), label)
210 + test_mae_loss.append(loss.data)
211 + tqdm.write('Test Loss (MAE): %f' % torch.mean(torch.stack(test_mae_loss)).item())
...\ No newline at end of file ...\ No newline at end of file