Merge branch 'code' of http://khuhub.khu.ac.kr/2020-1-capstone-design2/2014103189 into report
Showing
5 changed files
with
474 additions
and
0 deletions
소스코드/.gitignore
0 → 100644
소스코드/model.py
0 → 100644
1 | +from keras.models import Model | ||
2 | + | ||
3 | +from keras.layers import Input, Reshape | ||
4 | +from keras.layers.core import Dense, Lambda, Activation | ||
5 | +from keras.layers.convolutional import Conv2D | ||
6 | +from keras.layers.pooling import GlobalAveragePooling2D, MaxPooling2D | ||
7 | +from keras.layers.merge import concatenate, add, multiply | ||
8 | +from keras.layers.normalization import BatchNormalization | ||
9 | + | ||
10 | +from keras.regularizers import l2 | ||
11 | +from keras.utils.data_utils import get_file | ||
12 | +from keras_applications.imagenet_utils import _obtain_input_shape | ||
13 | + | ||
14 | +import keras.backend as K | ||
15 | + | ||
16 | +import os | ||
17 | + | ||
18 | + | ||
19 | +class SEResNeXt(Model): | ||
20 | + def __init__(self, weight, input_shape=None): | ||
21 | + ''' | ||
22 | + ResNext Model | ||
23 | + | ||
24 | + ## Args | ||
25 | + + weight: | ||
26 | + + input_shape: optional shape tuple | ||
27 | + ''' | ||
28 | + | ||
29 | + if weights not in {'cifar10', 'imagenet'}: | ||
30 | + raise ValueError | ||
31 | + | ||
32 | + self.__weight = weight | ||
33 | + | ||
34 | + if weight == 'cifar10': | ||
35 | + self.__depth = 29 | ||
36 | + self.__cardinality = 8 | ||
37 | + self.__width = 64 | ||
38 | + self.__classes = 10 | ||
39 | + else: | ||
40 | + self.__depth = [3, 8, 36, 3] | ||
41 | + self.__cardinality = 32 | ||
42 | + self.__width = 4 | ||
43 | + self.__classes = 1000 | ||
44 | + | ||
45 | + self.__reduction_ratio = 4 | ||
46 | + self.__weight_decay = 5e-4 | ||
47 | + self.__channel_axis = 1 if K.image_data_format() == "channels_first" else -1 | ||
48 | + | ||
49 | + if weight == 'cifar10': | ||
50 | + self.__input_shape = _obtain_input_shape(input_shape, default_size=32, min_size=8, data_format=K.image_data_format(), require_flatten=True) | ||
51 | + else: | ||
52 | + self.__input_shape = _obtain_input_shape(input_shape, default_size=224, min_size=112, data_format=K.image_data_format(), require_flatten=True) | ||
53 | + self.__img_input = Input(shape=self.__input_shape) | ||
54 | + | ||
55 | + # Create model. | ||
56 | + super(SEResNeXt, self).__init__(self.__img_input, self.__create_res_next(), name='seresnext') | ||
57 | + | ||
58 | + def __initial_conv_block(self): | ||
59 | + ''' | ||
60 | + Adds an initial conv block, with batch norm and relu for the inception resnext | ||
61 | + ''' | ||
62 | + if weight == 'cifar10': | ||
63 | + x = Conv2D(64, (3, 3), padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay))(self.__img_input) | ||
64 | + x = BatchNormalization(axis=self.__channel_axis)(x) | ||
65 | + x = Activation('relu')(x) | ||
66 | + return x | ||
67 | + else: | ||
68 | + x = Conv2D(64, (7, 7), padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay), strides=(2, 2))(self.__img_input) | ||
69 | + x = BatchNormalization(axis=self.__channel_axis)(x) | ||
70 | + x = Activation('relu')(x) | ||
71 | + x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x) | ||
72 | + return x | ||
73 | + | ||
74 | + def __grouped_convolution_block(self, input, grouped_channels, strides): | ||
75 | + ''' | ||
76 | + Adds a grouped convolution block. It is an equivalent block from the paper | ||
77 | + | ||
78 | + ## Args | ||
79 | + + input: input tensor | ||
80 | + + grouped_channels: grouped number of filters | ||
81 | + + strides: performs strided convolution for downscaling if > 1 | ||
82 | + | ||
83 | + ## Returns | ||
84 | + a keras tensor | ||
85 | + ''' | ||
86 | + init = input | ||
87 | + | ||
88 | + group_list = [] | ||
89 | + for c in range(self.__cardinality): | ||
90 | + x = Lambda(lambda z: z[:, :, :, c * grouped_channels:(c + 1) * grouped_channels])(input) | ||
91 | + x = Conv2D(grouped_channels, (3, 3), padding='same', use_bias=False, strides=(strides, strides), kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay))(x) | ||
92 | + group_list.append(x) | ||
93 | + | ||
94 | + group_merge = concatenate(group_list, axis=self.__channel_axis) | ||
95 | + x = BatchNormalization(axis=self.__channel_axis)(group_merge) | ||
96 | + x = Activation('relu')(x) | ||
97 | + | ||
98 | + return x | ||
99 | + | ||
100 | + def __bottleneck_block(self, input, filters=64, strides=1): | ||
101 | + ''' | ||
102 | + Adds a bottleneck block | ||
103 | + | ||
104 | + ## Args | ||
105 | + + input: input tensor | ||
106 | + + filters: number of output filters | ||
107 | + + strides: performs strided convolution for downsampling if > 1 | ||
108 | + | ||
109 | + ## Returns | ||
110 | + a keras tensor | ||
111 | + ''' | ||
112 | + init = input | ||
113 | + | ||
114 | + grouped_channels = int(filters / self.__cardinality) | ||
115 | + | ||
116 | + # Check if input number of filters is same as 16 * k, else create convolution2d for this input | ||
117 | + if init._keras_shape[-1] != 2 * filters: | ||
118 | + init = Conv2D(filters * 2, (1, 1), padding='same', strides=(strides, strides), use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay))(init) | ||
119 | + init = BatchNormalization(axis=self.__channel_axis)(init) | ||
120 | + | ||
121 | + x = Conv2D(filters, (1, 1), padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay))(input) | ||
122 | + x = BatchNormalization(axis=self.__channel_axis)(x) | ||
123 | + x = Activation('relu')(x) | ||
124 | + x = self.__squeeze_excitation_layer(x, x[0].get_shape()[self.__channel_axis]) | ||
125 | + | ||
126 | + x = self.__grouped_convolution_block(x, grouped_channels, strides) | ||
127 | + | ||
128 | + x = Conv2D(filters * 2, (1, 1), padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay))(x) | ||
129 | + x = BatchNormalization(axis=self.__channel_axis)(x) | ||
130 | + | ||
131 | + x = add([init, x]) | ||
132 | + x = Activation('relu')(x) | ||
133 | + | ||
134 | + return x | ||
135 | + | ||
136 | + def __create_res_next(self): | ||
137 | + ''' | ||
138 | + Creates a ResNeXt model with specified parameters | ||
139 | + ''' | ||
140 | + if type(self.__depth) is list or type(self.__depth) is tuple: | ||
141 | + N = list(self.__depth) | ||
142 | + else: | ||
143 | + N = [(self.__depth - 2) // 9 for _ in range(3)] | ||
144 | + print(N) | ||
145 | + | ||
146 | + filters = self.__cardinality * self.__width | ||
147 | + filters_list = [] | ||
148 | + for i in range(len(N)): | ||
149 | + filters_list.append(filters) | ||
150 | + filters *= 2 # double the size of the filters | ||
151 | + | ||
152 | + x = self.__initial_conv_block() | ||
153 | + | ||
154 | + # block 1 (no pooling) | ||
155 | + for i in range(N[0]): | ||
156 | + x = self.__bottleneck_block(x, filters_list[0], strides=1) | ||
157 | + | ||
158 | + N = N[1:] # remove the first block from block definition list | ||
159 | + filters_list = filters_list[1:] # remove the first filter from the filter list | ||
160 | + | ||
161 | + # block 2 to N | ||
162 | + for block_idx, n_i in enumerate(N): | ||
163 | + for i in range(n_i): | ||
164 | + if i == 0: | ||
165 | + x = self.__bottleneck_block(x, filters_list[block_idx], strides=2) | ||
166 | + else: | ||
167 | + x = self.__bottleneck_block(x, filters_list[block_idx], strides=1) | ||
168 | + | ||
169 | + | ||
170 | + x = GlobalAveragePooling2D()(x) | ||
171 | + x = Dense(self.__classes, use_bias=False, kernel_regularizer=l2(self.__weight_decay), kernel_initializer='he_normal', activation='softmax')(x) | ||
172 | + | ||
173 | + return x | ||
174 | + | ||
175 | + def __squeeze_excitation_layer(self, x, out_dim): | ||
176 | + ''' | ||
177 | + SE Block Function | ||
178 | + | ||
179 | + ## Args | ||
180 | + + x : input feature map | ||
181 | + + out_dim : dimention of output channel | ||
182 | + ''' | ||
183 | + squeeze = GlobalAveragePooling2D()(x) | ||
184 | + | ||
185 | + excitation = Dense(units=out_dim // self.__reduction_ratio)(squeeze) | ||
186 | + excitation = Activation('relu')(excitation) | ||
187 | + excitation = Dense(units=out_dim)(excitation) | ||
188 | + excitation = Activation('sigmoid')(excitation) | ||
189 | + excitation = Reshape((1,1,out_dim))(excitation) | ||
190 | + | ||
191 | + scale = multiply([x,excitation]) | ||
192 | + | ||
193 | + return scale | ||
194 | + | ||
195 | + | ||
196 | +if __name__ == '__main__': | ||
197 | + model = SEResNeXt((112, 112, 3)) | ||
198 | + model.summary() |
소스코드/requirements.txt
0 → 100644
1 | +alabaster==0.7.12 | ||
2 | +anaconda-client==1.7.2 | ||
3 | +anaconda-navigator==1.9.12 | ||
4 | +anaconda-project==0.8.3 | ||
5 | +argh==0.26.2 | ||
6 | +asn1crypto==1.3.0 | ||
7 | +astroid==2.3.3 | ||
8 | +astropy==4.0 | ||
9 | +atomicwrites==1.3.0 | ||
10 | +attrs==19.3.0 | ||
11 | +autopep8==1.4.4 | ||
12 | +Babel==2.8.0 | ||
13 | +backcall==0.1.0 | ||
14 | +backports.functools-lru-cache==1.6.1 | ||
15 | +backports.shutil-get-terminal-size==1.0.0 | ||
16 | +backports.tempfile==1.0 | ||
17 | +backports.weakref==1.0.post1 | ||
18 | +beautifulsoup4==4.8.2 | ||
19 | +bitarray==1.2.1 | ||
20 | +bkcharts==0.2 | ||
21 | +bleach==3.1.0 | ||
22 | +bokeh==1.4.0 | ||
23 | +boto==2.49.0 | ||
24 | +Bottleneck==1.3.2 | ||
25 | +certifi==2019.11.28 | ||
26 | +cffi==1.14.0 | ||
27 | +chardet==3.0.4 | ||
28 | +Click==7.0 | ||
29 | +cloudpickle==1.3.0 | ||
30 | +clyent==1.2.2 | ||
31 | +colorama==0.4.3 | ||
32 | +conda==4.8.3 | ||
33 | +conda-build==3.18.11 | ||
34 | +conda-package-handling==1.7.0 | ||
35 | +conda-verify==3.4.2 | ||
36 | +contextlib2==0.6.0.post1 | ||
37 | +cryptography==2.8 | ||
38 | +cycler==0.10.0 | ||
39 | +Cython==0.29.15 | ||
40 | +cytoolz==0.10.1 | ||
41 | +dask==2.11.0 | ||
42 | +decorator==4.4.1 | ||
43 | +defusedxml==0.6.0 | ||
44 | +diff-match-patch==20181111 | ||
45 | +distributed==2.11.0 | ||
46 | +docutils==0.16 | ||
47 | +entrypoints==0.3 | ||
48 | +et-xmlfile==1.0.1 | ||
49 | +fastcache==1.1.0 | ||
50 | +filelock==3.0.12 | ||
51 | +flake8==3.7.9 | ||
52 | +Flask==1.1.1 | ||
53 | +fsspec==0.6.2 | ||
54 | +future==0.18.2 | ||
55 | +gevent==1.4.0 | ||
56 | +glob2==0.7 | ||
57 | +gmpy2==2.0.8 | ||
58 | +greenlet==0.4.15 | ||
59 | +h5py==2.10.0 | ||
60 | +HeapDict==1.0.1 | ||
61 | +html5lib==1.0.1 | ||
62 | +hypothesis==5.5.4 | ||
63 | +idna==2.8 | ||
64 | +imageio==2.6.1 | ||
65 | +imagesize==1.2.0 | ||
66 | +importlib-metadata==1.5.0 | ||
67 | +intervaltree==3.0.2 | ||
68 | +ipykernel==5.1.4 | ||
69 | +ipython==7.12.0 | ||
70 | +ipython-genutils==0.2.0 | ||
71 | +ipywidgets==7.5.1 | ||
72 | +isort==4.3.21 | ||
73 | +itsdangerous==1.1.0 | ||
74 | +jdcal==1.4.1 | ||
75 | +jedi==0.14.1 | ||
76 | +jeepney==0.4.2 | ||
77 | +Jinja2==2.11.1 | ||
78 | +joblib==0.14.1 | ||
79 | +json5==0.9.1 | ||
80 | +jsonschema==3.2.0 | ||
81 | +jupyter==1.0.0 | ||
82 | +jupyter-client==5.3.4 | ||
83 | +jupyter-console==6.1.0 | ||
84 | +jupyter-core==4.6.1 | ||
85 | +jupyterlab==1.2.6 | ||
86 | +jupyterlab-server==1.0.6 | ||
87 | +keyring==21.1.0 | ||
88 | +kiwisolver==1.1.0 | ||
89 | +lazy-object-proxy==1.4.3 | ||
90 | +libarchive-c==2.8 | ||
91 | +lief==0.9.0 | ||
92 | +llvmlite==0.31.0 | ||
93 | +locket==0.2.0 | ||
94 | +lxml==4.5.0 | ||
95 | +MarkupSafe==1.1.1 | ||
96 | +matplotlib==3.1.3 | ||
97 | +mccabe==0.6.1 | ||
98 | +mistune==0.8.4 | ||
99 | +mkl-fft==1.0.15 | ||
100 | +mkl-random==1.1.0 | ||
101 | +mkl-service==2.3.0 | ||
102 | +mock==4.0.1 | ||
103 | +more-itertools==8.2.0 | ||
104 | +mpmath==1.1.0 | ||
105 | +msgpack==0.6.1 | ||
106 | +multipledispatch==0.6.0 | ||
107 | +navigator-updater==0.2.1 | ||
108 | +nbconvert==5.6.1 | ||
109 | +nbformat==5.0.4 | ||
110 | +networkx==2.4 | ||
111 | +nltk==3.4.5 | ||
112 | +nose==1.3.7 | ||
113 | +notebook==6.0.3 | ||
114 | +numba==0.48.0 | ||
115 | +numexpr==2.7.1 | ||
116 | +numpy==1.18.1 | ||
117 | +numpydoc==0.9.2 | ||
118 | +olefile==0.46 | ||
119 | +openpyxl==3.0.3 | ||
120 | +packaging==20.1 | ||
121 | +pandas==1.0.1 | ||
122 | +pandocfilters==1.4.2 | ||
123 | +parso==0.5.2 | ||
124 | +partd==1.1.0 | ||
125 | +path==13.1.0 | ||
126 | +pathlib2==2.3.5 | ||
127 | +pathtools==0.1.2 | ||
128 | +patsy==0.5.1 | ||
129 | +pep8==1.7.1 | ||
130 | +pexpect==4.8.0 | ||
131 | +pickleshare==0.7.5 | ||
132 | +Pillow==7.0.0 | ||
133 | +pkginfo==1.5.0.1 | ||
134 | +pluggy==0.13.1 | ||
135 | +ply==3.11 | ||
136 | +prometheus-client==0.7.1 | ||
137 | +prompt-toolkit==3.0.3 | ||
138 | +psutil==5.6.7 | ||
139 | +ptyprocess==0.6.0 | ||
140 | +py==1.8.1 | ||
141 | +pycodestyle==2.5.0 | ||
142 | +pycosat==0.6.3 | ||
143 | +pycparser==2.19 | ||
144 | +pycrypto==2.6.1 | ||
145 | +pycurl==7.43.0.5 | ||
146 | +pydocstyle==4.0.1 | ||
147 | +pyflakes==2.1.1 | ||
148 | +Pygments==2.5.2 | ||
149 | +pylint==2.4.4 | ||
150 | +pyodbc===4.0.0-unsupported | ||
151 | +pyOpenSSL==19.1.0 | ||
152 | +pyparsing==2.4.6 | ||
153 | +pyrsistent==0.15.7 | ||
154 | +PySocks==1.7.1 | ||
155 | +pytest==5.3.5 | ||
156 | +pytest-arraydiff==0.3 | ||
157 | +pytest-astropy==0.8.0 | ||
158 | +pytest-astropy-header==0.1.2 | ||
159 | +pytest-doctestplus==0.5.0 | ||
160 | +pytest-openfiles==0.4.0 | ||
161 | +pytest-remotedata==0.3.2 | ||
162 | +python-dateutil==2.8.1 | ||
163 | +python-jsonrpc-server==0.3.4 | ||
164 | +python-language-server==0.31.7 | ||
165 | +pytz==2019.3 | ||
166 | +PyWavelets==1.1.1 | ||
167 | +pyxdg==0.26 | ||
168 | +PyYAML==5.3 | ||
169 | +pyzmq==18.1.1 | ||
170 | +QDarkStyle==2.8 | ||
171 | +QtAwesome==0.6.1 | ||
172 | +qtconsole==4.6.0 | ||
173 | +QtPy==1.9.0 | ||
174 | +requests==2.22.0 | ||
175 | +rope==0.16.0 | ||
176 | +Rtree==0.9.3 | ||
177 | +ruamel-yaml==0.15.87 | ||
178 | +scikit-image==0.16.2 | ||
179 | +scikit-learn==0.22.1 | ||
180 | +scipy==1.4.1 | ||
181 | +seaborn==0.10.0 | ||
182 | +SecretStorage==3.1.2 | ||
183 | +Send2Trash==1.5.0 | ||
184 | +simplegeneric==0.8.1 | ||
185 | +singledispatch==3.4.0.3 | ||
186 | +six==1.14.0 | ||
187 | +snowballstemmer==2.0.0 | ||
188 | +sortedcollections==1.1.2 | ||
189 | +sortedcontainers==2.1.0 | ||
190 | +soupsieve==1.9.5 | ||
191 | +Sphinx==2.4.0 | ||
192 | +sphinxcontrib-applehelp==1.0.1 | ||
193 | +sphinxcontrib-devhelp==1.0.1 | ||
194 | +sphinxcontrib-htmlhelp==1.0.2 | ||
195 | +sphinxcontrib-jsmath==1.0.1 | ||
196 | +sphinxcontrib-qthelp==1.0.2 | ||
197 | +sphinxcontrib-serializinghtml==1.1.3 | ||
198 | +sphinxcontrib-websupport==1.2.0 | ||
199 | +spyder==4.0.1 | ||
200 | +spyder-kernels==1.8.1 | ||
201 | +SQLAlchemy==1.3.13 | ||
202 | +statsmodels==0.11.0 | ||
203 | +sympy==1.5.1 | ||
204 | +tables==3.6.1 | ||
205 | +tblib==1.6.0 | ||
206 | +terminado==0.8.3 | ||
207 | +testpath==0.4.4 | ||
208 | +toolz==0.10.0 | ||
209 | +tornado==6.0.3 | ||
210 | +tqdm==4.42.1 | ||
211 | +traitlets==4.3.3 | ||
212 | +ujson==1.35 | ||
213 | +unicodecsv==0.14.1 | ||
214 | +urllib3==1.25.8 | ||
215 | +watchdog==0.10.2 | ||
216 | +wcwidth==0.1.8 | ||
217 | +webencodings==0.5.1 | ||
218 | +Werkzeug==1.0.0 | ||
219 | +widgetsnbextension==3.5.1 | ||
220 | +wrapt==1.11.2 | ||
221 | +wurlitzer==2.0.0 | ||
222 | +xlrd==1.2.0 | ||
223 | +XlsxWriter==1.2.7 | ||
224 | +xlwt==1.3.0 | ||
225 | +xmltodict==0.12.0 | ||
226 | +yapf==0.28.0 | ||
227 | +zict==1.0.0 | ||
228 | +zipp==2.2.0 |
소스코드/test.py
0 → 100644
1 | +from keras.models import load_model | ||
2 | +from keras.datasets import fashion_mnist | ||
3 | +import matplotlib.pyplot as plt | ||
4 | +import os | ||
5 | + | ||
6 | +MODEL_SAVE_FOLDER_PATH = os.path.join(os.getcwd(), 'trained') | ||
7 | +MODEL_SAVE_PATH = MODEL_SAVE_FOLDER_PATH + '/seresnext_cifar10.h5' | ||
8 | +TEST_IMAGE_FOLDER_PATH = os.path.join(os.getcwd(), 'test') | ||
9 | +TEST_IMAGE_PATH = TEST_IMAGE_FOLDER_PATH + '/test01.png' | ||
10 | + | ||
11 | +model = load_model('MODEL_SAVE_PATH') | ||
12 | +(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data() | ||
13 | +model.predict(test_images[:1, :]) | ||
14 | +model.predict_classes(test_images[:1, :], verbose=0) | ||
15 | + | ||
16 | +plt.imshow(test_images[0]) |
소스코드/train.py
0 → 100644
1 | +from model import SEResNeXt | ||
2 | + | ||
3 | +from keras.datasets import fashion_mnist | ||
4 | +from keras import optimizers | ||
5 | + | ||
6 | +import os | ||
7 | +import sys | ||
8 | + | ||
9 | +import tensorflow_datasets as tfds | ||
10 | + | ||
11 | +import os | ||
12 | + | ||
13 | +MODEL_SAVE_FOLDER_PATH = os.path.join(os.getcwd(), 'trained') | ||
14 | +if not os.path.exists(MODEL_SAVE_FOLDER_PATH): | ||
15 | + os.mkdir(MODEL_SAVE_FOLDER_PATH) | ||
16 | +MODEL_SAVE_PATH = MODEL_SAVE_FOLDER_PATH + '/seresnext_cifar.h5' | ||
17 | + | ||
18 | +model = SEResNeXt('cifar-10', (32, 32, 3)) | ||
19 | + | ||
20 | +model.compile(optimizer=optimizers.Adam(0.001), loss='categorical_crossentropy', metrics=['accuracy']) | ||
21 | + | ||
22 | +# ds_train = tfds.load('imagenet2012_corrupted', split='train', shuffle_files=True) | ||
23 | +ds_train = tfds.load('cifar-10', split='train', shuffle_files=True) | ||
24 | +ds_train = ds_train.shuffle(1000).batch(128).prefetch(10) | ||
25 | + | ||
26 | +model.fit(ds_train['image'], ds_train['label'], epochs=10, steps_per_epoch=30) | ||
27 | + | ||
28 | +model.save(MODEL_SAVE_PATH) |
-
Please register or login to post a comment