Showing
3 changed files
with
46 additions
and
30 deletions
... | @@ -17,30 +17,39 @@ import os | ... | @@ -17,30 +17,39 @@ import os |
17 | 17 | ||
18 | 18 | ||
19 | class SEResNeXt(Model): | 19 | class SEResNeXt(Model): |
20 | - def __init__(self, weight, input_shape=None, depth=[3, 8, 36, 3], cardinality=32, width=4, reduction_ratio=4, weight_decay=5e-4, classes=1000, channel_axis=None): | 20 | + def __init__(self, weight, input_shape=None): |
21 | ''' | 21 | ''' |
22 | ResNext Model | 22 | ResNext Model |
23 | 23 | ||
24 | ## Args | 24 | ## Args |
25 | + + weight: | ||
25 | + input_shape: optional shape tuple | 26 | + input_shape: optional shape tuple |
26 | - + depth: number or layers in the each block, defined as a list | ||
27 | - + cardinality: the size of the set of transformations | ||
28 | - + width: multiplier to the ResNeXt width (number of filters) | ||
29 | - + redution_ratio: ratio of reducition in SE Block | ||
30 | - + weight_decay: weight decay (l2 norm) | ||
31 | - + classes: number of classes to classify images into | ||
32 | - + channel_axis: channel axis in keras.backend.image_data_format() | ||
33 | ''' | 27 | ''' |
28 | + | ||
29 | + if weights not in {'cifar10', 'imagenet'}: | ||
30 | + raise ValueError | ||
31 | + | ||
34 | self.__weight = weight | 32 | self.__weight = weight |
35 | - self.__depth = depth | 33 | + |
36 | - self.__cardinality = cardinality | 34 | + if weight == 'cifar10': |
37 | - self.__width = width | 35 | + self.__depth = 29 |
38 | - self.__reduction_ratio = reduction_ratio | 36 | + self.__cardinality = 8 |
39 | - self.__weight_decay = weight_decay | 37 | + self.__width = 64 |
40 | - self.__classes = classes | 38 | + self.__classes = 10 |
39 | + else: | ||
40 | + self.__depth = [3, 8, 36, 3] | ||
41 | + self.__cardinality = 32 | ||
42 | + self.__width = 4 | ||
43 | + self.__classes = 1000 | ||
44 | + | ||
45 | + self.__reduction_ratio = 4 | ||
46 | + self.__weight_decay = 5e-4 | ||
41 | self.__channel_axis = 1 if K.image_data_format() == "channels_first" else -1 | 47 | self.__channel_axis = 1 if K.image_data_format() == "channels_first" else -1 |
42 | 48 | ||
43 | - self.__input_shape = _obtain_input_shape(input_shape, default_size = 224, min_size = 112, data_format=K.image_data_format(), require_flatten=True) | 49 | + if weight == 'cifar10': |
50 | + self.__input_shape = _obtain_input_shape(input_shape, default_size=32, min_size=8, data_format=K.image_data_format(), require_flatten=True) | ||
51 | + else: | ||
52 | + self.__input_shape = _obtain_input_shape(input_shape, default_size=224, min_size=112, data_format=K.image_data_format(), require_flatten=True) | ||
44 | self.__img_input = Input(shape=self.__input_shape) | 53 | self.__img_input = Input(shape=self.__input_shape) |
45 | 54 | ||
46 | # Create model. | 55 | # Create model. |
... | @@ -50,15 +59,17 @@ class SEResNeXt(Model): | ... | @@ -50,15 +59,17 @@ class SEResNeXt(Model): |
50 | ''' | 59 | ''' |
51 | Adds an initial conv block, with batch norm and relu for the inception resnext | 60 | Adds an initial conv block, with batch norm and relu for the inception resnext |
52 | ''' | 61 | ''' |
53 | - channel_axis = -1 | 62 | + if weight == 'cifar10': |
54 | - | 63 | + x = Conv2D(64, (3, 3), padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay))(self.__img_input) |
55 | - x = Conv2D(64, (7, 7), padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay), strides=(2, 2))(self.__img_input) | 64 | + x = BatchNormalization(axis=self.__channel_axis)(x) |
56 | - x = BatchNormalization(axis=channel_axis)(x) | 65 | + x = Activation('relu')(x) |
57 | - x = Activation('relu')(x) | 66 | + return x |
58 | - | 67 | + else: |
59 | - x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x) | 68 | + x = Conv2D(64, (7, 7), padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay), strides=(2, 2))(self.__img_input) |
60 | - | 69 | + x = BatchNormalization(axis=self.__channel_axis)(x) |
61 | - return x | 70 | + x = Activation('relu')(x) |
71 | + x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x) | ||
72 | + return x | ||
62 | 73 | ||
63 | def __grouped_convolution_block(self, input, grouped_channels, strides): | 74 | def __grouped_convolution_block(self, input, grouped_channels, strides): |
64 | ''' | 75 | ''' |
... | @@ -126,8 +137,11 @@ class SEResNeXt(Model): | ... | @@ -126,8 +137,11 @@ class SEResNeXt(Model): |
126 | ''' | 137 | ''' |
127 | Creates a ResNeXt model with specified parameters | 138 | Creates a ResNeXt model with specified parameters |
128 | ''' | 139 | ''' |
129 | - | 140 | + if type(self.__depth) is list or type(self.__depth) is tuple: |
130 | - N = list(self.__depth) | 141 | + N = list(self.__depth) |
142 | + else: | ||
143 | + N = [(self.__depth - 2) // 9 for _ in range(3)] | ||
144 | + print(N) | ||
131 | 145 | ||
132 | filters = self.__cardinality * self.__width | 146 | filters = self.__cardinality * self.__width |
133 | filters_list = [] | 147 | filters_list = [] | ... | ... |
... | @@ -4,7 +4,7 @@ import matplotlib.pyplot as plt | ... | @@ -4,7 +4,7 @@ import matplotlib.pyplot as plt |
4 | import os | 4 | import os |
5 | 5 | ||
6 | MODEL_SAVE_FOLDER_PATH = os.path.join(os.getcwd(), 'trained') | 6 | MODEL_SAVE_FOLDER_PATH = os.path.join(os.getcwd(), 'trained') |
7 | -MODEL_SAVE_PATH = MODEL_SAVE_FOLDER_PATH + '/seresnext_imagenet.h5' | 7 | +MODEL_SAVE_PATH = MODEL_SAVE_FOLDER_PATH + '/seresnext_cifar10.h5' |
8 | TEST_IMAGE_FOLDER_PATH = os.path.join(os.getcwd(), 'test') | 8 | TEST_IMAGE_FOLDER_PATH = os.path.join(os.getcwd(), 'test') |
9 | TEST_IMAGE_PATH = TEST_IMAGE_FOLDER_PATH + '/test01.png' | 9 | TEST_IMAGE_PATH = TEST_IMAGE_FOLDER_PATH + '/test01.png' |
10 | 10 | ... | ... |
... | @@ -13,14 +13,16 @@ import os | ... | @@ -13,14 +13,16 @@ import os |
13 | MODEL_SAVE_FOLDER_PATH = os.path.join(os.getcwd(), 'trained') | 13 | MODEL_SAVE_FOLDER_PATH = os.path.join(os.getcwd(), 'trained') |
14 | if not os.path.exists(MODEL_SAVE_FOLDER_PATH): | 14 | if not os.path.exists(MODEL_SAVE_FOLDER_PATH): |
15 | os.mkdir(MODEL_SAVE_FOLDER_PATH) | 15 | os.mkdir(MODEL_SAVE_FOLDER_PATH) |
16 | -MODEL_SAVE_PATH = MODEL_SAVE_FOLDER_PATH + '/seresnext_imagenet.h5' | 16 | +MODEL_SAVE_PATH = MODEL_SAVE_FOLDER_PATH + '/seresnext_cifar.h5' |
17 | 17 | ||
18 | -model = SEResNeXt((112, 112, 3)) | 18 | +model = SEResNeXt('cifar-10', (32, 32, 3)) |
19 | 19 | ||
20 | model.compile(optimizer=optimizers.Adam(0.001), loss='categorical_crossentropy', metrics=['accuracy']) | 20 | model.compile(optimizer=optimizers.Adam(0.001), loss='categorical_crossentropy', metrics=['accuracy']) |
21 | 21 | ||
22 | -ds_train = tfds.load('imagenet2012_corrupted', split='train', shuffle_files=True) | 22 | +# ds_train = tfds.load('imagenet2012_corrupted', split='train', shuffle_files=True) |
23 | +ds_train = tfds.load('cifar-10', split='train', shuffle_files=True) | ||
23 | ds_train = ds_train.shuffle(1000).batch(128).prefetch(10) | 24 | ds_train = ds_train.shuffle(1000).batch(128).prefetch(10) |
25 | + | ||
24 | model.fit(ds_train['image'], ds_train['label'], epochs=10, steps_per_epoch=30) | 26 | model.fit(ds_train['image'], ds_train['label'], epochs=10, steps_per_epoch=30) |
25 | 27 | ||
26 | model.save(MODEL_SAVE_PATH) | 28 | model.save(MODEL_SAVE_PATH) | ... | ... |
-
Please register or login to post a comment