train.py
864 Bytes
from model import SEResNeXt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from keras import optimizers
from keras.datasets import cifar10
from keras.utils import np_utils
import os
import sys
MODEL_SAVE_FOLDER_PATH = os.path.join(os.getcwd(), 'trained')
if not os.path.exists(MODEL_SAVE_FOLDER_PATH):
os.mkdir(MODEL_SAVE_FOLDER_PATH)
MODEL_SAVE_PATH = MODEL_SAVE_FOLDER_PATH + '/seresnext_cifar.h5'
model = SEResNeXt('cifar10', (32, 32, 3))
model.compile(optimizer=optimizers.Adam(0.001), loss='categorical_crossentropy', metrics=['accuracy'])
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
Y_train = np_utils.to_categorical(y_train, 10)
Y_test = np_utils.to_categorical(y_test, 10)
model.fit(X_train, Y_train, epochs=50, verbose=1, shuffle=True, validation_data=(X_test, Y_test))
model.save(MODEL_SAVE_PATH)