이혜리

test

1 +'''
2 +Evaluate trained PredNet on KITTI sequences.
3 +Calculates mean-squared error and plots predictions.
4 +'''
5 +
6 +import os
7 +import numpy as np
8 +from six.moves import cPickle
9 +import matplotlib
10 +matplotlib.use('Agg')
11 +import matplotlib.pyplot as plt
12 +import matplotlib.gridspec as gridspec
13 +
14 +from keras import backend as K
15 +from keras.models import Model, model_from_json
16 +from keras.layers import Input, Dense, Flatten
17 +
18 +from prednet import PredNet
19 +from data_utils import SequenceGenerator
20 +from setting import *
21 +
22 +
23 +n_plot = 18
24 +batch_size = 6
25 +nt = 11
26 +
27 +weights_file = os.path.join(WEIGHTS_DIR, 'prednet_weights.hdf5') #★★★★★★★★★★★★★★★
28 +json_file = os.path.join(WEIGHTS_DIR, 'prednet_model.json') #★★★★★★★★★★★★★★★
29 +test_file = os.path.join(DATA_DIR, 'X_test.hkl')
30 +test_sources = os.path.join(DATA_DIR, 'sources_test.hkl')
31 +
32 +# Load trained model
33 +f = open(json_file, 'r')
34 +json_string = f.read()
35 +f.close()
36 +train_model = model_from_json(json_string, custom_objects = {'PredNet': PredNet})
37 +train_model.load_weights(weights_file)
38 +
39 +# Create testing model (to output predictions)
40 +layer_config = train_model.layers[1].get_config()
41 +layer_config['output_mode'] = 'prediction'
42 +data_format = layer_config['data_format'] if 'data_format' in layer_config else layer_config['dim_ordering']
43 +test_prednet = PredNet(weights=train_model.layers[1].get_weights(), **layer_config)
44 +input_shape = list(train_model.layers[0].batch_input_shape[1:])
45 +input_shape[0] = nt
46 +inputs = Input(shape=tuple(input_shape))
47 +predictions = test_prednet(inputs)
48 +test_model = Model(inputs=inputs, outputs=predictions)
49 +
50 +test_generator = SequenceGenerator(test_file, test_sources, nt, sequence_start_mode='unique', data_format=data_format)
51 +X_test = test_generator.create_all()
52 +X_hat = test_model.predict(X_test, batch_size)
53 +if data_format == 'channels_first':
54 + X_test = np.transpose(X_test, (0, 1, 3, 4, 2))
55 + X_hat = np.transpose(X_hat, (0, 1, 3, 4, 2))
56 +
57 +# Compare MSE of PredNet predictions vs. using last frame. Write results to prediction_scores.txt
58 +mse_model = np.mean( (X_test[:, 1:] - X_hat[:, 1:])**2 ) # look at all timesteps except the first
59 +mse_prev = np.mean( (X_test[:, :-1] - X_test[:, 1:])**2 )
60 +if not os.path.exists(RESULTS_SAVE_DIR): os.mkdir(RESULTS_SAVE_DIR)
61 +f = open(RESULTS_SAVE_DIR + 'prediction_scores_222.txt', 'w')
62 +f.write("Model MSE: %f\n" % mse_model)
63 +f.write("Previous Frame MSE: %f" % mse_prev)
64 +f.close()
65 +
66 +# Plot some predictions
67 +aspect_ratio = float(X_hat.shape[2]) / X_hat.shape[3]
68 +plt.figure(figsize = (nt, 2*aspect_ratio))
69 +gs = gridspec.GridSpec(2, nt)
70 +gs.update(wspace=0., hspace=0.)
71 +plot_save_dir = os.path.join(RESULTS_SAVE_DIR, 'prediction_plots222/')
72 +if not os.path.exists(plot_save_dir): os.mkdir(plot_save_dir)
73 +plot_idx = np.random.permutation(X_test.shape[0])[:n_plot]
74 +for i in plot_idx:
75 + for t in range(nt):
76 + plt.subplot(gs[t])
77 + X_test[i,t] = X_test[i,t]*255
78 + plt.imshow(X_test[i,t], interpolation='none')
79 + plt.tick_params(axis='both', which='both', bottom='off', top='off', left='off', right='off', labelbottom='off', labelleft='off')
80 + if t==0: plt.ylabel('Actual', fontsize=10)
81 +
82 + plt.subplot(gs[t + nt])
83 + X_hat[i,t] = X_hat[i,t]*1000
84 + plt.imshow(X_hat[i,t], interpolation='none')
85 + plt.tick_params(axis='both', which='both', bottom='off', top='off', left='off', right='off', labelbottom='off', labelleft='off')
86 + if t==0: plt.ylabel('Predicted', fontsize=10)
87 +
88 + plt.savefig(plot_save_dir + 'plot_' + str(i) + '.png')
89 + plt.clf()