Toggle navigation
Toggle navigation
This project
Loading...
Sign in
2020-1-capstone-design2
/
2014103189
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Snippets
Network
Create a new issue
Builds
Commits
Issue Boards
Authored by
4Moyede
2020-06-25 06:52:11 +0900
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
857f89c76ed8d7847b331d424bb7bfe026049bb8
857f89c7
1 parent
b7b22255
[Mod] add target dataset cifar10
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
25 deletions
소스코드/model.py
소스코드/test.py
소스코드/train.py
소스코드/model.py
View file @
857f89c
...
...
@@ -17,30 +17,39 @@ import os
class
SEResNeXt
(
Model
):
def
__init__
(
self
,
weight
,
input_shape
=
None
,
depth
=
[
3
,
8
,
36
,
3
],
cardinality
=
32
,
width
=
4
,
reduction_ratio
=
4
,
weight_decay
=
5e-4
,
classes
=
1000
,
channel_axis
=
None
):
def
__init__
(
self
,
weight
,
input_shape
=
None
):
'''
ResNext Model
## Args
+ weight:
+ input_shape: optional shape tuple
+ depth: number or layers in the each block, defined as a list
+ cardinality: the size of the set of transformations
+ width: multiplier to the ResNeXt width (number of filters)
+ redution_ratio: ratio of reducition in SE Block
+ weight_decay: weight decay (l2 norm)
+ classes: number of classes to classify images into
+ channel_axis: channel axis in keras.backend.image_data_format()
'''
if
weights
not
in
{
'cifar10'
,
'imagenet'
}:
raise
ValueError
self
.
__weight
=
weight
self
.
__depth
=
depth
self
.
__cardinality
=
cardinality
self
.
__width
=
width
self
.
__reduction_ratio
=
reduction_ratio
self
.
__weight_decay
=
weight_decay
self
.
__classes
=
classes
if
weight
==
'cifar10'
:
self
.
__depth
=
29
self
.
__cardinality
=
8
self
.
__width
=
64
self
.
__classes
=
10
else
:
self
.
__depth
=
[
3
,
8
,
36
,
3
]
self
.
__cardinality
=
32
self
.
__width
=
4
self
.
__classes
=
1000
self
.
__reduction_ratio
=
4
self
.
__weight_decay
=
5e-4
self
.
__channel_axis
=
1
if
K
.
image_data_format
()
==
"channels_first"
else
-
1
self
.
__input_shape
=
_obtain_input_shape
(
input_shape
,
default_size
=
224
,
min_size
=
112
,
data_format
=
K
.
image_data_format
(),
require_flatten
=
True
)
if
weight
==
'cifar10'
:
self
.
__input_shape
=
_obtain_input_shape
(
input_shape
,
default_size
=
32
,
min_size
=
8
,
data_format
=
K
.
image_data_format
(),
require_flatten
=
True
)
else
:
self
.
__input_shape
=
_obtain_input_shape
(
input_shape
,
default_size
=
224
,
min_size
=
112
,
data_format
=
K
.
image_data_format
(),
require_flatten
=
True
)
self
.
__img_input
=
Input
(
shape
=
self
.
__input_shape
)
# Create model.
...
...
@@ -50,14 +59,16 @@ class SEResNeXt(Model):
'''
Adds an initial conv block, with batch norm and relu for the inception resnext
'''
channel_axis
=
-
1
if
weight
==
'cifar10'
:
x
=
Conv2D
(
64
,
(
3
,
3
),
padding
=
'same'
,
use_bias
=
False
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
l2
(
self
.
__weight_decay
))(
self
.
__img_input
)
x
=
BatchNormalization
(
axis
=
self
.
__channel_axis
)(
x
)
x
=
Activation
(
'relu'
)(
x
)
return
x
else
:
x
=
Conv2D
(
64
,
(
7
,
7
),
padding
=
'same'
,
use_bias
=
False
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
l2
(
self
.
__weight_decay
),
strides
=
(
2
,
2
))(
self
.
__img_input
)
x
=
BatchNormalization
(
axis
=
channel_axis
)(
x
)
x
=
BatchNormalization
(
axis
=
self
.
__
channel_axis
)(
x
)
x
=
Activation
(
'relu'
)(
x
)
x
=
MaxPooling2D
((
3
,
3
),
strides
=
(
2
,
2
),
padding
=
'same'
)(
x
)
return
x
def
__grouped_convolution_block
(
self
,
input
,
grouped_channels
,
strides
):
...
...
@@ -126,8 +137,11 @@ class SEResNeXt(Model):
'''
Creates a ResNeXt model with specified parameters
'''
if
type
(
self
.
__depth
)
is
list
or
type
(
self
.
__depth
)
is
tuple
:
N
=
list
(
self
.
__depth
)
else
:
N
=
[(
self
.
__depth
-
2
)
//
9
for
_
in
range
(
3
)]
print
(
N
)
filters
=
self
.
__cardinality
*
self
.
__width
filters_list
=
[]
...
...
소스코드/test.py
View file @
857f89c
...
...
@@ -4,7 +4,7 @@ import matplotlib.pyplot as plt
import
os
MODEL_SAVE_FOLDER_PATH
=
os
.
path
.
join
(
os
.
getcwd
(),
'trained'
)
MODEL_SAVE_PATH
=
MODEL_SAVE_FOLDER_PATH
+
'/seresnext_
imagenet
.h5'
MODEL_SAVE_PATH
=
MODEL_SAVE_FOLDER_PATH
+
'/seresnext_
cifar10
.h5'
TEST_IMAGE_FOLDER_PATH
=
os
.
path
.
join
(
os
.
getcwd
(),
'test'
)
TEST_IMAGE_PATH
=
TEST_IMAGE_FOLDER_PATH
+
'/test01.png'
...
...
소스코드/train.py
View file @
857f89c
...
...
@@ -13,14 +13,16 @@ import os
MODEL_SAVE_FOLDER_PATH
=
os
.
path
.
join
(
os
.
getcwd
(),
'trained'
)
if
not
os
.
path
.
exists
(
MODEL_SAVE_FOLDER_PATH
):
os
.
mkdir
(
MODEL_SAVE_FOLDER_PATH
)
MODEL_SAVE_PATH
=
MODEL_SAVE_FOLDER_PATH
+
'/seresnext_
imagenet
.h5'
MODEL_SAVE_PATH
=
MODEL_SAVE_FOLDER_PATH
+
'/seresnext_
cifar
.h5'
model
=
SEResNeXt
(
(
112
,
11
2
,
3
))
model
=
SEResNeXt
(
'cifar-10'
,
(
32
,
3
2
,
3
))
model
.
compile
(
optimizer
=
optimizers
.
Adam
(
0.001
),
loss
=
'categorical_crossentropy'
,
metrics
=
[
'accuracy'
])
ds_train
=
tfds
.
load
(
'imagenet2012_corrupted'
,
split
=
'train'
,
shuffle_files
=
True
)
# ds_train = tfds.load('imagenet2012_corrupted', split='train', shuffle_files=True)
ds_train
=
tfds
.
load
(
'cifar-10'
,
split
=
'train'
,
shuffle_files
=
True
)
ds_train
=
ds_train
.
shuffle
(
1000
)
.
batch
(
128
)
.
prefetch
(
10
)
model
.
fit
(
ds_train
[
'image'
],
ds_train
[
'label'
],
epochs
=
10
,
steps_per_epoch
=
30
)
model
.
save
(
MODEL_SAVE_PATH
)
...
...
Please
register
or
login
to post a comment