Showing
1 changed file
with
89 additions
and
0 deletions
action_test.py
0 → 100644
1 | +''' | ||
2 | +Evaluate trained PredNet on KITTI sequences. | ||
3 | +Calculates mean-squared error and plots predictions. | ||
4 | +''' | ||
5 | + | ||
6 | +import os | ||
7 | +import numpy as np | ||
8 | +from six.moves import cPickle | ||
9 | +import matplotlib | ||
10 | +matplotlib.use('Agg') | ||
11 | +import matplotlib.pyplot as plt | ||
12 | +import matplotlib.gridspec as gridspec | ||
13 | + | ||
14 | +from keras import backend as K | ||
15 | +from keras.models import Model, model_from_json | ||
16 | +from keras.layers import Input, Dense, Flatten | ||
17 | + | ||
18 | +from prednet import PredNet | ||
19 | +from data_utils import SequenceGenerator | ||
20 | +from setting import * | ||
21 | + | ||
22 | + | ||
23 | +n_plot = 18 | ||
24 | +batch_size = 6 | ||
25 | +nt = 11 | ||
26 | + | ||
27 | +weights_file = os.path.join(WEIGHTS_DIR, 'prednet_weights.hdf5') #★★★★★★★★★★★★★★★ | ||
28 | +json_file = os.path.join(WEIGHTS_DIR, 'prednet_model.json') #★★★★★★★★★★★★★★★ | ||
29 | +test_file = os.path.join(DATA_DIR, 'X_test.hkl') | ||
30 | +test_sources = os.path.join(DATA_DIR, 'sources_test.hkl') | ||
31 | + | ||
32 | +# Load trained model | ||
33 | +f = open(json_file, 'r') | ||
34 | +json_string = f.read() | ||
35 | +f.close() | ||
36 | +train_model = model_from_json(json_string, custom_objects = {'PredNet': PredNet}) | ||
37 | +train_model.load_weights(weights_file) | ||
38 | + | ||
39 | +# Create testing model (to output predictions) | ||
40 | +layer_config = train_model.layers[1].get_config() | ||
41 | +layer_config['output_mode'] = 'prediction' | ||
42 | +data_format = layer_config['data_format'] if 'data_format' in layer_config else layer_config['dim_ordering'] | ||
43 | +test_prednet = PredNet(weights=train_model.layers[1].get_weights(), **layer_config) | ||
44 | +input_shape = list(train_model.layers[0].batch_input_shape[1:]) | ||
45 | +input_shape[0] = nt | ||
46 | +inputs = Input(shape=tuple(input_shape)) | ||
47 | +predictions = test_prednet(inputs) | ||
48 | +test_model = Model(inputs=inputs, outputs=predictions) | ||
49 | + | ||
50 | +test_generator = SequenceGenerator(test_file, test_sources, nt, sequence_start_mode='unique', data_format=data_format) | ||
51 | +X_test = test_generator.create_all() | ||
52 | +X_hat = test_model.predict(X_test, batch_size) | ||
53 | +if data_format == 'channels_first': | ||
54 | + X_test = np.transpose(X_test, (0, 1, 3, 4, 2)) | ||
55 | + X_hat = np.transpose(X_hat, (0, 1, 3, 4, 2)) | ||
56 | + | ||
57 | +# Compare MSE of PredNet predictions vs. using last frame. Write results to prediction_scores.txt | ||
58 | +mse_model = np.mean( (X_test[:, 1:] - X_hat[:, 1:])**2 ) # look at all timesteps except the first | ||
59 | +mse_prev = np.mean( (X_test[:, :-1] - X_test[:, 1:])**2 ) | ||
60 | +if not os.path.exists(RESULTS_SAVE_DIR): os.mkdir(RESULTS_SAVE_DIR) | ||
61 | +f = open(RESULTS_SAVE_DIR + 'prediction_scores_222.txt', 'w') | ||
62 | +f.write("Model MSE: %f\n" % mse_model) | ||
63 | +f.write("Previous Frame MSE: %f" % mse_prev) | ||
64 | +f.close() | ||
65 | + | ||
66 | +# Plot some predictions | ||
67 | +aspect_ratio = float(X_hat.shape[2]) / X_hat.shape[3] | ||
68 | +plt.figure(figsize = (nt, 2*aspect_ratio)) | ||
69 | +gs = gridspec.GridSpec(2, nt) | ||
70 | +gs.update(wspace=0., hspace=0.) | ||
71 | +plot_save_dir = os.path.join(RESULTS_SAVE_DIR, 'prediction_plots222/') | ||
72 | +if not os.path.exists(plot_save_dir): os.mkdir(plot_save_dir) | ||
73 | +plot_idx = np.random.permutation(X_test.shape[0])[:n_plot] | ||
74 | +for i in plot_idx: | ||
75 | + for t in range(nt): | ||
76 | + plt.subplot(gs[t]) | ||
77 | + X_test[i,t] = X_test[i,t]*255 | ||
78 | + plt.imshow(X_test[i,t], interpolation='none') | ||
79 | + plt.tick_params(axis='both', which='both', bottom='off', top='off', left='off', right='off', labelbottom='off', labelleft='off') | ||
80 | + if t==0: plt.ylabel('Actual', fontsize=10) | ||
81 | + | ||
82 | + plt.subplot(gs[t + nt]) | ||
83 | + X_hat[i,t] = X_hat[i,t]*1000 | ||
84 | + plt.imshow(X_hat[i,t], interpolation='none') | ||
85 | + plt.tick_params(axis='both', which='both', bottom='off', top='off', left='off', right='off', labelbottom='off', labelleft='off') | ||
86 | + if t==0: plt.ylabel('Predicted', fontsize=10) | ||
87 | + | ||
88 | + plt.savefig(plot_save_dir + 'plot_' + str(i) + '.png') | ||
89 | + plt.clf() |
-
Please register or login to post a comment