4Moyede

[Mod] add target dataset cifar10

...@@ -17,30 +17,39 @@ import os ...@@ -17,30 +17,39 @@ import os
17 17
18 18
19 class SEResNeXt(Model): 19 class SEResNeXt(Model):
20 - def __init__(self, weight, input_shape=None, depth=[3, 8, 36, 3], cardinality=32, width=4, reduction_ratio=4, weight_decay=5e-4, classes=1000, channel_axis=None): 20 + def __init__(self, weight, input_shape=None):
21 ''' 21 '''
22 ResNext Model 22 ResNext Model
23 23
24 ## Args 24 ## Args
25 + + weight:
25 + input_shape: optional shape tuple 26 + input_shape: optional shape tuple
26 - + depth: number or layers in the each block, defined as a list
27 - + cardinality: the size of the set of transformations
28 - + width: multiplier to the ResNeXt width (number of filters)
29 - + redution_ratio: ratio of reducition in SE Block
30 - + weight_decay: weight decay (l2 norm)
31 - + classes: number of classes to classify images into
32 - + channel_axis: channel axis in keras.backend.image_data_format()
33 ''' 27 '''
28 +
29 + if weights not in {'cifar10', 'imagenet'}:
30 + raise ValueError
31 +
34 self.__weight = weight 32 self.__weight = weight
35 - self.__depth = depth 33 +
36 - self.__cardinality = cardinality 34 + if weight == 'cifar10':
37 - self.__width = width 35 + self.__depth = 29
38 - self.__reduction_ratio = reduction_ratio 36 + self.__cardinality = 8
39 - self.__weight_decay = weight_decay 37 + self.__width = 64
40 - self.__classes = classes 38 + self.__classes = 10
39 + else:
40 + self.__depth = [3, 8, 36, 3]
41 + self.__cardinality = 32
42 + self.__width = 4
43 + self.__classes = 1000
44 +
45 + self.__reduction_ratio = 4
46 + self.__weight_decay = 5e-4
41 self.__channel_axis = 1 if K.image_data_format() == "channels_first" else -1 47 self.__channel_axis = 1 if K.image_data_format() == "channels_first" else -1
42 48
43 - self.__input_shape = _obtain_input_shape(input_shape, default_size = 224, min_size = 112, data_format=K.image_data_format(), require_flatten=True) 49 + if weight == 'cifar10':
50 + self.__input_shape = _obtain_input_shape(input_shape, default_size=32, min_size=8, data_format=K.image_data_format(), require_flatten=True)
51 + else:
52 + self.__input_shape = _obtain_input_shape(input_shape, default_size=224, min_size=112, data_format=K.image_data_format(), require_flatten=True)
44 self.__img_input = Input(shape=self.__input_shape) 53 self.__img_input = Input(shape=self.__input_shape)
45 54
46 # Create model. 55 # Create model.
...@@ -50,14 +59,16 @@ class SEResNeXt(Model): ...@@ -50,14 +59,16 @@ class SEResNeXt(Model):
50 ''' 59 '''
51 Adds an initial conv block, with batch norm and relu for the inception resnext 60 Adds an initial conv block, with batch norm and relu for the inception resnext
52 ''' 61 '''
53 - channel_axis = -1 62 + if weight == 'cifar10':
54 - 63 + x = Conv2D(64, (3, 3), padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay))(self.__img_input)
64 + x = BatchNormalization(axis=self.__channel_axis)(x)
65 + x = Activation('relu')(x)
66 + return x
67 + else:
55 x = Conv2D(64, (7, 7), padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay), strides=(2, 2))(self.__img_input) 68 x = Conv2D(64, (7, 7), padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay), strides=(2, 2))(self.__img_input)
56 - x = BatchNormalization(axis=channel_axis)(x) 69 + x = BatchNormalization(axis=self.__channel_axis)(x)
57 x = Activation('relu')(x) 70 x = Activation('relu')(x)
58 -
59 x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x) 71 x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
60 -
61 return x 72 return x
62 73
63 def __grouped_convolution_block(self, input, grouped_channels, strides): 74 def __grouped_convolution_block(self, input, grouped_channels, strides):
...@@ -126,8 +137,11 @@ class SEResNeXt(Model): ...@@ -126,8 +137,11 @@ class SEResNeXt(Model):
126 ''' 137 '''
127 Creates a ResNeXt model with specified parameters 138 Creates a ResNeXt model with specified parameters
128 ''' 139 '''
129 - 140 + if type(self.__depth) is list or type(self.__depth) is tuple:
130 N = list(self.__depth) 141 N = list(self.__depth)
142 + else:
143 + N = [(self.__depth - 2) // 9 for _ in range(3)]
144 + print(N)
131 145
132 filters = self.__cardinality * self.__width 146 filters = self.__cardinality * self.__width
133 filters_list = [] 147 filters_list = []
......
...@@ -4,7 +4,7 @@ import matplotlib.pyplot as plt ...@@ -4,7 +4,7 @@ import matplotlib.pyplot as plt
4 import os 4 import os
5 5
6 MODEL_SAVE_FOLDER_PATH = os.path.join(os.getcwd(), 'trained') 6 MODEL_SAVE_FOLDER_PATH = os.path.join(os.getcwd(), 'trained')
7 -MODEL_SAVE_PATH = MODEL_SAVE_FOLDER_PATH + '/seresnext_imagenet.h5' 7 +MODEL_SAVE_PATH = MODEL_SAVE_FOLDER_PATH + '/seresnext_cifar10.h5'
8 TEST_IMAGE_FOLDER_PATH = os.path.join(os.getcwd(), 'test') 8 TEST_IMAGE_FOLDER_PATH = os.path.join(os.getcwd(), 'test')
9 TEST_IMAGE_PATH = TEST_IMAGE_FOLDER_PATH + '/test01.png' 9 TEST_IMAGE_PATH = TEST_IMAGE_FOLDER_PATH + '/test01.png'
10 10
......
...@@ -13,14 +13,16 @@ import os ...@@ -13,14 +13,16 @@ import os
13 MODEL_SAVE_FOLDER_PATH = os.path.join(os.getcwd(), 'trained') 13 MODEL_SAVE_FOLDER_PATH = os.path.join(os.getcwd(), 'trained')
14 if not os.path.exists(MODEL_SAVE_FOLDER_PATH): 14 if not os.path.exists(MODEL_SAVE_FOLDER_PATH):
15 os.mkdir(MODEL_SAVE_FOLDER_PATH) 15 os.mkdir(MODEL_SAVE_FOLDER_PATH)
16 -MODEL_SAVE_PATH = MODEL_SAVE_FOLDER_PATH + '/seresnext_imagenet.h5' 16 +MODEL_SAVE_PATH = MODEL_SAVE_FOLDER_PATH + '/seresnext_cifar.h5'
17 17
18 -model = SEResNeXt((112, 112, 3)) 18 +model = SEResNeXt('cifar-10', (32, 32, 3))
19 19
20 model.compile(optimizer=optimizers.Adam(0.001), loss='categorical_crossentropy', metrics=['accuracy']) 20 model.compile(optimizer=optimizers.Adam(0.001), loss='categorical_crossentropy', metrics=['accuracy'])
21 21
22 -ds_train = tfds.load('imagenet2012_corrupted', split='train', shuffle_files=True) 22 +# ds_train = tfds.load('imagenet2012_corrupted', split='train', shuffle_files=True)
23 +ds_train = tfds.load('cifar-10', split='train', shuffle_files=True)
23 ds_train = ds_train.shuffle(1000).batch(128).prefetch(10) 24 ds_train = ds_train.shuffle(1000).batch(128).prefetch(10)
25 +
24 model.fit(ds_train['image'], ds_train['label'], epochs=10, steps_per_epoch=30) 26 model.fit(ds_train['image'], ds_train['label'], epochs=10, steps_per_epoch=30)
25 27
26 model.save(MODEL_SAVE_PATH) 28 model.save(MODEL_SAVE_PATH)
......