Showing
9 changed files
with
576 additions
and
0 deletions
code/pytorch_vdsr/LICENSE
0 → 100644
1 | +The MIT License (MIT) | ||
2 | + | ||
3 | +Copyright (c) 2017- Jiu XU | ||
4 | +Copyright (c) 2017- Rakuten, Inc | ||
5 | +Copyright (c) 2017- Rakuten Institute of Technology | ||
6 | + | ||
7 | +Permission is hereby granted, free of charge, to any person obtaining a copy | ||
8 | +of this software and associated documentation files (the "Software"), to deal | ||
9 | +in the Software without restriction, including without limitation the rights | ||
10 | +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
11 | +copies of the Software, and to permit persons to whom the Software is | ||
12 | +furnished to do so, subject to the following conditions: | ||
13 | + | ||
14 | +The above copyright notice and this permission notice shall be included in all | ||
15 | +copies or substantial portions of the Software. | ||
16 | + | ||
17 | +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
18 | +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
19 | +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
20 | +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
21 | +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
22 | +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
23 | +SOFTWARE. | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/pytorch_vdsr/README.md
0 → 100644
1 | +# PyTorch VDSR | ||
2 | +Implementation of CVPR2016 Paper: "Accurate Image Super-Resolution Using | ||
3 | +Very Deep Convolutional Networks"(http://cv.snu.ac.kr/research/VDSR/) in PyTorch | ||
4 | + | ||
5 | +## Usage | ||
6 | +### Training | ||
7 | +``` | ||
8 | +usage: main_vdsr.py [-h] [--batchSize BATCHSIZE] [--nEpochs NEPOCHS] [--lr LR] | ||
9 | + [--step STEP] [--cuda] [--resume RESUME] | ||
10 | + [--start-epoch START_EPOCH] [--clip CLIP] [--threads THREADS] | ||
11 | + [--momentum MOMENTUM] [--weight-decay WEIGHT_DECAY] | ||
12 | + [--pretrained PRETRAINED] [--gpus GPUS] | ||
13 | + | ||
14 | +optional arguments: | ||
15 | + -h, --help Show this help message and exit | ||
16 | + --batchSize Training batch size | ||
17 | + --nEpochs Number of epochs to train for | ||
18 | + --lr Learning rate. Default=0.01 | ||
19 | + --step Learning rate decay, Default: n=10 epochs | ||
20 | + --cuda Use cuda | ||
21 | + --resume Path to checkpoint | ||
22 | + --clip Clipping Gradients. Default=0.4 | ||
23 | + --threads Number of threads for data loader to use Default=1 | ||
24 | + --momentum Momentum, Default: 0.9 | ||
25 | + --weight-decay Weight decay, Default: 1e-4 | ||
26 | + --pretrained PRETRAINED | ||
27 | + path to pretrained model (default: none) | ||
28 | + --gpus GPUS gpu ids (default: 0) | ||
29 | +``` | ||
30 | +An example of training usage is shown as follows: | ||
31 | +``` | ||
32 | +python main_vdsr.py --cuda --gpus 0 | ||
33 | +``` | ||
34 | + | ||
35 | +### Evaluation | ||
36 | +``` | ||
37 | +usage: eval.py [-h] [--cuda] [--model MODEL] [--dataset DATASET] | ||
38 | + [--scale SCALE] [--gpus GPUS] | ||
39 | + | ||
40 | +PyTorch VDSR Eval | ||
41 | + | ||
42 | +optional arguments: | ||
43 | + -h, --help show this help message and exit | ||
44 | + --cuda use cuda? | ||
45 | + --model MODEL model path | ||
46 | + --dataset DATASET dataset name, Default: Set5 | ||
47 | + --gpus GPUS gpu ids (default: 0) | ||
48 | +``` | ||
49 | +An example of training usage is shown as follows: | ||
50 | +``` | ||
51 | +python eval.py --cuda --dataset Set5 | ||
52 | +``` | ||
53 | + | ||
54 | +### Demo | ||
55 | +``` | ||
56 | +usage: demo.py [-h] [--cuda] [--model MODEL] [--image IMAGE] [--scale SCALE] [--gpus GPUS] | ||
57 | + | ||
58 | +optional arguments: | ||
59 | + -h, --help Show this help message and exit | ||
60 | + --cuda Use cuda | ||
61 | + --model Model path. Default=model/model_epoch_50.pth | ||
62 | + --image Image name. Default=butterfly_GT | ||
63 | + --scale Scale factor, Default: 4 | ||
64 | + --gpus GPUS gpu ids (default: 0) | ||
65 | +``` | ||
66 | +An example of usage is shown as follows: | ||
67 | +``` | ||
68 | +python eval.py --model model/model_epoch_50.pth --dataset Set5 --cuda | ||
69 | +``` | ||
70 | + | ||
71 | +### Prepare Training dataset | ||
72 | + - We provide a simple hdf5 format training sample in data folder with 'data' and 'label' keys, the training data is generated with Matlab Bicubic Interplotation, please refer [Code for Data Generation](https://github.com/twtygqyy/pytorch-vdsr/tree/master/data) for creating training files. | ||
73 | + | ||
74 | +### Performance | ||
75 | + - We provide a pretrained VDSR model trained on [291](https://drive.google.com/open?id=1Rt3asDLuMgLuJvPA1YrhyjWhb97Ly742) images with data augmentation | ||
76 | + - No bias is used in this implementation, and the gradient clipping's implementation is different from paper | ||
77 | + - Performance in PSNR on Set5 | ||
78 | + | ||
79 | +| Scale | VDSR Paper | VDSR PyTorch| | ||
80 | +| ------------- |:-------------:| -----:| | ||
81 | +| 2x | 37.53 | 37.65 | | ||
82 | +| 3x | 33.66 | 33.77| | ||
83 | +| 4x | 31.35 | 31.45 | | ||
84 | + | ||
85 | +### Result | ||
86 | +From left to right are ground truth, bicubic and vdsr | ||
87 | +<p> | ||
88 | + <img src='Set5/butterfly_GT.bmp' height='200' width='200'/> | ||
89 | + <img src='result/input.bmp' height='200' width='200'/> | ||
90 | + <img src='result/output.bmp' height='200' width='200'/> | ||
91 | +</p> |
code/pytorch_vdsr/data.py
0 → 100644
1 | +from torch.utils.data import Dataset | ||
2 | +from PIL import Image | ||
3 | +import os | ||
4 | +from glob import glob | ||
5 | +from torchvision import transforms | ||
6 | +from torch.utils.data.dataset import Dataset | ||
7 | +from torchvision import transforms | ||
8 | +import torch | ||
9 | +import pdb | ||
10 | +import math | ||
11 | +import numpy as np | ||
12 | +class FeatureDataset(Dataset): | ||
13 | + def __init__(self, data_path, datatype, rescale_factor,valid): | ||
14 | + self.data_path = data_path | ||
15 | + self.datatype = datatype | ||
16 | + self.rescale_factor = rescale_factor | ||
17 | + if not os.path.exists(data_path): | ||
18 | + raise Exception(f"[!] {self.data_path} not existed") | ||
19 | + if(valid): | ||
20 | + self.hr_path = os.path.join(self.data_path,'valid') | ||
21 | + self.hr_path = os.path.join(self.hr_path,self.datatype) | ||
22 | + else: | ||
23 | + self.hr_path = os.path.join(self.data_path,'LR_2') | ||
24 | + self.hr_path = os.path.join(self.hr_path,self.datatype) | ||
25 | + print(self.hr_path) | ||
26 | + self.hr_path = sorted(glob(os.path.join(self.hr_path, "*.*"))) | ||
27 | + self.hr_imgs = [] | ||
28 | + w,h = Image.open(self.hr_path[0]).size | ||
29 | + self.width = int(w/16) | ||
30 | + self.height = int(h/16) | ||
31 | + self.lwidth = int(self.width/self.rescale_factor) | ||
32 | + self.lheight = int(self.height/self.rescale_factor) | ||
33 | + print("lr: ({} {}), hr: ({} {})".format(self.lwidth,self.lheight,self.width,self.height)) | ||
34 | + for hr in self.hr_path: | ||
35 | + hr_image = Image.open(hr)#.convert('RGB')\ | ||
36 | + for i in range(16): | ||
37 | + for j in range(16): | ||
38 | + (left,upper,right,lower) = (i*self.width,j*self.height,(i+1)*self.width,(j+1)*self.height) | ||
39 | + crop = hr_image.crop((left,upper,right,lower)) | ||
40 | + self.hr_imgs.append(crop) | ||
41 | + | ||
42 | + def __getitem__(self, idx): | ||
43 | + hr_image = self.hr_imgs[idx] | ||
44 | + transform = transforms.Compose([ | ||
45 | + transforms.Resize((self.lheight,self.lwidth),3), | ||
46 | + transforms.ToTensor() | ||
47 | + ]) | ||
48 | + return transform(hr_image), transforms.ToTensor()(hr_image) | ||
49 | + | ||
50 | + def __len__(self): | ||
51 | + return len(self.hr_path*16*16) | ||
52 | + | ||
53 | +def get_data_loader(data_path, feature_type, rescale_factor, batch_size, num_workers): | ||
54 | + full_dataset = FeatureDataset(data_path,feature_type,rescale_factor,False) | ||
55 | + train_size = int(0.9 * len(full_dataset)) | ||
56 | + test_size = len(full_dataset) - train_size | ||
57 | + train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size]) | ||
58 | + torch.manual_seed(3334) | ||
59 | + train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True,num_workers=num_workers, pin_memory=False) | ||
60 | + test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False,num_workers=num_workers, pin_memory=True) | ||
61 | + | ||
62 | + return train_loader, test_loader | ||
63 | + | ||
64 | +def get_infer_dataloader(data_path, feature_type, rescale_factor, batch_size, num_workers): | ||
65 | + dataset = FeatureDataset(data_path,feature_type,rescale_factor,True) | ||
66 | + data_loader = torch.utils.data.DataLoader(dataset=dataset,batch_size=batch_size,shuffle=False,num_workers=num_workers,pin_memory=False) | ||
67 | + return data_loader | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/pytorch_vdsr/dataset.py
0 → 100644
1 | +import torch.utils.data as data | ||
2 | +import torch | ||
3 | +import h5py | ||
4 | + | ||
5 | +class DatasetFromHdf5(data.Dataset): | ||
6 | + def __init__(self, file_path): | ||
7 | + super(DatasetFromHdf5, self).__init__() | ||
8 | + hf = h5py.File(file_path) | ||
9 | + self.data = hf.get('data') | ||
10 | + self.target = hf.get('label') | ||
11 | + | ||
12 | + def __getitem__(self, index): | ||
13 | + return torch.from_numpy(self.data[index,:,:,:]).float(), torch.from_numpy(self.target[index,:,:,:]).float() | ||
14 | + | ||
15 | + def __len__(self): | ||
16 | + return self.data.shape[0] | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/pytorch_vdsr/demo.py
0 → 100644
1 | +import argparse, os | ||
2 | +import torch | ||
3 | +from torch.autograd import Variable | ||
4 | +from scipy.ndimage import imread | ||
5 | +from PIL import Image | ||
6 | +import numpy as np | ||
7 | +import time, math | ||
8 | +import matplotlib.pyplot as plt | ||
9 | + | ||
10 | +parser = argparse.ArgumentParser(description="PyTorch VDSR Demo") | ||
11 | +parser.add_argument("--cuda", action="store_true", help="use cuda?") | ||
12 | +parser.add_argument("--model", default="model/model_epoch_50.pth", type=str, help="model path") | ||
13 | +parser.add_argument("--image", default="butterfly_GT", type=str, help="image name") | ||
14 | +parser.add_argument("--scale", default=4, type=int, help="scale factor, Default: 4") | ||
15 | +parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)") | ||
16 | + | ||
17 | +def PSNR(pred, gt, shave_border=0): | ||
18 | + height, width = pred.shape[:2] | ||
19 | + pred = pred[shave_border:height - shave_border, shave_border:width - shave_border] | ||
20 | + gt = gt[shave_border:height - shave_border, shave_border:width - shave_border] | ||
21 | + imdff = pred - gt | ||
22 | + rmse = math.sqrt(np.mean(imdff ** 2)) | ||
23 | + if rmse == 0: | ||
24 | + return 100 | ||
25 | + return 20 * math.log10(255.0 / rmse) | ||
26 | + | ||
27 | +def colorize(y, ycbcr): | ||
28 | + img = np.zeros((y.shape[0], y.shape[1], 3), np.uint8) | ||
29 | + img[:,:,0] = y | ||
30 | + img[:,:,1] = ycbcr[:,:,1] | ||
31 | + img[:,:,2] = ycbcr[:,:,2] | ||
32 | + img = Image.fromarray(img, "YCbCr").convert("RGB") | ||
33 | + return img | ||
34 | + | ||
35 | +opt = parser.parse_args() | ||
36 | +cuda = opt.cuda | ||
37 | + | ||
38 | +if cuda: | ||
39 | + print("=> use gpu id: '{}'".format(opt.gpus)) | ||
40 | + os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus | ||
41 | + if not torch.cuda.is_available(): | ||
42 | + raise Exception("No GPU found or Wrong gpu id, please run without --cuda") | ||
43 | + | ||
44 | + | ||
45 | +model = torch.load(opt.model, map_location=lambda storage, loc: storage)["model"] | ||
46 | + | ||
47 | +im_gt_ycbcr = imread("Set5/" + opt.image + ".bmp", mode="YCbCr") | ||
48 | +im_b_ycbcr = imread("Set5/"+ opt.image + "_scale_"+ str(opt.scale) + ".bmp", mode="YCbCr") | ||
49 | + | ||
50 | +im_gt_y = im_gt_ycbcr[:,:,0].astype(float) | ||
51 | +im_b_y = im_b_ycbcr[:,:,0].astype(float) | ||
52 | + | ||
53 | +psnr_bicubic = PSNR(im_gt_y, im_b_y,shave_border=opt.scale) | ||
54 | + | ||
55 | +im_input = im_b_y/255. | ||
56 | + | ||
57 | +im_input = Variable(torch.from_numpy(im_input).float()).view(1, -1, im_input.shape[0], im_input.shape[1]) | ||
58 | + | ||
59 | +if cuda: | ||
60 | + model = model.cuda() | ||
61 | + im_input = im_input.cuda() | ||
62 | +else: | ||
63 | + model = model.cpu() | ||
64 | + | ||
65 | +start_time = time.time() | ||
66 | +out = model(im_input) | ||
67 | +elapsed_time = time.time() - start_time | ||
68 | + | ||
69 | +out = out.cpu() | ||
70 | + | ||
71 | +im_h_y = out.data[0].numpy().astype(np.float32) | ||
72 | + | ||
73 | +im_h_y = im_h_y * 255. | ||
74 | +im_h_y[im_h_y < 0] = 0 | ||
75 | +im_h_y[im_h_y > 255.] = 255. | ||
76 | + | ||
77 | +psnr_predicted = PSNR(im_gt_y, im_h_y[0,:,:], shave_border=opt.scale) | ||
78 | + | ||
79 | +im_h = colorize(im_h_y[0,:,:], im_b_ycbcr) | ||
80 | +im_gt = Image.fromarray(im_gt_ycbcr, "YCbCr").convert("RGB") | ||
81 | +im_b = Image.fromarray(im_b_ycbcr, "YCbCr").convert("RGB") | ||
82 | + | ||
83 | +print("Scale=",opt.scale) | ||
84 | +print("PSNR_predicted=", psnr_predicted) | ||
85 | +print("PSNR_bicubic=", psnr_bicubic) | ||
86 | +print("It takes {}s for processing".format(elapsed_time)) | ||
87 | + | ||
88 | +fig = plt.figure() | ||
89 | +ax = plt.subplot("131") | ||
90 | +ax.imshow(im_gt) | ||
91 | +ax.set_title("GT") | ||
92 | + | ||
93 | +ax = plt.subplot("132") | ||
94 | +ax.imshow(im_b) | ||
95 | +ax.set_title("Input(bicubic)") | ||
96 | + | ||
97 | +ax = plt.subplot("133") | ||
98 | +ax.imshow(im_h) | ||
99 | +ax.set_title("Output(vdsr)") | ||
100 | +plt.show() |
code/pytorch_vdsr/eval.py
0 → 100644
1 | +import argparse, os | ||
2 | +import torch | ||
3 | +from torch.autograd import Variable | ||
4 | +import numpy as np | ||
5 | +import time, math, glob | ||
6 | +import scipy.io as sio | ||
7 | + | ||
8 | +parser = argparse.ArgumentParser(description="PyTorch VDSR Eval") | ||
9 | +parser.add_argument("--cuda", action="store_true", help="use cuda?") | ||
10 | +parser.add_argument("--model", default="model/model_epoch_50.pth", type=str, help="model path") | ||
11 | +parser.add_argument("--dataset", default="Set5", type=str, help="dataset name, Default: Set5") | ||
12 | +parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)") | ||
13 | + | ||
14 | +def PSNR(pred, gt, shave_border=0): | ||
15 | + height, width = pred.shape[:2] | ||
16 | + pred = pred[shave_border:height - shave_border, shave_border:width - shave_border] | ||
17 | + gt = gt[shave_border:height - shave_border, shave_border:width - shave_border] | ||
18 | + imdff = pred - gt | ||
19 | + rmse = math.sqrt(np.mean(imdff ** 2)) | ||
20 | + if rmse == 0: | ||
21 | + return 100 | ||
22 | + return 20 * math.log10(255.0 / rmse) | ||
23 | + | ||
24 | +opt = parser.parse_args() | ||
25 | +cuda = opt.cuda | ||
26 | + | ||
27 | +if cuda: | ||
28 | + print("=> use gpu id: '{}'".format(opt.gpus)) | ||
29 | + os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus | ||
30 | + if not torch.cuda.is_available(): | ||
31 | + raise Exception("No GPU found or Wrong gpu id, please run without --cuda") | ||
32 | + | ||
33 | +model = torch.load(opt.model, map_location=lambda storage, loc: storage)["model"] | ||
34 | + | ||
35 | +scales = [2,3,4] | ||
36 | + | ||
37 | +image_list = glob.glob(opt.dataset+"_mat/*.*") | ||
38 | + | ||
39 | +for scale in scales: | ||
40 | + avg_psnr_predicted = 0.0 | ||
41 | + avg_psnr_bicubic = 0.0 | ||
42 | + avg_elapsed_time = 0.0 | ||
43 | + count = 0.0 | ||
44 | + for image_name in image_list: | ||
45 | + if str(scale) in image_name: | ||
46 | + count += 1 | ||
47 | + print("Processing ", image_name) | ||
48 | + im_gt_y = sio.loadmat(image_name)['im_gt_y'] | ||
49 | + im_b_y = sio.loadmat(image_name)['im_b_y'] | ||
50 | + | ||
51 | + im_gt_y = im_gt_y.astype(float) | ||
52 | + im_b_y = im_b_y.astype(float) | ||
53 | + | ||
54 | + psnr_bicubic = PSNR(im_gt_y, im_b_y,shave_border=scale) | ||
55 | + avg_psnr_bicubic += psnr_bicubic | ||
56 | + | ||
57 | + im_input = im_b_y/255. | ||
58 | + | ||
59 | + im_input = Variable(torch.from_numpy(im_input).float()).view(1, -1, im_input.shape[0], im_input.shape[1]) | ||
60 | + | ||
61 | + if cuda: | ||
62 | + model = model.cuda() | ||
63 | + im_input = im_input.cuda() | ||
64 | + else: | ||
65 | + model = model.cpu() | ||
66 | + | ||
67 | + start_time = time.time() | ||
68 | + HR = model(im_input) | ||
69 | + elapsed_time = time.time() - start_time | ||
70 | + avg_elapsed_time += elapsed_time | ||
71 | + | ||
72 | + HR = HR.cpu() | ||
73 | + | ||
74 | + im_h_y = HR.data[0].numpy().astype(np.float32) | ||
75 | + | ||
76 | + im_h_y = im_h_y * 255. | ||
77 | + im_h_y[im_h_y < 0] = 0 | ||
78 | + im_h_y[im_h_y > 255.] = 255. | ||
79 | + im_h_y = im_h_y[0,:,:] | ||
80 | + | ||
81 | + psnr_predicted = PSNR(im_gt_y, im_h_y,shave_border=scale) | ||
82 | + avg_psnr_predicted += psnr_predicted | ||
83 | + | ||
84 | + print("Scale=", scale) | ||
85 | + print("Dataset=", opt.dataset) | ||
86 | + print("PSNR_predicted=", avg_psnr_predicted/count) | ||
87 | + print("PSNR_bicubic=", avg_psnr_bicubic/count) | ||
88 | + print("It takes average {}s for processing".format(avg_elapsed_time/count)) |
code/pytorch_vdsr/main_vdsr.py
0 → 100644
1 | +import argparse, os | ||
2 | +import torch | ||
3 | +import random | ||
4 | +import torch.backends.cudnn as cudnn | ||
5 | +import torch.nn as nn | ||
6 | +import torch.optim as optim | ||
7 | +from torch.autograd import Variable | ||
8 | +from torch.utils.data import DataLoader | ||
9 | +from vdsr import Net | ||
10 | +from dataset import DatasetFromHdf5 | ||
11 | +## Custom | ||
12 | +from data import FeatureDataset | ||
13 | + | ||
14 | +# Training settings | ||
15 | +parser = argparse.ArgumentParser(description="PyTorch VDSR") | ||
16 | +parser.add_argument("--batchSize", type=int, default=128, help="Training batch size") | ||
17 | +parser.add_argument("--nEpochs", type=int, default=50, help="Number of epochs to train for") | ||
18 | +parser.add_argument("--lr", type=float, default=0.1, help="Learning Rate. Default=0.1") | ||
19 | +parser.add_argument("--step", type=int, default=10, help="Sets the learning rate to the initial LR decayed by momentum every n epochs, Default: n=10") | ||
20 | +parser.add_argument("--cuda", action="store_true", help="Use cuda?") | ||
21 | +parser.add_argument("--resume", default="", type=str, help="Path to checkpoint (default: none)") | ||
22 | +parser.add_argument("--start-epoch", default=1, type=int, help="Manual epoch number (useful on restarts)") | ||
23 | +parser.add_argument("--clip", type=float, default=0.4, help="Clipping Gradients. Default=0.4") | ||
24 | +# 1->3 custom | ||
25 | +parser.add_argument("--threads", type=int, default=3, help="Number of threads for data loader to use, Default: 1") | ||
26 | +parser.add_argument("--momentum", default=0.9, type=float, help="Momentum, Default: 0.9") | ||
27 | +parser.add_argument("--weight-decay", "--wd", default=1e-4, type=float, help="Weight decay, Default: 1e-4") | ||
28 | +parser.add_argument('--pretrained', default='', type=str, help='path to pretrained model (default: none)') | ||
29 | +parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)") | ||
30 | +## custom | ||
31 | +parser.add_argument("--dataPath", type=str) | ||
32 | +parser.add_argument("--featureType", type=str, default="p2") | ||
33 | +parser.add_argument("--scaleFactor",type=int, default=4) | ||
34 | + | ||
35 | +# parser.add_argument("--trainingData", type=DataLoader) | ||
36 | + | ||
37 | + | ||
38 | +def main(): | ||
39 | + global opt, model | ||
40 | + opt = parser.parse_args() | ||
41 | + print(opt) | ||
42 | + | ||
43 | + cuda = opt.cuda | ||
44 | + if cuda: | ||
45 | + print("=> use gpu id: '{}'".format(opt.gpus)) | ||
46 | + os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus | ||
47 | + if not torch.cuda.is_available(): | ||
48 | + raise Exception("No GPU found or Wrong gpu id, please run without --cuda") | ||
49 | + | ||
50 | + opt.seed = random.randint(1, 10000) | ||
51 | + print("Random Seed: ", opt.seed) | ||
52 | + torch.manual_seed(opt.seed) | ||
53 | + if cuda: | ||
54 | + torch.cuda.manual_seed(opt.seed) | ||
55 | + | ||
56 | + cudnn.benchmark = True | ||
57 | + | ||
58 | + print("===> Loading datasets") | ||
59 | + | ||
60 | + if os.path.isfile('dataloader/training_data_loader.pth'): | ||
61 | + training_data_loader = torch.load('dataloader/training_data_loader.pth') | ||
62 | + else: | ||
63 | + train_set = FeatureDataset(opt.dataPath,opt.featureType,opt.scaleFactor,False) | ||
64 | + train_size = 100 #우선은 100개만 | ||
65 | + test_size = len(train_set) - train_size | ||
66 | + train_dataset, test_dataset = torch.utils.data.random_split(train_set, [train_size, test_size]) | ||
67 | + training_data_loader = DataLoader(dataset=train_dataset, num_workers=3, batch_size=8, shuffle=True, pin_memory=False) | ||
68 | + torch.save(training_data_loader, 'dataloader/training_data_loader.pth'.format(DataLoader)) | ||
69 | + | ||
70 | + print("===> Building model") | ||
71 | + model = Net(opt.scaleFactor) | ||
72 | + criterion = nn.MSELoss(size_average=False) | ||
73 | + | ||
74 | + print("===> Setting GPU") | ||
75 | + if cuda: | ||
76 | + model = model.cuda() | ||
77 | + criterion = criterion.cuda() | ||
78 | + | ||
79 | + # optionally resume from a checkpoint | ||
80 | + if opt.resume: | ||
81 | + if os.path.isfile(opt.resume): | ||
82 | + print("=> loading checkpoint '{}'".format(opt.resume)) | ||
83 | + checkpoint = torch.load(opt.resume) | ||
84 | + opt.start_epoch = checkpoint["epoch"] + 1 | ||
85 | + model.load_state_dict(checkpoint["model"].state_dict()) | ||
86 | + else: | ||
87 | + print("=> no checkpoint found at '{}'".format(opt.resume)) | ||
88 | + | ||
89 | + # optionally copy weights from a checkpoint | ||
90 | + if opt.pretrained: | ||
91 | + if os.path.isfile(opt.pretrained): | ||
92 | + print("=> loading model '{}'".format(opt.pretrained)) | ||
93 | + weights = torch.load(opt.pretrained) | ||
94 | + model.load_state_dict(weights['model'].state_dict()) | ||
95 | + else: | ||
96 | + print("=> no model found at '{}'".format(opt.pretrained)) | ||
97 | + | ||
98 | + print("===> Setting Optimizer") | ||
99 | + optimizer = optim.SGD(model.parameters(), lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay) | ||
100 | + | ||
101 | + print("===> Training") | ||
102 | + for epoch in range(opt.start_epoch, opt.nEpochs + 1): | ||
103 | + train(training_data_loader, optimizer, model, criterion, epoch) | ||
104 | + save_checkpoint(model, epoch) | ||
105 | + | ||
106 | +def adjust_learning_rate(optimizer, epoch): | ||
107 | + """Sets the learning rate to the initial LR decayed by 10 every 10 epochs""" | ||
108 | + lr = opt.lr * (0.1 ** (epoch // opt.step)) | ||
109 | + return lr | ||
110 | + | ||
111 | +def train(training_data_loader, optimizer, model, criterion, epoch): | ||
112 | + lr = adjust_learning_rate(optimizer, epoch-1) | ||
113 | + | ||
114 | + for param_group in optimizer.param_groups: | ||
115 | + param_group["lr"] = lr | ||
116 | + | ||
117 | + print("Epoch = {}, lr = {}".format(epoch, optimizer.param_groups[0]["lr"])) | ||
118 | + | ||
119 | + model.train() | ||
120 | + | ||
121 | + for iteration, batch in enumerate(training_data_loader, 1): | ||
122 | + input, target = Variable(batch[0]), Variable(batch[1], requires_grad=False) | ||
123 | + | ||
124 | + if opt.cuda: | ||
125 | + input = input.cuda() | ||
126 | + target = target.cuda() | ||
127 | + | ||
128 | + loss = criterion(model(input), target) | ||
129 | + optimizer.zero_grad() | ||
130 | + loss.backward() | ||
131 | + nn.utils.clip_grad_norm(model.parameters(),opt.clip) | ||
132 | + optimizer.step() | ||
133 | + | ||
134 | + if iteration%10 == 0: | ||
135 | + # loss.data[0] --> loss.data | ||
136 | + print("===> Epoch[{}]({}/{}): Loss: {:.10f}".format(epoch, iteration, len(training_data_loader), loss.data)) | ||
137 | + | ||
138 | +def save_checkpoint(model, epoch): | ||
139 | + model_out_path = "checkpoint/" + "model_epoch_{}.pth".format(epoch) | ||
140 | + state = {"epoch": epoch ,"model": model} | ||
141 | + if not os.path.exists("checkpoint/"): | ||
142 | + os.makedirs("checkpoint/") | ||
143 | + | ||
144 | + torch.save(state, model_out_path) | ||
145 | + | ||
146 | + print("Checkpoint saved to {}".format(model_out_path)) | ||
147 | + | ||
148 | +if __name__ == "__main__": | ||
149 | + main() | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/pytorch_vdsr/vdsr.py
0 → 100644
1 | +import torch | ||
2 | +import torch.nn as nn | ||
3 | +from math import sqrt | ||
4 | + | ||
5 | +class Conv_ReLU_Block(nn.Module): | ||
6 | + def __init__(self): | ||
7 | + super(Conv_ReLU_Block, self).__init__() | ||
8 | + self.conv = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) | ||
9 | + self.relu = nn.ReLU(inplace=True) | ||
10 | + | ||
11 | + def forward(self, x): | ||
12 | + return self.relu(self.conv(x)) | ||
13 | + | ||
14 | +class Net(nn.Module): | ||
15 | + def __init__(self,upscale_factor): | ||
16 | + super(Net, self).__init__() | ||
17 | + self.residual_layer = self.make_layer(Conv_ReLU_Block, 18) | ||
18 | + self.input = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) | ||
19 | + self.output = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False) | ||
20 | + self.upsample = nn.Upsample(scale_factor=upscale_factor, mode='bicubic') | ||
21 | + self.relu = nn.ReLU(inplace=True) | ||
22 | + | ||
23 | + for m in self.modules(): | ||
24 | + if isinstance(m, nn.Conv2d): | ||
25 | + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | ||
26 | + m.weight.data.normal_(0, sqrt(2. / n)) | ||
27 | + | ||
28 | + def make_layer(self, block, num_of_layer): | ||
29 | + layers = [] | ||
30 | + for _ in range(num_of_layer): | ||
31 | + layers.append(block()) | ||
32 | + return nn.Sequential(*layers) | ||
33 | + | ||
34 | + def forward(self, x): | ||
35 | + x = self.upsample(x) | ||
36 | + residual = x | ||
37 | + out = self.relu(self.input(x)) | ||
38 | + out = self.residual_layer(out) | ||
39 | + out = self.output(out) | ||
40 | + out = torch.add(out,residual) | ||
41 | + return out | ||
42 | + | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
면담확인서/캡스톤 디자인 2 면담확인서 10주차.docx
0 → 100644
No preview for this file type
-
Please register or login to post a comment