Showing
2 changed files
with
65 additions
and
0 deletions
keras_utils.py
0 → 100644
1 | +import os | ||
2 | +import numpy as np | ||
3 | + | ||
4 | +from keras import backend as K | ||
5 | +from keras.legacy.interfaces import generate_legacy_interface, recurrent_args_preprocessor | ||
6 | +from keras.models import model_from_json | ||
7 | + | ||
8 | +legacy_prednet_support = generate_legacy_interface( | ||
9 | + allowed_positional_args=['stack_sizes', 'R_stack_sizes', | ||
10 | + 'A_filt_sizes', 'Ahat_filt_sizes', 'R_filt_sizes'], | ||
11 | + conversions=[('dim_ordering', 'data_format'), | ||
12 | + ('consume_less', 'implementation')], | ||
13 | + value_conversions={'dim_ordering': {'tf': 'channels_last', | ||
14 | + 'th': 'channels_first', | ||
15 | + 'default': None}, | ||
16 | + 'consume_less': {'cpu': 0, | ||
17 | + 'mem': 1, | ||
18 | + 'gpu': 2}}, | ||
19 | + preprocessor=recurrent_args_preprocessor) | ||
20 | + | ||
21 | +# Convert old Keras (1.2) json models and weights to Keras 2.0 | ||
22 | +def convert_model_to_keras2(old_json_file, old_weights_file, new_json_file, new_weights_file): | ||
23 | + from prednet import PredNet | ||
24 | + # If using tensorflow, it doesn't allow you to load the old weights. | ||
25 | + if K.backend() != 'theano': | ||
26 | + os.environ['KERAS_BACKEND'] = backend | ||
27 | + reload(K) | ||
28 | + | ||
29 | + f = open(old_json_file, 'r') | ||
30 | + json_string = f.read() | ||
31 | + f.close() | ||
32 | + model = model_from_json(json_string, custom_objects = {'PredNet': PredNet}) | ||
33 | + model.load_weights(old_weights_file) | ||
34 | + | ||
35 | + weights = model.layers[1].get_weights() | ||
36 | + if weights[0].shape[0] == model.layers[1].stack_sizes[1]: | ||
37 | + for i, w in enumerate(weights): | ||
38 | + if w.ndim == 4: | ||
39 | + weights[i] = np.transpose(w, (2, 3, 1, 0)) | ||
40 | + model.set_weights(weights) | ||
41 | + | ||
42 | + model.save_weights(new_weights_file) | ||
43 | + json_string = model.to_json() | ||
44 | + with open(new_json_file, "w") as f: | ||
45 | + f.write(json_string) | ||
46 | + | ||
47 | + | ||
48 | +if __name__ == '__main__': | ||
49 | + old_dir = './model_data/' | ||
50 | + new_dir = './model_data_keras2/' | ||
51 | + if not os.path.exists(new_dir): | ||
52 | + os.mkdir(new_dir) | ||
53 | + for w_tag in ['', '-Lall', '-extrapfinetuned']: | ||
54 | + m_tag = '' if w_tag == '-Lall' else w_tag | ||
55 | + convert_model_to_keras2(old_dir + 'prednet_kitti_model' + m_tag + '.json', | ||
56 | + old_dir + 'prednet_kitti_weights' + w_tag + '.hdf5', | ||
57 | + new_dir + 'prednet_kitti_model' + m_tag + '.json', | ||
58 | + new_dir + 'prednet_kitti_weights' + w_tag + '.hdf5') |
setting.py
0 → 100644
1 | +DATA_DIR = './data3/' | ||
2 | + | ||
3 | +# Where model weights and config will be saved if you run train.py | ||
4 | +WEIGHTS_DIR = './model_data_keras2/' | ||
5 | + | ||
6 | +# Where results (prediction plots and evaluation file) will be saved. | ||
7 | +RESULTS_SAVE_DIR = './results/' | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment