이혜리

prednet model

Showing 1 changed file with 311 additions and 0 deletions
1 +import numpy as np
2 +
3 +from keras import backend as K
4 +from keras import activations
5 +from keras.layers import Recurrent
6 +from keras.layers import Conv2D, UpSampling2D, MaxPooling2D
7 +from keras.engine import InputSpec
8 +from keras_utils import legacy_prednet_support
9 +
10 +class PredNet(Recurrent):
11 + '''PredNet architecture - Lotter 2016.
12 + Stacked convolutional LSTM inspired by predictive coding principles.
13 +
14 + # Arguments
15 + stack_sizes: number of channels in targets (A) and predictions (Ahat) in each layer of the architecture.
16 + Length is the number of layers in the architecture.
17 + First element is the number of channels in the input.
18 + Ex. (3, 16, 32) would correspond to a 3 layer architecture that takes in RGB images and has 16 and 32
19 + channels in the second and third layers, respectively.
20 + R_stack_sizes: number of channels in the representation (R) modules.
21 + Length must equal length of stack_sizes, but the number of channels per layer can be different.
22 + A_filt_sizes: filter sizes for the target (A) modules.
23 + Has length of 1 - len(stack_sizes).
24 + Ex. (3, 3) would mean that targets for layers 2 and 3 are computed by a 3x3 convolution of the errors (E)
25 + from the layer below (followed by max-pooling)
26 + Ahat_filt_sizes: filter sizes for the prediction (Ahat) modules.
27 + Has length equal to length of stack_sizes.
28 + Ex. (3, 3, 3) would mean that the predictions for each layer are computed by a 3x3 convolution of the
29 + representation (R) modules at each layer.
30 + R_filt_sizes: filter sizes for the representation (R) modules.
31 + Has length equal to length of stack_sizes.
32 + Corresponds to the filter sizes for all convolutions in the LSTM.
33 + pixel_max: the maximum pixel value.
34 + Used to clip the pixel-layer prediction.
35 + error_activation: activation function for the error (E) units.
36 + A_activation: activation function for the target (A) and prediction (A_hat) units.
37 + LSTM_activation: activation function for the cell and hidden states of the LSTM.
38 + LSTM_inner_activation: activation function for the gates in the LSTM.
39 + output_mode: either 'error', 'prediction', 'all' or layer specification (ex. R2, see below).
40 + Controls what is outputted by the PredNet.
41 + If 'error', the mean response of the error (E) units of each layer will be outputted.
42 + That is, the output shape will be (batch_size, nb_layers).
43 + If 'prediction', the frame prediction will be outputted.
44 + If 'all', the output will be the frame prediction concatenated with the mean layer errors.
45 + The frame prediction is flattened before concatenation.
46 + Nomenclature of 'all' is kept for backwards compatibility, but should not be confused with returning all of the layers of the model
47 + For returning the features of a particular layer, output_mode should be of the form unit_type + layer_number.
48 + For instance, to return the features of the LSTM "representational" units in the lowest layer, output_mode should be specificied as 'R0'.
49 + The possible unit types are 'R', 'Ahat', 'A', and 'E' corresponding to the 'representation', 'prediction', 'target', and 'error' units respectively.
50 + extrap_start_time: time step for which model will start extrapolating.
51 + Starting at this time step, the prediction from the previous time step will be treated as the "actual"
52 + data_format: 'channels_first' or 'channels_last'.
53 + It defaults to the `image_data_format` value found in your
54 + Keras config file at `~/.keras/keras.json`.
55 +
56 + # References
57 + - [Deep predictive coding networks for video prediction and unsupervised learning](https://arxiv.org/abs/1605.08104)
58 + - [Long short-term memory](http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf)
59 + - [Convolutional LSTM network: a machine learning approach for precipitation nowcasting](http://arxiv.org/abs/1506.04214)
60 + - [Predictive coding in the visual cortex: a functional interpretation of some extra-classical receptive-field effects](http://www.nature.com/neuro/journal/v2/n1/pdf/nn0199_79.pdf)
61 + '''
62 + @legacy_prednet_support
63 + def __init__(self, stack_sizes, R_stack_sizes,
64 + A_filt_sizes, Ahat_filt_sizes, R_filt_sizes,
65 + pixel_max=1., error_activation='relu', A_activation='relu',
66 + LSTM_activation='tanh', LSTM_inner_activation='hard_sigmoid',
67 + output_mode='error', extrap_start_time=None,
68 + data_format=K.image_data_format(), **kwargs):
69 + self.stack_sizes = stack_sizes
70 + self.nb_layers = len(stack_sizes)
71 + assert len(R_stack_sizes) == self.nb_layers, 'len(R_stack_sizes) must equal len(stack_sizes)'
72 + self.R_stack_sizes = R_stack_sizes
73 + assert len(A_filt_sizes) == (self.nb_layers - 1), 'len(A_filt_sizes) must equal len(stack_sizes) - 1'
74 + self.A_filt_sizes = A_filt_sizes
75 + assert len(Ahat_filt_sizes) == self.nb_layers, 'len(Ahat_filt_sizes) must equal len(stack_sizes)'
76 + self.Ahat_filt_sizes = Ahat_filt_sizes
77 + assert len(R_filt_sizes) == (self.nb_layers), 'len(R_filt_sizes) must equal len(stack_sizes)'
78 + self.R_filt_sizes = R_filt_sizes
79 +
80 + self.pixel_max = pixel_max
81 + self.error_activation = activations.get(error_activation)
82 + self.A_activation = activations.get(A_activation)
83 + self.LSTM_activation = activations.get(LSTM_activation)
84 + self.LSTM_inner_activation = activations.get(LSTM_inner_activation)
85 +
86 + default_output_modes = ['prediction', 'error', 'all']
87 + layer_output_modes = [layer + str(n) for n in range(self.nb_layers) for layer in ['R', 'E', 'A', 'Ahat']]
88 + assert output_mode in default_output_modes + layer_output_modes, 'Invalid output_mode: ' + str(output_mode)
89 + self.output_mode = output_mode
90 + if self.output_mode in layer_output_modes:
91 + self.output_layer_type = self.output_mode[:-1]
92 + self.output_layer_num = int(self.output_mode[-1])
93 + else:
94 + self.output_layer_type = None
95 + self.output_layer_num = None
96 + self.extrap_start_time = extrap_start_time
97 +
98 + assert data_format in {'channels_last', 'channels_first'}, 'data_format must be in {channels_last, channels_first}'
99 + self.data_format = data_format
100 + self.channel_axis = -3 if data_format == 'channels_first' else -1
101 + self.row_axis = -2 if data_format == 'channels_first' else -3
102 + self.column_axis = -1 if data_format == 'channels_first' else -2
103 + super(PredNet, self).__init__(**kwargs)
104 + self.input_spec = [InputSpec(ndim=5)]
105 +
106 + def compute_output_shape(self, input_shape):
107 + if self.output_mode == 'prediction':
108 + out_shape = input_shape[2:]
109 + elif self.output_mode == 'error':
110 + out_shape = (self.nb_layers,)
111 + elif self.output_mode == 'all':
112 + out_shape = (np.prod(input_shape[2:]) + self.nb_layers,)
113 + else:
114 + stack_str = 'R_stack_sizes' if self.output_layer_type == 'R' else 'stack_sizes'
115 + stack_mult = 2 if self.output_layer_type == 'E' else 1
116 + out_stack_size = stack_mult * getattr(self, stack_str)[self.output_layer_num]
117 + out_nb_row = input_shape[self.row_axis] / 2**self.output_layer_num
118 + out_nb_col = input_shape[self.column_axis] / 2**self.output_layer_num
119 + if self.data_format == 'channels_first':
120 + out_shape = (out_stack_size, out_nb_row, out_nb_col)
121 + else:
122 + out_shape = (out_nb_row, out_nb_col, out_stack_size)
123 +
124 + if self.return_sequences:
125 + return (input_shape[0], input_shape[1]) + out_shape
126 + else:
127 + return (input_shape[0],) + out_shape
128 +
129 + def get_initial_state(self, x):
130 + input_shape = self.input_spec[0].shape
131 + init_nb_row = input_shape[self.row_axis]
132 + init_nb_col = input_shape[self.column_axis]
133 +
134 + base_initial_state = K.zeros_like(x) # (samples, timesteps) + image_shape
135 + non_channel_axis = -1 if self.data_format == 'channels_first' else -2
136 + for _ in range(2):
137 + base_initial_state = K.sum(base_initial_state, axis=non_channel_axis)
138 + base_initial_state = K.sum(base_initial_state, axis=1) # (samples, nb_channels)
139 +
140 + initial_states = []
141 + states_to_pass = ['r', 'c', 'e']
142 + nlayers_to_pass = {u: self.nb_layers for u in states_to_pass}
143 + if self.extrap_start_time is not None:
144 + states_to_pass.append('ahat') # pass prediction in states so can use as actual for t+1 when extrapolating
145 + nlayers_to_pass['ahat'] = 1
146 + for u in states_to_pass:
147 + for l in range(nlayers_to_pass[u]):
148 + ds_factor = 2 ** l
149 + nb_row = init_nb_row // ds_factor
150 + nb_col = init_nb_col // ds_factor
151 + if u in ['r', 'c']:
152 + stack_size = self.R_stack_sizes[l]
153 + elif u == 'e':
154 + stack_size = 2 * self.stack_sizes[l]
155 + elif u == 'ahat':
156 + stack_size = self.stack_sizes[l]
157 + output_size = stack_size * nb_row * nb_col # flattened size
158 +
159 + reducer = K.zeros((input_shape[self.channel_axis], output_size)) # (nb_channels, output_size)
160 + initial_state = K.dot(base_initial_state, reducer) # (samples, output_size)
161 + if self.data_format == 'channels_first':
162 + output_shp = (-1, stack_size, nb_row, nb_col)
163 + else:
164 + output_shp = (-1, nb_row, nb_col, stack_size)
165 + initial_state = K.reshape(initial_state, output_shp)
166 + initial_states += [initial_state]
167 +
168 + if K.backend() == 'theano':
169 + from theano import tensor as T
170 + # There is a known issue in the Theano scan op when dealing with inputs whose shape is 1 along a dimension.
171 + # In our case, this is a problem when training on grayscale images, and the below line fixes it.
172 + initial_states = [T.unbroadcast(init_state, 0, 1) for init_state in initial_states]
173 +
174 + if self.extrap_start_time is not None:
175 + initial_states += [K.variable(0, int if K.backend() != 'tensorflow' else 'int32')] # the last state will correspond to the current timestep
176 + return initial_states
177 +
178 + def build(self, input_shape):
179 + self.input_spec = [InputSpec(shape=input_shape)]
180 + self.conv_layers = {c: [] for c in ['i', 'f', 'c', 'o', 'a', 'ahat']}
181 +
182 + for l in range(self.nb_layers):
183 + for c in ['i', 'f', 'c', 'o']:
184 + act = self.LSTM_activation if c == 'c' else self.LSTM_inner_activation
185 + self.conv_layers[c].append(Conv2D(self.R_stack_sizes[l], self.R_filt_sizes[l], padding='same', activation=act, data_format=self.data_format))
186 +
187 + act = 'relu' if l == 0 else self.A_activation
188 + self.conv_layers['ahat'].append(Conv2D(self.stack_sizes[l], self.Ahat_filt_sizes[l], padding='same', activation=act, data_format=self.data_format))
189 +
190 + if l < self.nb_layers - 1:
191 + self.conv_layers['a'].append(Conv2D(self.stack_sizes[l+1], self.A_filt_sizes[l], padding='same', activation=self.A_activation, data_format=self.data_format))
192 +
193 + self.upsample = UpSampling2D(data_format=self.data_format)
194 + self.pool = MaxPooling2D(data_format=self.data_format)
195 +
196 + self.trainable_weights = []
197 + nb_row, nb_col = (input_shape[-2], input_shape[-1]) if self.data_format == 'channels_first' else (input_shape[-3], input_shape[-2])
198 + for c in sorted(self.conv_layers.keys()):
199 + for l in range(len(self.conv_layers[c])):
200 + ds_factor = 2 ** l
201 + if c == 'ahat':
202 + nb_channels = self.R_stack_sizes[l]
203 + elif c == 'a':
204 + nb_channels = 2 * self.stack_sizes[l]
205 + else:
206 + nb_channels = self.stack_sizes[l] * 2 + self.R_stack_sizes[l]
207 + if l < self.nb_layers - 1:
208 + nb_channels += self.R_stack_sizes[l+1]
209 + in_shape = (input_shape[0], nb_channels, nb_row // ds_factor, nb_col // ds_factor)
210 + if self.data_format == 'channels_last': in_shape = (in_shape[0], in_shape[2], in_shape[3], in_shape[1])
211 + with K.name_scope('layer_' + c + '_' + str(l)):
212 + self.conv_layers[c][l].build(in_shape)
213 + self.trainable_weights += self.conv_layers[c][l].trainable_weights
214 +
215 + self.states = [None] * self.nb_layers*3
216 +
217 + if self.extrap_start_time is not None:
218 + self.t_extrap = K.variable(self.extrap_start_time, int if K.backend() != 'tensorflow' else 'int32')
219 + self.states += [None] * 2 # [previous frame prediction, timestep]
220 +
221 + def step(self, a, states):
222 + r_tm1 = states[:self.nb_layers]
223 + c_tm1 = states[self.nb_layers:2*self.nb_layers]
224 + e_tm1 = states[2*self.nb_layers:3*self.nb_layers]
225 +
226 + if self.extrap_start_time is not None:
227 + t = states[-1]
228 + a = K.switch(t >= self.t_extrap, states[-2], a) # if past self.extrap_start_time, the previous prediction will be treated as the actual
229 +
230 + c = []
231 + r = []
232 + e = []
233 +
234 + # Update R units starting from the top
235 + for l in reversed(range(self.nb_layers)):
236 + inputs = [r_tm1[l], e_tm1[l]]
237 + if l < self.nb_layers - 1:
238 + inputs.append(r_up)
239 +
240 + inputs = K.concatenate(inputs, axis=self.channel_axis)
241 + i = self.conv_layers['i'][l].call(inputs)
242 + f = self.conv_layers['f'][l].call(inputs)
243 + o = self.conv_layers['o'][l].call(inputs)
244 + _c = f * c_tm1[l] + i * self.conv_layers['c'][l].call(inputs)
245 + _r = o * self.LSTM_activation(_c)
246 + c.insert(0, _c)
247 + r.insert(0, _r)
248 +
249 + if l > 0:
250 + r_up = self.upsample.call(_r)
251 +
252 + # Update feedforward path starting from the bottom
253 + for l in range(self.nb_layers):
254 + ahat = self.conv_layers['ahat'][l].call(r[l])
255 + if l == 0:
256 + ahat = K.minimum(ahat, self.pixel_max)
257 + frame_prediction = ahat
258 +
259 + # compute errors
260 + e_up = self.error_activation(ahat - a)
261 + e_down = self.error_activation(a - ahat)
262 +
263 + e.append(K.concatenate((e_up, e_down), axis=self.channel_axis))
264 +
265 + if self.output_layer_num == l:
266 + if self.output_layer_type == 'A':
267 + output = a
268 + elif self.output_layer_type == 'Ahat':
269 + output = ahat
270 + elif self.output_layer_type == 'R':
271 + output = r[l]
272 + elif self.output_layer_type == 'E':
273 + output = e[l]
274 +
275 + if l < self.nb_layers - 1:
276 + a = self.conv_layers['a'][l].call(e[l])
277 + a = self.pool.call(a) # target for next layer
278 +
279 + if self.output_layer_type is None:
280 + if self.output_mode == 'prediction':
281 + output = frame_prediction
282 + else:
283 + for l in range(self.nb_layers):
284 + layer_error = K.mean(K.batch_flatten(e[l]), axis=-1, keepdims=True)
285 + all_error = layer_error if l == 0 else K.concatenate((all_error, layer_error), axis=-1)
286 + if self.output_mode == 'error':
287 + output = all_error
288 + else:
289 + output = K.concatenate((K.batch_flatten(frame_prediction), all_error), axis=-1)
290 +
291 + states = r + c + e
292 + if self.extrap_start_time is not None:
293 + states += [frame_prediction, t + 1]
294 + return output, states
295 +
296 + def get_config(self):
297 + config = {'stack_sizes': self.stack_sizes,
298 + 'R_stack_sizes': self.R_stack_sizes,
299 + 'A_filt_sizes': self.A_filt_sizes,
300 + 'Ahat_filt_sizes': self.Ahat_filt_sizes,
301 + 'R_filt_sizes': self.R_filt_sizes,
302 + 'pixel_max': self.pixel_max,
303 + 'error_activation': self.error_activation.__name__,
304 + 'A_activation': self.A_activation.__name__,
305 + 'LSTM_activation': self.LSTM_activation.__name__,
306 + 'LSTM_inner_activation': self.LSTM_inner_activation.__name__,
307 + 'data_format': self.data_format,
308 + 'extrap_start_time': self.extrap_start_time,
309 + 'output_mode': self.output_mode}
310 + base_config = super(PredNet, self).get_config()
311 + return dict(list(base_config.items()) + list(config.items()))