train.py
846 Bytes
from model import SEResNeXt
from keras.datasets import fashion_mnist
from keras import optimizers
import os
import sys
import tensorflow_datasets as tfds
import os
MODEL_SAVE_FOLDER_PATH = os.path.join(os.getcwd(), 'trained')
if not os.path.exists(MODEL_SAVE_FOLDER_PATH):
os.mkdir(MODEL_SAVE_FOLDER_PATH)
MODEL_SAVE_PATH = MODEL_SAVE_FOLDER_PATH + '/seresnext_cifar.h5'
model = SEResNeXt('cifar-10', (32, 32, 3))
model.compile(optimizer=optimizers.Adam(0.001), loss='categorical_crossentropy', metrics=['accuracy'])
# ds_train = tfds.load('imagenet2012_corrupted', split='train', shuffle_files=True)
ds_train = tfds.load('cifar-10', split='train', shuffle_files=True)
ds_train = ds_train.shuffle(1000).batch(128).prefetch(10)
model.fit(ds_train['image'], ds_train['label'], epochs=10, steps_per_epoch=30)
model.save(MODEL_SAVE_PATH)