서민정

docs: 면담확인서 업로드 및 코드 업로드

1 +The MIT License (MIT)
2 +
3 +Copyright (c) 2017- Jiu XU
4 +Copyright (c) 2017- Rakuten, Inc
5 +Copyright (c) 2017- Rakuten Institute of Technology
6 +
7 +Permission is hereby granted, free of charge, to any person obtaining a copy
8 +of this software and associated documentation files (the "Software"), to deal
9 +in the Software without restriction, including without limitation the rights
10 +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 +copies of the Software, and to permit persons to whom the Software is
12 +furnished to do so, subject to the following conditions:
13 +
14 +The above copyright notice and this permission notice shall be included in all
15 +copies or substantial portions of the Software.
16 +
17 +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 +SOFTWARE.
...\ No newline at end of file ...\ No newline at end of file
1 +# PyTorch VDSR
2 +Implementation of CVPR2016 Paper: "Accurate Image Super-Resolution Using
3 +Very Deep Convolutional Networks"(http://cv.snu.ac.kr/research/VDSR/) in PyTorch
4 +
5 +## Usage
6 +### Training
7 +```
8 +usage: main_vdsr.py [-h] [--batchSize BATCHSIZE] [--nEpochs NEPOCHS] [--lr LR]
9 + [--step STEP] [--cuda] [--resume RESUME]
10 + [--start-epoch START_EPOCH] [--clip CLIP] [--threads THREADS]
11 + [--momentum MOMENTUM] [--weight-decay WEIGHT_DECAY]
12 + [--pretrained PRETRAINED] [--gpus GPUS]
13 +
14 +optional arguments:
15 + -h, --help Show this help message and exit
16 + --batchSize Training batch size
17 + --nEpochs Number of epochs to train for
18 + --lr Learning rate. Default=0.01
19 + --step Learning rate decay, Default: n=10 epochs
20 + --cuda Use cuda
21 + --resume Path to checkpoint
22 + --clip Clipping Gradients. Default=0.4
23 + --threads Number of threads for data loader to use Default=1
24 + --momentum Momentum, Default: 0.9
25 + --weight-decay Weight decay, Default: 1e-4
26 + --pretrained PRETRAINED
27 + path to pretrained model (default: none)
28 + --gpus GPUS gpu ids (default: 0)
29 +```
30 +An example of training usage is shown as follows:
31 +```
32 +python main_vdsr.py --cuda --gpus 0
33 +```
34 +
35 +### Evaluation
36 +```
37 +usage: eval.py [-h] [--cuda] [--model MODEL] [--dataset DATASET]
38 + [--scale SCALE] [--gpus GPUS]
39 +
40 +PyTorch VDSR Eval
41 +
42 +optional arguments:
43 + -h, --help show this help message and exit
44 + --cuda use cuda?
45 + --model MODEL model path
46 + --dataset DATASET dataset name, Default: Set5
47 + --gpus GPUS gpu ids (default: 0)
48 +```
49 +An example of training usage is shown as follows:
50 +```
51 +python eval.py --cuda --dataset Set5
52 +```
53 +
54 +### Demo
55 +```
56 +usage: demo.py [-h] [--cuda] [--model MODEL] [--image IMAGE] [--scale SCALE] [--gpus GPUS]
57 +
58 +optional arguments:
59 + -h, --help Show this help message and exit
60 + --cuda Use cuda
61 + --model Model path. Default=model/model_epoch_50.pth
62 + --image Image name. Default=butterfly_GT
63 + --scale Scale factor, Default: 4
64 + --gpus GPUS gpu ids (default: 0)
65 +```
66 +An example of usage is shown as follows:
67 +```
68 +python eval.py --model model/model_epoch_50.pth --dataset Set5 --cuda
69 +```
70 +
71 +### Prepare Training dataset
72 + - We provide a simple hdf5 format training sample in data folder with 'data' and 'label' keys, the training data is generated with Matlab Bicubic Interplotation, please refer [Code for Data Generation](https://github.com/twtygqyy/pytorch-vdsr/tree/master/data) for creating training files.
73 +
74 +### Performance
75 + - We provide a pretrained VDSR model trained on [291](https://drive.google.com/open?id=1Rt3asDLuMgLuJvPA1YrhyjWhb97Ly742) images with data augmentation
76 + - No bias is used in this implementation, and the gradient clipping's implementation is different from paper
77 + - Performance in PSNR on Set5
78 +
79 +| Scale | VDSR Paper | VDSR PyTorch|
80 +| ------------- |:-------------:| -----:|
81 +| 2x | 37.53 | 37.65 |
82 +| 3x | 33.66 | 33.77|
83 +| 4x | 31.35 | 31.45 |
84 +
85 +### Result
86 +From left to right are ground truth, bicubic and vdsr
87 +<p>
88 + <img src='Set5/butterfly_GT.bmp' height='200' width='200'/>
89 + <img src='result/input.bmp' height='200' width='200'/>
90 + <img src='result/output.bmp' height='200' width='200'/>
91 +</p>
1 +from torch.utils.data import Dataset
2 +from PIL import Image
3 +import os
4 +from glob import glob
5 +from torchvision import transforms
6 +from torch.utils.data.dataset import Dataset
7 +from torchvision import transforms
8 +import torch
9 +import pdb
10 +import math
11 +import numpy as np
12 +class FeatureDataset(Dataset):
13 + def __init__(self, data_path, datatype, rescale_factor,valid):
14 + self.data_path = data_path
15 + self.datatype = datatype
16 + self.rescale_factor = rescale_factor
17 + if not os.path.exists(data_path):
18 + raise Exception(f"[!] {self.data_path} not existed")
19 + if(valid):
20 + self.hr_path = os.path.join(self.data_path,'valid')
21 + self.hr_path = os.path.join(self.hr_path,self.datatype)
22 + else:
23 + self.hr_path = os.path.join(self.data_path,'LR_2')
24 + self.hr_path = os.path.join(self.hr_path,self.datatype)
25 + print(self.hr_path)
26 + self.hr_path = sorted(glob(os.path.join(self.hr_path, "*.*")))
27 + self.hr_imgs = []
28 + w,h = Image.open(self.hr_path[0]).size
29 + self.width = int(w/16)
30 + self.height = int(h/16)
31 + self.lwidth = int(self.width/self.rescale_factor)
32 + self.lheight = int(self.height/self.rescale_factor)
33 + print("lr: ({} {}), hr: ({} {})".format(self.lwidth,self.lheight,self.width,self.height))
34 + for hr in self.hr_path:
35 + hr_image = Image.open(hr)#.convert('RGB')\
36 + for i in range(16):
37 + for j in range(16):
38 + (left,upper,right,lower) = (i*self.width,j*self.height,(i+1)*self.width,(j+1)*self.height)
39 + crop = hr_image.crop((left,upper,right,lower))
40 + self.hr_imgs.append(crop)
41 +
42 + def __getitem__(self, idx):
43 + hr_image = self.hr_imgs[idx]
44 + transform = transforms.Compose([
45 + transforms.Resize((self.lheight,self.lwidth),3),
46 + transforms.ToTensor()
47 + ])
48 + return transform(hr_image), transforms.ToTensor()(hr_image)
49 +
50 + def __len__(self):
51 + return len(self.hr_path*16*16)
52 +
53 +def get_data_loader(data_path, feature_type, rescale_factor, batch_size, num_workers):
54 + full_dataset = FeatureDataset(data_path,feature_type,rescale_factor,False)
55 + train_size = int(0.9 * len(full_dataset))
56 + test_size = len(full_dataset) - train_size
57 + train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
58 + torch.manual_seed(3334)
59 + train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True,num_workers=num_workers, pin_memory=False)
60 + test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False,num_workers=num_workers, pin_memory=True)
61 +
62 + return train_loader, test_loader
63 +
64 +def get_infer_dataloader(data_path, feature_type, rescale_factor, batch_size, num_workers):
65 + dataset = FeatureDataset(data_path,feature_type,rescale_factor,True)
66 + data_loader = torch.utils.data.DataLoader(dataset=dataset,batch_size=batch_size,shuffle=False,num_workers=num_workers,pin_memory=False)
67 + return data_loader
...\ No newline at end of file ...\ No newline at end of file
1 +import torch.utils.data as data
2 +import torch
3 +import h5py
4 +
5 +class DatasetFromHdf5(data.Dataset):
6 + def __init__(self, file_path):
7 + super(DatasetFromHdf5, self).__init__()
8 + hf = h5py.File(file_path)
9 + self.data = hf.get('data')
10 + self.target = hf.get('label')
11 +
12 + def __getitem__(self, index):
13 + return torch.from_numpy(self.data[index,:,:,:]).float(), torch.from_numpy(self.target[index,:,:,:]).float()
14 +
15 + def __len__(self):
16 + return self.data.shape[0]
...\ No newline at end of file ...\ No newline at end of file
1 +import argparse, os
2 +import torch
3 +from torch.autograd import Variable
4 +from scipy.ndimage import imread
5 +from PIL import Image
6 +import numpy as np
7 +import time, math
8 +import matplotlib.pyplot as plt
9 +
10 +parser = argparse.ArgumentParser(description="PyTorch VDSR Demo")
11 +parser.add_argument("--cuda", action="store_true", help="use cuda?")
12 +parser.add_argument("--model", default="model/model_epoch_50.pth", type=str, help="model path")
13 +parser.add_argument("--image", default="butterfly_GT", type=str, help="image name")
14 +parser.add_argument("--scale", default=4, type=int, help="scale factor, Default: 4")
15 +parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)")
16 +
17 +def PSNR(pred, gt, shave_border=0):
18 + height, width = pred.shape[:2]
19 + pred = pred[shave_border:height - shave_border, shave_border:width - shave_border]
20 + gt = gt[shave_border:height - shave_border, shave_border:width - shave_border]
21 + imdff = pred - gt
22 + rmse = math.sqrt(np.mean(imdff ** 2))
23 + if rmse == 0:
24 + return 100
25 + return 20 * math.log10(255.0 / rmse)
26 +
27 +def colorize(y, ycbcr):
28 + img = np.zeros((y.shape[0], y.shape[1], 3), np.uint8)
29 + img[:,:,0] = y
30 + img[:,:,1] = ycbcr[:,:,1]
31 + img[:,:,2] = ycbcr[:,:,2]
32 + img = Image.fromarray(img, "YCbCr").convert("RGB")
33 + return img
34 +
35 +opt = parser.parse_args()
36 +cuda = opt.cuda
37 +
38 +if cuda:
39 + print("=> use gpu id: '{}'".format(opt.gpus))
40 + os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus
41 + if not torch.cuda.is_available():
42 + raise Exception("No GPU found or Wrong gpu id, please run without --cuda")
43 +
44 +
45 +model = torch.load(opt.model, map_location=lambda storage, loc: storage)["model"]
46 +
47 +im_gt_ycbcr = imread("Set5/" + opt.image + ".bmp", mode="YCbCr")
48 +im_b_ycbcr = imread("Set5/"+ opt.image + "_scale_"+ str(opt.scale) + ".bmp", mode="YCbCr")
49 +
50 +im_gt_y = im_gt_ycbcr[:,:,0].astype(float)
51 +im_b_y = im_b_ycbcr[:,:,0].astype(float)
52 +
53 +psnr_bicubic = PSNR(im_gt_y, im_b_y,shave_border=opt.scale)
54 +
55 +im_input = im_b_y/255.
56 +
57 +im_input = Variable(torch.from_numpy(im_input).float()).view(1, -1, im_input.shape[0], im_input.shape[1])
58 +
59 +if cuda:
60 + model = model.cuda()
61 + im_input = im_input.cuda()
62 +else:
63 + model = model.cpu()
64 +
65 +start_time = time.time()
66 +out = model(im_input)
67 +elapsed_time = time.time() - start_time
68 +
69 +out = out.cpu()
70 +
71 +im_h_y = out.data[0].numpy().astype(np.float32)
72 +
73 +im_h_y = im_h_y * 255.
74 +im_h_y[im_h_y < 0] = 0
75 +im_h_y[im_h_y > 255.] = 255.
76 +
77 +psnr_predicted = PSNR(im_gt_y, im_h_y[0,:,:], shave_border=opt.scale)
78 +
79 +im_h = colorize(im_h_y[0,:,:], im_b_ycbcr)
80 +im_gt = Image.fromarray(im_gt_ycbcr, "YCbCr").convert("RGB")
81 +im_b = Image.fromarray(im_b_ycbcr, "YCbCr").convert("RGB")
82 +
83 +print("Scale=",opt.scale)
84 +print("PSNR_predicted=", psnr_predicted)
85 +print("PSNR_bicubic=", psnr_bicubic)
86 +print("It takes {}s for processing".format(elapsed_time))
87 +
88 +fig = plt.figure()
89 +ax = plt.subplot("131")
90 +ax.imshow(im_gt)
91 +ax.set_title("GT")
92 +
93 +ax = plt.subplot("132")
94 +ax.imshow(im_b)
95 +ax.set_title("Input(bicubic)")
96 +
97 +ax = plt.subplot("133")
98 +ax.imshow(im_h)
99 +ax.set_title("Output(vdsr)")
100 +plt.show()
1 +import argparse, os
2 +import torch
3 +from torch.autograd import Variable
4 +import numpy as np
5 +import time, math, glob
6 +import scipy.io as sio
7 +
8 +parser = argparse.ArgumentParser(description="PyTorch VDSR Eval")
9 +parser.add_argument("--cuda", action="store_true", help="use cuda?")
10 +parser.add_argument("--model", default="model/model_epoch_50.pth", type=str, help="model path")
11 +parser.add_argument("--dataset", default="Set5", type=str, help="dataset name, Default: Set5")
12 +parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)")
13 +
14 +def PSNR(pred, gt, shave_border=0):
15 + height, width = pred.shape[:2]
16 + pred = pred[shave_border:height - shave_border, shave_border:width - shave_border]
17 + gt = gt[shave_border:height - shave_border, shave_border:width - shave_border]
18 + imdff = pred - gt
19 + rmse = math.sqrt(np.mean(imdff ** 2))
20 + if rmse == 0:
21 + return 100
22 + return 20 * math.log10(255.0 / rmse)
23 +
24 +opt = parser.parse_args()
25 +cuda = opt.cuda
26 +
27 +if cuda:
28 + print("=> use gpu id: '{}'".format(opt.gpus))
29 + os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus
30 + if not torch.cuda.is_available():
31 + raise Exception("No GPU found or Wrong gpu id, please run without --cuda")
32 +
33 +model = torch.load(opt.model, map_location=lambda storage, loc: storage)["model"]
34 +
35 +scales = [2,3,4]
36 +
37 +image_list = glob.glob(opt.dataset+"_mat/*.*")
38 +
39 +for scale in scales:
40 + avg_psnr_predicted = 0.0
41 + avg_psnr_bicubic = 0.0
42 + avg_elapsed_time = 0.0
43 + count = 0.0
44 + for image_name in image_list:
45 + if str(scale) in image_name:
46 + count += 1
47 + print("Processing ", image_name)
48 + im_gt_y = sio.loadmat(image_name)['im_gt_y']
49 + im_b_y = sio.loadmat(image_name)['im_b_y']
50 +
51 + im_gt_y = im_gt_y.astype(float)
52 + im_b_y = im_b_y.astype(float)
53 +
54 + psnr_bicubic = PSNR(im_gt_y, im_b_y,shave_border=scale)
55 + avg_psnr_bicubic += psnr_bicubic
56 +
57 + im_input = im_b_y/255.
58 +
59 + im_input = Variable(torch.from_numpy(im_input).float()).view(1, -1, im_input.shape[0], im_input.shape[1])
60 +
61 + if cuda:
62 + model = model.cuda()
63 + im_input = im_input.cuda()
64 + else:
65 + model = model.cpu()
66 +
67 + start_time = time.time()
68 + HR = model(im_input)
69 + elapsed_time = time.time() - start_time
70 + avg_elapsed_time += elapsed_time
71 +
72 + HR = HR.cpu()
73 +
74 + im_h_y = HR.data[0].numpy().astype(np.float32)
75 +
76 + im_h_y = im_h_y * 255.
77 + im_h_y[im_h_y < 0] = 0
78 + im_h_y[im_h_y > 255.] = 255.
79 + im_h_y = im_h_y[0,:,:]
80 +
81 + psnr_predicted = PSNR(im_gt_y, im_h_y,shave_border=scale)
82 + avg_psnr_predicted += psnr_predicted
83 +
84 + print("Scale=", scale)
85 + print("Dataset=", opt.dataset)
86 + print("PSNR_predicted=", avg_psnr_predicted/count)
87 + print("PSNR_bicubic=", avg_psnr_bicubic/count)
88 + print("It takes average {}s for processing".format(avg_elapsed_time/count))
1 +import argparse, os
2 +import torch
3 +import random
4 +import torch.backends.cudnn as cudnn
5 +import torch.nn as nn
6 +import torch.optim as optim
7 +from torch.autograd import Variable
8 +from torch.utils.data import DataLoader
9 +from vdsr import Net
10 +from dataset import DatasetFromHdf5
11 +## Custom
12 +from data import FeatureDataset
13 +
14 +# Training settings
15 +parser = argparse.ArgumentParser(description="PyTorch VDSR")
16 +parser.add_argument("--batchSize", type=int, default=128, help="Training batch size")
17 +parser.add_argument("--nEpochs", type=int, default=50, help="Number of epochs to train for")
18 +parser.add_argument("--lr", type=float, default=0.1, help="Learning Rate. Default=0.1")
19 +parser.add_argument("--step", type=int, default=10, help="Sets the learning rate to the initial LR decayed by momentum every n epochs, Default: n=10")
20 +parser.add_argument("--cuda", action="store_true", help="Use cuda?")
21 +parser.add_argument("--resume", default="", type=str, help="Path to checkpoint (default: none)")
22 +parser.add_argument("--start-epoch", default=1, type=int, help="Manual epoch number (useful on restarts)")
23 +parser.add_argument("--clip", type=float, default=0.4, help="Clipping Gradients. Default=0.4")
24 +# 1->3 custom
25 +parser.add_argument("--threads", type=int, default=3, help="Number of threads for data loader to use, Default: 1")
26 +parser.add_argument("--momentum", default=0.9, type=float, help="Momentum, Default: 0.9")
27 +parser.add_argument("--weight-decay", "--wd", default=1e-4, type=float, help="Weight decay, Default: 1e-4")
28 +parser.add_argument('--pretrained', default='', type=str, help='path to pretrained model (default: none)')
29 +parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)")
30 +## custom
31 +parser.add_argument("--dataPath", type=str)
32 +parser.add_argument("--featureType", type=str, default="p2")
33 +parser.add_argument("--scaleFactor",type=int, default=4)
34 +
35 +# parser.add_argument("--trainingData", type=DataLoader)
36 +
37 +
38 +def main():
39 + global opt, model
40 + opt = parser.parse_args()
41 + print(opt)
42 +
43 + cuda = opt.cuda
44 + if cuda:
45 + print("=> use gpu id: '{}'".format(opt.gpus))
46 + os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus
47 + if not torch.cuda.is_available():
48 + raise Exception("No GPU found or Wrong gpu id, please run without --cuda")
49 +
50 + opt.seed = random.randint(1, 10000)
51 + print("Random Seed: ", opt.seed)
52 + torch.manual_seed(opt.seed)
53 + if cuda:
54 + torch.cuda.manual_seed(opt.seed)
55 +
56 + cudnn.benchmark = True
57 +
58 + print("===> Loading datasets")
59 +
60 + if os.path.isfile('dataloader/training_data_loader.pth'):
61 + training_data_loader = torch.load('dataloader/training_data_loader.pth')
62 + else:
63 + train_set = FeatureDataset(opt.dataPath,opt.featureType,opt.scaleFactor,False)
64 + train_size = 100 #우선은 100개만
65 + test_size = len(train_set) - train_size
66 + train_dataset, test_dataset = torch.utils.data.random_split(train_set, [train_size, test_size])
67 + training_data_loader = DataLoader(dataset=train_dataset, num_workers=3, batch_size=8, shuffle=True, pin_memory=False)
68 + torch.save(training_data_loader, 'dataloader/training_data_loader.pth'.format(DataLoader))
69 +
70 + print("===> Building model")
71 + model = Net(opt.scaleFactor)
72 + criterion = nn.MSELoss(size_average=False)
73 +
74 + print("===> Setting GPU")
75 + if cuda:
76 + model = model.cuda()
77 + criterion = criterion.cuda()
78 +
79 + # optionally resume from a checkpoint
80 + if opt.resume:
81 + if os.path.isfile(opt.resume):
82 + print("=> loading checkpoint '{}'".format(opt.resume))
83 + checkpoint = torch.load(opt.resume)
84 + opt.start_epoch = checkpoint["epoch"] + 1
85 + model.load_state_dict(checkpoint["model"].state_dict())
86 + else:
87 + print("=> no checkpoint found at '{}'".format(opt.resume))
88 +
89 + # optionally copy weights from a checkpoint
90 + if opt.pretrained:
91 + if os.path.isfile(opt.pretrained):
92 + print("=> loading model '{}'".format(opt.pretrained))
93 + weights = torch.load(opt.pretrained)
94 + model.load_state_dict(weights['model'].state_dict())
95 + else:
96 + print("=> no model found at '{}'".format(opt.pretrained))
97 +
98 + print("===> Setting Optimizer")
99 + optimizer = optim.SGD(model.parameters(), lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay)
100 +
101 + print("===> Training")
102 + for epoch in range(opt.start_epoch, opt.nEpochs + 1):
103 + train(training_data_loader, optimizer, model, criterion, epoch)
104 + save_checkpoint(model, epoch)
105 +
106 +def adjust_learning_rate(optimizer, epoch):
107 + """Sets the learning rate to the initial LR decayed by 10 every 10 epochs"""
108 + lr = opt.lr * (0.1 ** (epoch // opt.step))
109 + return lr
110 +
111 +def train(training_data_loader, optimizer, model, criterion, epoch):
112 + lr = adjust_learning_rate(optimizer, epoch-1)
113 +
114 + for param_group in optimizer.param_groups:
115 + param_group["lr"] = lr
116 +
117 + print("Epoch = {}, lr = {}".format(epoch, optimizer.param_groups[0]["lr"]))
118 +
119 + model.train()
120 +
121 + for iteration, batch in enumerate(training_data_loader, 1):
122 + input, target = Variable(batch[0]), Variable(batch[1], requires_grad=False)
123 +
124 + if opt.cuda:
125 + input = input.cuda()
126 + target = target.cuda()
127 +
128 + loss = criterion(model(input), target)
129 + optimizer.zero_grad()
130 + loss.backward()
131 + nn.utils.clip_grad_norm(model.parameters(),opt.clip)
132 + optimizer.step()
133 +
134 + if iteration%10 == 0:
135 + # loss.data[0] --> loss.data
136 + print("===> Epoch[{}]({}/{}): Loss: {:.10f}".format(epoch, iteration, len(training_data_loader), loss.data))
137 +
138 +def save_checkpoint(model, epoch):
139 + model_out_path = "checkpoint/" + "model_epoch_{}.pth".format(epoch)
140 + state = {"epoch": epoch ,"model": model}
141 + if not os.path.exists("checkpoint/"):
142 + os.makedirs("checkpoint/")
143 +
144 + torch.save(state, model_out_path)
145 +
146 + print("Checkpoint saved to {}".format(model_out_path))
147 +
148 +if __name__ == "__main__":
149 + main()
...\ No newline at end of file ...\ No newline at end of file
1 +import torch
2 +import torch.nn as nn
3 +from math import sqrt
4 +
5 +class Conv_ReLU_Block(nn.Module):
6 + def __init__(self):
7 + super(Conv_ReLU_Block, self).__init__()
8 + self.conv = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
9 + self.relu = nn.ReLU(inplace=True)
10 +
11 + def forward(self, x):
12 + return self.relu(self.conv(x))
13 +
14 +class Net(nn.Module):
15 + def __init__(self,upscale_factor):
16 + super(Net, self).__init__()
17 + self.residual_layer = self.make_layer(Conv_ReLU_Block, 18)
18 + self.input = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
19 + self.output = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False)
20 + self.upsample = nn.Upsample(scale_factor=upscale_factor, mode='bicubic')
21 + self.relu = nn.ReLU(inplace=True)
22 +
23 + for m in self.modules():
24 + if isinstance(m, nn.Conv2d):
25 + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
26 + m.weight.data.normal_(0, sqrt(2. / n))
27 +
28 + def make_layer(self, block, num_of_layer):
29 + layers = []
30 + for _ in range(num_of_layer):
31 + layers.append(block())
32 + return nn.Sequential(*layers)
33 +
34 + def forward(self, x):
35 + x = self.upsample(x)
36 + residual = x
37 + out = self.relu(self.input(x))
38 + out = self.residual_layer(out)
39 + out = self.output(out)
40 + out = torch.add(out,residual)
41 + return out
42 +
...\ No newline at end of file ...\ No newline at end of file