Toggle navigation
Toggle navigation
This project
Loading...
Sign in
2021-1-capstone-design1
/
BSH_Project3
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Snippets
Network
Create a new issue
Builds
Commits
Issue Boards
Authored by
김재형
2021-04-07 14:22:44 +0900
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
0dfbe0f88c7c28f75baad284b2e06ca3ed4741e5
0dfbe0f8
1 parent
2db55916
CARN 테스트 코드 추가
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
101 additions
and
0 deletions
carn/carn/test.py
carn/carn/test.py
0 → 100644
View file @
0dfbe0f
import
os
import
json
import
time
import
importlib
import
argparse
import
numpy
as
np
from
collections
import
OrderedDict
import
torch
import
torch.nn
as
nn
import
torch.utils.data
as
data
from
glob
import
glob
from
torch.autograd
import
Variable
from
PIL
import
Image
import
torchvision.transforms
as
transforms
from
tqdm
import
tqdm
def
parse_args
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model"
,
type
=
str
)
parser
.
add_argument
(
"--ckpt_path"
,
type
=
str
)
parser
.
add_argument
(
"--group"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--sample_dir"
,
type
=
str
)
parser
.
add_argument
(
"--test_data_dir"
,
type
=
str
,
default
=
"dataset/Urban100"
)
parser
.
add_argument
(
"--cuda"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--scale"
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
"--shave"
,
type
=
int
,
default
=
20
)
return
parser
.
parse_args
()
def
save_image
(
tensor
,
filename
):
tensor
=
tensor
.
cpu
()
ndarr
=
tensor
.
mul
(
255
)
.
clamp
(
0
,
255
)
.
byte
()
.
permute
(
1
,
2
,
0
)
.
numpy
()
im
=
Image
.
fromarray
(
ndarr
)
im
.
save
(
filename
)
class
TestDataset
(
data
.
Dataset
):
def
__init__
(
self
,
dirname
,
scale
):
super
(
TestDataset
,
self
)
.
__init__
()
self
.
lr
=
glob
(
os
.
path
.
join
(
dirname
,
"*.png"
))
self
.
lr
.
sort
()
self
.
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
()
])
def
__getitem__
(
self
,
index
):
lr
=
Image
.
open
(
self
.
lr
[
index
])
lr
=
lr
.
convert
(
"RGB"
)
filename
=
self
.
lr
[
index
]
.
split
(
"/"
)[
-
1
]
return
self
.
transform
(
lr
),
filename
def
__len__
(
self
):
return
len
(
self
.
lr
)
def
sample
(
net
,
device
,
dataset
,
cfg
):
scale
=
cfg
.
scale
for
lr
,
name
in
tqdm
(
dataset
):
t1
=
time
.
time
()
lr
=
lr
.
unsqueeze
(
0
)
.
to
(
device
)
sr
=
net
(
lr
,
cfg
.
scale
)
.
detach
()
.
squeeze
(
0
)
lr
=
lr
.
squeeze
(
0
)
t2
=
time
.
time
()
sr_dir
=
os
.
path
.
join
(
cfg
.
sample_dir
,
cfg
.
test_data_dir
.
split
(
"/"
)[
-
1
])
os
.
makedirs
(
sr_dir
,
exist_ok
=
True
)
sr_im_path
=
os
.
path
.
join
(
sr_dir
,
name
)
save_image
(
sr
,
sr_im_path
)
def
main
(
cfg
):
module
=
importlib
.
import_module
(
"model.{}"
.
format
(
cfg
.
model
))
net
=
module
.
Net
(
multi_scale
=
True
,
group
=
cfg
.
group
)
print
(
json
.
dumps
(
vars
(
cfg
),
indent
=
4
,
sort_keys
=
True
))
state_dict
=
torch
.
load
(
cfg
.
ckpt_path
)
new_state_dict
=
OrderedDict
()
for
k
,
v
in
state_dict
.
items
():
name
=
k
# name = k[7:] # remove "module."
new_state_dict
[
name
]
=
v
net
.
load_state_dict
(
new_state_dict
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
net
=
net
.
to
(
device
)
dataset
=
TestDataset
(
cfg
.
test_data_dir
,
cfg
.
scale
)
sample
(
net
,
device
,
dataset
,
cfg
)
if
__name__
==
"__main__"
:
cfg
=
parse_args
()
main
(
cfg
)
Please
register
or
login
to post a comment