train.py
768 Bytes
from model import SEResNeXt
from keras.datasets import fashion_mnist
from keras import optimizers
import os
import sys
import tensorflow_datasets as tfds
import os
MODEL_SAVE_FOLDER_PATH = os.path.join(os.getcwd(), 'trained')
if not os.path.exists(MODEL_SAVE_FOLDER_PATH):
os.mkdir(MODEL_SAVE_FOLDER_PATH)
MODEL_SAVE_PATH = MODEL_SAVE_FOLDER_PATH + '/seresnext_imagenet.h5'
model = SEResNeXt((112, 112, 3))
model.compile(optimizer=optimizers.Adam(0.001), loss='categorical_crossentropy', metrics=['accuracy'])
ds_train = tfds.load('imagenet2012_corrupted', split='train', shuffle_files=True)
ds_train = ds_train.shuffle(1000).batch(128).prefetch(10)
model.fit(ds_train['image'], ds_train['label'], epochs=10, steps_per_epoch=30)
model.save(MODEL_SAVE_PATH)