4Moyede
1 +.vscode/
2 +__pycache__/
3 +trained/
4 +src/
1 +from keras.models import Model
2 +
3 +from keras.layers import Input, Reshape
4 +from keras.layers.core import Dense, Lambda, Activation
5 +from keras.layers.convolutional import Conv2D
6 +from keras.layers.pooling import GlobalAveragePooling2D, MaxPooling2D
7 +from keras.layers.merge import concatenate, add, multiply
8 +from keras.layers.normalization import BatchNormalization
9 +
10 +from keras.regularizers import l2
11 +from keras.utils.data_utils import get_file
12 +from keras_applications.imagenet_utils import _obtain_input_shape
13 +
14 +import keras.backend as K
15 +
16 +import os
17 +
18 +
19 +class SEResNeXt(Model):
20 + def __init__(self, weight, input_shape=None):
21 + '''
22 + ResNext Model
23 +
24 + ## Args
25 + + weight:
26 + + input_shape: optional shape tuple
27 + '''
28 +
29 + if weights not in {'cifar10', 'imagenet'}:
30 + raise ValueError
31 +
32 + self.__weight = weight
33 +
34 + if weight == 'cifar10':
35 + self.__depth = 29
36 + self.__cardinality = 8
37 + self.__width = 64
38 + self.__classes = 10
39 + else:
40 + self.__depth = [3, 8, 36, 3]
41 + self.__cardinality = 32
42 + self.__width = 4
43 + self.__classes = 1000
44 +
45 + self.__reduction_ratio = 4
46 + self.__weight_decay = 5e-4
47 + self.__channel_axis = 1 if K.image_data_format() == "channels_first" else -1
48 +
49 + if weight == 'cifar10':
50 + self.__input_shape = _obtain_input_shape(input_shape, default_size=32, min_size=8, data_format=K.image_data_format(), require_flatten=True)
51 + else:
52 + self.__input_shape = _obtain_input_shape(input_shape, default_size=224, min_size=112, data_format=K.image_data_format(), require_flatten=True)
53 + self.__img_input = Input(shape=self.__input_shape)
54 +
55 + # Create model.
56 + super(SEResNeXt, self).__init__(self.__img_input, self.__create_res_next(), name='seresnext')
57 +
58 + def __initial_conv_block(self):
59 + '''
60 + Adds an initial conv block, with batch norm and relu for the inception resnext
61 + '''
62 + if weight == 'cifar10':
63 + x = Conv2D(64, (3, 3), padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay))(self.__img_input)
64 + x = BatchNormalization(axis=self.__channel_axis)(x)
65 + x = Activation('relu')(x)
66 + return x
67 + else:
68 + x = Conv2D(64, (7, 7), padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay), strides=(2, 2))(self.__img_input)
69 + x = BatchNormalization(axis=self.__channel_axis)(x)
70 + x = Activation('relu')(x)
71 + x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
72 + return x
73 +
74 + def __grouped_convolution_block(self, input, grouped_channels, strides):
75 + '''
76 + Adds a grouped convolution block. It is an equivalent block from the paper
77 +
78 + ## Args
79 + + input: input tensor
80 + + grouped_channels: grouped number of filters
81 + + strides: performs strided convolution for downscaling if > 1
82 +
83 + ## Returns
84 + a keras tensor
85 + '''
86 + init = input
87 +
88 + group_list = []
89 + for c in range(self.__cardinality):
90 + x = Lambda(lambda z: z[:, :, :, c * grouped_channels:(c + 1) * grouped_channels])(input)
91 + x = Conv2D(grouped_channels, (3, 3), padding='same', use_bias=False, strides=(strides, strides), kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay))(x)
92 + group_list.append(x)
93 +
94 + group_merge = concatenate(group_list, axis=self.__channel_axis)
95 + x = BatchNormalization(axis=self.__channel_axis)(group_merge)
96 + x = Activation('relu')(x)
97 +
98 + return x
99 +
100 + def __bottleneck_block(self, input, filters=64, strides=1):
101 + '''
102 + Adds a bottleneck block
103 +
104 + ## Args
105 + + input: input tensor
106 + + filters: number of output filters
107 + + strides: performs strided convolution for downsampling if > 1
108 +
109 + ## Returns
110 + a keras tensor
111 + '''
112 + init = input
113 +
114 + grouped_channels = int(filters / self.__cardinality)
115 +
116 + # Check if input number of filters is same as 16 * k, else create convolution2d for this input
117 + if init._keras_shape[-1] != 2 * filters:
118 + init = Conv2D(filters * 2, (1, 1), padding='same', strides=(strides, strides), use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay))(init)
119 + init = BatchNormalization(axis=self.__channel_axis)(init)
120 +
121 + x = Conv2D(filters, (1, 1), padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay))(input)
122 + x = BatchNormalization(axis=self.__channel_axis)(x)
123 + x = Activation('relu')(x)
124 + x = self.__squeeze_excitation_layer(x, x[0].get_shape()[self.__channel_axis])
125 +
126 + x = self.__grouped_convolution_block(x, grouped_channels, strides)
127 +
128 + x = Conv2D(filters * 2, (1, 1), padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay))(x)
129 + x = BatchNormalization(axis=self.__channel_axis)(x)
130 +
131 + x = add([init, x])
132 + x = Activation('relu')(x)
133 +
134 + return x
135 +
136 + def __create_res_next(self):
137 + '''
138 + Creates a ResNeXt model with specified parameters
139 + '''
140 + if type(self.__depth) is list or type(self.__depth) is tuple:
141 + N = list(self.__depth)
142 + else:
143 + N = [(self.__depth - 2) // 9 for _ in range(3)]
144 + print(N)
145 +
146 + filters = self.__cardinality * self.__width
147 + filters_list = []
148 + for i in range(len(N)):
149 + filters_list.append(filters)
150 + filters *= 2 # double the size of the filters
151 +
152 + x = self.__initial_conv_block()
153 +
154 + # block 1 (no pooling)
155 + for i in range(N[0]):
156 + x = self.__bottleneck_block(x, filters_list[0], strides=1)
157 +
158 + N = N[1:] # remove the first block from block definition list
159 + filters_list = filters_list[1:] # remove the first filter from the filter list
160 +
161 + # block 2 to N
162 + for block_idx, n_i in enumerate(N):
163 + for i in range(n_i):
164 + if i == 0:
165 + x = self.__bottleneck_block(x, filters_list[block_idx], strides=2)
166 + else:
167 + x = self.__bottleneck_block(x, filters_list[block_idx], strides=1)
168 +
169 +
170 + x = GlobalAveragePooling2D()(x)
171 + x = Dense(self.__classes, use_bias=False, kernel_regularizer=l2(self.__weight_decay), kernel_initializer='he_normal', activation='softmax')(x)
172 +
173 + return x
174 +
175 + def __squeeze_excitation_layer(self, x, out_dim):
176 + '''
177 + SE Block Function
178 +
179 + ## Args
180 + + x : input feature map
181 + + out_dim : dimention of output channel
182 + '''
183 + squeeze = GlobalAveragePooling2D()(x)
184 +
185 + excitation = Dense(units=out_dim // self.__reduction_ratio)(squeeze)
186 + excitation = Activation('relu')(excitation)
187 + excitation = Dense(units=out_dim)(excitation)
188 + excitation = Activation('sigmoid')(excitation)
189 + excitation = Reshape((1,1,out_dim))(excitation)
190 +
191 + scale = multiply([x,excitation])
192 +
193 + return scale
194 +
195 +
196 +if __name__ == '__main__':
197 + model = SEResNeXt((112, 112, 3))
198 + model.summary()
1 +alabaster==0.7.12
2 +anaconda-client==1.7.2
3 +anaconda-navigator==1.9.12
4 +anaconda-project==0.8.3
5 +argh==0.26.2
6 +asn1crypto==1.3.0
7 +astroid==2.3.3
8 +astropy==4.0
9 +atomicwrites==1.3.0
10 +attrs==19.3.0
11 +autopep8==1.4.4
12 +Babel==2.8.0
13 +backcall==0.1.0
14 +backports.functools-lru-cache==1.6.1
15 +backports.shutil-get-terminal-size==1.0.0
16 +backports.tempfile==1.0
17 +backports.weakref==1.0.post1
18 +beautifulsoup4==4.8.2
19 +bitarray==1.2.1
20 +bkcharts==0.2
21 +bleach==3.1.0
22 +bokeh==1.4.0
23 +boto==2.49.0
24 +Bottleneck==1.3.2
25 +certifi==2019.11.28
26 +cffi==1.14.0
27 +chardet==3.0.4
28 +Click==7.0
29 +cloudpickle==1.3.0
30 +clyent==1.2.2
31 +colorama==0.4.3
32 +conda==4.8.3
33 +conda-build==3.18.11
34 +conda-package-handling==1.7.0
35 +conda-verify==3.4.2
36 +contextlib2==0.6.0.post1
37 +cryptography==2.8
38 +cycler==0.10.0
39 +Cython==0.29.15
40 +cytoolz==0.10.1
41 +dask==2.11.0
42 +decorator==4.4.1
43 +defusedxml==0.6.0
44 +diff-match-patch==20181111
45 +distributed==2.11.0
46 +docutils==0.16
47 +entrypoints==0.3
48 +et-xmlfile==1.0.1
49 +fastcache==1.1.0
50 +filelock==3.0.12
51 +flake8==3.7.9
52 +Flask==1.1.1
53 +fsspec==0.6.2
54 +future==0.18.2
55 +gevent==1.4.0
56 +glob2==0.7
57 +gmpy2==2.0.8
58 +greenlet==0.4.15
59 +h5py==2.10.0
60 +HeapDict==1.0.1
61 +html5lib==1.0.1
62 +hypothesis==5.5.4
63 +idna==2.8
64 +imageio==2.6.1
65 +imagesize==1.2.0
66 +importlib-metadata==1.5.0
67 +intervaltree==3.0.2
68 +ipykernel==5.1.4
69 +ipython==7.12.0
70 +ipython-genutils==0.2.0
71 +ipywidgets==7.5.1
72 +isort==4.3.21
73 +itsdangerous==1.1.0
74 +jdcal==1.4.1
75 +jedi==0.14.1
76 +jeepney==0.4.2
77 +Jinja2==2.11.1
78 +joblib==0.14.1
79 +json5==0.9.1
80 +jsonschema==3.2.0
81 +jupyter==1.0.0
82 +jupyter-client==5.3.4
83 +jupyter-console==6.1.0
84 +jupyter-core==4.6.1
85 +jupyterlab==1.2.6
86 +jupyterlab-server==1.0.6
87 +keyring==21.1.0
88 +kiwisolver==1.1.0
89 +lazy-object-proxy==1.4.3
90 +libarchive-c==2.8
91 +lief==0.9.0
92 +llvmlite==0.31.0
93 +locket==0.2.0
94 +lxml==4.5.0
95 +MarkupSafe==1.1.1
96 +matplotlib==3.1.3
97 +mccabe==0.6.1
98 +mistune==0.8.4
99 +mkl-fft==1.0.15
100 +mkl-random==1.1.0
101 +mkl-service==2.3.0
102 +mock==4.0.1
103 +more-itertools==8.2.0
104 +mpmath==1.1.0
105 +msgpack==0.6.1
106 +multipledispatch==0.6.0
107 +navigator-updater==0.2.1
108 +nbconvert==5.6.1
109 +nbformat==5.0.4
110 +networkx==2.4
111 +nltk==3.4.5
112 +nose==1.3.7
113 +notebook==6.0.3
114 +numba==0.48.0
115 +numexpr==2.7.1
116 +numpy==1.18.1
117 +numpydoc==0.9.2
118 +olefile==0.46
119 +openpyxl==3.0.3
120 +packaging==20.1
121 +pandas==1.0.1
122 +pandocfilters==1.4.2
123 +parso==0.5.2
124 +partd==1.1.0
125 +path==13.1.0
126 +pathlib2==2.3.5
127 +pathtools==0.1.2
128 +patsy==0.5.1
129 +pep8==1.7.1
130 +pexpect==4.8.0
131 +pickleshare==0.7.5
132 +Pillow==7.0.0
133 +pkginfo==1.5.0.1
134 +pluggy==0.13.1
135 +ply==3.11
136 +prometheus-client==0.7.1
137 +prompt-toolkit==3.0.3
138 +psutil==5.6.7
139 +ptyprocess==0.6.0
140 +py==1.8.1
141 +pycodestyle==2.5.0
142 +pycosat==0.6.3
143 +pycparser==2.19
144 +pycrypto==2.6.1
145 +pycurl==7.43.0.5
146 +pydocstyle==4.0.1
147 +pyflakes==2.1.1
148 +Pygments==2.5.2
149 +pylint==2.4.4
150 +pyodbc===4.0.0-unsupported
151 +pyOpenSSL==19.1.0
152 +pyparsing==2.4.6
153 +pyrsistent==0.15.7
154 +PySocks==1.7.1
155 +pytest==5.3.5
156 +pytest-arraydiff==0.3
157 +pytest-astropy==0.8.0
158 +pytest-astropy-header==0.1.2
159 +pytest-doctestplus==0.5.0
160 +pytest-openfiles==0.4.0
161 +pytest-remotedata==0.3.2
162 +python-dateutil==2.8.1
163 +python-jsonrpc-server==0.3.4
164 +python-language-server==0.31.7
165 +pytz==2019.3
166 +PyWavelets==1.1.1
167 +pyxdg==0.26
168 +PyYAML==5.3
169 +pyzmq==18.1.1
170 +QDarkStyle==2.8
171 +QtAwesome==0.6.1
172 +qtconsole==4.6.0
173 +QtPy==1.9.0
174 +requests==2.22.0
175 +rope==0.16.0
176 +Rtree==0.9.3
177 +ruamel-yaml==0.15.87
178 +scikit-image==0.16.2
179 +scikit-learn==0.22.1
180 +scipy==1.4.1
181 +seaborn==0.10.0
182 +SecretStorage==3.1.2
183 +Send2Trash==1.5.0
184 +simplegeneric==0.8.1
185 +singledispatch==3.4.0.3
186 +six==1.14.0
187 +snowballstemmer==2.0.0
188 +sortedcollections==1.1.2
189 +sortedcontainers==2.1.0
190 +soupsieve==1.9.5
191 +Sphinx==2.4.0
192 +sphinxcontrib-applehelp==1.0.1
193 +sphinxcontrib-devhelp==1.0.1
194 +sphinxcontrib-htmlhelp==1.0.2
195 +sphinxcontrib-jsmath==1.0.1
196 +sphinxcontrib-qthelp==1.0.2
197 +sphinxcontrib-serializinghtml==1.1.3
198 +sphinxcontrib-websupport==1.2.0
199 +spyder==4.0.1
200 +spyder-kernels==1.8.1
201 +SQLAlchemy==1.3.13
202 +statsmodels==0.11.0
203 +sympy==1.5.1
204 +tables==3.6.1
205 +tblib==1.6.0
206 +terminado==0.8.3
207 +testpath==0.4.4
208 +toolz==0.10.0
209 +tornado==6.0.3
210 +tqdm==4.42.1
211 +traitlets==4.3.3
212 +ujson==1.35
213 +unicodecsv==0.14.1
214 +urllib3==1.25.8
215 +watchdog==0.10.2
216 +wcwidth==0.1.8
217 +webencodings==0.5.1
218 +Werkzeug==1.0.0
219 +widgetsnbextension==3.5.1
220 +wrapt==1.11.2
221 +wurlitzer==2.0.0
222 +xlrd==1.2.0
223 +XlsxWriter==1.2.7
224 +xlwt==1.3.0
225 +xmltodict==0.12.0
226 +yapf==0.28.0
227 +zict==1.0.0
228 +zipp==2.2.0
1 +from keras.models import load_model
2 +from keras.datasets import fashion_mnist
3 +import matplotlib.pyplot as plt
4 +import os
5 +
6 +MODEL_SAVE_FOLDER_PATH = os.path.join(os.getcwd(), 'trained')
7 +MODEL_SAVE_PATH = MODEL_SAVE_FOLDER_PATH + '/seresnext_cifar10.h5'
8 +TEST_IMAGE_FOLDER_PATH = os.path.join(os.getcwd(), 'test')
9 +TEST_IMAGE_PATH = TEST_IMAGE_FOLDER_PATH + '/test01.png'
10 +
11 +model = load_model('MODEL_SAVE_PATH')
12 +(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
13 +model.predict(test_images[:1, :])
14 +model.predict_classes(test_images[:1, :], verbose=0)
15 +
16 +plt.imshow(test_images[0])
1 +from model import SEResNeXt
2 +
3 +from keras.datasets import fashion_mnist
4 +from keras import optimizers
5 +
6 +import os
7 +import sys
8 +
9 +import tensorflow_datasets as tfds
10 +
11 +import os
12 +
13 +MODEL_SAVE_FOLDER_PATH = os.path.join(os.getcwd(), 'trained')
14 +if not os.path.exists(MODEL_SAVE_FOLDER_PATH):
15 + os.mkdir(MODEL_SAVE_FOLDER_PATH)
16 +MODEL_SAVE_PATH = MODEL_SAVE_FOLDER_PATH + '/seresnext_cifar.h5'
17 +
18 +model = SEResNeXt('cifar-10', (32, 32, 3))
19 +
20 +model.compile(optimizer=optimizers.Adam(0.001), loss='categorical_crossentropy', metrics=['accuracy'])
21 +
22 +# ds_train = tfds.load('imagenet2012_corrupted', split='train', shuffle_files=True)
23 +ds_train = tfds.load('cifar-10', split='train', shuffle_files=True)
24 +ds_train = ds_train.shuffle(1000).batch(128).prefetch(10)
25 +
26 +model.fit(ds_train['image'], ds_train['label'], epochs=10, steps_per_epoch=30)
27 +
28 +model.save(MODEL_SAVE_PATH)