Showing
1 changed file
with
311 additions
and
0 deletions
prednet.py
0 → 100644
1 | +import numpy as np | ||
2 | + | ||
3 | +from keras import backend as K | ||
4 | +from keras import activations | ||
5 | +from keras.layers import Recurrent | ||
6 | +from keras.layers import Conv2D, UpSampling2D, MaxPooling2D | ||
7 | +from keras.engine import InputSpec | ||
8 | +from keras_utils import legacy_prednet_support | ||
9 | + | ||
10 | +class PredNet(Recurrent): | ||
11 | + '''PredNet architecture - Lotter 2016. | ||
12 | + Stacked convolutional LSTM inspired by predictive coding principles. | ||
13 | + | ||
14 | + # Arguments | ||
15 | + stack_sizes: number of channels in targets (A) and predictions (Ahat) in each layer of the architecture. | ||
16 | + Length is the number of layers in the architecture. | ||
17 | + First element is the number of channels in the input. | ||
18 | + Ex. (3, 16, 32) would correspond to a 3 layer architecture that takes in RGB images and has 16 and 32 | ||
19 | + channels in the second and third layers, respectively. | ||
20 | + R_stack_sizes: number of channels in the representation (R) modules. | ||
21 | + Length must equal length of stack_sizes, but the number of channels per layer can be different. | ||
22 | + A_filt_sizes: filter sizes for the target (A) modules. | ||
23 | + Has length of 1 - len(stack_sizes). | ||
24 | + Ex. (3, 3) would mean that targets for layers 2 and 3 are computed by a 3x3 convolution of the errors (E) | ||
25 | + from the layer below (followed by max-pooling) | ||
26 | + Ahat_filt_sizes: filter sizes for the prediction (Ahat) modules. | ||
27 | + Has length equal to length of stack_sizes. | ||
28 | + Ex. (3, 3, 3) would mean that the predictions for each layer are computed by a 3x3 convolution of the | ||
29 | + representation (R) modules at each layer. | ||
30 | + R_filt_sizes: filter sizes for the representation (R) modules. | ||
31 | + Has length equal to length of stack_sizes. | ||
32 | + Corresponds to the filter sizes for all convolutions in the LSTM. | ||
33 | + pixel_max: the maximum pixel value. | ||
34 | + Used to clip the pixel-layer prediction. | ||
35 | + error_activation: activation function for the error (E) units. | ||
36 | + A_activation: activation function for the target (A) and prediction (A_hat) units. | ||
37 | + LSTM_activation: activation function for the cell and hidden states of the LSTM. | ||
38 | + LSTM_inner_activation: activation function for the gates in the LSTM. | ||
39 | + output_mode: either 'error', 'prediction', 'all' or layer specification (ex. R2, see below). | ||
40 | + Controls what is outputted by the PredNet. | ||
41 | + If 'error', the mean response of the error (E) units of each layer will be outputted. | ||
42 | + That is, the output shape will be (batch_size, nb_layers). | ||
43 | + If 'prediction', the frame prediction will be outputted. | ||
44 | + If 'all', the output will be the frame prediction concatenated with the mean layer errors. | ||
45 | + The frame prediction is flattened before concatenation. | ||
46 | + Nomenclature of 'all' is kept for backwards compatibility, but should not be confused with returning all of the layers of the model | ||
47 | + For returning the features of a particular layer, output_mode should be of the form unit_type + layer_number. | ||
48 | + For instance, to return the features of the LSTM "representational" units in the lowest layer, output_mode should be specificied as 'R0'. | ||
49 | + The possible unit types are 'R', 'Ahat', 'A', and 'E' corresponding to the 'representation', 'prediction', 'target', and 'error' units respectively. | ||
50 | + extrap_start_time: time step for which model will start extrapolating. | ||
51 | + Starting at this time step, the prediction from the previous time step will be treated as the "actual" | ||
52 | + data_format: 'channels_first' or 'channels_last'. | ||
53 | + It defaults to the `image_data_format` value found in your | ||
54 | + Keras config file at `~/.keras/keras.json`. | ||
55 | + | ||
56 | + # References | ||
57 | + - [Deep predictive coding networks for video prediction and unsupervised learning](https://arxiv.org/abs/1605.08104) | ||
58 | + - [Long short-term memory](http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf) | ||
59 | + - [Convolutional LSTM network: a machine learning approach for precipitation nowcasting](http://arxiv.org/abs/1506.04214) | ||
60 | + - [Predictive coding in the visual cortex: a functional interpretation of some extra-classical receptive-field effects](http://www.nature.com/neuro/journal/v2/n1/pdf/nn0199_79.pdf) | ||
61 | + ''' | ||
62 | + @legacy_prednet_support | ||
63 | + def __init__(self, stack_sizes, R_stack_sizes, | ||
64 | + A_filt_sizes, Ahat_filt_sizes, R_filt_sizes, | ||
65 | + pixel_max=1., error_activation='relu', A_activation='relu', | ||
66 | + LSTM_activation='tanh', LSTM_inner_activation='hard_sigmoid', | ||
67 | + output_mode='error', extrap_start_time=None, | ||
68 | + data_format=K.image_data_format(), **kwargs): | ||
69 | + self.stack_sizes = stack_sizes | ||
70 | + self.nb_layers = len(stack_sizes) | ||
71 | + assert len(R_stack_sizes) == self.nb_layers, 'len(R_stack_sizes) must equal len(stack_sizes)' | ||
72 | + self.R_stack_sizes = R_stack_sizes | ||
73 | + assert len(A_filt_sizes) == (self.nb_layers - 1), 'len(A_filt_sizes) must equal len(stack_sizes) - 1' | ||
74 | + self.A_filt_sizes = A_filt_sizes | ||
75 | + assert len(Ahat_filt_sizes) == self.nb_layers, 'len(Ahat_filt_sizes) must equal len(stack_sizes)' | ||
76 | + self.Ahat_filt_sizes = Ahat_filt_sizes | ||
77 | + assert len(R_filt_sizes) == (self.nb_layers), 'len(R_filt_sizes) must equal len(stack_sizes)' | ||
78 | + self.R_filt_sizes = R_filt_sizes | ||
79 | + | ||
80 | + self.pixel_max = pixel_max | ||
81 | + self.error_activation = activations.get(error_activation) | ||
82 | + self.A_activation = activations.get(A_activation) | ||
83 | + self.LSTM_activation = activations.get(LSTM_activation) | ||
84 | + self.LSTM_inner_activation = activations.get(LSTM_inner_activation) | ||
85 | + | ||
86 | + default_output_modes = ['prediction', 'error', 'all'] | ||
87 | + layer_output_modes = [layer + str(n) for n in range(self.nb_layers) for layer in ['R', 'E', 'A', 'Ahat']] | ||
88 | + assert output_mode in default_output_modes + layer_output_modes, 'Invalid output_mode: ' + str(output_mode) | ||
89 | + self.output_mode = output_mode | ||
90 | + if self.output_mode in layer_output_modes: | ||
91 | + self.output_layer_type = self.output_mode[:-1] | ||
92 | + self.output_layer_num = int(self.output_mode[-1]) | ||
93 | + else: | ||
94 | + self.output_layer_type = None | ||
95 | + self.output_layer_num = None | ||
96 | + self.extrap_start_time = extrap_start_time | ||
97 | + | ||
98 | + assert data_format in {'channels_last', 'channels_first'}, 'data_format must be in {channels_last, channels_first}' | ||
99 | + self.data_format = data_format | ||
100 | + self.channel_axis = -3 if data_format == 'channels_first' else -1 | ||
101 | + self.row_axis = -2 if data_format == 'channels_first' else -3 | ||
102 | + self.column_axis = -1 if data_format == 'channels_first' else -2 | ||
103 | + super(PredNet, self).__init__(**kwargs) | ||
104 | + self.input_spec = [InputSpec(ndim=5)] | ||
105 | + | ||
106 | + def compute_output_shape(self, input_shape): | ||
107 | + if self.output_mode == 'prediction': | ||
108 | + out_shape = input_shape[2:] | ||
109 | + elif self.output_mode == 'error': | ||
110 | + out_shape = (self.nb_layers,) | ||
111 | + elif self.output_mode == 'all': | ||
112 | + out_shape = (np.prod(input_shape[2:]) + self.nb_layers,) | ||
113 | + else: | ||
114 | + stack_str = 'R_stack_sizes' if self.output_layer_type == 'R' else 'stack_sizes' | ||
115 | + stack_mult = 2 if self.output_layer_type == 'E' else 1 | ||
116 | + out_stack_size = stack_mult * getattr(self, stack_str)[self.output_layer_num] | ||
117 | + out_nb_row = input_shape[self.row_axis] / 2**self.output_layer_num | ||
118 | + out_nb_col = input_shape[self.column_axis] / 2**self.output_layer_num | ||
119 | + if self.data_format == 'channels_first': | ||
120 | + out_shape = (out_stack_size, out_nb_row, out_nb_col) | ||
121 | + else: | ||
122 | + out_shape = (out_nb_row, out_nb_col, out_stack_size) | ||
123 | + | ||
124 | + if self.return_sequences: | ||
125 | + return (input_shape[0], input_shape[1]) + out_shape | ||
126 | + else: | ||
127 | + return (input_shape[0],) + out_shape | ||
128 | + | ||
129 | + def get_initial_state(self, x): | ||
130 | + input_shape = self.input_spec[0].shape | ||
131 | + init_nb_row = input_shape[self.row_axis] | ||
132 | + init_nb_col = input_shape[self.column_axis] | ||
133 | + | ||
134 | + base_initial_state = K.zeros_like(x) # (samples, timesteps) + image_shape | ||
135 | + non_channel_axis = -1 if self.data_format == 'channels_first' else -2 | ||
136 | + for _ in range(2): | ||
137 | + base_initial_state = K.sum(base_initial_state, axis=non_channel_axis) | ||
138 | + base_initial_state = K.sum(base_initial_state, axis=1) # (samples, nb_channels) | ||
139 | + | ||
140 | + initial_states = [] | ||
141 | + states_to_pass = ['r', 'c', 'e'] | ||
142 | + nlayers_to_pass = {u: self.nb_layers for u in states_to_pass} | ||
143 | + if self.extrap_start_time is not None: | ||
144 | + states_to_pass.append('ahat') # pass prediction in states so can use as actual for t+1 when extrapolating | ||
145 | + nlayers_to_pass['ahat'] = 1 | ||
146 | + for u in states_to_pass: | ||
147 | + for l in range(nlayers_to_pass[u]): | ||
148 | + ds_factor = 2 ** l | ||
149 | + nb_row = init_nb_row // ds_factor | ||
150 | + nb_col = init_nb_col // ds_factor | ||
151 | + if u in ['r', 'c']: | ||
152 | + stack_size = self.R_stack_sizes[l] | ||
153 | + elif u == 'e': | ||
154 | + stack_size = 2 * self.stack_sizes[l] | ||
155 | + elif u == 'ahat': | ||
156 | + stack_size = self.stack_sizes[l] | ||
157 | + output_size = stack_size * nb_row * nb_col # flattened size | ||
158 | + | ||
159 | + reducer = K.zeros((input_shape[self.channel_axis], output_size)) # (nb_channels, output_size) | ||
160 | + initial_state = K.dot(base_initial_state, reducer) # (samples, output_size) | ||
161 | + if self.data_format == 'channels_first': | ||
162 | + output_shp = (-1, stack_size, nb_row, nb_col) | ||
163 | + else: | ||
164 | + output_shp = (-1, nb_row, nb_col, stack_size) | ||
165 | + initial_state = K.reshape(initial_state, output_shp) | ||
166 | + initial_states += [initial_state] | ||
167 | + | ||
168 | + if K.backend() == 'theano': | ||
169 | + from theano import tensor as T | ||
170 | + # There is a known issue in the Theano scan op when dealing with inputs whose shape is 1 along a dimension. | ||
171 | + # In our case, this is a problem when training on grayscale images, and the below line fixes it. | ||
172 | + initial_states = [T.unbroadcast(init_state, 0, 1) for init_state in initial_states] | ||
173 | + | ||
174 | + if self.extrap_start_time is not None: | ||
175 | + initial_states += [K.variable(0, int if K.backend() != 'tensorflow' else 'int32')] # the last state will correspond to the current timestep | ||
176 | + return initial_states | ||
177 | + | ||
178 | + def build(self, input_shape): | ||
179 | + self.input_spec = [InputSpec(shape=input_shape)] | ||
180 | + self.conv_layers = {c: [] for c in ['i', 'f', 'c', 'o', 'a', 'ahat']} | ||
181 | + | ||
182 | + for l in range(self.nb_layers): | ||
183 | + for c in ['i', 'f', 'c', 'o']: | ||
184 | + act = self.LSTM_activation if c == 'c' else self.LSTM_inner_activation | ||
185 | + self.conv_layers[c].append(Conv2D(self.R_stack_sizes[l], self.R_filt_sizes[l], padding='same', activation=act, data_format=self.data_format)) | ||
186 | + | ||
187 | + act = 'relu' if l == 0 else self.A_activation | ||
188 | + self.conv_layers['ahat'].append(Conv2D(self.stack_sizes[l], self.Ahat_filt_sizes[l], padding='same', activation=act, data_format=self.data_format)) | ||
189 | + | ||
190 | + if l < self.nb_layers - 1: | ||
191 | + self.conv_layers['a'].append(Conv2D(self.stack_sizes[l+1], self.A_filt_sizes[l], padding='same', activation=self.A_activation, data_format=self.data_format)) | ||
192 | + | ||
193 | + self.upsample = UpSampling2D(data_format=self.data_format) | ||
194 | + self.pool = MaxPooling2D(data_format=self.data_format) | ||
195 | + | ||
196 | + self.trainable_weights = [] | ||
197 | + nb_row, nb_col = (input_shape[-2], input_shape[-1]) if self.data_format == 'channels_first' else (input_shape[-3], input_shape[-2]) | ||
198 | + for c in sorted(self.conv_layers.keys()): | ||
199 | + for l in range(len(self.conv_layers[c])): | ||
200 | + ds_factor = 2 ** l | ||
201 | + if c == 'ahat': | ||
202 | + nb_channels = self.R_stack_sizes[l] | ||
203 | + elif c == 'a': | ||
204 | + nb_channels = 2 * self.stack_sizes[l] | ||
205 | + else: | ||
206 | + nb_channels = self.stack_sizes[l] * 2 + self.R_stack_sizes[l] | ||
207 | + if l < self.nb_layers - 1: | ||
208 | + nb_channels += self.R_stack_sizes[l+1] | ||
209 | + in_shape = (input_shape[0], nb_channels, nb_row // ds_factor, nb_col // ds_factor) | ||
210 | + if self.data_format == 'channels_last': in_shape = (in_shape[0], in_shape[2], in_shape[3], in_shape[1]) | ||
211 | + with K.name_scope('layer_' + c + '_' + str(l)): | ||
212 | + self.conv_layers[c][l].build(in_shape) | ||
213 | + self.trainable_weights += self.conv_layers[c][l].trainable_weights | ||
214 | + | ||
215 | + self.states = [None] * self.nb_layers*3 | ||
216 | + | ||
217 | + if self.extrap_start_time is not None: | ||
218 | + self.t_extrap = K.variable(self.extrap_start_time, int if K.backend() != 'tensorflow' else 'int32') | ||
219 | + self.states += [None] * 2 # [previous frame prediction, timestep] | ||
220 | + | ||
221 | + def step(self, a, states): | ||
222 | + r_tm1 = states[:self.nb_layers] | ||
223 | + c_tm1 = states[self.nb_layers:2*self.nb_layers] | ||
224 | + e_tm1 = states[2*self.nb_layers:3*self.nb_layers] | ||
225 | + | ||
226 | + if self.extrap_start_time is not None: | ||
227 | + t = states[-1] | ||
228 | + a = K.switch(t >= self.t_extrap, states[-2], a) # if past self.extrap_start_time, the previous prediction will be treated as the actual | ||
229 | + | ||
230 | + c = [] | ||
231 | + r = [] | ||
232 | + e = [] | ||
233 | + | ||
234 | + # Update R units starting from the top | ||
235 | + for l in reversed(range(self.nb_layers)): | ||
236 | + inputs = [r_tm1[l], e_tm1[l]] | ||
237 | + if l < self.nb_layers - 1: | ||
238 | + inputs.append(r_up) | ||
239 | + | ||
240 | + inputs = K.concatenate(inputs, axis=self.channel_axis) | ||
241 | + i = self.conv_layers['i'][l].call(inputs) | ||
242 | + f = self.conv_layers['f'][l].call(inputs) | ||
243 | + o = self.conv_layers['o'][l].call(inputs) | ||
244 | + _c = f * c_tm1[l] + i * self.conv_layers['c'][l].call(inputs) | ||
245 | + _r = o * self.LSTM_activation(_c) | ||
246 | + c.insert(0, _c) | ||
247 | + r.insert(0, _r) | ||
248 | + | ||
249 | + if l > 0: | ||
250 | + r_up = self.upsample.call(_r) | ||
251 | + | ||
252 | + # Update feedforward path starting from the bottom | ||
253 | + for l in range(self.nb_layers): | ||
254 | + ahat = self.conv_layers['ahat'][l].call(r[l]) | ||
255 | + if l == 0: | ||
256 | + ahat = K.minimum(ahat, self.pixel_max) | ||
257 | + frame_prediction = ahat | ||
258 | + | ||
259 | + # compute errors | ||
260 | + e_up = self.error_activation(ahat - a) | ||
261 | + e_down = self.error_activation(a - ahat) | ||
262 | + | ||
263 | + e.append(K.concatenate((e_up, e_down), axis=self.channel_axis)) | ||
264 | + | ||
265 | + if self.output_layer_num == l: | ||
266 | + if self.output_layer_type == 'A': | ||
267 | + output = a | ||
268 | + elif self.output_layer_type == 'Ahat': | ||
269 | + output = ahat | ||
270 | + elif self.output_layer_type == 'R': | ||
271 | + output = r[l] | ||
272 | + elif self.output_layer_type == 'E': | ||
273 | + output = e[l] | ||
274 | + | ||
275 | + if l < self.nb_layers - 1: | ||
276 | + a = self.conv_layers['a'][l].call(e[l]) | ||
277 | + a = self.pool.call(a) # target for next layer | ||
278 | + | ||
279 | + if self.output_layer_type is None: | ||
280 | + if self.output_mode == 'prediction': | ||
281 | + output = frame_prediction | ||
282 | + else: | ||
283 | + for l in range(self.nb_layers): | ||
284 | + layer_error = K.mean(K.batch_flatten(e[l]), axis=-1, keepdims=True) | ||
285 | + all_error = layer_error if l == 0 else K.concatenate((all_error, layer_error), axis=-1) | ||
286 | + if self.output_mode == 'error': | ||
287 | + output = all_error | ||
288 | + else: | ||
289 | + output = K.concatenate((K.batch_flatten(frame_prediction), all_error), axis=-1) | ||
290 | + | ||
291 | + states = r + c + e | ||
292 | + if self.extrap_start_time is not None: | ||
293 | + states += [frame_prediction, t + 1] | ||
294 | + return output, states | ||
295 | + | ||
296 | + def get_config(self): | ||
297 | + config = {'stack_sizes': self.stack_sizes, | ||
298 | + 'R_stack_sizes': self.R_stack_sizes, | ||
299 | + 'A_filt_sizes': self.A_filt_sizes, | ||
300 | + 'Ahat_filt_sizes': self.Ahat_filt_sizes, | ||
301 | + 'R_filt_sizes': self.R_filt_sizes, | ||
302 | + 'pixel_max': self.pixel_max, | ||
303 | + 'error_activation': self.error_activation.__name__, | ||
304 | + 'A_activation': self.A_activation.__name__, | ||
305 | + 'LSTM_activation': self.LSTM_activation.__name__, | ||
306 | + 'LSTM_inner_activation': self.LSTM_inner_activation.__name__, | ||
307 | + 'data_format': self.data_format, | ||
308 | + 'extrap_start_time': self.extrap_start_time, | ||
309 | + 'output_mode': self.output_mode} | ||
310 | + base_config = super(PredNet, self).get_config() | ||
311 | + return dict(list(base_config.items()) + list(config.items())) |
-
Please register or login to post a comment