서민정

feat: 최종 업로드

...@@ -36,10 +36,6 @@ parser.add_argument("--weight-decay", "--wd", default=1e-4, type=float, help="We ...@@ -36,10 +36,6 @@ parser.add_argument("--weight-decay", "--wd", default=1e-4, type=float, help="We
36 parser.add_argument('--pretrained', default='', type=str, help='path to pretrained model (default: none)') 36 parser.add_argument('--pretrained', default='', type=str, help='path to pretrained model (default: none)')
37 parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)") 37 parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)")
38 38
39 -
40 -total_loss_for_plot = list()
41 -total_pnsr = list()
42 -
43 def main(): 39 def main():
44 global opt, model 40 global opt, model
45 opt = parser.parse_args() 41 opt = parser.parse_args()
...@@ -141,17 +137,14 @@ def train(training_data_loader, optimizer, model, criterion, epoch): ...@@ -141,17 +137,14 @@ def train(training_data_loader, optimizer, model, criterion, epoch):
141 optimizer.step() 137 optimizer.step()
142 138
143 epoch_loss = total_loss / len(training_data_loader) 139 epoch_loss = total_loss / len(training_data_loader)
144 - total_loss_for_plot.append(epoch_loss)
145 psnr = PSNR(epoch_loss) 140 psnr = PSNR(epoch_loss)
146 - total_pnsr.append(psnr)
147 print("===> Epoch[{}]: loss : {:.10f} ,PSNR : {:.10f}".format(epoch, epoch_loss, psnr)) 141 print("===> Epoch[{}]: loss : {:.10f} ,PSNR : {:.10f}".format(epoch, epoch_loss, psnr))
148 # if iteration%100 == 0: 142 # if iteration%100 == 0:
149 # print("===> Epoch[{}]({}/{}): Loss: {:.10f}".format(epoch, iteration, len(training_data_loader), loss.item())) 143 # print("===> Epoch[{}]({}/{}): Loss: {:.10f}".format(epoch, iteration, len(training_data_loader), loss.item()))
150 144
151 def save_checkpoint(model, epoch, optimizer): 145 def save_checkpoint(model, epoch, optimizer):
152 model_out_path = "checkpoint/" + "model_epoch_{}_{}.pth".format(epoch, opt.featureType) 146 model_out_path = "checkpoint/" + "model_epoch_{}_{}.pth".format(epoch, opt.featureType)
153 - state = {"epoch": epoch ,"model": model, "model_state_dict":model.state_dict(), "optimizer_state_dict":optimizer.state_dict(), 147 + state = {"epoch": epoch ,"model": model, "model_state_dict":model.state_dict(), "optimizer_state_dict":optimizer.state_dict()}
154 - "loss": total_loss_for_plot, "psnr":total_pnsr}
155 if not os.path.exists("checkpoint/"): 148 if not os.path.exists("checkpoint/"):
156 os.makedirs("checkpoint/") 149 os.makedirs("checkpoint/")
157 150
......
No preview for this file type