4Moyede

[Mod] add target dataset cifar10

......@@ -17,30 +17,39 @@ import os
class SEResNeXt(Model):
def __init__(self, weight, input_shape=None, depth=[3, 8, 36, 3], cardinality=32, width=4, reduction_ratio=4, weight_decay=5e-4, classes=1000, channel_axis=None):
def __init__(self, weight, input_shape=None):
'''
ResNext Model
## Args
+ weight:
+ input_shape: optional shape tuple
+ depth: number or layers in the each block, defined as a list
+ cardinality: the size of the set of transformations
+ width: multiplier to the ResNeXt width (number of filters)
+ redution_ratio: ratio of reducition in SE Block
+ weight_decay: weight decay (l2 norm)
+ classes: number of classes to classify images into
+ channel_axis: channel axis in keras.backend.image_data_format()
'''
if weights not in {'cifar10', 'imagenet'}:
raise ValueError
self.__weight = weight
self.__depth = depth
self.__cardinality = cardinality
self.__width = width
self.__reduction_ratio = reduction_ratio
self.__weight_decay = weight_decay
self.__classes = classes
if weight == 'cifar10':
self.__depth = 29
self.__cardinality = 8
self.__width = 64
self.__classes = 10
else:
self.__depth = [3, 8, 36, 3]
self.__cardinality = 32
self.__width = 4
self.__classes = 1000
self.__reduction_ratio = 4
self.__weight_decay = 5e-4
self.__channel_axis = 1 if K.image_data_format() == "channels_first" else -1
self.__input_shape = _obtain_input_shape(input_shape, default_size = 224, min_size = 112, data_format=K.image_data_format(), require_flatten=True)
if weight == 'cifar10':
self.__input_shape = _obtain_input_shape(input_shape, default_size=32, min_size=8, data_format=K.image_data_format(), require_flatten=True)
else:
self.__input_shape = _obtain_input_shape(input_shape, default_size=224, min_size=112, data_format=K.image_data_format(), require_flatten=True)
self.__img_input = Input(shape=self.__input_shape)
# Create model.
......@@ -50,15 +59,17 @@ class SEResNeXt(Model):
'''
Adds an initial conv block, with batch norm and relu for the inception resnext
'''
channel_axis = -1
x = Conv2D(64, (7, 7), padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay), strides=(2, 2))(self.__img_input)
x = BatchNormalization(axis=channel_axis)(x)
x = Activation('relu')(x)
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
return x
if weight == 'cifar10':
x = Conv2D(64, (3, 3), padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay))(self.__img_input)
x = BatchNormalization(axis=self.__channel_axis)(x)
x = Activation('relu')(x)
return x
else:
x = Conv2D(64, (7, 7), padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay), strides=(2, 2))(self.__img_input)
x = BatchNormalization(axis=self.__channel_axis)(x)
x = Activation('relu')(x)
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
return x
def __grouped_convolution_block(self, input, grouped_channels, strides):
'''
......@@ -126,8 +137,11 @@ class SEResNeXt(Model):
'''
Creates a ResNeXt model with specified parameters
'''
N = list(self.__depth)
if type(self.__depth) is list or type(self.__depth) is tuple:
N = list(self.__depth)
else:
N = [(self.__depth - 2) // 9 for _ in range(3)]
print(N)
filters = self.__cardinality * self.__width
filters_list = []
......
......@@ -4,7 +4,7 @@ import matplotlib.pyplot as plt
import os
MODEL_SAVE_FOLDER_PATH = os.path.join(os.getcwd(), 'trained')
MODEL_SAVE_PATH = MODEL_SAVE_FOLDER_PATH + '/seresnext_imagenet.h5'
MODEL_SAVE_PATH = MODEL_SAVE_FOLDER_PATH + '/seresnext_cifar10.h5'
TEST_IMAGE_FOLDER_PATH = os.path.join(os.getcwd(), 'test')
TEST_IMAGE_PATH = TEST_IMAGE_FOLDER_PATH + '/test01.png'
......
......@@ -13,14 +13,16 @@ import os
MODEL_SAVE_FOLDER_PATH = os.path.join(os.getcwd(), 'trained')
if not os.path.exists(MODEL_SAVE_FOLDER_PATH):
os.mkdir(MODEL_SAVE_FOLDER_PATH)
MODEL_SAVE_PATH = MODEL_SAVE_FOLDER_PATH + '/seresnext_imagenet.h5'
MODEL_SAVE_PATH = MODEL_SAVE_FOLDER_PATH + '/seresnext_cifar.h5'
model = SEResNeXt((112, 112, 3))
model = SEResNeXt('cifar-10', (32, 32, 3))
model.compile(optimizer=optimizers.Adam(0.001), loss='categorical_crossentropy', metrics=['accuracy'])
ds_train = tfds.load('imagenet2012_corrupted', split='train', shuffle_files=True)
# ds_train = tfds.load('imagenet2012_corrupted', split='train', shuffle_files=True)
ds_train = tfds.load('cifar-10', split='train', shuffle_files=True)
ds_train = ds_train.shuffle(1000).batch(128).prefetch(10)
model.fit(ds_train['image'], ds_train['label'], epochs=10, steps_per_epoch=30)
model.save(MODEL_SAVE_PATH)
......