강형구

[Add] add SEResNeXt model

1 +.vscode/
2 +__pycache__/
3 +trained/
4 +src/
1 +from keras.models import Model
2 +
3 +from keras.layers import Input, Reshape
4 +from keras.layers.core import Dense, Lambda, Activation
5 +from keras.layers.convolutional import Conv2D
6 +from keras.layers.pooling import GlobalAveragePooling2D, MaxPooling2D
7 +from keras.layers.merge import concatenate, add, multiply
8 +from keras.layers.normalization import BatchNormalization
9 +
10 +from keras.regularizers import l2
11 +from keras.utils.data_utils import get_file
12 +from keras_applications.imagenet_utils import _obtain_input_shape
13 +
14 +import keras.backend as K
15 +
16 +import os
17 +
18 +
19 +class SEResNeXt(Model):
20 + def __init__(self, weight, input_shape=None, depth=[3, 8, 36, 3], cardinality=32, width=4, reduction_ratio=4, weight_decay=5e-4, classes=1000, channel_axis=None):
21 + '''
22 + ResNext Model
23 +
24 + ## Args
25 + + input_shape: optional shape tuple
26 + + depth: number or layers in the each block, defined as a list
27 + + cardinality: the size of the set of transformations
28 + + width: multiplier to the ResNeXt width (number of filters)
29 + + redution_ratio: ratio of reducition in SE Block
30 + + weight_decay: weight decay (l2 norm)
31 + + classes: number of classes to classify images into
32 + + channel_axis: channel axis in keras.backend.image_data_format()
33 + '''
34 + self.__weight = weight
35 + self.__depth = depth
36 + self.__cardinality = cardinality
37 + self.__width = width
38 + self.__reduction_ratio = reduction_ratio
39 + self.__weight_decay = weight_decay
40 + self.__classes = classes
41 + self.__channel_axis = 1 if K.image_data_format() == "channels_first" else -1
42 +
43 + self.__input_shape = _obtain_input_shape(input_shape, default_size = 224, min_size = 112, data_format=K.image_data_format(), require_flatten=True)
44 + self.__img_input = Input(shape=self.__input_shape)
45 +
46 + # Create model.
47 + super(SEResNeXt, self).__init__(self.__img_input, self.__create_res_next(), name='seresnext')
48 +
49 + def __initial_conv_block(self):
50 + '''
51 + Adds an initial conv block, with batch norm and relu for the inception resnext
52 + '''
53 + channel_axis = -1
54 +
55 + x = Conv2D(64, (7, 7), padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay), strides=(2, 2))(self.__img_input)
56 + x = BatchNormalization(axis=channel_axis)(x)
57 + x = Activation('relu')(x)
58 +
59 + x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
60 +
61 + return x
62 +
63 + def __grouped_convolution_block(self, input, grouped_channels, strides):
64 + '''
65 + Adds a grouped convolution block. It is an equivalent block from the paper
66 +
67 + ## Args
68 + + input: input tensor
69 + + grouped_channels: grouped number of filters
70 + + strides: performs strided convolution for downscaling if > 1
71 +
72 + ## Returns
73 + a keras tensor
74 + '''
75 + init = input
76 +
77 + group_list = []
78 + for c in range(self.__cardinality):
79 + x = Lambda(lambda z: z[:, :, :, c * grouped_channels:(c + 1) * grouped_channels])(input)
80 + x = Conv2D(grouped_channels, (3, 3), padding='same', use_bias=False, strides=(strides, strides), kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay))(x)
81 + group_list.append(x)
82 +
83 + group_merge = concatenate(group_list, axis=self.__channel_axis)
84 + x = BatchNormalization(axis=self.__channel_axis)(group_merge)
85 + x = Activation('relu')(x)
86 +
87 + return x
88 +
89 + def __bottleneck_block(self, input, filters=64, strides=1):
90 + '''
91 + Adds a bottleneck block
92 +
93 + ## Args
94 + + input: input tensor
95 + + filters: number of output filters
96 + + strides: performs strided convolution for downsampling if > 1
97 +
98 + ## Returns
99 + a keras tensor
100 + '''
101 + init = input
102 +
103 + grouped_channels = int(filters / self.__cardinality)
104 +
105 + # Check if input number of filters is same as 16 * k, else create convolution2d for this input
106 + if init._keras_shape[-1] != 2 * filters:
107 + init = Conv2D(filters * 2, (1, 1), padding='same', strides=(strides, strides), use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay))(init)
108 + init = BatchNormalization(axis=self.__channel_axis)(init)
109 +
110 + x = Conv2D(filters, (1, 1), padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay))(input)
111 + x = BatchNormalization(axis=self.__channel_axis)(x)
112 + x = Activation('relu')(x)
113 + x = self.__squeeze_excitation_layer(x, x[0].get_shape()[self.__channel_axis])
114 +
115 + x = self.__grouped_convolution_block(x, grouped_channels, strides)
116 +
117 + x = Conv2D(filters * 2, (1, 1), padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay))(x)
118 + x = BatchNormalization(axis=self.__channel_axis)(x)
119 +
120 + x = add([init, x])
121 + x = Activation('relu')(x)
122 +
123 + return x
124 +
125 + def __create_res_next(self):
126 + '''
127 + Creates a ResNeXt model with specified parameters
128 + '''
129 +
130 + N = list(self.__depth)
131 +
132 + filters = self.__cardinality * self.__width
133 + filters_list = []
134 + for i in range(len(N)):
135 + filters_list.append(filters)
136 + filters *= 2 # double the size of the filters
137 +
138 + x = self.__initial_conv_block()
139 +
140 + # block 1 (no pooling)
141 + for i in range(N[0]):
142 + x = self.__bottleneck_block(x, filters_list[0], strides=1)
143 +
144 + N = N[1:] # remove the first block from block definition list
145 + filters_list = filters_list[1:] # remove the first filter from the filter list
146 +
147 + # block 2 to N
148 + for block_idx, n_i in enumerate(N):
149 + for i in range(n_i):
150 + if i == 0:
151 + x = self.__bottleneck_block(x, filters_list[block_idx], strides=2)
152 + else:
153 + x = self.__bottleneck_block(x, filters_list[block_idx], strides=1)
154 +
155 +
156 + x = GlobalAveragePooling2D()(x)
157 + x = Dense(self.__classes, use_bias=False, kernel_regularizer=l2(self.__weight_decay), kernel_initializer='he_normal', activation='softmax')(x)
158 +
159 + return x
160 +
161 + def __squeeze_excitation_layer(self, x, out_dim):
162 + '''
163 + SE Block Function
164 +
165 + ## Args
166 + + x : input feature map
167 + + out_dim : dimention of output channel
168 + '''
169 + squeeze = GlobalAveragePooling2D()(x)
170 +
171 + excitation = Dense(units=out_dim // self.__reduction_ratio)(squeeze)
172 + excitation = Activation('relu')(excitation)
173 + excitation = Dense(units=out_dim)(excitation)
174 + excitation = Activation('sigmoid')(excitation)
175 + excitation = Reshape((1,1,out_dim))(excitation)
176 +
177 + scale = multiply([x,excitation])
178 +
179 + return scale
180 +
181 +
182 +if __name__ == '__main__':
183 + model = SEResNeXt((112, 112, 3))
184 + model.summary()
1 +alabaster==0.7.12
2 +anaconda-client==1.7.2
3 +anaconda-navigator==1.9.12
4 +anaconda-project==0.8.3
5 +argh==0.26.2
6 +asn1crypto==1.3.0
7 +astroid==2.3.3
8 +astropy==4.0
9 +atomicwrites==1.3.0
10 +attrs==19.3.0
11 +autopep8==1.4.4
12 +Babel==2.8.0
13 +backcall==0.1.0
14 +backports.functools-lru-cache==1.6.1
15 +backports.shutil-get-terminal-size==1.0.0
16 +backports.tempfile==1.0
17 +backports.weakref==1.0.post1
18 +beautifulsoup4==4.8.2
19 +bitarray==1.2.1
20 +bkcharts==0.2
21 +bleach==3.1.0
22 +bokeh==1.4.0
23 +boto==2.49.0
24 +Bottleneck==1.3.2
25 +certifi==2019.11.28
26 +cffi==1.14.0
27 +chardet==3.0.4
28 +Click==7.0
29 +cloudpickle==1.3.0
30 +clyent==1.2.2
31 +colorama==0.4.3
32 +conda==4.8.3
33 +conda-build==3.18.11
34 +conda-package-handling==1.7.0
35 +conda-verify==3.4.2
36 +contextlib2==0.6.0.post1
37 +cryptography==2.8
38 +cycler==0.10.0
39 +Cython==0.29.15
40 +cytoolz==0.10.1
41 +dask==2.11.0
42 +decorator==4.4.1
43 +defusedxml==0.6.0
44 +diff-match-patch==20181111
45 +distributed==2.11.0
46 +docutils==0.16
47 +entrypoints==0.3
48 +et-xmlfile==1.0.1
49 +fastcache==1.1.0
50 +filelock==3.0.12
51 +flake8==3.7.9
52 +Flask==1.1.1
53 +fsspec==0.6.2
54 +future==0.18.2
55 +gevent==1.4.0
56 +glob2==0.7
57 +gmpy2==2.0.8
58 +greenlet==0.4.15
59 +h5py==2.10.0
60 +HeapDict==1.0.1
61 +html5lib==1.0.1
62 +hypothesis==5.5.4
63 +idna==2.8
64 +imageio==2.6.1
65 +imagesize==1.2.0
66 +importlib-metadata==1.5.0
67 +intervaltree==3.0.2
68 +ipykernel==5.1.4
69 +ipython==7.12.0
70 +ipython-genutils==0.2.0
71 +ipywidgets==7.5.1
72 +isort==4.3.21
73 +itsdangerous==1.1.0
74 +jdcal==1.4.1
75 +jedi==0.14.1
76 +jeepney==0.4.2
77 +Jinja2==2.11.1
78 +joblib==0.14.1
79 +json5==0.9.1
80 +jsonschema==3.2.0
81 +jupyter==1.0.0
82 +jupyter-client==5.3.4
83 +jupyter-console==6.1.0
84 +jupyter-core==4.6.1
85 +jupyterlab==1.2.6
86 +jupyterlab-server==1.0.6
87 +keyring==21.1.0
88 +kiwisolver==1.1.0
89 +lazy-object-proxy==1.4.3
90 +libarchive-c==2.8
91 +lief==0.9.0
92 +llvmlite==0.31.0
93 +locket==0.2.0
94 +lxml==4.5.0
95 +MarkupSafe==1.1.1
96 +matplotlib==3.1.3
97 +mccabe==0.6.1
98 +mistune==0.8.4
99 +mkl-fft==1.0.15
100 +mkl-random==1.1.0
101 +mkl-service==2.3.0
102 +mock==4.0.1
103 +more-itertools==8.2.0
104 +mpmath==1.1.0
105 +msgpack==0.6.1
106 +multipledispatch==0.6.0
107 +navigator-updater==0.2.1
108 +nbconvert==5.6.1
109 +nbformat==5.0.4
110 +networkx==2.4
111 +nltk==3.4.5
112 +nose==1.3.7
113 +notebook==6.0.3
114 +numba==0.48.0
115 +numexpr==2.7.1
116 +numpy==1.18.1
117 +numpydoc==0.9.2
118 +olefile==0.46
119 +openpyxl==3.0.3
120 +packaging==20.1
121 +pandas==1.0.1
122 +pandocfilters==1.4.2
123 +parso==0.5.2
124 +partd==1.1.0
125 +path==13.1.0
126 +pathlib2==2.3.5
127 +pathtools==0.1.2
128 +patsy==0.5.1
129 +pep8==1.7.1
130 +pexpect==4.8.0
131 +pickleshare==0.7.5
132 +Pillow==7.0.0
133 +pkginfo==1.5.0.1
134 +pluggy==0.13.1
135 +ply==3.11
136 +prometheus-client==0.7.1
137 +prompt-toolkit==3.0.3
138 +psutil==5.6.7
139 +ptyprocess==0.6.0
140 +py==1.8.1
141 +pycodestyle==2.5.0
142 +pycosat==0.6.3
143 +pycparser==2.19
144 +pycrypto==2.6.1
145 +pycurl==7.43.0.5
146 +pydocstyle==4.0.1
147 +pyflakes==2.1.1
148 +Pygments==2.5.2
149 +pylint==2.4.4
150 +pyodbc===4.0.0-unsupported
151 +pyOpenSSL==19.1.0
152 +pyparsing==2.4.6
153 +pyrsistent==0.15.7
154 +PySocks==1.7.1
155 +pytest==5.3.5
156 +pytest-arraydiff==0.3
157 +pytest-astropy==0.8.0
158 +pytest-astropy-header==0.1.2
159 +pytest-doctestplus==0.5.0
160 +pytest-openfiles==0.4.0
161 +pytest-remotedata==0.3.2
162 +python-dateutil==2.8.1
163 +python-jsonrpc-server==0.3.4
164 +python-language-server==0.31.7
165 +pytz==2019.3
166 +PyWavelets==1.1.1
167 +pyxdg==0.26
168 +PyYAML==5.3
169 +pyzmq==18.1.1
170 +QDarkStyle==2.8
171 +QtAwesome==0.6.1
172 +qtconsole==4.6.0
173 +QtPy==1.9.0
174 +requests==2.22.0
175 +rope==0.16.0
176 +Rtree==0.9.3
177 +ruamel-yaml==0.15.87
178 +scikit-image==0.16.2
179 +scikit-learn==0.22.1
180 +scipy==1.4.1
181 +seaborn==0.10.0
182 +SecretStorage==3.1.2
183 +Send2Trash==1.5.0
184 +simplegeneric==0.8.1
185 +singledispatch==3.4.0.3
186 +six==1.14.0
187 +snowballstemmer==2.0.0
188 +sortedcollections==1.1.2
189 +sortedcontainers==2.1.0
190 +soupsieve==1.9.5
191 +Sphinx==2.4.0
192 +sphinxcontrib-applehelp==1.0.1
193 +sphinxcontrib-devhelp==1.0.1
194 +sphinxcontrib-htmlhelp==1.0.2
195 +sphinxcontrib-jsmath==1.0.1
196 +sphinxcontrib-qthelp==1.0.2
197 +sphinxcontrib-serializinghtml==1.1.3
198 +sphinxcontrib-websupport==1.2.0
199 +spyder==4.0.1
200 +spyder-kernels==1.8.1
201 +SQLAlchemy==1.3.13
202 +statsmodels==0.11.0
203 +sympy==1.5.1
204 +tables==3.6.1
205 +tblib==1.6.0
206 +terminado==0.8.3
207 +testpath==0.4.4
208 +toolz==0.10.0
209 +tornado==6.0.3
210 +tqdm==4.42.1
211 +traitlets==4.3.3
212 +ujson==1.35
213 +unicodecsv==0.14.1
214 +urllib3==1.25.8
215 +watchdog==0.10.2
216 +wcwidth==0.1.8
217 +webencodings==0.5.1
218 +Werkzeug==1.0.0
219 +widgetsnbextension==3.5.1
220 +wrapt==1.11.2
221 +wurlitzer==2.0.0
222 +xlrd==1.2.0
223 +XlsxWriter==1.2.7
224 +xlwt==1.3.0
225 +xmltodict==0.12.0
226 +yapf==0.28.0
227 +zict==1.0.0
228 +zipp==2.2.0
1 +from keras.models import load_model
2 +from keras.datasets import fashion_mnist
3 +import matplotlib.pyplot as plt
4 +import os
5 +
6 +MODEL_SAVE_FOLDER_PATH = os.path.join(os.getcwd(), 'trained')
7 +MODEL_SAVE_PATH = MODEL_SAVE_FOLDER_PATH + '/seresnext_imagenet.h5'
8 +TEST_IMAGE_FOLDER_PATH = os.path.join(os.getcwd(), 'test')
9 +TEST_IMAGE_PATH = TEST_IMAGE_FOLDER_PATH + '/test01.png'
10 +
11 +model = load_model('MODEL_SAVE_PATH')
12 +(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
13 +model.predict(test_images[:1, :])
14 +model.predict_classes(test_images[:1, :], verbose=0)
15 +
16 +plt.imshow(test_images[0])
1 +from model import SEResNeXt
2 +
3 +from keras.datasets import fashion_mnist
4 +from keras import optimizers
5 +
6 +import os
7 +import sys
8 +
9 +import tensorflow_datasets as tfds
10 +
11 +import os
12 +
13 +MODEL_SAVE_FOLDER_PATH = os.path.join(os.getcwd(), 'trained')
14 +if not os.path.exists(MODEL_SAVE_FOLDER_PATH):
15 + os.mkdir(MODEL_SAVE_FOLDER_PATH)
16 +MODEL_SAVE_PATH = MODEL_SAVE_FOLDER_PATH + '/seresnext_imagenet.h5'
17 +
18 +model = SEResNeXt((112, 112, 3))
19 +
20 +model.compile(optimizer=optimizers.Adam(0.001), loss='categorical_crossentropy', metrics=['accuracy'])
21 +
22 +ds_train = tfds.load('imagenet2012_corrupted', split='train', shuffle_files=True)
23 +ds_train = ds_train.shuffle(1000).batch(128).prefetch(10)
24 +model.fit(ds_train['image'], ds_train['label'], epochs=10, steps_per_epoch=30)
25 +
26 +model.save(MODEL_SAVE_PATH)