Showing
18 changed files
with
376 additions
and
383 deletions
code/pytorch_vdsr/LICENSE
deleted
100644 → 0
1 | -The MIT License (MIT) | ||
2 | - | ||
3 | -Copyright (c) 2017- Jiu XU | ||
4 | -Copyright (c) 2017- Rakuten, Inc | ||
5 | -Copyright (c) 2017- Rakuten Institute of Technology | ||
6 | - | ||
7 | -Permission is hereby granted, free of charge, to any person obtaining a copy | ||
8 | -of this software and associated documentation files (the "Software"), to deal | ||
9 | -in the Software without restriction, including without limitation the rights | ||
10 | -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
11 | -copies of the Software, and to permit persons to whom the Software is | ||
12 | -furnished to do so, subject to the following conditions: | ||
13 | - | ||
14 | -The above copyright notice and this permission notice shall be included in all | ||
15 | -copies or substantial portions of the Software. | ||
16 | - | ||
17 | -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
18 | -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
19 | -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
20 | -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
21 | -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
22 | -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
23 | -SOFTWARE. | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/pytorch_vdsr/README.md
deleted
100644 → 0
1 | -# PyTorch VDSR | ||
2 | -Implementation of CVPR2016 Paper: "Accurate Image Super-Resolution Using | ||
3 | -Very Deep Convolutional Networks"(http://cv.snu.ac.kr/research/VDSR/) in PyTorch | ||
4 | - | ||
5 | -## Usage | ||
6 | -### Training | ||
7 | -``` | ||
8 | -usage: main_vdsr.py [-h] [--batchSize BATCHSIZE] [--nEpochs NEPOCHS] [--lr LR] | ||
9 | - [--step STEP] [--cuda] [--resume RESUME] | ||
10 | - [--start-epoch START_EPOCH] [--clip CLIP] [--threads THREADS] | ||
11 | - [--momentum MOMENTUM] [--weight-decay WEIGHT_DECAY] | ||
12 | - [--pretrained PRETRAINED] [--gpus GPUS] | ||
13 | - | ||
14 | -optional arguments: | ||
15 | - -h, --help Show this help message and exit | ||
16 | - --batchSize Training batch size | ||
17 | - --nEpochs Number of epochs to train for | ||
18 | - --lr Learning rate. Default=0.01 | ||
19 | - --step Learning rate decay, Default: n=10 epochs | ||
20 | - --cuda Use cuda | ||
21 | - --resume Path to checkpoint | ||
22 | - --clip Clipping Gradients. Default=0.4 | ||
23 | - --threads Number of threads for data loader to use Default=1 | ||
24 | - --momentum Momentum, Default: 0.9 | ||
25 | - --weight-decay Weight decay, Default: 1e-4 | ||
26 | - --pretrained PRETRAINED | ||
27 | - path to pretrained model (default: none) | ||
28 | - --gpus GPUS gpu ids (default: 0) | ||
29 | -``` | ||
30 | -An example of training usage is shown as follows: | ||
31 | -``` | ||
32 | -python main_vdsr.py --cuda --gpus 0 | ||
33 | -``` | ||
34 | - | ||
35 | -### Evaluation | ||
36 | -``` | ||
37 | -usage: eval.py [-h] [--cuda] [--model MODEL] [--dataset DATASET] | ||
38 | - [--scale SCALE] [--gpus GPUS] | ||
39 | - | ||
40 | -PyTorch VDSR Eval | ||
41 | - | ||
42 | -optional arguments: | ||
43 | - -h, --help show this help message and exit | ||
44 | - --cuda use cuda? | ||
45 | - --model MODEL model path | ||
46 | - --dataset DATASET dataset name, Default: Set5 | ||
47 | - --gpus GPUS gpu ids (default: 0) | ||
48 | -``` | ||
49 | -An example of training usage is shown as follows: | ||
50 | -``` | ||
51 | -python eval.py --cuda --dataset Set5 | ||
52 | -``` | ||
53 | - | ||
54 | -### Demo | ||
55 | -``` | ||
56 | -usage: demo.py [-h] [--cuda] [--model MODEL] [--image IMAGE] [--scale SCALE] [--gpus GPUS] | ||
57 | - | ||
58 | -optional arguments: | ||
59 | - -h, --help Show this help message and exit | ||
60 | - --cuda Use cuda | ||
61 | - --model Model path. Default=model/model_epoch_50.pth | ||
62 | - --image Image name. Default=butterfly_GT | ||
63 | - --scale Scale factor, Default: 4 | ||
64 | - --gpus GPUS gpu ids (default: 0) | ||
65 | -``` | ||
66 | -An example of usage is shown as follows: | ||
67 | -``` | ||
68 | -python eval.py --model model/model_epoch_50.pth --dataset Set5 --cuda | ||
69 | -``` | ||
70 | - | ||
71 | -### Prepare Training dataset | ||
72 | - - We provide a simple hdf5 format training sample in data folder with 'data' and 'label' keys, the training data is generated with Matlab Bicubic Interplotation, please refer [Code for Data Generation](https://github.com/twtygqyy/pytorch-vdsr/tree/master/data) for creating training files. | ||
73 | - | ||
74 | -### Performance | ||
75 | - - We provide a pretrained VDSR model trained on [291](https://drive.google.com/open?id=1Rt3asDLuMgLuJvPA1YrhyjWhb97Ly742) images with data augmentation | ||
76 | - - No bias is used in this implementation, and the gradient clipping's implementation is different from paper | ||
77 | - - Performance in PSNR on Set5 | ||
78 | - | ||
79 | -| Scale | VDSR Paper | VDSR PyTorch| | ||
80 | -| ------------- |:-------------:| -----:| | ||
81 | -| 2x | 37.53 | 37.65 | | ||
82 | -| 3x | 33.66 | 33.77| | ||
83 | -| 4x | 31.35 | 31.45 | | ||
84 | - | ||
85 | -### Result | ||
86 | -From left to right are ground truth, bicubic and vdsr | ||
87 | -<p> | ||
88 | - <img src='Set5/butterfly_GT.bmp' height='200' width='200'/> | ||
89 | - <img src='result/input.bmp' height='200' width='200'/> | ||
90 | - <img src='result/output.bmp' height='200' width='200'/> | ||
91 | -</p> |
code/pytorch_vdsr/dataset.py
deleted
100644 → 0
1 | -import torch.utils.data as data | ||
2 | -import torch | ||
3 | -import h5py | ||
4 | - | ||
5 | -class DatasetFromHdf5(data.Dataset): | ||
6 | - def __init__(self, file_path): | ||
7 | - super(DatasetFromHdf5, self).__init__() | ||
8 | - hf = h5py.File(file_path) | ||
9 | - self.data = hf.get('data') | ||
10 | - self.target = hf.get('label') | ||
11 | - | ||
12 | - def __getitem__(self, index): | ||
13 | - return torch.from_numpy(self.data[index,:,:,:]).float(), torch.from_numpy(self.target[index,:,:,:]).float() | ||
14 | - | ||
15 | - def __len__(self): | ||
16 | - return self.data.shape[0] | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/pytorch_vdsr/demo.py
deleted
100644 → 0
1 | -import argparse, os | ||
2 | -import torch | ||
3 | -from torch.autograd import Variable | ||
4 | -from scipy.ndimage import imread | ||
5 | -from PIL import Image | ||
6 | -import numpy as np | ||
7 | -import time, math | ||
8 | -import matplotlib.pyplot as plt | ||
9 | - | ||
10 | -parser = argparse.ArgumentParser(description="PyTorch VDSR Demo") | ||
11 | -parser.add_argument("--cuda", action="store_true", help="use cuda?") | ||
12 | -parser.add_argument("--model", default="model/model_epoch_50.pth", type=str, help="model path") | ||
13 | -parser.add_argument("--image", default="butterfly_GT", type=str, help="image name") | ||
14 | -parser.add_argument("--scale", default=4, type=int, help="scale factor, Default: 4") | ||
15 | -parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)") | ||
16 | - | ||
17 | -def PSNR(pred, gt, shave_border=0): | ||
18 | - height, width = pred.shape[:2] | ||
19 | - pred = pred[shave_border:height - shave_border, shave_border:width - shave_border] | ||
20 | - gt = gt[shave_border:height - shave_border, shave_border:width - shave_border] | ||
21 | - imdff = pred - gt | ||
22 | - rmse = math.sqrt(np.mean(imdff ** 2)) | ||
23 | - if rmse == 0: | ||
24 | - return 100 | ||
25 | - return 20 * math.log10(255.0 / rmse) | ||
26 | - | ||
27 | -def colorize(y, ycbcr): | ||
28 | - img = np.zeros((y.shape[0], y.shape[1], 3), np.uint8) | ||
29 | - img[:,:,0] = y | ||
30 | - img[:,:,1] = ycbcr[:,:,1] | ||
31 | - img[:,:,2] = ycbcr[:,:,2] | ||
32 | - img = Image.fromarray(img, "YCbCr").convert("RGB") | ||
33 | - return img | ||
34 | - | ||
35 | -opt = parser.parse_args() | ||
36 | -cuda = opt.cuda | ||
37 | - | ||
38 | -if cuda: | ||
39 | - print("=> use gpu id: '{}'".format(opt.gpus)) | ||
40 | - os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus | ||
41 | - if not torch.cuda.is_available(): | ||
42 | - raise Exception("No GPU found or Wrong gpu id, please run without --cuda") | ||
43 | - | ||
44 | - | ||
45 | -model = torch.load(opt.model, map_location=lambda storage, loc: storage)["model"] | ||
46 | - | ||
47 | -im_gt_ycbcr = imread("Set5/" + opt.image + ".bmp", mode="YCbCr") | ||
48 | -im_b_ycbcr = imread("Set5/"+ opt.image + "_scale_"+ str(opt.scale) + ".bmp", mode="YCbCr") | ||
49 | - | ||
50 | -im_gt_y = im_gt_ycbcr[:,:,0].astype(float) | ||
51 | -im_b_y = im_b_ycbcr[:,:,0].astype(float) | ||
52 | - | ||
53 | -psnr_bicubic = PSNR(im_gt_y, im_b_y,shave_border=opt.scale) | ||
54 | - | ||
55 | -im_input = im_b_y/255. | ||
56 | - | ||
57 | -im_input = Variable(torch.from_numpy(im_input).float()).view(1, -1, im_input.shape[0], im_input.shape[1]) | ||
58 | - | ||
59 | -if cuda: | ||
60 | - model = model.cuda() | ||
61 | - im_input = im_input.cuda() | ||
62 | -else: | ||
63 | - model = model.cpu() | ||
64 | - | ||
65 | -start_time = time.time() | ||
66 | -out = model(im_input) | ||
67 | -elapsed_time = time.time() - start_time | ||
68 | - | ||
69 | -out = out.cpu() | ||
70 | - | ||
71 | -im_h_y = out.data[0].numpy().astype(np.float32) | ||
72 | - | ||
73 | -im_h_y = im_h_y * 255. | ||
74 | -im_h_y[im_h_y < 0] = 0 | ||
75 | -im_h_y[im_h_y > 255.] = 255. | ||
76 | - | ||
77 | -psnr_predicted = PSNR(im_gt_y, im_h_y[0,:,:], shave_border=opt.scale) | ||
78 | - | ||
79 | -im_h = colorize(im_h_y[0,:,:], im_b_ycbcr) | ||
80 | -im_gt = Image.fromarray(im_gt_ycbcr, "YCbCr").convert("RGB") | ||
81 | -im_b = Image.fromarray(im_b_ycbcr, "YCbCr").convert("RGB") | ||
82 | - | ||
83 | -print("Scale=",opt.scale) | ||
84 | -print("PSNR_predicted=", psnr_predicted) | ||
85 | -print("PSNR_bicubic=", psnr_bicubic) | ||
86 | -print("It takes {}s for processing".format(elapsed_time)) | ||
87 | - | ||
88 | -fig = plt.figure() | ||
89 | -ax = plt.subplot("131") | ||
90 | -ax.imshow(im_gt) | ||
91 | -ax.set_title("GT") | ||
92 | - | ||
93 | -ax = plt.subplot("132") | ||
94 | -ax.imshow(im_b) | ||
95 | -ax.set_title("Input(bicubic)") | ||
96 | - | ||
97 | -ax = plt.subplot("133") | ||
98 | -ax.imshow(im_h) | ||
99 | -ax.set_title("Output(vdsr)") | ||
100 | -plt.show() |
code/pytorch_vdsr/eval.py
deleted
100644 → 0
1 | -import argparse, os | ||
2 | -import torch | ||
3 | -from torch.autograd import Variable | ||
4 | -import numpy as np | ||
5 | -import time, math, glob | ||
6 | -import scipy.io as sio | ||
7 | - | ||
8 | -parser = argparse.ArgumentParser(description="PyTorch VDSR Eval") | ||
9 | -parser.add_argument("--cuda", action="store_true", help="use cuda?") | ||
10 | -parser.add_argument("--model", default="model/model_epoch_50.pth", type=str, help="model path") | ||
11 | -parser.add_argument("--dataset", default="Set5", type=str, help="dataset name, Default: Set5") | ||
12 | -parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)") | ||
13 | - | ||
14 | -def PSNR(pred, gt, shave_border=0): | ||
15 | - height, width = pred.shape[:2] | ||
16 | - pred = pred[shave_border:height - shave_border, shave_border:width - shave_border] | ||
17 | - gt = gt[shave_border:height - shave_border, shave_border:width - shave_border] | ||
18 | - imdff = pred - gt | ||
19 | - rmse = math.sqrt(np.mean(imdff ** 2)) | ||
20 | - if rmse == 0: | ||
21 | - return 100 | ||
22 | - return 20 * math.log10(255.0 / rmse) | ||
23 | - | ||
24 | -opt = parser.parse_args() | ||
25 | -cuda = opt.cuda | ||
26 | - | ||
27 | -if cuda: | ||
28 | - print("=> use gpu id: '{}'".format(opt.gpus)) | ||
29 | - os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus | ||
30 | - if not torch.cuda.is_available(): | ||
31 | - raise Exception("No GPU found or Wrong gpu id, please run without --cuda") | ||
32 | - | ||
33 | -model = torch.load(opt.model, map_location=lambda storage, loc: storage)["model"] | ||
34 | - | ||
35 | -scales = [2,3,4] | ||
36 | - | ||
37 | -image_list = glob.glob(opt.dataset+"_mat/*.*") | ||
38 | - | ||
39 | -for scale in scales: | ||
40 | - avg_psnr_predicted = 0.0 | ||
41 | - avg_psnr_bicubic = 0.0 | ||
42 | - avg_elapsed_time = 0.0 | ||
43 | - count = 0.0 | ||
44 | - for image_name in image_list: | ||
45 | - if str(scale) in image_name: | ||
46 | - count += 1 | ||
47 | - print("Processing ", image_name) | ||
48 | - im_gt_y = sio.loadmat(image_name)['im_gt_y'] | ||
49 | - im_b_y = sio.loadmat(image_name)['im_b_y'] | ||
50 | - | ||
51 | - im_gt_y = im_gt_y.astype(float) | ||
52 | - im_b_y = im_b_y.astype(float) | ||
53 | - | ||
54 | - psnr_bicubic = PSNR(im_gt_y, im_b_y,shave_border=scale) | ||
55 | - avg_psnr_bicubic += psnr_bicubic | ||
56 | - | ||
57 | - im_input = im_b_y/255. | ||
58 | - | ||
59 | - im_input = Variable(torch.from_numpy(im_input).float()).view(1, -1, im_input.shape[0], im_input.shape[1]) | ||
60 | - | ||
61 | - if cuda: | ||
62 | - model = model.cuda() | ||
63 | - im_input = im_input.cuda() | ||
64 | - else: | ||
65 | - model = model.cpu() | ||
66 | - | ||
67 | - start_time = time.time() | ||
68 | - HR = model(im_input) | ||
69 | - elapsed_time = time.time() - start_time | ||
70 | - avg_elapsed_time += elapsed_time | ||
71 | - | ||
72 | - HR = HR.cpu() | ||
73 | - | ||
74 | - im_h_y = HR.data[0].numpy().astype(np.float32) | ||
75 | - | ||
76 | - im_h_y = im_h_y * 255. | ||
77 | - im_h_y[im_h_y < 0] = 0 | ||
78 | - im_h_y[im_h_y > 255.] = 255. | ||
79 | - im_h_y = im_h_y[0,:,:] | ||
80 | - | ||
81 | - psnr_predicted = PSNR(im_gt_y, im_h_y,shave_border=scale) | ||
82 | - avg_psnr_predicted += psnr_predicted | ||
83 | - | ||
84 | - print("Scale=", scale) | ||
85 | - print("Dataset=", opt.dataset) | ||
86 | - print("PSNR_predicted=", avg_psnr_predicted/count) | ||
87 | - print("PSNR_bicubic=", avg_psnr_bicubic/count) | ||
88 | - print("It takes average {}s for processing".format(avg_elapsed_time/count)) |
code/vdsr/README.md
0 → 100644
1 | +# Feature SR | ||
2 | + | ||
3 | +1. train | ||
4 | + | ||
5 | +`!python main.py --dataRoot /content/drive/MyDrive/feature/HR_trainset/features --scaleFactor 4 --featureType p6 --batchSize 16 --cuda --nEpochs 20` | ||
6 | + | ||
7 | +2. inference | ||
8 | + | ||
9 | +`!python inference.py --cuda --model "model.pth" --dataset "/content/drive/MyDrive/feature/features/LR_2" --featureType "p3" --scaleFactor 4` | ||
10 | + | ||
11 | +3. calculate mAP | ||
12 | + | ||
13 | +``` | ||
14 | +# [1] | ||
15 | +# install dependencies: | ||
16 | +!pip install pyyaml==5.1 | ||
17 | +import torch, torchvision | ||
18 | +print(torch.__version__, torch.cuda.is_available()) | ||
19 | +!gcc --version | ||
20 | +# opencv is pre-installed on colab | ||
21 | + | ||
22 | +# [2] | ||
23 | +# install detectron2: (Colab has CUDA 10.1 + torch 1.8) | ||
24 | +# See https://detectron2.readthedocs.io/tutorials/install.html for instructions | ||
25 | +import torch | ||
26 | +assert torch.__version__.startswith("1.8") # need to manually install torch 1.8 if Colab changes its default version | ||
27 | +!pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu101/torch1.8/index.html | ||
28 | +# exit(0) # After installation, you need to "restart runtime" in Colab. This line can also restart runtime | ||
29 | + | ||
30 | +# [3] | ||
31 | +# Some basic setup: | ||
32 | +# Setup detectron2 logger | ||
33 | +import detectron2 | ||
34 | +from detectron2.utils.logger import setup_logger | ||
35 | +setup_logger() | ||
36 | + | ||
37 | +!python calculate_mAP.py --valid_data_path /content/drive/MyDrive/dataset/validset_100/ --model_name VDSR --loss_type MSE --batch_size 16 | ||
38 | +``` | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/vdsr/calculate_mAP.py
0 → 100644
This diff is collapsed. Click to expand it.
code/vdsr/crop_feature.py
0 → 100644
1 | +import os | ||
2 | +from PIL import Image | ||
3 | + | ||
4 | +def crop_feature(datapath, feature_type, scale_factor, print_message=False): | ||
5 | + data_path = datapath | ||
6 | + datatype = feature_type | ||
7 | + rescale_factor = scale_factor | ||
8 | + if not os.path.exists(data_path): | ||
9 | + raise Exception(f"[!] {data_path} not existed") | ||
10 | + | ||
11 | + hr_imgs = [] | ||
12 | + w, h = Image.open(datapath).size | ||
13 | + width = int(w / 16) | ||
14 | + height = int(h / 16) | ||
15 | + lwidth = int(width / rescale_factor) | ||
16 | + lheight = int(height / rescale_factor) | ||
17 | + if print_message: | ||
18 | + print("lr: ({} {}), hr: ({} {})".format(lwidth, lheight, width, height)) | ||
19 | + hr_image = Image.open(datapath) # .convert('RGB')\ | ||
20 | + for i in range(16): | ||
21 | + for j in range(16): | ||
22 | + (left, upper, right, lower) = ( | ||
23 | + i * width, j * height, (i + 1) * width, (j + 1) * height) | ||
24 | + crop = hr_image.crop((left, upper, right, lower)) | ||
25 | + crop = crop.resize((lwidth,lheight), Image.BICUBIC) | ||
26 | + crop = crop.resize((width, height), Image.BICUBIC) | ||
27 | + hr_imgs.append(crop) | ||
28 | + | ||
29 | + return hr_imgs | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
... | @@ -4,64 +4,87 @@ import os | ... | @@ -4,64 +4,87 @@ import os |
4 | from glob import glob | 4 | from glob import glob |
5 | from torchvision import transforms | 5 | from torchvision import transforms |
6 | from torch.utils.data.dataset import Dataset | 6 | from torch.utils.data.dataset import Dataset |
7 | -from torchvision import transforms | ||
8 | import torch | 7 | import torch |
9 | import pdb | 8 | import pdb |
10 | import math | 9 | import math |
11 | import numpy as np | 10 | import numpy as np |
11 | + | ||
12 | + | ||
12 | class FeatureDataset(Dataset): | 13 | class FeatureDataset(Dataset): |
13 | - def __init__(self, data_path, datatype, rescale_factor,valid): | 14 | + def __init__(self, data_path, datatype, rescale_factor, valid): |
14 | self.data_path = data_path | 15 | self.data_path = data_path |
15 | self.datatype = datatype | 16 | self.datatype = datatype |
16 | self.rescale_factor = rescale_factor | 17 | self.rescale_factor = rescale_factor |
17 | if not os.path.exists(data_path): | 18 | if not os.path.exists(data_path): |
18 | raise Exception(f"[!] {self.data_path} not existed") | 19 | raise Exception(f"[!] {self.data_path} not existed") |
19 | - if(valid): | 20 | + if (valid): |
20 | - self.hr_path = os.path.join(self.data_path,'valid') | 21 | + self.hr_path = os.path.join(self.data_path, 'valid') |
21 | - self.hr_path = os.path.join(self.hr_path,self.datatype) | 22 | + self.hr_path = os.path.join(self.hr_path, self.datatype) |
22 | else: | 23 | else: |
23 | - self.hr_path = os.path.join(self.data_path,'LR_2') | 24 | + self.hr_path = os.path.join(self.data_path, 'LR_2') |
24 | - self.hr_path = os.path.join(self.hr_path,self.datatype) | 25 | + self.hr_path = os.path.join(self.hr_path, self.datatype) |
25 | print(self.hr_path) | 26 | print(self.hr_path) |
26 | self.hr_path = sorted(glob(os.path.join(self.hr_path, "*.*"))) | 27 | self.hr_path = sorted(glob(os.path.join(self.hr_path, "*.*"))) |
27 | self.hr_imgs = [] | 28 | self.hr_imgs = [] |
28 | - w,h = Image.open(self.hr_path[0]).size | 29 | + w, h = Image.open(self.hr_path[0]).size |
29 | - self.width = int(w/16) | 30 | + self.width = int(w / 16) |
30 | - self.height = int(h/16) | 31 | + self.height = int(h / 16) |
31 | - self.lwidth = int(self.width/self.rescale_factor) | 32 | + self.lwidth = int(self.width / self.rescale_factor) # rescale_factor만큼 크기를 줄인다. |
32 | - self.lheight = int(self.height/self.rescale_factor) | 33 | + self.lheight = int(self.height / self.rescale_factor) |
33 | - print("lr: ({} {}), hr: ({} {})".format(self.lwidth,self.lheight,self.width,self.height)) | 34 | + print("lr: ({} {}), hr: ({} {})".format(self.lwidth, self.lheight, self.width, self.height)) |
34 | - for hr in self.hr_path: | 35 | + for hr in self.hr_path: # 256개의 피쳐로 나눈다. |
35 | - hr_image = Image.open(hr)#.convert('RGB')\ | 36 | + hr_image = Image.open(hr) # .convert('RGB')\ |
36 | for i in range(16): | 37 | for i in range(16): |
37 | for j in range(16): | 38 | for j in range(16): |
38 | - (left,upper,right,lower) = (i*self.width,j*self.height,(i+1)*self.width,(j+1)*self.height) | 39 | + (left, upper, right, lower) = ( |
39 | - crop = hr_image.crop((left,upper,right,lower)) | 40 | + i * self.width, j * self.height, (i + 1) * self.width, (j + 1) * self.height) |
41 | + crop = hr_image.crop((left, upper, right, lower)) | ||
40 | self.hr_imgs.append(crop) | 42 | self.hr_imgs.append(crop) |
41 | 43 | ||
42 | def __getitem__(self, idx): | 44 | def __getitem__(self, idx): |
43 | hr_image = self.hr_imgs[idx] | 45 | hr_image = self.hr_imgs[idx] |
44 | transform = transforms.Compose([ | 46 | transform = transforms.Compose([ |
45 | - transforms.Resize((self.lheight,self.lwidth),3), | 47 | + transforms.Resize((self.lheight, self.lwidth), Image.BICUBIC), |
48 | + transforms.Resize((self.height, self.width), Image.BICUBIC), | ||
46 | transforms.ToTensor() | 49 | transforms.ToTensor() |
47 | ]) | 50 | ]) |
48 | - return transform(hr_image), transforms.ToTensor()(hr_image) | 51 | + return transform(hr_image), transforms.ToTensor()(hr_image) # hr_image를 변환한 것과, 변환하지 않은 것을 Tensor로 각각 반환 |
49 | 52 | ||
50 | def __len__(self): | 53 | def __len__(self): |
51 | - return len(self.hr_path*16*16) | 54 | + return len(self.hr_path * 16 * 16) |
55 | + | ||
56 | + | ||
57 | +def get_data_loader_test_version(data_path, feature_type, rescale_factor, batch_size, num_workers): | ||
58 | + full_dataset = FeatureDataset(data_path, feature_type, rescale_factor, False) | ||
59 | + print("dataset의 사이즈는 {}".format(len(full_dataset))) | ||
60 | + for f in full_dataset: | ||
61 | + print(type(f)) | ||
62 | + | ||
52 | 63 | ||
53 | def get_data_loader(data_path, feature_type, rescale_factor, batch_size, num_workers): | 64 | def get_data_loader(data_path, feature_type, rescale_factor, batch_size, num_workers): |
54 | - full_dataset = FeatureDataset(data_path,feature_type,rescale_factor,False) | 65 | + full_dataset = FeatureDataset(data_path, feature_type, rescale_factor, False) |
55 | train_size = int(0.9 * len(full_dataset)) | 66 | train_size = int(0.9 * len(full_dataset)) |
56 | test_size = len(full_dataset) - train_size | 67 | test_size = len(full_dataset) - train_size |
57 | train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size]) | 68 | train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size]) |
58 | torch.manual_seed(3334) | 69 | torch.manual_seed(3334) |
59 | - train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True,num_workers=num_workers, pin_memory=False) | 70 | + train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, |
60 | - test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False,num_workers=num_workers, pin_memory=True) | 71 | + num_workers=num_workers, pin_memory=False) |
72 | + test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, | ||
73 | + num_workers=num_workers, pin_memory=True) | ||
61 | 74 | ||
62 | return train_loader, test_loader | 75 | return train_loader, test_loader |
63 | 76 | ||
77 | + | ||
78 | +def get_training_data_loader(data_path, feature_type, rescale_factor, batch_size, num_workers): | ||
79 | + full_dataset = FeatureDataset(data_path, feature_type, rescale_factor, False) | ||
80 | + torch.manual_seed(3334) | ||
81 | + train_loader = torch.utils.data.DataLoader(dataset=full_dataset, batch_size=batch_size, shuffle=True, | ||
82 | + num_workers=num_workers, pin_memory=False) | ||
83 | + return train_loader | ||
84 | + | ||
85 | + | ||
64 | def get_infer_dataloader(data_path, feature_type, rescale_factor, batch_size, num_workers): | 86 | def get_infer_dataloader(data_path, feature_type, rescale_factor, batch_size, num_workers): |
65 | - dataset = FeatureDataset(data_path,feature_type,rescale_factor,True) | 87 | + dataset = FeatureDataset(data_path, feature_type, rescale_factor, True) |
66 | - data_loader = torch.utils.data.DataLoader(dataset=dataset,batch_size=batch_size,shuffle=False,num_workers=num_workers,pin_memory=False) | 88 | + data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, |
89 | + num_workers=num_workers, pin_memory=False) | ||
67 | return data_loader | 90 | return data_loader |
... | \ No newline at end of file | ... | \ No newline at end of file | ... | ... |
code/vdsr/inference.py
0 → 100644
1 | +import argparse, os | ||
2 | +import torch | ||
3 | +from torch.autograd import Variable | ||
4 | +import numpy as np | ||
5 | +import time, math, glob | ||
6 | +import scipy.io as sio | ||
7 | +from crop_feature import crop_feature | ||
8 | +from PIL import Image | ||
9 | +import cv2 | ||
10 | +from matplotlib import pyplot as plt | ||
11 | +from math import log10, sqrt | ||
12 | + | ||
13 | +parser = argparse.ArgumentParser(description="PyTorch VDSR Eval") | ||
14 | +parser.add_argument("--cuda", action="store_true", help="use cuda?") | ||
15 | +parser.add_argument("--model", default="model/model_epoch_50.pth", type=str, help="model path") | ||
16 | +parser.add_argument("--dataset", default="Set5", type=str, help="dataset name, Default: Set5") | ||
17 | +parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)") | ||
18 | +parser.add_argument("--featureType", default="p3", type=str) | ||
19 | +parser.add_argument("--scaleFactor", default=4, type=int, help="scale factor") | ||
20 | +parser.add_argument("--singleImage", type=str, default="N", help="if it is a single image, enter \"y\"") | ||
21 | + | ||
22 | + | ||
23 | +def PSNR(pred, gt, shave_border=0): | ||
24 | + height, width = pred.shape[:2] | ||
25 | + pred = pred[shave_border:height - shave_border, shave_border:width - shave_border] | ||
26 | + gt = gt[shave_border:height - shave_border, shave_border:width - shave_border] | ||
27 | + imdff = pred - gt | ||
28 | + rmse = math.sqrt(np.mean(imdff ** 2)) | ||
29 | + if rmse == 0: | ||
30 | + return 100 | ||
31 | + return 20 * math.log10(255.0 / rmse) | ||
32 | + | ||
33 | +def concatFeatures(features, image_name, bicubic=False): | ||
34 | + features_0 = features[:16] | ||
35 | + features_1 = features[16:32] | ||
36 | + features_2 = features[32:48] | ||
37 | + features_3 = features[48:64] | ||
38 | + features_4 = features[64:80] | ||
39 | + features_5 = features[80:96] | ||
40 | + features_6 = features[96:112] | ||
41 | + features_7 = features[112:128] | ||
42 | + features_8 = features[128:144] | ||
43 | + features_9 = features[144:160] | ||
44 | + features_10 = features[160:176] | ||
45 | + features_11 = features[176:192] | ||
46 | + features_12 = features[192:208] | ||
47 | + features_13 = features[208:224] | ||
48 | + features_14 = features[224:240] | ||
49 | + features_15 = features[240:256] | ||
50 | + | ||
51 | + features_new = list() | ||
52 | + features_new.extend([ | ||
53 | + concat_vertical(features_0), | ||
54 | + concat_vertical(features_1), | ||
55 | + concat_vertical(features_2), | ||
56 | + concat_vertical(features_3), | ||
57 | + concat_vertical(features_4), | ||
58 | + concat_vertical(features_5), | ||
59 | + concat_vertical(features_6), | ||
60 | + concat_vertical(features_7), | ||
61 | + concat_vertical(features_8), | ||
62 | + concat_vertical(features_9), | ||
63 | + concat_vertical(features_10), | ||
64 | + concat_vertical(features_11), | ||
65 | + concat_vertical(features_12), | ||
66 | + concat_vertical(features_13), | ||
67 | + concat_vertical(features_14), | ||
68 | + concat_vertical(features_15) | ||
69 | + ]) | ||
70 | + | ||
71 | + final_concat_feature = concat_horizontal(features_new) | ||
72 | + | ||
73 | + if bicubic: | ||
74 | + save_path = "features/LR_2/LR/" + opt.featureType + "/" + image_name | ||
75 | + if not os.path.exists("features/"): | ||
76 | + os.makedirs("features/") | ||
77 | + if not os.path.exists("features/LR_2/"): | ||
78 | + os.makedirs("features/LR_2/") | ||
79 | + if not os.path.exists("features/LR_2/LR/"): | ||
80 | + os.makedirs("features/LR_2/LR/") | ||
81 | + if not os.path.exists("features/LR_2/LR/" + opt.featureType): | ||
82 | + os.makedirs("features/LR_2/LR/" + opt.featureType) | ||
83 | + cv2.imwrite(save_path, final_concat_feature) | ||
84 | + else: | ||
85 | + save_path = "features/LR_2/" + opt.featureType + "/" + image_name | ||
86 | + if not os.path.exists("features/"): | ||
87 | + os.makedirs("features/") | ||
88 | + if not os.path.exists("features/LR_2/"): | ||
89 | + os.makedirs("features/LR_2/") | ||
90 | + if not os.path.exists("features/LR_2/" + opt.featureType): | ||
91 | + os.makedirs("features/LR_2/" + opt.featureType) | ||
92 | + cv2.imwrite(save_path, final_concat_feature) | ||
93 | + | ||
94 | +def concat_horizontal(feature): | ||
95 | + result = cv2.hconcat([feature[0], feature[1]]) | ||
96 | + for i in range(2, len(feature)): | ||
97 | + result = cv2.hconcat([result, feature[i]]) | ||
98 | + return result | ||
99 | + | ||
100 | +def concat_vertical(feature): | ||
101 | + result = cv2.vconcat([feature[0], feature[1]]) | ||
102 | + for i in range(2, len(feature)): | ||
103 | + result = cv2.vconcat([result, feature[i]]) | ||
104 | + return result | ||
105 | + | ||
106 | + | ||
107 | +opt = parser.parse_args() | ||
108 | +cuda = opt.cuda | ||
109 | + | ||
110 | +if cuda: | ||
111 | + print("=> use gpu id: '{}'".format(opt.gpus)) | ||
112 | + os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus | ||
113 | + if not torch.cuda.is_available(): | ||
114 | + raise Exception("No GPU found or Wrong gpu id, please run without --cuda") | ||
115 | + | ||
116 | +model = torch.load(opt.model, map_location=lambda storage, loc: storage)["model"] | ||
117 | + | ||
118 | +scales = [opt.scaleFactor] | ||
119 | + | ||
120 | +# image_list = glob.glob(opt.dataset+"/*.*") | ||
121 | +if opt.singleImage == "Y" : | ||
122 | + # image_list = crop_feature(opt.dataset, opt.featureType, opt.scaleFactor) | ||
123 | + image_list = opt.dataset | ||
124 | +else: | ||
125 | + image_path = os.path.join(opt.dataset, opt.featureType) | ||
126 | + image_list = os.listdir(image_path) | ||
127 | + print(image_path) | ||
128 | + print(image_list) | ||
129 | + | ||
130 | + | ||
131 | +for scale in scales: | ||
132 | + for image in image_list: | ||
133 | + avg_psnr_predicted = 0.0 | ||
134 | + avg_psnr_bicubic = 0.0 | ||
135 | + avg_elapsed_time = 0.0 | ||
136 | + count = 0.0 | ||
137 | + image_name_cropped = crop_feature(os.path.join(image_path, image), opt.featureType, opt.scaleFactor) | ||
138 | + features = [] | ||
139 | + features_bicubic = [] | ||
140 | + for image_name in image_name_cropped: | ||
141 | + count += 1 | ||
142 | + f_gt = image_name | ||
143 | + w, h = image_name.size | ||
144 | + f_bi = image_name.resize((w//scale,h//scale), Image.BICUBIC) | ||
145 | + f_bi = f_bi.resize((w,h), Image.BICUBIC) | ||
146 | + | ||
147 | + f_gt = np.array(f_gt) | ||
148 | + f_bi = np.array(f_bi) | ||
149 | + f_gt = f_gt.astype(float) | ||
150 | + f_bi = f_bi.astype(float) | ||
151 | + features_bicubic.append(f_bi) | ||
152 | + psnr_bicubic = PSNR(f_bi, f_gt, shave_border=scale) | ||
153 | + # psnr_bicubic = PSNR_ver2(cv2.imread(f_gt), cv2.imread(f_bi)) | ||
154 | + avg_psnr_bicubic += psnr_bicubic | ||
155 | + | ||
156 | + f_input = f_bi/255. | ||
157 | + f_input = Variable(torch.from_numpy(f_input).float()).view(1, -1, f_input.shape[0], f_input.shape[1]) | ||
158 | + | ||
159 | + if cuda: | ||
160 | + model = model.cuda() | ||
161 | + f_input = f_input.cuda() | ||
162 | + else: | ||
163 | + model = model.cpu() | ||
164 | + | ||
165 | + start_time = time.time() | ||
166 | + SR = model(f_input) | ||
167 | + elapsed_time = time.time() - start_time | ||
168 | + avg_elapsed_time += elapsed_time | ||
169 | + | ||
170 | + SR = SR.cpu() | ||
171 | + | ||
172 | + f_sr = SR.data[0].numpy().astype(np.float32) | ||
173 | + | ||
174 | + f_sr = f_sr * 255 | ||
175 | + f_sr[f_sr<0] = 0 | ||
176 | + f_sr[f_sr>255.] = 255. | ||
177 | + f_sr = f_sr[0,:,:] | ||
178 | + | ||
179 | + psnr_predicted = PSNR(f_sr, f_gt, shave_border=scale) | ||
180 | + # psnr_predicted = PSNR_ver2(cv2.imread(f_gt), cv2.imread(f_sr)) | ||
181 | + avg_psnr_predicted += psnr_predicted | ||
182 | + features.append(f_sr) | ||
183 | + | ||
184 | + concatFeatures(features, image) | ||
185 | + concatFeatures(features_bicubic, image, True) | ||
186 | + print("Scale=", scale) | ||
187 | + print("Dataset=", opt.dataset) | ||
188 | + print("Average PSNR_predicted=", avg_psnr_predicted/count) | ||
189 | + print("Average PSNR_bicubic=", avg_psnr_bicubic/count) | ||
190 | + | ||
191 | + | ||
192 | +# Show graph | ||
193 | +# f_gt = Image.fromarray(f_gt) | ||
194 | +# f_b = Image.fromarray(f_bi) | ||
195 | +# f_sr = Image.fromarray(f_sr) | ||
196 | + | ||
197 | +# fig = plt.figure(figsize=(18, 16), dpi= 80) | ||
198 | +# ax = plt.subplot("131") | ||
199 | +# ax.imshow(f_gt) | ||
200 | +# ax.set_title("GT") | ||
201 | + | ||
202 | +# ax = plt.subplot("132") | ||
203 | +# ax.imshow(f_bi) | ||
204 | +# ax.set_title("Input(bicubic)") | ||
205 | + | ||
206 | +# ax = plt.subplot("133") | ||
207 | +# ax.imshow(f_sr) | ||
208 | +# ax.set_title("Output(vdsr)") | ||
209 | +# plt.show() |
code/vdsr/instances_val2017_dataset100.json
0 → 100644
This diff could not be displayed because it is too large.
1 | import argparse, os | 1 | import argparse, os |
2 | +from datasets import get_data_loader | ||
2 | import torch | 3 | import torch |
3 | import random | 4 | import random |
4 | import torch.backends.cudnn as cudnn | 5 | import torch.backends.cudnn as cudnn |
... | @@ -7,33 +8,37 @@ import torch.optim as optim | ... | @@ -7,33 +8,37 @@ import torch.optim as optim |
7 | from torch.autograd import Variable | 8 | from torch.autograd import Variable |
8 | from torch.utils.data import DataLoader | 9 | from torch.utils.data import DataLoader |
9 | from vdsr import Net | 10 | from vdsr import Net |
10 | -from dataset import DatasetFromHdf5 | 11 | +from datasets import get_training_data_loader |
11 | -## Custom | 12 | +# from datasets import get_data_loader_test_version |
12 | -from data import FeatureDataset | 13 | +# from feature_dataset import get_training_data_loader |
14 | +# from make_dataset import make_dataset | ||
15 | +import numpy as np | ||
16 | +from dataFromH5 import Read_dataset_h5 | ||
17 | +import matplotlib.pyplot as plt | ||
18 | +import math | ||
13 | 19 | ||
14 | # Training settings | 20 | # Training settings |
15 | parser = argparse.ArgumentParser(description="PyTorch VDSR") | 21 | parser = argparse.ArgumentParser(description="PyTorch VDSR") |
16 | -parser.add_argument("--batchSize", type=int, default=128, help="Training batch size") | 22 | +parser.add_argument("--dataRoot", type=str) |
17 | -parser.add_argument("--nEpochs", type=int, default=50, help="Number of epochs to train for") | 23 | +parser.add_argument("--featureType", type=str) |
18 | -parser.add_argument("--lr", type=float, default=0.1, help="Learning Rate. Default=0.1") | 24 | +parser.add_argument("--scaleFactor", type=int, default=4) |
25 | +parser.add_argument("--batchSize", type=int, default=64, help="Training batch size") | ||
26 | +parser.add_argument("--nEpochs", type=int, default=20, help="Number of epochs to train for") | ||
27 | +parser.add_argument("--lr", type=float, default=0.001, help="Learning Rate. Default=0.1") | ||
19 | parser.add_argument("--step", type=int, default=10, help="Sets the learning rate to the initial LR decayed by momentum every n epochs, Default: n=10") | 28 | parser.add_argument("--step", type=int, default=10, help="Sets the learning rate to the initial LR decayed by momentum every n epochs, Default: n=10") |
20 | parser.add_argument("--cuda", action="store_true", help="Use cuda?") | 29 | parser.add_argument("--cuda", action="store_true", help="Use cuda?") |
21 | parser.add_argument("--resume", default="", type=str, help="Path to checkpoint (default: none)") | 30 | parser.add_argument("--resume", default="", type=str, help="Path to checkpoint (default: none)") |
22 | parser.add_argument("--start-epoch", default=1, type=int, help="Manual epoch number (useful on restarts)") | 31 | parser.add_argument("--start-epoch", default=1, type=int, help="Manual epoch number (useful on restarts)") |
23 | parser.add_argument("--clip", type=float, default=0.4, help="Clipping Gradients. Default=0.4") | 32 | parser.add_argument("--clip", type=float, default=0.4, help="Clipping Gradients. Default=0.4") |
24 | -# 1->3 custom | 33 | +parser.add_argument("--threads", type=int, default=1, help="Number of threads for data loader to use, Default: 1") |
25 | -parser.add_argument("--threads", type=int, default=3, help="Number of threads for data loader to use, Default: 1") | ||
26 | parser.add_argument("--momentum", default=0.9, type=float, help="Momentum, Default: 0.9") | 34 | parser.add_argument("--momentum", default=0.9, type=float, help="Momentum, Default: 0.9") |
27 | parser.add_argument("--weight-decay", "--wd", default=1e-4, type=float, help="Weight decay, Default: 1e-4") | 35 | parser.add_argument("--weight-decay", "--wd", default=1e-4, type=float, help="Weight decay, Default: 1e-4") |
28 | parser.add_argument('--pretrained', default='', type=str, help='path to pretrained model (default: none)') | 36 | parser.add_argument('--pretrained', default='', type=str, help='path to pretrained model (default: none)') |
29 | parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)") | 37 | parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)") |
30 | -## custom | ||
31 | -parser.add_argument("--dataPath", type=str) | ||
32 | -parser.add_argument("--featureType", type=str, default="p2") | ||
33 | -parser.add_argument("--scaleFactor",type=int, default=4) | ||
34 | 38 | ||
35 | -# parser.add_argument("--trainingData", type=DataLoader) | ||
36 | 39 | ||
40 | +total_loss_for_plot = list() | ||
41 | +total_pnsr = list() | ||
37 | 42 | ||
38 | def main(): | 43 | def main(): |
39 | global opt, model | 44 | global opt, model |
... | @@ -55,21 +60,17 @@ def main(): | ... | @@ -55,21 +60,17 @@ def main(): |
55 | 60 | ||
56 | cudnn.benchmark = True | 61 | cudnn.benchmark = True |
57 | 62 | ||
58 | - print("===> Loading datasets") | 63 | +################## Loading Datasets ########################## |
59 | 64 | ||
60 | - if os.path.isfile('dataloader/training_data_loader.pth'): | 65 | + print("===> Loading datasets") |
61 | - training_data_loader = torch.load('dataloader/training_data_loader.pth') | 66 | + # train_set = DatasetFromHdf5("data/train.h5") |
62 | - else: | 67 | + # training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True) |
63 | - train_set = FeatureDataset(opt.dataPath,opt.featureType,opt.scaleFactor,False) | 68 | + training_data_loader = get_training_data_loader(opt.dataRoot, opt.featureType, opt.scaleFactor, opt.batchSize, opt.threads) |
64 | - train_size = 100 #우선은 100개만 | 69 | + # training_data_loader = make_dataset(opt.dataRoot, opt.featureType, opt.scaleFactor, opt.batchSize, opt.threads) |
65 | - test_size = len(train_set) - train_size | ||
66 | - train_dataset, test_dataset = torch.utils.data.random_split(train_set, [train_size, test_size]) | ||
67 | - training_data_loader = DataLoader(dataset=train_dataset, num_workers=3, batch_size=8, shuffle=True, pin_memory=False) | ||
68 | - torch.save(training_data_loader, 'dataloader/training_data_loader.pth'.format(DataLoader)) | ||
69 | 70 | ||
70 | print("===> Building model") | 71 | print("===> Building model") |
71 | - model = Net(opt.scaleFactor) | 72 | + model = Net() |
72 | - criterion = nn.MSELoss(size_average=False) | 73 | + criterion = nn.MSELoss(reduction='sum') |
73 | 74 | ||
74 | print("===> Setting GPU") | 75 | print("===> Setting GPU") |
75 | if cuda: | 76 | if cuda: |
... | @@ -91,6 +92,7 @@ def main(): | ... | @@ -91,6 +92,7 @@ def main(): |
91 | if os.path.isfile(opt.pretrained): | 92 | if os.path.isfile(opt.pretrained): |
92 | print("=> loading model '{}'".format(opt.pretrained)) | 93 | print("=> loading model '{}'".format(opt.pretrained)) |
93 | weights = torch.load(opt.pretrained) | 94 | weights = torch.load(opt.pretrained) |
95 | + opt.start_epoch = weights["epoch"] + 1 | ||
94 | model.load_state_dict(weights['model'].state_dict()) | 96 | model.load_state_dict(weights['model'].state_dict()) |
95 | else: | 97 | else: |
96 | print("=> no model found at '{}'".format(opt.pretrained)) | 98 | print("=> no model found at '{}'".format(opt.pretrained)) |
... | @@ -99,15 +101,21 @@ def main(): | ... | @@ -99,15 +101,21 @@ def main(): |
99 | optimizer = optim.SGD(model.parameters(), lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay) | 101 | optimizer = optim.SGD(model.parameters(), lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay) |
100 | 102 | ||
101 | print("===> Training") | 103 | print("===> Training") |
104 | + | ||
102 | for epoch in range(opt.start_epoch, opt.nEpochs + 1): | 105 | for epoch in range(opt.start_epoch, opt.nEpochs + 1): |
103 | train(training_data_loader, optimizer, model, criterion, epoch) | 106 | train(training_data_loader, optimizer, model, criterion, epoch) |
104 | - save_checkpoint(model, epoch) | 107 | + save_checkpoint(model, epoch, optimizer) |
105 | 108 | ||
106 | def adjust_learning_rate(optimizer, epoch): | 109 | def adjust_learning_rate(optimizer, epoch): |
107 | """Sets the learning rate to the initial LR decayed by 10 every 10 epochs""" | 110 | """Sets the learning rate to the initial LR decayed by 10 every 10 epochs""" |
108 | lr = opt.lr * (0.1 ** (epoch // opt.step)) | 111 | lr = opt.lr * (0.1 ** (epoch // opt.step)) |
109 | return lr | 112 | return lr |
110 | 113 | ||
114 | +def PSNR(loss): | ||
115 | + psnr = 10 * np.log10(1 / (loss + 1e-10)) | ||
116 | + # psnr = 20 * math.log10(255.0 / (math.sqrt(loss))) | ||
117 | + return psnr | ||
118 | + | ||
111 | def train(training_data_loader, optimizer, model, criterion, epoch): | 119 | def train(training_data_loader, optimizer, model, criterion, epoch): |
112 | lr = adjust_learning_rate(optimizer, epoch-1) | 120 | lr = adjust_learning_rate(optimizer, epoch-1) |
113 | 121 | ||
... | @@ -119,25 +127,31 @@ def train(training_data_loader, optimizer, model, criterion, epoch): | ... | @@ -119,25 +127,31 @@ def train(training_data_loader, optimizer, model, criterion, epoch): |
119 | model.train() | 127 | model.train() |
120 | 128 | ||
121 | for iteration, batch in enumerate(training_data_loader, 1): | 129 | for iteration, batch in enumerate(training_data_loader, 1): |
122 | - input, target = Variable(batch[0]), Variable(batch[1], requires_grad=False) | 130 | + optimizer.zero_grad() |
123 | - | 131 | + input, target = Variable(batch[0], requires_grad=False), Variable(batch[1], requires_grad=False) |
132 | + total_loss = 0 | ||
124 | if opt.cuda: | 133 | if opt.cuda: |
125 | input = input.cuda() | 134 | input = input.cuda() |
126 | target = target.cuda() | 135 | target = target.cuda() |
127 | 136 | ||
128 | loss = criterion(model(input), target) | 137 | loss = criterion(model(input), target) |
129 | - optimizer.zero_grad() | 138 | + total_loss += loss.item() |
130 | loss.backward() | 139 | loss.backward() |
131 | - nn.utils.clip_grad_norm(model.parameters(),opt.clip) | 140 | + nn.utils.clip_grad_norm_(model.parameters(), opt.clip) |
132 | optimizer.step() | 141 | optimizer.step() |
133 | 142 | ||
134 | - if iteration%10 == 0: | 143 | + epoch_loss = total_loss / len(training_data_loader) |
135 | - # loss.data[0] --> loss.data | 144 | + total_loss_for_plot.append(epoch_loss) |
136 | - print("===> Epoch[{}]({}/{}): Loss: {:.10f}".format(epoch, iteration, len(training_data_loader), loss.data)) | 145 | + psnr = PSNR(epoch_loss) |
137 | - | 146 | + total_pnsr.append(psnr) |
138 | -def save_checkpoint(model, epoch): | 147 | + print("===> Epoch[{}]: loss : {:.10f} ,PSNR : {:.10f}".format(epoch, epoch_loss, psnr)) |
139 | - model_out_path = "checkpoint/" + "model_epoch_{}.pth".format(epoch) | 148 | + # if iteration%100 == 0: |
140 | - state = {"epoch": epoch ,"model": model} | 149 | + # print("===> Epoch[{}]({}/{}): Loss: {:.10f}".format(epoch, iteration, len(training_data_loader), loss.item())) |
150 | + | ||
151 | +def save_checkpoint(model, epoch, optimizer): | ||
152 | + model_out_path = "checkpoint/" + "model_epoch_{}_{}.pth".format(epoch, opt.featureType) | ||
153 | + state = {"epoch": epoch ,"model": model, "model_state_dict":model.state_dict(), "optimizer_state_dict":optimizer.state_dict(), | ||
154 | + "loss": total_loss_for_plot, "psnr":total_pnsr} | ||
141 | if not os.path.exists("checkpoint/"): | 155 | if not os.path.exists("checkpoint/"): |
142 | os.makedirs("checkpoint/") | 156 | os.makedirs("checkpoint/") |
143 | 157 | ... | ... |
... | @@ -12,12 +12,11 @@ class Conv_ReLU_Block(nn.Module): | ... | @@ -12,12 +12,11 @@ class Conv_ReLU_Block(nn.Module): |
12 | return self.relu(self.conv(x)) | 12 | return self.relu(self.conv(x)) |
13 | 13 | ||
14 | class Net(nn.Module): | 14 | class Net(nn.Module): |
15 | - def __init__(self,upscale_factor): | 15 | + def __init__(self): |
16 | super(Net, self).__init__() | 16 | super(Net, self).__init__() |
17 | self.residual_layer = self.make_layer(Conv_ReLU_Block, 18) | 17 | self.residual_layer = self.make_layer(Conv_ReLU_Block, 18) |
18 | self.input = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) | 18 | self.input = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) |
19 | self.output = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False) | 19 | self.output = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False) |
20 | - self.upsample = nn.Upsample(scale_factor=upscale_factor, mode='bicubic') | ||
21 | self.relu = nn.ReLU(inplace=True) | 20 | self.relu = nn.ReLU(inplace=True) |
22 | 21 | ||
23 | for m in self.modules(): | 22 | for m in self.modules(): |
... | @@ -32,7 +31,6 @@ class Net(nn.Module): | ... | @@ -32,7 +31,6 @@ class Net(nn.Module): |
32 | return nn.Sequential(*layers) | 31 | return nn.Sequential(*layers) |
33 | 32 | ||
34 | def forward(self, x): | 33 | def forward(self, x): |
35 | - x = self.upsample(x) | ||
36 | residual = x | 34 | residual = x |
37 | out = self.relu(self.input(x)) | 35 | out = self.relu(self.input(x)) |
38 | out = self.residual_layer(out) | 36 | out = self.residual_layer(out) | ... | ... |
면담확인서/캡스톤 디자인 2 면담확인서 12주차.docx
0 → 100644
No preview for this file type
주간보고서/캡스톤 디자인 2 주간보고서 6월.docx
0 → 100644
No preview for this file type
중간보고서/중간보고서.docx
deleted
100644 → 0
No preview for this file type
중간보고서/중간보고서.hwp
deleted
100644 → 0
No preview for this file type
중간보고서/중간보고서_2017103084_서민정.docx
0 → 100644
No preview for this file type
-
Please register or login to post a comment