train.py 846 Bytes
from model import SEResNeXt

from keras.datasets import fashion_mnist
from keras import optimizers

import os
import sys

import tensorflow_datasets as tfds

import os

MODEL_SAVE_FOLDER_PATH = os.path.join(os.getcwd(), 'trained')
if not os.path.exists(MODEL_SAVE_FOLDER_PATH):
  os.mkdir(MODEL_SAVE_FOLDER_PATH)
MODEL_SAVE_PATH = MODEL_SAVE_FOLDER_PATH + '/seresnext_cifar.h5'

model = SEResNeXt('cifar-10', (32, 32, 3))

model.compile(optimizer=optimizers.Adam(0.001), loss='categorical_crossentropy', metrics=['accuracy'])

# ds_train = tfds.load('imagenet2012_corrupted', split='train', shuffle_files=True)
ds_train = tfds.load('cifar-10', split='train', shuffle_files=True)
ds_train = ds_train.shuffle(1000).batch(128).prefetch(10)

model.fit(ds_train['image'], ds_train['label'], epochs=10, steps_per_epoch=30)

model.save(MODEL_SAVE_PATH)