서민정

docs&code: 중간보고서 수정 및 code 업로드

1 -The MIT License (MIT)
2 -
3 -Copyright (c) 2017- Jiu XU
4 -Copyright (c) 2017- Rakuten, Inc
5 -Copyright (c) 2017- Rakuten Institute of Technology
6 -
7 -Permission is hereby granted, free of charge, to any person obtaining a copy
8 -of this software and associated documentation files (the "Software"), to deal
9 -in the Software without restriction, including without limitation the rights
10 -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 -copies of the Software, and to permit persons to whom the Software is
12 -furnished to do so, subject to the following conditions:
13 -
14 -The above copyright notice and this permission notice shall be included in all
15 -copies or substantial portions of the Software.
16 -
17 -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 -SOFTWARE.
...\ No newline at end of file ...\ No newline at end of file
1 -# PyTorch VDSR
2 -Implementation of CVPR2016 Paper: "Accurate Image Super-Resolution Using
3 -Very Deep Convolutional Networks"(http://cv.snu.ac.kr/research/VDSR/) in PyTorch
4 -
5 -## Usage
6 -### Training
7 -```
8 -usage: main_vdsr.py [-h] [--batchSize BATCHSIZE] [--nEpochs NEPOCHS] [--lr LR]
9 - [--step STEP] [--cuda] [--resume RESUME]
10 - [--start-epoch START_EPOCH] [--clip CLIP] [--threads THREADS]
11 - [--momentum MOMENTUM] [--weight-decay WEIGHT_DECAY]
12 - [--pretrained PRETRAINED] [--gpus GPUS]
13 -
14 -optional arguments:
15 - -h, --help Show this help message and exit
16 - --batchSize Training batch size
17 - --nEpochs Number of epochs to train for
18 - --lr Learning rate. Default=0.01
19 - --step Learning rate decay, Default: n=10 epochs
20 - --cuda Use cuda
21 - --resume Path to checkpoint
22 - --clip Clipping Gradients. Default=0.4
23 - --threads Number of threads for data loader to use Default=1
24 - --momentum Momentum, Default: 0.9
25 - --weight-decay Weight decay, Default: 1e-4
26 - --pretrained PRETRAINED
27 - path to pretrained model (default: none)
28 - --gpus GPUS gpu ids (default: 0)
29 -```
30 -An example of training usage is shown as follows:
31 -```
32 -python main_vdsr.py --cuda --gpus 0
33 -```
34 -
35 -### Evaluation
36 -```
37 -usage: eval.py [-h] [--cuda] [--model MODEL] [--dataset DATASET]
38 - [--scale SCALE] [--gpus GPUS]
39 -
40 -PyTorch VDSR Eval
41 -
42 -optional arguments:
43 - -h, --help show this help message and exit
44 - --cuda use cuda?
45 - --model MODEL model path
46 - --dataset DATASET dataset name, Default: Set5
47 - --gpus GPUS gpu ids (default: 0)
48 -```
49 -An example of training usage is shown as follows:
50 -```
51 -python eval.py --cuda --dataset Set5
52 -```
53 -
54 -### Demo
55 -```
56 -usage: demo.py [-h] [--cuda] [--model MODEL] [--image IMAGE] [--scale SCALE] [--gpus GPUS]
57 -
58 -optional arguments:
59 - -h, --help Show this help message and exit
60 - --cuda Use cuda
61 - --model Model path. Default=model/model_epoch_50.pth
62 - --image Image name. Default=butterfly_GT
63 - --scale Scale factor, Default: 4
64 - --gpus GPUS gpu ids (default: 0)
65 -```
66 -An example of usage is shown as follows:
67 -```
68 -python eval.py --model model/model_epoch_50.pth --dataset Set5 --cuda
69 -```
70 -
71 -### Prepare Training dataset
72 - - We provide a simple hdf5 format training sample in data folder with 'data' and 'label' keys, the training data is generated with Matlab Bicubic Interplotation, please refer [Code for Data Generation](https://github.com/twtygqyy/pytorch-vdsr/tree/master/data) for creating training files.
73 -
74 -### Performance
75 - - We provide a pretrained VDSR model trained on [291](https://drive.google.com/open?id=1Rt3asDLuMgLuJvPA1YrhyjWhb97Ly742) images with data augmentation
76 - - No bias is used in this implementation, and the gradient clipping's implementation is different from paper
77 - - Performance in PSNR on Set5
78 -
79 -| Scale | VDSR Paper | VDSR PyTorch|
80 -| ------------- |:-------------:| -----:|
81 -| 2x | 37.53 | 37.65 |
82 -| 3x | 33.66 | 33.77|
83 -| 4x | 31.35 | 31.45 |
84 -
85 -### Result
86 -From left to right are ground truth, bicubic and vdsr
87 -<p>
88 - <img src='Set5/butterfly_GT.bmp' height='200' width='200'/>
89 - <img src='result/input.bmp' height='200' width='200'/>
90 - <img src='result/output.bmp' height='200' width='200'/>
91 -</p>
1 -import torch.utils.data as data
2 -import torch
3 -import h5py
4 -
5 -class DatasetFromHdf5(data.Dataset):
6 - def __init__(self, file_path):
7 - super(DatasetFromHdf5, self).__init__()
8 - hf = h5py.File(file_path)
9 - self.data = hf.get('data')
10 - self.target = hf.get('label')
11 -
12 - def __getitem__(self, index):
13 - return torch.from_numpy(self.data[index,:,:,:]).float(), torch.from_numpy(self.target[index,:,:,:]).float()
14 -
15 - def __len__(self):
16 - return self.data.shape[0]
...\ No newline at end of file ...\ No newline at end of file
1 -import argparse, os
2 -import torch
3 -from torch.autograd import Variable
4 -from scipy.ndimage import imread
5 -from PIL import Image
6 -import numpy as np
7 -import time, math
8 -import matplotlib.pyplot as plt
9 -
10 -parser = argparse.ArgumentParser(description="PyTorch VDSR Demo")
11 -parser.add_argument("--cuda", action="store_true", help="use cuda?")
12 -parser.add_argument("--model", default="model/model_epoch_50.pth", type=str, help="model path")
13 -parser.add_argument("--image", default="butterfly_GT", type=str, help="image name")
14 -parser.add_argument("--scale", default=4, type=int, help="scale factor, Default: 4")
15 -parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)")
16 -
17 -def PSNR(pred, gt, shave_border=0):
18 - height, width = pred.shape[:2]
19 - pred = pred[shave_border:height - shave_border, shave_border:width - shave_border]
20 - gt = gt[shave_border:height - shave_border, shave_border:width - shave_border]
21 - imdff = pred - gt
22 - rmse = math.sqrt(np.mean(imdff ** 2))
23 - if rmse == 0:
24 - return 100
25 - return 20 * math.log10(255.0 / rmse)
26 -
27 -def colorize(y, ycbcr):
28 - img = np.zeros((y.shape[0], y.shape[1], 3), np.uint8)
29 - img[:,:,0] = y
30 - img[:,:,1] = ycbcr[:,:,1]
31 - img[:,:,2] = ycbcr[:,:,2]
32 - img = Image.fromarray(img, "YCbCr").convert("RGB")
33 - return img
34 -
35 -opt = parser.parse_args()
36 -cuda = opt.cuda
37 -
38 -if cuda:
39 - print("=> use gpu id: '{}'".format(opt.gpus))
40 - os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus
41 - if not torch.cuda.is_available():
42 - raise Exception("No GPU found or Wrong gpu id, please run without --cuda")
43 -
44 -
45 -model = torch.load(opt.model, map_location=lambda storage, loc: storage)["model"]
46 -
47 -im_gt_ycbcr = imread("Set5/" + opt.image + ".bmp", mode="YCbCr")
48 -im_b_ycbcr = imread("Set5/"+ opt.image + "_scale_"+ str(opt.scale) + ".bmp", mode="YCbCr")
49 -
50 -im_gt_y = im_gt_ycbcr[:,:,0].astype(float)
51 -im_b_y = im_b_ycbcr[:,:,0].astype(float)
52 -
53 -psnr_bicubic = PSNR(im_gt_y, im_b_y,shave_border=opt.scale)
54 -
55 -im_input = im_b_y/255.
56 -
57 -im_input = Variable(torch.from_numpy(im_input).float()).view(1, -1, im_input.shape[0], im_input.shape[1])
58 -
59 -if cuda:
60 - model = model.cuda()
61 - im_input = im_input.cuda()
62 -else:
63 - model = model.cpu()
64 -
65 -start_time = time.time()
66 -out = model(im_input)
67 -elapsed_time = time.time() - start_time
68 -
69 -out = out.cpu()
70 -
71 -im_h_y = out.data[0].numpy().astype(np.float32)
72 -
73 -im_h_y = im_h_y * 255.
74 -im_h_y[im_h_y < 0] = 0
75 -im_h_y[im_h_y > 255.] = 255.
76 -
77 -psnr_predicted = PSNR(im_gt_y, im_h_y[0,:,:], shave_border=opt.scale)
78 -
79 -im_h = colorize(im_h_y[0,:,:], im_b_ycbcr)
80 -im_gt = Image.fromarray(im_gt_ycbcr, "YCbCr").convert("RGB")
81 -im_b = Image.fromarray(im_b_ycbcr, "YCbCr").convert("RGB")
82 -
83 -print("Scale=",opt.scale)
84 -print("PSNR_predicted=", psnr_predicted)
85 -print("PSNR_bicubic=", psnr_bicubic)
86 -print("It takes {}s for processing".format(elapsed_time))
87 -
88 -fig = plt.figure()
89 -ax = plt.subplot("131")
90 -ax.imshow(im_gt)
91 -ax.set_title("GT")
92 -
93 -ax = plt.subplot("132")
94 -ax.imshow(im_b)
95 -ax.set_title("Input(bicubic)")
96 -
97 -ax = plt.subplot("133")
98 -ax.imshow(im_h)
99 -ax.set_title("Output(vdsr)")
100 -plt.show()
1 -import argparse, os
2 -import torch
3 -from torch.autograd import Variable
4 -import numpy as np
5 -import time, math, glob
6 -import scipy.io as sio
7 -
8 -parser = argparse.ArgumentParser(description="PyTorch VDSR Eval")
9 -parser.add_argument("--cuda", action="store_true", help="use cuda?")
10 -parser.add_argument("--model", default="model/model_epoch_50.pth", type=str, help="model path")
11 -parser.add_argument("--dataset", default="Set5", type=str, help="dataset name, Default: Set5")
12 -parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)")
13 -
14 -def PSNR(pred, gt, shave_border=0):
15 - height, width = pred.shape[:2]
16 - pred = pred[shave_border:height - shave_border, shave_border:width - shave_border]
17 - gt = gt[shave_border:height - shave_border, shave_border:width - shave_border]
18 - imdff = pred - gt
19 - rmse = math.sqrt(np.mean(imdff ** 2))
20 - if rmse == 0:
21 - return 100
22 - return 20 * math.log10(255.0 / rmse)
23 -
24 -opt = parser.parse_args()
25 -cuda = opt.cuda
26 -
27 -if cuda:
28 - print("=> use gpu id: '{}'".format(opt.gpus))
29 - os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus
30 - if not torch.cuda.is_available():
31 - raise Exception("No GPU found or Wrong gpu id, please run without --cuda")
32 -
33 -model = torch.load(opt.model, map_location=lambda storage, loc: storage)["model"]
34 -
35 -scales = [2,3,4]
36 -
37 -image_list = glob.glob(opt.dataset+"_mat/*.*")
38 -
39 -for scale in scales:
40 - avg_psnr_predicted = 0.0
41 - avg_psnr_bicubic = 0.0
42 - avg_elapsed_time = 0.0
43 - count = 0.0
44 - for image_name in image_list:
45 - if str(scale) in image_name:
46 - count += 1
47 - print("Processing ", image_name)
48 - im_gt_y = sio.loadmat(image_name)['im_gt_y']
49 - im_b_y = sio.loadmat(image_name)['im_b_y']
50 -
51 - im_gt_y = im_gt_y.astype(float)
52 - im_b_y = im_b_y.astype(float)
53 -
54 - psnr_bicubic = PSNR(im_gt_y, im_b_y,shave_border=scale)
55 - avg_psnr_bicubic += psnr_bicubic
56 -
57 - im_input = im_b_y/255.
58 -
59 - im_input = Variable(torch.from_numpy(im_input).float()).view(1, -1, im_input.shape[0], im_input.shape[1])
60 -
61 - if cuda:
62 - model = model.cuda()
63 - im_input = im_input.cuda()
64 - else:
65 - model = model.cpu()
66 -
67 - start_time = time.time()
68 - HR = model(im_input)
69 - elapsed_time = time.time() - start_time
70 - avg_elapsed_time += elapsed_time
71 -
72 - HR = HR.cpu()
73 -
74 - im_h_y = HR.data[0].numpy().astype(np.float32)
75 -
76 - im_h_y = im_h_y * 255.
77 - im_h_y[im_h_y < 0] = 0
78 - im_h_y[im_h_y > 255.] = 255.
79 - im_h_y = im_h_y[0,:,:]
80 -
81 - psnr_predicted = PSNR(im_gt_y, im_h_y,shave_border=scale)
82 - avg_psnr_predicted += psnr_predicted
83 -
84 - print("Scale=", scale)
85 - print("Dataset=", opt.dataset)
86 - print("PSNR_predicted=", avg_psnr_predicted/count)
87 - print("PSNR_bicubic=", avg_psnr_bicubic/count)
88 - print("It takes average {}s for processing".format(avg_elapsed_time/count))
1 +# Feature SR
2 +
3 +1. train
4 +
5 +`!python main.py --dataRoot /content/drive/MyDrive/feature/HR_trainset/features --scaleFactor 4 --featureType p6 --batchSize 16 --cuda --nEpochs 20`
6 +
7 +2. inference
8 +
9 +`!python inference.py --cuda --model "model.pth" --dataset "/content/drive/MyDrive/feature/features/LR_2" --featureType "p3" --scaleFactor 4`
10 +
11 +3. calculate mAP
12 +
13 +```
14 +# [1]
15 +# install dependencies:
16 +!pip install pyyaml==5.1
17 +import torch, torchvision
18 +print(torch.__version__, torch.cuda.is_available())
19 +!gcc --version
20 +# opencv is pre-installed on colab
21 +
22 +# [2]
23 +# install detectron2: (Colab has CUDA 10.1 + torch 1.8)
24 +# See https://detectron2.readthedocs.io/tutorials/install.html for instructions
25 +import torch
26 +assert torch.__version__.startswith("1.8") # need to manually install torch 1.8 if Colab changes its default version
27 +!pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu101/torch1.8/index.html
28 +# exit(0) # After installation, you need to "restart runtime" in Colab. This line can also restart runtime
29 +
30 +# [3]
31 +# Some basic setup:
32 +# Setup detectron2 logger
33 +import detectron2
34 +from detectron2.utils.logger import setup_logger
35 +setup_logger()
36 +
37 +!python calculate_mAP.py --valid_data_path /content/drive/MyDrive/dataset/validset_100/ --model_name VDSR --loss_type MSE --batch_size 16
38 +```
...\ No newline at end of file ...\ No newline at end of file
1 +# # # [1]
2 +# # # install dependencies:
3 +# !pip install pyyaml==5.1
4 +# import torch, torchvision
5 +# print(torch.__version__, torch.cuda.is_available())
6 +# !gcc --version
7 +# # opencv is pre-installed on colab
8 +
9 +# # # [2]
10 +# # # install detectron2: (Colab has CUDA 10.1 + torch 1.8)
11 +# # # See https://detectron2.readthedocs.io/tutorials/install.html for instructions
12 +# import torch
13 +# assert torch.__version__.startswith("1.8") # need to manually install torch 1.8 if Colab changes its default version
14 +# !pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu101/torch1.8/index.html
15 +# # exit(0) # After installation, you need to "restart runtime" in Colab. This line can also restart runtime
16 +
17 +# # # [3]
18 +# # # Some basic setup:
19 +# # # Setup detectron2 logger
20 +# import detectron2
21 +# from detectron2.utils.logger import setup_logger
22 +# setup_logger()
23 +
24 +# import some common libraries
25 +import torch
26 +import numpy as np
27 +import os, json, cv2, random, math
28 +from PIL import Image
29 +from torch.nn.utils.rnn import pad_sequence
30 +
31 +# import some common detectron2 utilities
32 +from detectron2 import model_zoo
33 +from detectron2.engine import DefaultPredictor
34 +from detectron2.config import get_cfg
35 +from detectron2.utils.visualizer import Visualizer
36 +from detectron2.data import MetadataCatalog, DatasetCatalog
37 +from detectron2.modeling import build_model, build_backbone
38 +from detectron2.checkpoint import DetectionCheckpointer
39 +from detectron2.utils.visualizer import Visualizer
40 +import detectron2.data.transforms as T
41 +
42 +from pycocotools.coco import COCO
43 +from pycocotools.cocoeval import COCOeval
44 +from pycocotools.mask import encode
45 +import argparse
46 +
47 +parser = argparse.ArgumentParser(description="PyTorch CARN")
48 +parser.add_argument("--data_path", type=str, default = "/home/ubuntu/JH/exp1/dataset")
49 +parser.add_argument("--valid_data_path", type=str)
50 +parser.add_argument("--rescale_factor", type=int, default=4, help="rescale factor for using in training")
51 +parser.add_argument("--model_name", type=str,choices= ["VDSR", "CARN", "SRRN","FRGAN"], default='CARN', help="Feature type for usingin training")
52 +parser.add_argument("--loss_type", type=str, choices= ["MSE", "L1", "SmoothL1","vgg_loss","ssim_loss","adv_loss","lpips"], default='MSE', help="loss type in training")
53 +parser.add_argument('--batch_size', type=int, default=256)
54 +opt = parser.parse_args()
55 +print(opt)
56 +
57 +
58 +def myRound(x): # 양수와 음수에 대해 0을 대칭으로 rounding
59 + abs_x = abs(x)
60 + val = np.int16(abs_x + 0.5)
61 + val2 = np.choose(
62 + x < 0,
63 + [
64 + val, val*(-1)
65 + ]
66 + )
67 + return val2
68 +
69 +def myClip(x, maxV):
70 + val = np.choose(
71 + x > maxV,
72 + [
73 + x, maxV
74 + ]
75 + )
76 + return val
77 +
78 +image_idx = 0
79 +cfg = get_cfg()
80 +# add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library
81 +cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
82 +cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model
83 +# Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as well
84 +cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
85 +
86 +model = build_model(cfg)
87 +DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS)
88 +model.eval()
89 +
90 +image_idx = 0
91 +anns = 0
92 +
93 +# Original_8bit
94 +
95 +image_files = ['001000', '002153', '008021', '009769', '009891', '015335', '017627', '018150', '018837', '022589']
96 +image_files.extend(['022935', '023230', '024610', '025560', '025593', '027620', '155341', '161397', '165336', '166287'])
97 +image_files.extend(['166642', '169996', '172330', '172648', '176606', '176701', '179765', '180101', '186296', '250758'])
98 +image_files.extend(['259382', '267191', '287545', '287649', '289741', '293245', '308328', '309452', '335529', '337987'])
99 +image_files.extend(['338625', '344029', '350122', '389933', '393226', '395343', '395633', '401862', '402473', '402992'])
100 +image_files.extend(['404568', '406997', '408112', '410650', '414385', '414795', '415194', '415536', '416104', '416758'])
101 +image_files.extend(['427055', '428562', '430073', '433204', '447200', '447313', '448448', '452321', '453001', '458755'])
102 +image_files.extend(['462904', '463522', '464089', '468965', '469192', '469246', '471450', '474078', '474881', '475678'])
103 +image_files.extend(['475779', '537802', '542625', '543043', '543300', '543528', '547502', '550691', '553669', '567740'])
104 +image_files.extend(['570688', '570834', '571943', '573391', '574315', '575372', '575970', '578093', '579158', '581100'])
105 +
106 +
107 +for iter in range(0, 100):
108 +
109 + image_file_number = image_files[image_idx]
110 + aug = T.ResizeShortestEdge(
111 + # [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
112 + # [480, 480], cfg.INPUT.MAX_SIZE_TEST
113 + [768, 768], cfg.INPUT.MAX_SIZE_TEST
114 + )
115 + image_prefix = "COCO_val2017_"
116 + image = cv2.imread(opt.valid_data_path + '000000'+ image_file_number +'.jpg')
117 + # image = cv2.imread('./dataset/validset_100/000000'+ image_file_number +'.jpg')
118 + height, width = image.shape[:2]
119 + image = aug.get_transform(image).apply_image(image)
120 + image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
121 + inputs = [{"image": image, "height": height, "width": width}]
122 + with torch.no_grad():
123 + images = model.preprocess_image(inputs) # don't forget to preprocess
124 + features = model.backbone(images.tensor) # set of cnn features
125 +
126 +
127 + p2_feature_original = features['p2'].to("cpu")
128 + p3_feature_original = features['p3'].to("cpu")
129 + p4_feature_original = features['p4'].to("cpu")
130 +
131 + bitDepth = 8
132 + maxRange = [0, 0, 0, 0, 0]
133 +
134 + def maxVal(x):
135 + return pow(2, x)
136 + def offsetVal(x):
137 + return pow(2, x-1)
138 +
139 + def maxRange_layer(x):
140 + absolute_arr = torch.abs(x) * 2
141 + max_arr = torch.max(absolute_arr)
142 + return torch.ceil(max_arr)
143 +
144 +
145 + act2 = p2_feature_original.squeeze()
146 + maxRange[0] = maxRange_layer(act2)
147 +
148 + act3 = p3_feature_original.squeeze()
149 + maxRange[1] = maxRange_layer(act3)
150 +
151 + act4 = p4_feature_original.squeeze()
152 + maxRange[2] = maxRange_layer(act4)
153 +
154 + globals()['maxRange_{}'.format(image_file_number)] = maxRange
155 +
156 + # p2_feature_img = Image.open('./original/qp32/COCO_val2014_000000'+ image_file_number +'_p2.png'
157 + # p2_feature_img = Image.open('./result/{}/inference/{}_p2x{}/SR_{}.png'.format(opt.loss_type,opt.model_name,opt.rescale_factor,str(iter)))
158 + p2_feature_img = Image.open('/content/drive/MyDrive/result/inference/LR_2/p2/' + image_prefix + '000000' + image_file_number + '_p2' +'.png')
159 + # # y_p2, cb, cr = p2_feature_img.split()
160 + p2_feature_arr = np.array(p2_feature_img)
161 + p2_feature_arr_round = myRound(p2_feature_arr)
162 +
163 + # p3_feature_img = Image.open('./original/qp32/COCO_val2014_000000'+ image_file_number +'_p3.png')
164 +
165 + # p3_feature_img = Image.open('./result/{}/inference/{}_p3x{}/SR_{}.png'.format(opt.loss_type,opt.model_name,opt.rescale_factor,str(iter)))
166 + p3_feature_img = Image.open('/content/drive/MyDrive/result/inference/LR_2/p3/' + image_prefix + '000000' + image_file_number + '_p3' +'.png')
167 + # # y_p3, cb2, cr2 = p3_feature_img.split()
168 + p3_feature_arr = np.array(p3_feature_img)
169 + p3_feature_arr_round = myRound(p3_feature_arr)
170 +
171 + # p4_feature_img = Image.open('./original/qp32/COCO_val2014_000000'+ image_file_number +'_p4.png')
172 + # p4_feature_img = Image.open('./result/{}/inference/{}_p4x{}/SR_{}.png'.format(opt.loss_type,opt.model_name,opt.rescale_factor,str(iter)))
173 + p4_feature_img = Image.open('/content/drive/MyDrive/result/inference/LR_2/p4/' + image_prefix + '000000' + image_file_number + '_p4' +'.png')
174 + # y_p4, cb3, cr3 = p4_feature_img.split()
175 + p4_feature_arr = np.array(p4_feature_img)
176 + p4_feature_arr_round = myRound(p4_feature_arr)
177 +
178 +
179 + # 복원
180 + recon_p2 = (((p2_feature_arr_round - offsetVal(bitDepth)) / maxVal(bitDepth)) * maxRange[0].numpy())
181 + recon_p3 = (((p3_feature_arr_round - offsetVal(bitDepth)) / maxVal(bitDepth)) * maxRange[1].numpy())
182 + recon_p4 = (((p4_feature_arr_round - offsetVal(bitDepth)) / maxVal(bitDepth)) * maxRange[2].numpy())
183 +
184 + tensor_value = recon_p2
185 + tensor_value2 = recon_p3
186 + tensor_value3 = recon_p4
187 +
188 + # # MSB 코드 끝
189 +
190 + # lsb 및 원래 코드
191 + # 복원
192 + # recon_p2 = (((p2_feature_arr_round - offsetVal(bitDepth)) / maxVal(bitDepth)) * maxRange[0].numpy())
193 + # recon_p3 = (((p3_feature_arr_round - offsetVal(bitDepth)) / maxVal(bitDepth)) * maxRange[1].numpy())
194 + # recon_p4 = (((p4_feature_arr_round - offsetVal(bitDepth)) / maxVal(bitDepth)) * maxRange[2].numpy())
195 + # recon_p5 = (((p5_feature_arr_round - offsetVal(bitDepth)) / maxVal(bitDepth)) * maxRange[3].numpy())
196 + # recon_p6 = (((p6_feature_arr_round - offsetVal(bitDepth)) / maxVal(bitDepth)) * maxRange[4].numpy())
197 +
198 + tensor_value = torch.as_tensor(recon_p2.astype("float32"))
199 + tensor_value2 = torch.as_tensor(recon_p3.astype("float32"))
200 + tensor_value3 = torch.as_tensor(recon_p4.astype("float32"))
201 + #lsb 및 원래 코드 끝
202 +
203 + t = [None] * 16
204 + t[0], t[1], t[2], t[3], t[4], t[5], t[6], t[7], t[8], t[9], t[10], t[11], t[12], t[13], t[14], t[15] = torch.chunk(tensor_value, 16, dim=0)
205 + p2 = [None] * 256
206 +
207 + t2 = [None] * 16
208 + t2[0], t2[1], t2[2], t2[3], t2[4], t2[5], t2[6], t2[7], t2[8], t2[9], t2[10], t2[11], t2[12], t2[13], t2[14], t2[15] = torch.chunk(tensor_value2, 16, dim=0)
209 + p3 = [None] * 256
210 +
211 + t3 = [None] * 16
212 + t3[0], t3[1], t3[2], t3[3], t3[4], t3[5], t3[6], t3[7], t3[8], t3[9], t3[10], t3[11], t3[12], t3[13], t3[14], t3[15] = torch.chunk(tensor_value3, 16, dim=0)
213 + p4 = [None] * 256
214 +
215 + p2[0], p2[1], p2[2], p2[3], p2[4], p2[5], p2[6], p2[7], p2[8], p2[9], p2[10], p2[11], p2[12], p2[13], p2[14], p2[15] = torch.chunk(t[0], 16, dim=1)
216 + p2[16], p2[17], p2[18], p2[19], p2[20], p2[21], p2[22], p2[23], p2[24], p2[25], p2[26], p2[27], p2[28], p2[29], p2[30], p2[31] = torch.chunk(t[1], 16, dim=1)
217 + p2[32], p2[33], p2[34], p2[35], p2[36], p2[37], p2[38], p2[39], p2[40], p2[41], p2[42], p2[43], p2[44], p2[45], p2[46], p2[47] = torch.chunk(t[2], 16, dim=1)
218 + p2[48], p2[49], p2[50], p2[51], p2[52], p2[53], p2[54], p2[55], p2[56], p2[57], p2[58], p2[59], p2[60], p2[61], p2[62], p2[63] = torch.chunk(t[3], 16, dim=1)
219 + p2[64], p2[65], p2[66], p2[67], p2[68], p2[69], p2[70], p2[71], p2[72], p2[73], p2[74], p2[75], p2[76], p2[77], p2[78], p2[79] = torch.chunk(t[4], 16, dim=1)
220 + p2[80], p2[81], p2[82], p2[83], p2[84], p2[85], p2[86], p2[87], p2[88], p2[89], p2[90], p2[91], p2[92], p2[93], p2[94], p2[95] = torch.chunk(t[5], 16, dim=1)
221 + p2[96], p2[97], p2[98], p2[99], p2[100], p2[101], p2[102], p2[103], p2[104], p2[105], p2[106], p2[107], p2[108], p2[109], p2[110], p2[111] = torch.chunk(t[6], 16, dim=1)
222 + p2[112], p2[113], p2[114], p2[115], p2[116], p2[117], p2[118], p2[119], p2[120], p2[121], p2[122], p2[123], p2[124], p2[125], p2[126], p2[127] = torch.chunk(t[7], 16, dim=1)
223 + p2[128], p2[129], p2[130], p2[131], p2[132], p2[133], p2[134], p2[135], p2[136], p2[137], p2[138], p2[139], p2[140], p2[141], p2[142], p2[143] = torch.chunk(t[8], 16, dim=1)
224 + p2[144], p2[145], p2[146], p2[147], p2[148], p2[149], p2[150], p2[151], p2[152], p2[153], p2[154], p2[155], p2[156], p2[157], p2[158], p2[159] = torch.chunk(t[9], 16, dim=1)
225 + p2[160], p2[161], p2[162], p2[163], p2[164], p2[165], p2[166], p2[167], p2[168], p2[169], p2[170], p2[171], p2[172], p2[173], p2[174], p2[175] = torch.chunk(t[10], 16, dim=1)
226 + p2[176], p2[177], p2[178], p2[179], p2[180], p2[181], p2[182], p2[183], p2[184], p2[185], p2[186], p2[187], p2[188], p2[189], p2[190], p2[191] = torch.chunk(t[11], 16, dim=1)
227 + p2[192], p2[193], p2[194], p2[195], p2[196], p2[197], p2[198], p2[199], p2[200], p2[201], p2[202], p2[203], p2[204], p2[205], p2[206], p2[207] = torch.chunk(t[12], 16, dim=1)
228 + p2[208], p2[209], p2[210], p2[211], p2[212], p2[213], p2[214], p2[215], p2[216], p2[217], p2[218], p2[219], p2[220], p2[221], p2[222], p2[223] = torch.chunk(t[13], 16, dim=1)
229 + p2[224], p2[225], p2[226], p2[227], p2[228], p2[229], p2[230], p2[231], p2[232], p2[233], p2[234], p2[235], p2[236], p2[237], p2[238], p2[239] = torch.chunk(t[14], 16, dim=1)
230 + p2[240], p2[241], p2[242], p2[243], p2[244], p2[245], p2[246], p2[247], p2[248], p2[249], p2[250], p2[251], p2[252], p2[253], p2[254], p2[255] = torch.chunk(t[15], 16, dim=1)
231 +
232 + p3[0], p3[1], p3[2], p3[3], p3[4], p3[5], p3[6], p3[7], p3[8], p3[9], p3[10], p3[11], p3[12], p3[13], p3[14], p3[15] = torch.chunk(t2[0], 16, dim=1)
233 + p3[16], p3[17], p3[18], p3[19], p3[20], p3[21], p3[22], p3[23], p3[24], p3[25], p3[26], p3[27], p3[28], p3[29], p3[30], p3[31] = torch.chunk(t2[1], 16, dim=1)
234 + p3[32], p3[33], p3[34], p3[35], p3[36], p3[37], p3[38], p3[39], p3[40], p3[41], p3[42], p3[43], p3[44], p3[45], p3[46], p3[47] = torch.chunk(t2[2], 16, dim=1)
235 + p3[48], p3[49], p3[50], p3[51], p3[52], p3[53], p3[54], p3[55], p3[56], p3[57], p3[58], p3[59], p3[60], p3[61], p3[62], p3[63] = torch.chunk(t2[3], 16, dim=1)
236 + p3[64], p3[65], p3[66], p3[67], p3[68], p3[69], p3[70], p3[71], p3[72], p3[73], p3[74], p3[75], p3[76], p3[77], p3[78], p3[79] = torch.chunk(t2[4], 16, dim=1)
237 + p3[80], p3[81], p3[82], p3[83], p3[84], p3[85], p3[86], p3[87], p3[88], p3[89], p3[90], p3[91], p3[92], p3[93], p3[94], p3[95] = torch.chunk(t2[5], 16, dim=1)
238 + p3[96], p3[97], p3[98], p3[99], p3[100], p3[101], p3[102], p3[103], p3[104], p3[105], p3[106], p3[107], p3[108], p3[109], p3[110], p3[111] = torch.chunk(t2[6], 16, dim=1)
239 + p3[112], p3[113], p3[114], p3[115], p3[116], p3[117], p3[118], p3[119], p3[120], p3[121], p3[122], p3[123], p3[124], p3[125], p3[126], p3[127] = torch.chunk(t2[7], 16, dim=1)
240 + p3[128], p3[129], p3[130], p3[131], p3[132], p3[133], p3[134], p3[135], p3[136], p3[137], p3[138], p3[139], p3[140], p3[141], p3[142], p3[143] = torch.chunk(t2[8], 16, dim=1)
241 + p3[144], p3[145], p3[146], p3[147], p3[148], p3[149], p3[150], p3[151], p3[152], p3[153], p3[154], p3[155], p3[156], p3[157], p3[158], p3[159] = torch.chunk(t2[9], 16, dim=1)
242 + p3[160], p3[161], p3[162], p3[163], p3[164], p3[165], p3[166], p3[167], p3[168], p3[169], p3[170], p3[171], p3[172], p3[173], p3[174], p3[175] = torch.chunk(t2[10], 16, dim=1)
243 + p3[176], p3[177], p3[178], p3[179], p3[180], p3[181], p3[182], p3[183], p3[184], p3[185], p3[186], p3[187], p3[188], p3[189], p3[190], p3[191] = torch.chunk(t2[11], 16, dim=1)
244 + p3[192], p3[193], p3[194], p3[195], p3[196], p3[197], p3[198], p3[199], p3[200], p3[201], p3[202], p3[203], p3[204], p3[205], p3[206], p3[207] = torch.chunk(t2[12], 16, dim=1)
245 + p3[208], p3[209], p3[210], p3[211], p3[212], p3[213], p3[214], p3[215], p3[216], p3[217], p3[218], p3[219], p3[220], p3[221], p3[222], p3[223] = torch.chunk(t2[13], 16, dim=1)
246 + p3[224], p3[225], p3[226], p3[227], p3[228], p3[229], p3[230], p3[231], p3[232], p3[233], p3[234], p3[235], p3[236], p3[237], p3[238], p3[239] = torch.chunk(t2[14], 16, dim=1)
247 + p3[240], p3[241], p3[242], p3[243], p3[244], p3[245], p3[246], p3[247], p3[248], p3[249], p3[250], p3[251], p3[252], p3[253], p3[254], p3[255] = torch.chunk(t2[15], 16, dim=1)
248 +
249 + p4[0], p4[1], p4[2], p4[3], p4[4], p4[5], p4[6], p4[7], p4[8], p4[9], p4[10], p4[11], p4[12], p4[13], p4[14], p4[15] = torch.chunk(t3[0], 16, dim=1)
250 + p4[16], p4[17], p4[18], p4[19], p4[20], p4[21], p4[22], p4[23], p4[24], p4[25], p4[26], p4[27], p4[28], p4[29], p4[30], p4[31] = torch.chunk(t3[1], 16, dim=1)
251 + p4[32], p4[33], p4[34], p4[35], p4[36], p4[37], p4[38], p4[39], p4[40], p4[41], p4[42], p4[43], p4[44], p4[45], p4[46], p4[47] = torch.chunk(t3[2], 16, dim=1)
252 + p4[48], p4[49], p4[50], p4[51], p4[52], p4[53], p4[54], p4[55], p4[56], p4[57], p4[58], p4[59], p4[60], p4[61], p4[62], p4[63] = torch.chunk(t3[3], 16, dim=1)
253 + p4[64], p4[65], p4[66], p4[67], p4[68], p4[69], p4[70], p4[71], p4[72], p4[73], p4[74], p4[75], p4[76], p4[77], p4[78], p4[79] = torch.chunk(t3[4], 16, dim=1)
254 + p4[80], p4[81], p4[82], p4[83], p4[84], p4[85], p4[86], p4[87], p4[88], p4[89], p4[90], p4[91], p4[92], p4[93], p4[94], p4[95] = torch.chunk(t3[5], 16, dim=1)
255 + p4[96], p4[97], p4[98], p4[99], p4[100], p4[101], p4[102], p4[103], p4[104], p4[105], p4[106], p4[107], p4[108], p4[109], p4[110], p4[111] = torch.chunk(t3[6], 16, dim=1)
256 + p4[112], p4[113], p4[114], p4[115], p4[116], p4[117], p4[118], p4[119], p4[120], p4[121], p4[122], p4[123], p4[124], p4[125], p4[126], p4[127] = torch.chunk(t3[7], 16, dim=1)
257 + p4[128], p4[129], p4[130], p4[131], p4[132], p4[133], p4[134], p4[135], p4[136], p4[137], p4[138], p4[139], p4[140], p4[141], p4[142], p4[143] = torch.chunk(t3[8], 16, dim=1)
258 + p4[144], p4[145], p4[146], p4[147], p4[148], p4[149], p4[150], p4[151], p4[152], p4[153], p4[154], p4[155], p4[156], p4[157], p4[158], p4[159] = torch.chunk(t3[9], 16, dim=1)
259 + p4[160], p4[161], p4[162], p4[163], p4[164], p4[165], p4[166], p4[167], p4[168], p4[169], p4[170], p4[171], p4[172], p4[173], p4[174], p4[175] = torch.chunk(t3[10], 16, dim=1)
260 + p4[176], p4[177], p4[178], p4[179], p4[180], p4[181], p4[182], p4[183], p4[184], p4[185], p4[186], p4[187], p4[188], p4[189], p4[190], p4[191] = torch.chunk(t3[11], 16, dim=1)
261 + p4[192], p4[193], p4[194], p4[195], p4[196], p4[197], p4[198], p4[199], p4[200], p4[201], p4[202], p4[203], p4[204], p4[205], p4[206], p4[207] = torch.chunk(t3[12], 16, dim=1)
262 + p4[208], p4[209], p4[210], p4[211], p4[212], p4[213], p4[214], p4[215], p4[216], p4[217], p4[218], p4[219], p4[220], p4[221], p4[222], p4[223] = torch.chunk(t3[13], 16, dim=1)
263 + p4[224], p4[225], p4[226], p4[227], p4[228], p4[229], p4[230], p4[231], p4[232], p4[233], p4[234], p4[235], p4[236], p4[237], p4[238], p4[239] = torch.chunk(t3[14], 16, dim=1)
264 + p4[240], p4[241], p4[242], p4[243], p4[244], p4[245], p4[246], p4[247], p4[248], p4[249], p4[250], p4[251], p4[252], p4[253], p4[254], p4[255] = torch.chunk(t3[15], 16, dim=1)
265 +
266 + p2_tensor = pad_sequence(p2, batch_first=True)
267 + p3_tensor = pad_sequence(p3, batch_first=True)
268 + p4_tensor = pad_sequence(p4, batch_first=True)
269 +
270 + cc = p2_tensor.unsqueeze(0)
271 + cc2 = p3_tensor.unsqueeze(0)
272 + cc3 = p4_tensor.unsqueeze(0)
273 +
274 + p2_cuda = cc.to(torch.device("cuda"))
275 + p3_cuda = cc2.to(torch.device("cuda"))
276 + p4_cuda = cc3.to(torch.device("cuda"))
277 +
278 + aug = T.ResizeShortestEdge(
279 + # [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
280 + # [480, 480], cfg.INPUT.MAX_SIZE_TEST
281 + [768, 768], cfg.INPUT.MAX_SIZE_TEST
282 + )
283 + image = cv2.imread(opt.valid_data_path + '000000'+ image_file_number +'.jpg')
284 + height, width = image.shape[:2]
285 + image = aug.get_transform(image).apply_image(image)
286 + image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
287 + inputs = [{"image": image, "height": height, "width": width}]
288 +
289 + with torch.no_grad():
290 + images = model.preprocess_image(inputs) # don't forget to preprocess
291 + features = model.backbone(images.tensor) # set of cnn features
292 + features['p2'] = p2_cuda
293 + features['p3'] = p3_cuda
294 + features['p4'] = p4_cuda
295 +
296 + proposals, _ = model.proposal_generator(images, features, None) # RPN
297 +
298 + features_ = [features[f] for f in model.roi_heads.box_in_features]
299 + box_features = model.roi_heads.box_pooler(features_, [x.proposal_boxes for x in proposals])
300 + box_features = model.roi_heads.box_head(box_features) # features of all 1k candidates
301 + predictions = model.roi_heads.box_predictor(box_features)
302 + pred_instances, pred_inds = model.roi_heads.box_predictor.inference(predictions, proposals)
303 + pred_instances = model.roi_heads.forward_with_given_boxes(features, pred_instances)
304 +
305 + # output boxes, masks, scores, etc
306 + pred_instances = model._postprocess(pred_instances, inputs, images.image_sizes) # scale box to orig size
307 + # features of the proposed boxes
308 + feats = box_features[pred_inds]
309 +
310 + pred_category = pred_instances[0]["instances"].pred_classes.to("cpu")
311 + pred_segmentation = pred_instances[0]["instances"].pred_masks.to("cpu")
312 + pred_score = pred_instances[0]["instances"].scores.to("cpu")
313 +
314 + xxx = pred_category
315 + xxx = xxx.numpy()
316 +
317 + xxx = xxx + 1
318 +
319 + for idx in range(len(xxx)):
320 + if -1 < int(xxx[idx]) < 12:
321 + xxx[idx] = xxx[idx]
322 + elif 11 < int(xxx[idx]) < 25:
323 + xxx[idx] = xxx[idx] + 1
324 + elif 24 < int(xxx[idx]) < 27:
325 + xxx[idx] = xxx[idx] + 2
326 + elif 26 < int(xxx[idx]) < 41:
327 + xxx[idx] = xxx[idx] + 4
328 + elif 40 < int(xxx[idx]) < 61:
329 + xxx[idx] = xxx[idx] + 5
330 + elif 60 < int(xxx[idx]) < 62:
331 + xxx[idx] = 67
332 + elif 61 < int(xxx[idx]) < 63:
333 + xxx[idx] = 70
334 + elif 62 < int(xxx[idx]) < 74:
335 + xxx[idx] = xxx[idx] + 9
336 + else:
337 + xxx[idx] = xxx[idx] + 10
338 +
339 + imgID = int(image_file_number)
340 + if image_idx == 0:
341 + anns = []
342 + else:
343 + anns = anns
344 +
345 + for idx in range(len(pred_category.numpy())):
346 +
347 + anndata = {}
348 + anndata['image_id'] = imgID
349 + anndata['category_id'] = int(xxx[idx])
350 +
351 + anndata['segmentation'] = encode(np.asfortranarray(pred_segmentation[idx].numpy()))
352 + anndata['score'] = float(pred_score[idx].numpy())
353 + anns.append(anndata)
354 +
355 + image_idx = image_idx + 1
356 + # print("###image###:{}".format(image_idx))
357 +
358 +annType = ['segm','bbox','keypoints']
359 +annType = annType[0] #specify type here
360 +prefix = 'instances'
361 +print('Running demo for *%s* results.'%(annType))
362 +# imgIds = [560474]
363 +
364 +annFile = './instances_val2017_dataset100.json'
365 +cocoGt=COCO(annFile)
366 +
367 +#initialize COCO detections api
368 +resFile = anns
369 +cocoDt=cocoGt.loadRes(resFile)
370 +
371 +# running evaluation
372 +cocoEval = COCOeval(cocoGt,cocoDt,annType)
373 +# cocoEval.params.imgIds = imgIds
374 +# 맨 윗줄
375 +cocoEval.evaluate()
376 +cocoEval.accumulate()
377 +cocoEval.summarize()
...\ No newline at end of file ...\ No newline at end of file
1 +import os
2 +from PIL import Image
3 +
4 +def crop_feature(datapath, feature_type, scale_factor, print_message=False):
5 + data_path = datapath
6 + datatype = feature_type
7 + rescale_factor = scale_factor
8 + if not os.path.exists(data_path):
9 + raise Exception(f"[!] {data_path} not existed")
10 +
11 + hr_imgs = []
12 + w, h = Image.open(datapath).size
13 + width = int(w / 16)
14 + height = int(h / 16)
15 + lwidth = int(width / rescale_factor)
16 + lheight = int(height / rescale_factor)
17 + if print_message:
18 + print("lr: ({} {}), hr: ({} {})".format(lwidth, lheight, width, height))
19 + hr_image = Image.open(datapath) # .convert('RGB')\
20 + for i in range(16):
21 + for j in range(16):
22 + (left, upper, right, lower) = (
23 + i * width, j * height, (i + 1) * width, (j + 1) * height)
24 + crop = hr_image.crop((left, upper, right, lower))
25 + crop = crop.resize((lwidth,lheight), Image.BICUBIC)
26 + crop = crop.resize((width, height), Image.BICUBIC)
27 + hr_imgs.append(crop)
28 +
29 + return hr_imgs
...\ No newline at end of file ...\ No newline at end of file
...@@ -4,64 +4,87 @@ import os ...@@ -4,64 +4,87 @@ import os
4 from glob import glob 4 from glob import glob
5 from torchvision import transforms 5 from torchvision import transforms
6 from torch.utils.data.dataset import Dataset 6 from torch.utils.data.dataset import Dataset
7 -from torchvision import transforms
8 import torch 7 import torch
9 import pdb 8 import pdb
10 import math 9 import math
11 import numpy as np 10 import numpy as np
11 +
12 +
12 class FeatureDataset(Dataset): 13 class FeatureDataset(Dataset):
13 - def __init__(self, data_path, datatype, rescale_factor,valid): 14 + def __init__(self, data_path, datatype, rescale_factor, valid):
14 self.data_path = data_path 15 self.data_path = data_path
15 self.datatype = datatype 16 self.datatype = datatype
16 self.rescale_factor = rescale_factor 17 self.rescale_factor = rescale_factor
17 if not os.path.exists(data_path): 18 if not os.path.exists(data_path):
18 raise Exception(f"[!] {self.data_path} not existed") 19 raise Exception(f"[!] {self.data_path} not existed")
19 - if(valid): 20 + if (valid):
20 - self.hr_path = os.path.join(self.data_path,'valid') 21 + self.hr_path = os.path.join(self.data_path, 'valid')
21 - self.hr_path = os.path.join(self.hr_path,self.datatype) 22 + self.hr_path = os.path.join(self.hr_path, self.datatype)
22 else: 23 else:
23 - self.hr_path = os.path.join(self.data_path,'LR_2') 24 + self.hr_path = os.path.join(self.data_path, 'LR_2')
24 - self.hr_path = os.path.join(self.hr_path,self.datatype) 25 + self.hr_path = os.path.join(self.hr_path, self.datatype)
25 print(self.hr_path) 26 print(self.hr_path)
26 self.hr_path = sorted(glob(os.path.join(self.hr_path, "*.*"))) 27 self.hr_path = sorted(glob(os.path.join(self.hr_path, "*.*")))
27 self.hr_imgs = [] 28 self.hr_imgs = []
28 - w,h = Image.open(self.hr_path[0]).size 29 + w, h = Image.open(self.hr_path[0]).size
29 - self.width = int(w/16) 30 + self.width = int(w / 16)
30 - self.height = int(h/16) 31 + self.height = int(h / 16)
31 - self.lwidth = int(self.width/self.rescale_factor) 32 + self.lwidth = int(self.width / self.rescale_factor) # rescale_factor만큼 크기를 줄인다.
32 - self.lheight = int(self.height/self.rescale_factor) 33 + self.lheight = int(self.height / self.rescale_factor)
33 - print("lr: ({} {}), hr: ({} {})".format(self.lwidth,self.lheight,self.width,self.height)) 34 + print("lr: ({} {}), hr: ({} {})".format(self.lwidth, self.lheight, self.width, self.height))
34 - for hr in self.hr_path: 35 + for hr in self.hr_path: # 256개의 피쳐로 나눈다.
35 - hr_image = Image.open(hr)#.convert('RGB')\ 36 + hr_image = Image.open(hr) # .convert('RGB')\
36 for i in range(16): 37 for i in range(16):
37 for j in range(16): 38 for j in range(16):
38 - (left,upper,right,lower) = (i*self.width,j*self.height,(i+1)*self.width,(j+1)*self.height) 39 + (left, upper, right, lower) = (
39 - crop = hr_image.crop((left,upper,right,lower)) 40 + i * self.width, j * self.height, (i + 1) * self.width, (j + 1) * self.height)
41 + crop = hr_image.crop((left, upper, right, lower))
40 self.hr_imgs.append(crop) 42 self.hr_imgs.append(crop)
41 43
42 def __getitem__(self, idx): 44 def __getitem__(self, idx):
43 hr_image = self.hr_imgs[idx] 45 hr_image = self.hr_imgs[idx]
44 transform = transforms.Compose([ 46 transform = transforms.Compose([
45 - transforms.Resize((self.lheight,self.lwidth),3), 47 + transforms.Resize((self.lheight, self.lwidth), Image.BICUBIC),
48 + transforms.Resize((self.height, self.width), Image.BICUBIC),
46 transforms.ToTensor() 49 transforms.ToTensor()
47 ]) 50 ])
48 - return transform(hr_image), transforms.ToTensor()(hr_image) 51 + return transform(hr_image), transforms.ToTensor()(hr_image) # hr_image를 변환한 것과, 변환하지 않은 것을 Tensor로 각각 반환
49 52
50 def __len__(self): 53 def __len__(self):
51 - return len(self.hr_path*16*16) 54 + return len(self.hr_path * 16 * 16)
55 +
56 +
57 +def get_data_loader_test_version(data_path, feature_type, rescale_factor, batch_size, num_workers):
58 + full_dataset = FeatureDataset(data_path, feature_type, rescale_factor, False)
59 + print("dataset의 사이즈는 {}".format(len(full_dataset)))
60 + for f in full_dataset:
61 + print(type(f))
62 +
52 63
53 def get_data_loader(data_path, feature_type, rescale_factor, batch_size, num_workers): 64 def get_data_loader(data_path, feature_type, rescale_factor, batch_size, num_workers):
54 - full_dataset = FeatureDataset(data_path,feature_type,rescale_factor,False) 65 + full_dataset = FeatureDataset(data_path, feature_type, rescale_factor, False)
55 train_size = int(0.9 * len(full_dataset)) 66 train_size = int(0.9 * len(full_dataset))
56 test_size = len(full_dataset) - train_size 67 test_size = len(full_dataset) - train_size
57 train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size]) 68 train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
58 torch.manual_seed(3334) 69 torch.manual_seed(3334)
59 - train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True,num_workers=num_workers, pin_memory=False) 70 + train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True,
60 - test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False,num_workers=num_workers, pin_memory=True) 71 + num_workers=num_workers, pin_memory=False)
72 + test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False,
73 + num_workers=num_workers, pin_memory=True)
61 74
62 return train_loader, test_loader 75 return train_loader, test_loader
63 76
77 +
78 +def get_training_data_loader(data_path, feature_type, rescale_factor, batch_size, num_workers):
79 + full_dataset = FeatureDataset(data_path, feature_type, rescale_factor, False)
80 + torch.manual_seed(3334)
81 + train_loader = torch.utils.data.DataLoader(dataset=full_dataset, batch_size=batch_size, shuffle=True,
82 + num_workers=num_workers, pin_memory=False)
83 + return train_loader
84 +
85 +
64 def get_infer_dataloader(data_path, feature_type, rescale_factor, batch_size, num_workers): 86 def get_infer_dataloader(data_path, feature_type, rescale_factor, batch_size, num_workers):
65 - dataset = FeatureDataset(data_path,feature_type,rescale_factor,True) 87 + dataset = FeatureDataset(data_path, feature_type, rescale_factor, True)
66 - data_loader = torch.utils.data.DataLoader(dataset=dataset,batch_size=batch_size,shuffle=False,num_workers=num_workers,pin_memory=False) 88 + data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False,
89 + num_workers=num_workers, pin_memory=False)
67 return data_loader 90 return data_loader
...\ No newline at end of file ...\ No newline at end of file
......
1 +import argparse, os
2 +import torch
3 +from torch.autograd import Variable
4 +import numpy as np
5 +import time, math, glob
6 +import scipy.io as sio
7 +from crop_feature import crop_feature
8 +from PIL import Image
9 +import cv2
10 +from matplotlib import pyplot as plt
11 +from math import log10, sqrt
12 +
13 +parser = argparse.ArgumentParser(description="PyTorch VDSR Eval")
14 +parser.add_argument("--cuda", action="store_true", help="use cuda?")
15 +parser.add_argument("--model", default="model/model_epoch_50.pth", type=str, help="model path")
16 +parser.add_argument("--dataset", default="Set5", type=str, help="dataset name, Default: Set5")
17 +parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)")
18 +parser.add_argument("--featureType", default="p3", type=str)
19 +parser.add_argument("--scaleFactor", default=4, type=int, help="scale factor")
20 +parser.add_argument("--singleImage", type=str, default="N", help="if it is a single image, enter \"y\"")
21 +
22 +
23 +def PSNR(pred, gt, shave_border=0):
24 + height, width = pred.shape[:2]
25 + pred = pred[shave_border:height - shave_border, shave_border:width - shave_border]
26 + gt = gt[shave_border:height - shave_border, shave_border:width - shave_border]
27 + imdff = pred - gt
28 + rmse = math.sqrt(np.mean(imdff ** 2))
29 + if rmse == 0:
30 + return 100
31 + return 20 * math.log10(255.0 / rmse)
32 +
33 +def concatFeatures(features, image_name, bicubic=False):
34 + features_0 = features[:16]
35 + features_1 = features[16:32]
36 + features_2 = features[32:48]
37 + features_3 = features[48:64]
38 + features_4 = features[64:80]
39 + features_5 = features[80:96]
40 + features_6 = features[96:112]
41 + features_7 = features[112:128]
42 + features_8 = features[128:144]
43 + features_9 = features[144:160]
44 + features_10 = features[160:176]
45 + features_11 = features[176:192]
46 + features_12 = features[192:208]
47 + features_13 = features[208:224]
48 + features_14 = features[224:240]
49 + features_15 = features[240:256]
50 +
51 + features_new = list()
52 + features_new.extend([
53 + concat_vertical(features_0),
54 + concat_vertical(features_1),
55 + concat_vertical(features_2),
56 + concat_vertical(features_3),
57 + concat_vertical(features_4),
58 + concat_vertical(features_5),
59 + concat_vertical(features_6),
60 + concat_vertical(features_7),
61 + concat_vertical(features_8),
62 + concat_vertical(features_9),
63 + concat_vertical(features_10),
64 + concat_vertical(features_11),
65 + concat_vertical(features_12),
66 + concat_vertical(features_13),
67 + concat_vertical(features_14),
68 + concat_vertical(features_15)
69 + ])
70 +
71 + final_concat_feature = concat_horizontal(features_new)
72 +
73 + if bicubic:
74 + save_path = "features/LR_2/LR/" + opt.featureType + "/" + image_name
75 + if not os.path.exists("features/"):
76 + os.makedirs("features/")
77 + if not os.path.exists("features/LR_2/"):
78 + os.makedirs("features/LR_2/")
79 + if not os.path.exists("features/LR_2/LR/"):
80 + os.makedirs("features/LR_2/LR/")
81 + if not os.path.exists("features/LR_2/LR/" + opt.featureType):
82 + os.makedirs("features/LR_2/LR/" + opt.featureType)
83 + cv2.imwrite(save_path, final_concat_feature)
84 + else:
85 + save_path = "features/LR_2/" + opt.featureType + "/" + image_name
86 + if not os.path.exists("features/"):
87 + os.makedirs("features/")
88 + if not os.path.exists("features/LR_2/"):
89 + os.makedirs("features/LR_2/")
90 + if not os.path.exists("features/LR_2/" + opt.featureType):
91 + os.makedirs("features/LR_2/" + opt.featureType)
92 + cv2.imwrite(save_path, final_concat_feature)
93 +
94 +def concat_horizontal(feature):
95 + result = cv2.hconcat([feature[0], feature[1]])
96 + for i in range(2, len(feature)):
97 + result = cv2.hconcat([result, feature[i]])
98 + return result
99 +
100 +def concat_vertical(feature):
101 + result = cv2.vconcat([feature[0], feature[1]])
102 + for i in range(2, len(feature)):
103 + result = cv2.vconcat([result, feature[i]])
104 + return result
105 +
106 +
107 +opt = parser.parse_args()
108 +cuda = opt.cuda
109 +
110 +if cuda:
111 + print("=> use gpu id: '{}'".format(opt.gpus))
112 + os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus
113 + if not torch.cuda.is_available():
114 + raise Exception("No GPU found or Wrong gpu id, please run without --cuda")
115 +
116 +model = torch.load(opt.model, map_location=lambda storage, loc: storage)["model"]
117 +
118 +scales = [opt.scaleFactor]
119 +
120 +# image_list = glob.glob(opt.dataset+"/*.*")
121 +if opt.singleImage == "Y" :
122 + # image_list = crop_feature(opt.dataset, opt.featureType, opt.scaleFactor)
123 + image_list = opt.dataset
124 +else:
125 + image_path = os.path.join(opt.dataset, opt.featureType)
126 + image_list = os.listdir(image_path)
127 + print(image_path)
128 + print(image_list)
129 +
130 +
131 +for scale in scales:
132 + for image in image_list:
133 + avg_psnr_predicted = 0.0
134 + avg_psnr_bicubic = 0.0
135 + avg_elapsed_time = 0.0
136 + count = 0.0
137 + image_name_cropped = crop_feature(os.path.join(image_path, image), opt.featureType, opt.scaleFactor)
138 + features = []
139 + features_bicubic = []
140 + for image_name in image_name_cropped:
141 + count += 1
142 + f_gt = image_name
143 + w, h = image_name.size
144 + f_bi = image_name.resize((w//scale,h//scale), Image.BICUBIC)
145 + f_bi = f_bi.resize((w,h), Image.BICUBIC)
146 +
147 + f_gt = np.array(f_gt)
148 + f_bi = np.array(f_bi)
149 + f_gt = f_gt.astype(float)
150 + f_bi = f_bi.astype(float)
151 + features_bicubic.append(f_bi)
152 + psnr_bicubic = PSNR(f_bi, f_gt, shave_border=scale)
153 + # psnr_bicubic = PSNR_ver2(cv2.imread(f_gt), cv2.imread(f_bi))
154 + avg_psnr_bicubic += psnr_bicubic
155 +
156 + f_input = f_bi/255.
157 + f_input = Variable(torch.from_numpy(f_input).float()).view(1, -1, f_input.shape[0], f_input.shape[1])
158 +
159 + if cuda:
160 + model = model.cuda()
161 + f_input = f_input.cuda()
162 + else:
163 + model = model.cpu()
164 +
165 + start_time = time.time()
166 + SR = model(f_input)
167 + elapsed_time = time.time() - start_time
168 + avg_elapsed_time += elapsed_time
169 +
170 + SR = SR.cpu()
171 +
172 + f_sr = SR.data[0].numpy().astype(np.float32)
173 +
174 + f_sr = f_sr * 255
175 + f_sr[f_sr<0] = 0
176 + f_sr[f_sr>255.] = 255.
177 + f_sr = f_sr[0,:,:]
178 +
179 + psnr_predicted = PSNR(f_sr, f_gt, shave_border=scale)
180 + # psnr_predicted = PSNR_ver2(cv2.imread(f_gt), cv2.imread(f_sr))
181 + avg_psnr_predicted += psnr_predicted
182 + features.append(f_sr)
183 +
184 + concatFeatures(features, image)
185 + concatFeatures(features_bicubic, image, True)
186 + print("Scale=", scale)
187 + print("Dataset=", opt.dataset)
188 + print("Average PSNR_predicted=", avg_psnr_predicted/count)
189 + print("Average PSNR_bicubic=", avg_psnr_bicubic/count)
190 +
191 +
192 +# Show graph
193 +# f_gt = Image.fromarray(f_gt)
194 +# f_b = Image.fromarray(f_bi)
195 +# f_sr = Image.fromarray(f_sr)
196 +
197 +# fig = plt.figure(figsize=(18, 16), dpi= 80)
198 +# ax = plt.subplot("131")
199 +# ax.imshow(f_gt)
200 +# ax.set_title("GT")
201 +
202 +# ax = plt.subplot("132")
203 +# ax.imshow(f_bi)
204 +# ax.set_title("Input(bicubic)")
205 +
206 +# ax = plt.subplot("133")
207 +# ax.imshow(f_sr)
208 +# ax.set_title("Output(vdsr)")
209 +# plt.show()
This diff could not be displayed because it is too large.
1 import argparse, os 1 import argparse, os
2 +from datasets import get_data_loader
2 import torch 3 import torch
3 import random 4 import random
4 import torch.backends.cudnn as cudnn 5 import torch.backends.cudnn as cudnn
...@@ -7,33 +8,37 @@ import torch.optim as optim ...@@ -7,33 +8,37 @@ import torch.optim as optim
7 from torch.autograd import Variable 8 from torch.autograd import Variable
8 from torch.utils.data import DataLoader 9 from torch.utils.data import DataLoader
9 from vdsr import Net 10 from vdsr import Net
10 -from dataset import DatasetFromHdf5 11 +from datasets import get_training_data_loader
11 -## Custom 12 +# from datasets import get_data_loader_test_version
12 -from data import FeatureDataset 13 +# from feature_dataset import get_training_data_loader
14 +# from make_dataset import make_dataset
15 +import numpy as np
16 +from dataFromH5 import Read_dataset_h5
17 +import matplotlib.pyplot as plt
18 +import math
13 19
14 # Training settings 20 # Training settings
15 parser = argparse.ArgumentParser(description="PyTorch VDSR") 21 parser = argparse.ArgumentParser(description="PyTorch VDSR")
16 -parser.add_argument("--batchSize", type=int, default=128, help="Training batch size") 22 +parser.add_argument("--dataRoot", type=str)
17 -parser.add_argument("--nEpochs", type=int, default=50, help="Number of epochs to train for") 23 +parser.add_argument("--featureType", type=str)
18 -parser.add_argument("--lr", type=float, default=0.1, help="Learning Rate. Default=0.1") 24 +parser.add_argument("--scaleFactor", type=int, default=4)
25 +parser.add_argument("--batchSize", type=int, default=64, help="Training batch size")
26 +parser.add_argument("--nEpochs", type=int, default=20, help="Number of epochs to train for")
27 +parser.add_argument("--lr", type=float, default=0.001, help="Learning Rate. Default=0.1")
19 parser.add_argument("--step", type=int, default=10, help="Sets the learning rate to the initial LR decayed by momentum every n epochs, Default: n=10") 28 parser.add_argument("--step", type=int, default=10, help="Sets the learning rate to the initial LR decayed by momentum every n epochs, Default: n=10")
20 parser.add_argument("--cuda", action="store_true", help="Use cuda?") 29 parser.add_argument("--cuda", action="store_true", help="Use cuda?")
21 parser.add_argument("--resume", default="", type=str, help="Path to checkpoint (default: none)") 30 parser.add_argument("--resume", default="", type=str, help="Path to checkpoint (default: none)")
22 parser.add_argument("--start-epoch", default=1, type=int, help="Manual epoch number (useful on restarts)") 31 parser.add_argument("--start-epoch", default=1, type=int, help="Manual epoch number (useful on restarts)")
23 parser.add_argument("--clip", type=float, default=0.4, help="Clipping Gradients. Default=0.4") 32 parser.add_argument("--clip", type=float, default=0.4, help="Clipping Gradients. Default=0.4")
24 -# 1->3 custom 33 +parser.add_argument("--threads", type=int, default=1, help="Number of threads for data loader to use, Default: 1")
25 -parser.add_argument("--threads", type=int, default=3, help="Number of threads for data loader to use, Default: 1")
26 parser.add_argument("--momentum", default=0.9, type=float, help="Momentum, Default: 0.9") 34 parser.add_argument("--momentum", default=0.9, type=float, help="Momentum, Default: 0.9")
27 parser.add_argument("--weight-decay", "--wd", default=1e-4, type=float, help="Weight decay, Default: 1e-4") 35 parser.add_argument("--weight-decay", "--wd", default=1e-4, type=float, help="Weight decay, Default: 1e-4")
28 parser.add_argument('--pretrained', default='', type=str, help='path to pretrained model (default: none)') 36 parser.add_argument('--pretrained', default='', type=str, help='path to pretrained model (default: none)')
29 parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)") 37 parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)")
30 -## custom
31 -parser.add_argument("--dataPath", type=str)
32 -parser.add_argument("--featureType", type=str, default="p2")
33 -parser.add_argument("--scaleFactor",type=int, default=4)
34 38
35 -# parser.add_argument("--trainingData", type=DataLoader)
36 39
40 +total_loss_for_plot = list()
41 +total_pnsr = list()
37 42
38 def main(): 43 def main():
39 global opt, model 44 global opt, model
...@@ -55,21 +60,17 @@ def main(): ...@@ -55,21 +60,17 @@ def main():
55 60
56 cudnn.benchmark = True 61 cudnn.benchmark = True
57 62
58 - print("===> Loading datasets") 63 +################## Loading Datasets ##########################
59 64
60 - if os.path.isfile('dataloader/training_data_loader.pth'): 65 + print("===> Loading datasets")
61 - training_data_loader = torch.load('dataloader/training_data_loader.pth') 66 + # train_set = DatasetFromHdf5("data/train.h5")
62 - else: 67 + # training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)
63 - train_set = FeatureDataset(opt.dataPath,opt.featureType,opt.scaleFactor,False) 68 + training_data_loader = get_training_data_loader(opt.dataRoot, opt.featureType, opt.scaleFactor, opt.batchSize, opt.threads)
64 - train_size = 100 #우선은 100개만 69 + # training_data_loader = make_dataset(opt.dataRoot, opt.featureType, opt.scaleFactor, opt.batchSize, opt.threads)
65 - test_size = len(train_set) - train_size
66 - train_dataset, test_dataset = torch.utils.data.random_split(train_set, [train_size, test_size])
67 - training_data_loader = DataLoader(dataset=train_dataset, num_workers=3, batch_size=8, shuffle=True, pin_memory=False)
68 - torch.save(training_data_loader, 'dataloader/training_data_loader.pth'.format(DataLoader))
69 70
70 print("===> Building model") 71 print("===> Building model")
71 - model = Net(opt.scaleFactor) 72 + model = Net()
72 - criterion = nn.MSELoss(size_average=False) 73 + criterion = nn.MSELoss(reduction='sum')
73 74
74 print("===> Setting GPU") 75 print("===> Setting GPU")
75 if cuda: 76 if cuda:
...@@ -91,6 +92,7 @@ def main(): ...@@ -91,6 +92,7 @@ def main():
91 if os.path.isfile(opt.pretrained): 92 if os.path.isfile(opt.pretrained):
92 print("=> loading model '{}'".format(opt.pretrained)) 93 print("=> loading model '{}'".format(opt.pretrained))
93 weights = torch.load(opt.pretrained) 94 weights = torch.load(opt.pretrained)
95 + opt.start_epoch = weights["epoch"] + 1
94 model.load_state_dict(weights['model'].state_dict()) 96 model.load_state_dict(weights['model'].state_dict())
95 else: 97 else:
96 print("=> no model found at '{}'".format(opt.pretrained)) 98 print("=> no model found at '{}'".format(opt.pretrained))
...@@ -99,15 +101,21 @@ def main(): ...@@ -99,15 +101,21 @@ def main():
99 optimizer = optim.SGD(model.parameters(), lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay) 101 optimizer = optim.SGD(model.parameters(), lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay)
100 102
101 print("===> Training") 103 print("===> Training")
104 +
102 for epoch in range(opt.start_epoch, opt.nEpochs + 1): 105 for epoch in range(opt.start_epoch, opt.nEpochs + 1):
103 train(training_data_loader, optimizer, model, criterion, epoch) 106 train(training_data_loader, optimizer, model, criterion, epoch)
104 - save_checkpoint(model, epoch) 107 + save_checkpoint(model, epoch, optimizer)
105 108
106 def adjust_learning_rate(optimizer, epoch): 109 def adjust_learning_rate(optimizer, epoch):
107 """Sets the learning rate to the initial LR decayed by 10 every 10 epochs""" 110 """Sets the learning rate to the initial LR decayed by 10 every 10 epochs"""
108 lr = opt.lr * (0.1 ** (epoch // opt.step)) 111 lr = opt.lr * (0.1 ** (epoch // opt.step))
109 return lr 112 return lr
110 113
114 +def PSNR(loss):
115 + psnr = 10 * np.log10(1 / (loss + 1e-10))
116 + # psnr = 20 * math.log10(255.0 / (math.sqrt(loss)))
117 + return psnr
118 +
111 def train(training_data_loader, optimizer, model, criterion, epoch): 119 def train(training_data_loader, optimizer, model, criterion, epoch):
112 lr = adjust_learning_rate(optimizer, epoch-1) 120 lr = adjust_learning_rate(optimizer, epoch-1)
113 121
...@@ -119,25 +127,31 @@ def train(training_data_loader, optimizer, model, criterion, epoch): ...@@ -119,25 +127,31 @@ def train(training_data_loader, optimizer, model, criterion, epoch):
119 model.train() 127 model.train()
120 128
121 for iteration, batch in enumerate(training_data_loader, 1): 129 for iteration, batch in enumerate(training_data_loader, 1):
122 - input, target = Variable(batch[0]), Variable(batch[1], requires_grad=False) 130 + optimizer.zero_grad()
123 - 131 + input, target = Variable(batch[0], requires_grad=False), Variable(batch[1], requires_grad=False)
132 + total_loss = 0
124 if opt.cuda: 133 if opt.cuda:
125 input = input.cuda() 134 input = input.cuda()
126 target = target.cuda() 135 target = target.cuda()
127 136
128 loss = criterion(model(input), target) 137 loss = criterion(model(input), target)
129 - optimizer.zero_grad() 138 + total_loss += loss.item()
130 loss.backward() 139 loss.backward()
131 - nn.utils.clip_grad_norm(model.parameters(),opt.clip) 140 + nn.utils.clip_grad_norm_(model.parameters(), opt.clip)
132 optimizer.step() 141 optimizer.step()
133 142
134 - if iteration%10 == 0: 143 + epoch_loss = total_loss / len(training_data_loader)
135 - # loss.data[0] --> loss.data 144 + total_loss_for_plot.append(epoch_loss)
136 - print("===> Epoch[{}]({}/{}): Loss: {:.10f}".format(epoch, iteration, len(training_data_loader), loss.data)) 145 + psnr = PSNR(epoch_loss)
137 - 146 + total_pnsr.append(psnr)
138 -def save_checkpoint(model, epoch): 147 + print("===> Epoch[{}]: loss : {:.10f} ,PSNR : {:.10f}".format(epoch, epoch_loss, psnr))
139 - model_out_path = "checkpoint/" + "model_epoch_{}.pth".format(epoch) 148 + # if iteration%100 == 0:
140 - state = {"epoch": epoch ,"model": model} 149 + # print("===> Epoch[{}]({}/{}): Loss: {:.10f}".format(epoch, iteration, len(training_data_loader), loss.item()))
150 +
151 +def save_checkpoint(model, epoch, optimizer):
152 + model_out_path = "checkpoint/" + "model_epoch_{}_{}.pth".format(epoch, opt.featureType)
153 + state = {"epoch": epoch ,"model": model, "model_state_dict":model.state_dict(), "optimizer_state_dict":optimizer.state_dict(),
154 + "loss": total_loss_for_plot, "psnr":total_pnsr}
141 if not os.path.exists("checkpoint/"): 155 if not os.path.exists("checkpoint/"):
142 os.makedirs("checkpoint/") 156 os.makedirs("checkpoint/")
143 157
......
...@@ -12,12 +12,11 @@ class Conv_ReLU_Block(nn.Module): ...@@ -12,12 +12,11 @@ class Conv_ReLU_Block(nn.Module):
12 return self.relu(self.conv(x)) 12 return self.relu(self.conv(x))
13 13
14 class Net(nn.Module): 14 class Net(nn.Module):
15 - def __init__(self,upscale_factor): 15 + def __init__(self):
16 super(Net, self).__init__() 16 super(Net, self).__init__()
17 self.residual_layer = self.make_layer(Conv_ReLU_Block, 18) 17 self.residual_layer = self.make_layer(Conv_ReLU_Block, 18)
18 self.input = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 18 self.input = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
19 self.output = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False) 19 self.output = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False)
20 - self.upsample = nn.Upsample(scale_factor=upscale_factor, mode='bicubic')
21 self.relu = nn.ReLU(inplace=True) 20 self.relu = nn.ReLU(inplace=True)
22 21
23 for m in self.modules(): 22 for m in self.modules():
...@@ -32,7 +31,6 @@ class Net(nn.Module): ...@@ -32,7 +31,6 @@ class Net(nn.Module):
32 return nn.Sequential(*layers) 31 return nn.Sequential(*layers)
33 32
34 def forward(self, x): 33 def forward(self, x):
35 - x = self.upsample(x)
36 residual = x 34 residual = x
37 out = self.relu(self.input(x)) 35 out = self.relu(self.input(x))
38 out = self.residual_layer(out) 36 out = self.residual_layer(out)
......
No preview for this file type
No preview for this file type