Showing
1 changed file
with
89 additions
and
0 deletions
train.py
0 → 100644
1 | +#한번 한 다음에는 from kera.models import load_model 이용해서 모델 가지고와서 다시학습,,,,,그거 구현,, | ||
2 | + | ||
3 | +import os | ||
4 | +import numpy as np | ||
5 | +np.random.seed(123) | ||
6 | +from six.moves import cPickle | ||
7 | + | ||
8 | +from keras import backend as K | ||
9 | +from keras.models import Model | ||
10 | +from keras.layers import Input, Dense, Flatten | ||
11 | +from keras.layers import LSTM | ||
12 | +from keras.layers import TimeDistributed | ||
13 | +from keras.callbacks import LearningRateScheduler, ModelCheckpoint | ||
14 | +from keras.optimizers import Adam | ||
15 | + | ||
16 | +from prednet import PredNet | ||
17 | +from data_utils import SequenceGenerator | ||
18 | +from setting import * | ||
19 | + | ||
20 | +save_model = True # if weights will be saved | ||
21 | +weights_file = os.path.join(WEIGHTS_DIR, 'prednet_weights.hdf5') # where weights will be saved | ||
22 | +json_file = os.path.join(WEIGHTS_DIR, 'prednet_model.json') | ||
23 | + | ||
24 | +# Data files | ||
25 | +train_file = os.path.join(DATA_DIR, 'X_train.hkl') | ||
26 | +train_sources = os.path.join(DATA_DIR, 'sources_train.hkl') | ||
27 | +val_file = os.path.join(DATA_DIR, 'X_val.hkl') | ||
28 | +val_sources = os.path.join(DATA_DIR, 'sources_val.hkl') | ||
29 | + | ||
30 | +# Training parameters | ||
31 | +nb_epoch = 150 | ||
32 | +batch_size = 10 | ||
33 | +samples_per_epoch = 900 | ||
34 | +N_seq_val = 100 # number of sequences to use for validation | ||
35 | + | ||
36 | + | ||
37 | +# Model parameters | ||
38 | +n_channels, im_height, im_width = (3, 128, 160) | ||
39 | +input_shape = (n_channels, im_height, im_width) if K.image_data_format() == 'channels_first' else (im_height, im_width, n_channels) | ||
40 | +stack_sizes = (n_channels, 48, 96, 192) | ||
41 | +R_stack_sizes = stack_sizes | ||
42 | +A_filt_sizes = (3, 3, 3) | ||
43 | +Ahat_filt_sizes = (3, 3, 3, 3) | ||
44 | +R_filt_sizes = (3, 3, 3, 3) | ||
45 | +layer_loss_weights = np.array([1., 0., 0., 0.]) # weighting for each layer in final loss; "L_0" model: [1, 0, 0, 0], "L_all": [1, 0.1, 0.1, 0.1] | ||
46 | +layer_loss_weights = np.expand_dims(layer_loss_weights, 1) | ||
47 | +nt = 10 # number of timesteps used for sequences in training | ||
48 | +time_loss_weights = 1./ (nt - 1) * np.ones((nt,1)) # equally weight all timesteps except the first | ||
49 | +time_loss_weights[0] = 0 | ||
50 | + | ||
51 | + | ||
52 | +prednet = PredNet(stack_sizes, R_stack_sizes, | ||
53 | + A_filt_sizes, Ahat_filt_sizes, R_filt_sizes, | ||
54 | + output_mode='error', return_sequences=True) | ||
55 | + | ||
56 | +#모델만들기 | ||
57 | +inputs = Input(shape=(nt,) + input_shape) | ||
58 | +errors = prednet(inputs) # errors will be (batch_size, nt, nb_layers) | ||
59 | +# TimeDistributed - LSTM이 many-to-many로 동작 | ||
60 | +# 매 스텝마다 cost가 계산되고 하위 스텝으로 오류가 전파 | ||
61 | +errors_by_time = TimeDistributed(Dense(1, trainable=False), weights=[layer_loss_weights, np.zeros(1)], trainable=False)(errors) # calculate weighted error by layer | ||
62 | +errors_by_time = Flatten()(errors_by_time) # will be (batch_size, nt) | ||
63 | +final_errors = Dense(1, weights=[time_loss_weights, np.zeros(1)], trainable=False)(errors_by_time) # weight errors by time | ||
64 | +model = Model(inputs=inputs, outputs=final_errors) | ||
65 | +model.compile(loss='mean_absolute_error', optimizer='adam') | ||
66 | + | ||
67 | + | ||
68 | +train_generator = SequenceGenerator(train_file, train_sources, nt, batch_size=batch_size, shuffle=True) | ||
69 | +val_generator = SequenceGenerator(val_file, val_sources, nt, batch_size=batch_size, N_seq=N_seq_val) | ||
70 | + | ||
71 | + | ||
72 | +lr_schedule = lambda epoch: 0.001 if epoch < 75 else 0.0001 # start with lr of 0.001 and then drop to 0.0001 after 75 epochs | ||
73 | +# 여기 콜백함수기능: 일정 구간이나 특정 accuracy threshold를 초과할 때 모델에서 checkpointing | ||
74 | +callbacks = [LearningRateScheduler(lr_schedule)] | ||
75 | +if save_model: | ||
76 | + if not os.path.exists(WEIGHTS_DIR): os.mkdir(WEIGHTS_DIR) | ||
77 | + callbacks.append(ModelCheckpoint(filepath=weights_file, monitor='val_loss', save_weights_only=True)) | ||
78 | +# ModelCheckpoint | ||
79 | +# 상대적으로 큰 데이터셋을 학습할 때, 빈번하게 모델의 체크포인트를 저장하는 것은 매우 중요 | ||
80 | +# monitor='val_loss', val_loss 값이 개선되었을때 호출됩니다 | ||
81 | +# save_best_only=True, 가장 best 값만 저장합니다 | ||
82 | + | ||
83 | +history = model.fit_generator(train_generator, int(samples_per_epoch / batch_size), nb_epoch, callbacks=callbacks, | ||
84 | + validation_data=val_generator, validation_steps=int(N_seq_val / batch_size)) | ||
85 | + | ||
86 | +if save_model: | ||
87 | + json_string = model.to_json() | ||
88 | + with open(json_file, "w") as f: | ||
89 | + f.write(json_string) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment