이혜리

train

Showing 1 changed file with 89 additions and 0 deletions
1 +#한번 한 다음에는 from kera.models import load_model 이용해서 모델 가지고와서 다시학습,,,,,그거 구현,,
2 +
3 +import os
4 +import numpy as np
5 +np.random.seed(123)
6 +from six.moves import cPickle
7 +
8 +from keras import backend as K
9 +from keras.models import Model
10 +from keras.layers import Input, Dense, Flatten
11 +from keras.layers import LSTM
12 +from keras.layers import TimeDistributed
13 +from keras.callbacks import LearningRateScheduler, ModelCheckpoint
14 +from keras.optimizers import Adam
15 +
16 +from prednet import PredNet
17 +from data_utils import SequenceGenerator
18 +from setting import *
19 +
20 +save_model = True # if weights will be saved
21 +weights_file = os.path.join(WEIGHTS_DIR, 'prednet_weights.hdf5') # where weights will be saved
22 +json_file = os.path.join(WEIGHTS_DIR, 'prednet_model.json')
23 +
24 +# Data files
25 +train_file = os.path.join(DATA_DIR, 'X_train.hkl')
26 +train_sources = os.path.join(DATA_DIR, 'sources_train.hkl')
27 +val_file = os.path.join(DATA_DIR, 'X_val.hkl')
28 +val_sources = os.path.join(DATA_DIR, 'sources_val.hkl')
29 +
30 +# Training parameters
31 +nb_epoch = 150
32 +batch_size = 10
33 +samples_per_epoch = 900
34 +N_seq_val = 100 # number of sequences to use for validation
35 +
36 +
37 +# Model parameters
38 +n_channels, im_height, im_width = (3, 128, 160)
39 +input_shape = (n_channels, im_height, im_width) if K.image_data_format() == 'channels_first' else (im_height, im_width, n_channels)
40 +stack_sizes = (n_channels, 48, 96, 192)
41 +R_stack_sizes = stack_sizes
42 +A_filt_sizes = (3, 3, 3)
43 +Ahat_filt_sizes = (3, 3, 3, 3)
44 +R_filt_sizes = (3, 3, 3, 3)
45 +layer_loss_weights = np.array([1., 0., 0., 0.]) # weighting for each layer in final loss; "L_0" model: [1, 0, 0, 0], "L_all": [1, 0.1, 0.1, 0.1]
46 +layer_loss_weights = np.expand_dims(layer_loss_weights, 1)
47 +nt = 10 # number of timesteps used for sequences in training
48 +time_loss_weights = 1./ (nt - 1) * np.ones((nt,1)) # equally weight all timesteps except the first
49 +time_loss_weights[0] = 0
50 +
51 +
52 +prednet = PredNet(stack_sizes, R_stack_sizes,
53 + A_filt_sizes, Ahat_filt_sizes, R_filt_sizes,
54 + output_mode='error', return_sequences=True)
55 +
56 +#모델만들기
57 +inputs = Input(shape=(nt,) + input_shape)
58 +errors = prednet(inputs) # errors will be (batch_size, nt, nb_layers)
59 +# TimeDistributed - LSTM이 many-to-many로 동작
60 +# 매 스텝마다 cost가 계산되고 하위 스텝으로 오류가 전파
61 +errors_by_time = TimeDistributed(Dense(1, trainable=False), weights=[layer_loss_weights, np.zeros(1)], trainable=False)(errors) # calculate weighted error by layer
62 +errors_by_time = Flatten()(errors_by_time) # will be (batch_size, nt)
63 +final_errors = Dense(1, weights=[time_loss_weights, np.zeros(1)], trainable=False)(errors_by_time) # weight errors by time
64 +model = Model(inputs=inputs, outputs=final_errors)
65 +model.compile(loss='mean_absolute_error', optimizer='adam')
66 +
67 +
68 +train_generator = SequenceGenerator(train_file, train_sources, nt, batch_size=batch_size, shuffle=True)
69 +val_generator = SequenceGenerator(val_file, val_sources, nt, batch_size=batch_size, N_seq=N_seq_val)
70 +
71 +
72 +lr_schedule = lambda epoch: 0.001 if epoch < 75 else 0.0001 # start with lr of 0.001 and then drop to 0.0001 after 75 epochs
73 +# 여기 콜백함수기능: 일정 구간이나 특정 accuracy threshold를 초과할 때 모델에서 checkpointing
74 +callbacks = [LearningRateScheduler(lr_schedule)]
75 +if save_model:
76 + if not os.path.exists(WEIGHTS_DIR): os.mkdir(WEIGHTS_DIR)
77 + callbacks.append(ModelCheckpoint(filepath=weights_file, monitor='val_loss', save_weights_only=True))
78 +# ModelCheckpoint
79 +# 상대적으로 큰 데이터셋을 학습할 때, 빈번하게 모델의 체크포인트를 저장하는 것은 매우 중요
80 +# monitor='val_loss', val_loss 값이 개선되었을때 호출됩니다
81 +# save_best_only=True, 가장 best 값만 저장합니다
82 +
83 +history = model.fit_generator(train_generator, int(samples_per_epoch / batch_size), nb_epoch, callbacks=callbacks,
84 + validation_data=val_generator, validation_steps=int(N_seq_val / batch_size))
85 +
86 +if save_model:
87 + json_string = model.to_json()
88 + with open(json_file, "w") as f:
89 + f.write(json_string)
...\ No newline at end of file ...\ No newline at end of file