Showing
26 changed files
with
2978 additions
and
1 deletions
youtube-8m @ e6f6bf68
1 | -Subproject commit e6f6bf682d20bb21904ea9c081c15e070809d914 |
yt8m/__init__.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. |
yt8m/average_precision_calculator.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Calculate or keep track of the interpolated average precision. | ||
15 | + | ||
16 | +It provides an interface for calculating interpolated average precision for an | ||
17 | +entire list or the top-n ranked items. For the definition of the | ||
18 | +(non-)interpolated average precision: | ||
19 | +http://trec.nist.gov/pubs/trec15/appendices/CE.MEASURES06.pdf | ||
20 | + | ||
21 | +Example usages: | ||
22 | +1) Use it as a static function call to directly calculate average precision for | ||
23 | +a short ranked list in the memory. | ||
24 | + | ||
25 | +``` | ||
26 | +import random | ||
27 | + | ||
28 | +p = np.array([random.random() for _ in xrange(10)]) | ||
29 | +a = np.array([random.choice([0, 1]) for _ in xrange(10)]) | ||
30 | + | ||
31 | +ap = average_precision_calculator.AveragePrecisionCalculator.ap(p, a) | ||
32 | +``` | ||
33 | + | ||
34 | +2) Use it as an object for long ranked list that cannot be stored in memory or | ||
35 | +the case where partial predictions can be observed at a time (Tensorflow | ||
36 | +predictions). In this case, we first call the function accumulate many times | ||
37 | +to process parts of the ranked list. After processing all the parts, we call | ||
38 | +peek_interpolated_ap_at_n. | ||
39 | +``` | ||
40 | +p1 = np.array([random.random() for _ in xrange(5)]) | ||
41 | +a1 = np.array([random.choice([0, 1]) for _ in xrange(5)]) | ||
42 | +p2 = np.array([random.random() for _ in xrange(5)]) | ||
43 | +a2 = np.array([random.choice([0, 1]) for _ in xrange(5)]) | ||
44 | + | ||
45 | +# interpolated average precision at 10 using 1000 break points | ||
46 | +calculator = average_precision_calculator.AveragePrecisionCalculator(10) | ||
47 | +calculator.accumulate(p1, a1) | ||
48 | +calculator.accumulate(p2, a2) | ||
49 | +ap3 = calculator.peek_ap_at_n() | ||
50 | +``` | ||
51 | +""" | ||
52 | + | ||
53 | +import heapq | ||
54 | +import random | ||
55 | +import numbers | ||
56 | + | ||
57 | +import numpy | ||
58 | + | ||
59 | + | ||
60 | +class AveragePrecisionCalculator(object): | ||
61 | + """Calculate the average precision and average precision at n.""" | ||
62 | + | ||
63 | + def __init__(self, top_n=None): | ||
64 | + """Construct an AveragePrecisionCalculator to calculate average precision. | ||
65 | + | ||
66 | + This class is used to calculate the average precision for a single label. | ||
67 | + | ||
68 | + Args: | ||
69 | + top_n: A positive Integer specifying the average precision at n, or None | ||
70 | + to use all provided data points. | ||
71 | + | ||
72 | + Raises: | ||
73 | + ValueError: An error occurred when the top_n is not a positive integer. | ||
74 | + """ | ||
75 | + if not ((isinstance(top_n, int) and top_n >= 0) or top_n is None): | ||
76 | + raise ValueError("top_n must be a positive integer or None.") | ||
77 | + | ||
78 | + self._top_n = top_n # average precision at n | ||
79 | + self._total_positives = 0 # total number of positives have seen | ||
80 | + self._heap = [] # max heap of (prediction, actual) | ||
81 | + | ||
82 | + @property | ||
83 | + def heap_size(self): | ||
84 | + """Gets the heap size maintained in the class.""" | ||
85 | + return len(self._heap) | ||
86 | + | ||
87 | + @property | ||
88 | + def num_accumulated_positives(self): | ||
89 | + """Gets the number of positive samples that have been accumulated.""" | ||
90 | + return self._total_positives | ||
91 | + | ||
92 | + def accumulate(self, predictions, actuals, num_positives=None): | ||
93 | + """Accumulate the predictions and their ground truth labels. | ||
94 | + | ||
95 | + After the function call, we may call peek_ap_at_n to actually calculate | ||
96 | + the average precision. | ||
97 | + Note predictions and actuals must have the same shape. | ||
98 | + | ||
99 | + Args: | ||
100 | + predictions: a list storing the prediction scores. | ||
101 | + actuals: a list storing the ground truth labels. Any value larger than 0 | ||
102 | + will be treated as positives, otherwise as negatives. num_positives = If | ||
103 | + the 'predictions' and 'actuals' inputs aren't complete, then it's | ||
104 | + possible some true positives were missed in them. In that case, you can | ||
105 | + provide 'num_positives' in order to accurately track recall. | ||
106 | + | ||
107 | + Raises: | ||
108 | + ValueError: An error occurred when the format of the input is not the | ||
109 | + numpy 1-D array or the shape of predictions and actuals does not match. | ||
110 | + """ | ||
111 | + if len(predictions) != len(actuals): | ||
112 | + raise ValueError("the shape of predictions and actuals does not match.") | ||
113 | + | ||
114 | + if num_positives is not None: | ||
115 | + if not isinstance(num_positives, numbers.Number) or num_positives < 0: | ||
116 | + raise ValueError( | ||
117 | + "'num_positives' was provided but it was a negative number.") | ||
118 | + | ||
119 | + if num_positives is not None: | ||
120 | + self._total_positives += num_positives | ||
121 | + else: | ||
122 | + self._total_positives += numpy.size( | ||
123 | + numpy.where(numpy.array(actuals) > 1e-5)) | ||
124 | + topk = self._top_n | ||
125 | + heap = self._heap | ||
126 | + | ||
127 | + for i in range(numpy.size(predictions)): | ||
128 | + if topk is None or len(heap) < topk: | ||
129 | + heapq.heappush(heap, (predictions[i], actuals[i])) | ||
130 | + else: | ||
131 | + if predictions[i] > heap[0][0]: # heap[0] is the smallest | ||
132 | + heapq.heappop(heap) | ||
133 | + heapq.heappush(heap, (predictions[i], actuals[i])) | ||
134 | + | ||
135 | + def clear(self): | ||
136 | + """Clear the accumulated predictions.""" | ||
137 | + self._heap = [] | ||
138 | + self._total_positives = 0 | ||
139 | + | ||
140 | + def peek_ap_at_n(self): | ||
141 | + """Peek the non-interpolated average precision at n. | ||
142 | + | ||
143 | + Returns: | ||
144 | + The non-interpolated average precision at n (default 0). | ||
145 | + If n is larger than the length of the ranked list, | ||
146 | + the average precision will be returned. | ||
147 | + """ | ||
148 | + if self.heap_size <= 0: | ||
149 | + return 0 | ||
150 | + predlists = numpy.array(list(zip(*self._heap))) | ||
151 | + | ||
152 | + ap = self.ap_at_n(predlists[0], | ||
153 | + predlists[1], | ||
154 | + n=self._top_n, | ||
155 | + total_num_positives=self._total_positives) | ||
156 | + return ap | ||
157 | + | ||
158 | + @staticmethod | ||
159 | + def ap(predictions, actuals): | ||
160 | + """Calculate the non-interpolated average precision. | ||
161 | + | ||
162 | + Args: | ||
163 | + predictions: a numpy 1-D array storing the sparse prediction scores. | ||
164 | + actuals: a numpy 1-D array storing the ground truth labels. Any value | ||
165 | + larger than 0 will be treated as positives, otherwise as negatives. | ||
166 | + | ||
167 | + Returns: | ||
168 | + The non-interpolated average precision at n. | ||
169 | + If n is larger than the length of the ranked list, | ||
170 | + the average precision will be returned. | ||
171 | + | ||
172 | + Raises: | ||
173 | + ValueError: An error occurred when the format of the input is not the | ||
174 | + numpy 1-D array or the shape of predictions and actuals does not match. | ||
175 | + """ | ||
176 | + return AveragePrecisionCalculator.ap_at_n(predictions, actuals, n=None) | ||
177 | + | ||
178 | + @staticmethod | ||
179 | + def ap_at_n(predictions, actuals, n=20, total_num_positives=None): | ||
180 | + """Calculate the non-interpolated average precision. | ||
181 | + | ||
182 | + Args: | ||
183 | + predictions: a numpy 1-D array storing the sparse prediction scores. | ||
184 | + actuals: a numpy 1-D array storing the ground truth labels. Any value | ||
185 | + larger than 0 will be treated as positives, otherwise as negatives. | ||
186 | + n: the top n items to be considered in ap@n. | ||
187 | + total_num_positives : (optionally) you can specify the number of total | ||
188 | + positive in the list. If specified, it will be used in calculation. | ||
189 | + | ||
190 | + Returns: | ||
191 | + The non-interpolated average precision at n. | ||
192 | + If n is larger than the length of the ranked list, | ||
193 | + the average precision will be returned. | ||
194 | + | ||
195 | + Raises: | ||
196 | + ValueError: An error occurred when | ||
197 | + 1) the format of the input is not the numpy 1-D array; | ||
198 | + 2) the shape of predictions and actuals does not match; | ||
199 | + 3) the input n is not a positive integer. | ||
200 | + """ | ||
201 | + if len(predictions) != len(actuals): | ||
202 | + raise ValueError("the shape of predictions and actuals does not match.") | ||
203 | + | ||
204 | + if n is not None: | ||
205 | + if not isinstance(n, int) or n <= 0: | ||
206 | + raise ValueError("n must be 'None' or a positive integer." | ||
207 | + " It was '%s'." % n) | ||
208 | + | ||
209 | + ap = 0.0 | ||
210 | + | ||
211 | + predictions = numpy.array(predictions) | ||
212 | + actuals = numpy.array(actuals) | ||
213 | + | ||
214 | + # add a shuffler to avoid overestimating the ap | ||
215 | + predictions, actuals = AveragePrecisionCalculator._shuffle( | ||
216 | + predictions, actuals) | ||
217 | + sortidx = sorted(range(len(predictions)), | ||
218 | + key=lambda k: predictions[k], | ||
219 | + reverse=True) | ||
220 | + | ||
221 | + if total_num_positives is None: | ||
222 | + numpos = numpy.size(numpy.where(actuals > 0)) | ||
223 | + else: | ||
224 | + numpos = total_num_positives | ||
225 | + | ||
226 | + if numpos == 0: | ||
227 | + return 0 | ||
228 | + | ||
229 | + if n is not None: | ||
230 | + numpos = min(numpos, n) | ||
231 | + delta_recall = 1.0 / numpos | ||
232 | + poscount = 0.0 | ||
233 | + | ||
234 | + # calculate the ap | ||
235 | + r = len(sortidx) | ||
236 | + if n is not None: | ||
237 | + r = min(r, n) | ||
238 | + for i in range(r): | ||
239 | + if actuals[sortidx[i]] > 0: | ||
240 | + poscount += 1 | ||
241 | + ap += poscount / (i + 1) * delta_recall | ||
242 | + return ap | ||
243 | + | ||
244 | + @staticmethod | ||
245 | + def _shuffle(predictions, actuals): | ||
246 | + random.seed(0) | ||
247 | + suffidx = random.sample(range(len(predictions)), len(predictions)) | ||
248 | + predictions = predictions[suffidx] | ||
249 | + actuals = actuals[suffidx] | ||
250 | + return predictions, actuals | ||
251 | + | ||
252 | + @staticmethod | ||
253 | + def _zero_one_normalize(predictions, epsilon=1e-7): | ||
254 | + """Normalize the predictions to the range between 0.0 and 1.0. | ||
255 | + | ||
256 | + For some predictions like SVM predictions, we need to normalize them before | ||
257 | + calculate the interpolated average precision. The normalization will not | ||
258 | + change the rank in the original list and thus won't change the average | ||
259 | + precision. | ||
260 | + | ||
261 | + Args: | ||
262 | + predictions: a numpy 1-D array storing the sparse prediction scores. | ||
263 | + epsilon: a small constant to avoid denominator being zero. | ||
264 | + | ||
265 | + Returns: | ||
266 | + The normalized prediction. | ||
267 | + """ | ||
268 | + denominator = numpy.max(predictions) - numpy.min(predictions) | ||
269 | + ret = (predictions - numpy.min(predictions)) / numpy.max( | ||
270 | + denominator, epsilon) | ||
271 | + return ret |
yt8m/convert_prediction_from_json_to_csv.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Utility to convert the output of batch prediction into a CSV submission. | ||
15 | + | ||
16 | +It converts the JSON files created by the command | ||
17 | +'gcloud beta ml jobs submit prediction' into a CSV file ready for submission. | ||
18 | +""" | ||
19 | + | ||
20 | +import json | ||
21 | +import tensorflow as tf | ||
22 | + | ||
23 | +from builtins import range | ||
24 | +from tensorflow import app | ||
25 | +from tensorflow import flags | ||
26 | +from tensorflow import gfile | ||
27 | +from tensorflow import logging | ||
28 | + | ||
29 | +FLAGS = flags.FLAGS | ||
30 | + | ||
31 | +if __name__ == "__main__": | ||
32 | + | ||
33 | + flags.DEFINE_string( | ||
34 | + "json_prediction_files_pattern", None, | ||
35 | + "Pattern specifying the list of JSON files that the command " | ||
36 | + "'gcloud beta ml jobs submit prediction' outputs. These files are " | ||
37 | + "located in the output path of the prediction command and are prefixed " | ||
38 | + "with 'prediction.results'.") | ||
39 | + flags.DEFINE_string( | ||
40 | + "csv_output_file", None, | ||
41 | + "The file to save the predictions converted to the CSV format.") | ||
42 | + | ||
43 | + | ||
44 | +def get_csv_header(): | ||
45 | + return "VideoId,LabelConfidencePairs\n" | ||
46 | + | ||
47 | + | ||
48 | +def to_csv_row(json_data): | ||
49 | + | ||
50 | + video_id = json_data["video_id"] | ||
51 | + | ||
52 | + class_indexes = json_data["class_indexes"] | ||
53 | + predictions = json_data["predictions"] | ||
54 | + | ||
55 | + if isinstance(video_id, list): | ||
56 | + video_id = video_id[0] | ||
57 | + class_indexes = class_indexes[0] | ||
58 | + predictions = predictions[0] | ||
59 | + | ||
60 | + if len(class_indexes) != len(predictions): | ||
61 | + raise ValueError( | ||
62 | + "The number of indexes (%s) and predictions (%s) must be equal." % | ||
63 | + (len(class_indexes), len(predictions))) | ||
64 | + | ||
65 | + return (video_id.decode("utf-8") + "," + | ||
66 | + " ".join("%i %f" % (class_indexes[i], predictions[i]) | ||
67 | + for i in range(len(class_indexes))) + "\n") | ||
68 | + | ||
69 | + | ||
70 | +def main(unused_argv): | ||
71 | + logging.set_verbosity(tf.logging.INFO) | ||
72 | + | ||
73 | + if not FLAGS.json_prediction_files_pattern: | ||
74 | + raise ValueError( | ||
75 | + "The flag --json_prediction_files_pattern must be specified.") | ||
76 | + | ||
77 | + if not FLAGS.csv_output_file: | ||
78 | + raise ValueError("The flag --csv_output_file must be specified.") | ||
79 | + | ||
80 | + logging.info("Looking for prediction files with pattern: %s", | ||
81 | + FLAGS.json_prediction_files_pattern) | ||
82 | + | ||
83 | + file_paths = gfile.Glob(FLAGS.json_prediction_files_pattern) | ||
84 | + logging.info("Found files: %s", file_paths) | ||
85 | + | ||
86 | + logging.info("Writing submission file to: %s", FLAGS.csv_output_file) | ||
87 | + with gfile.Open(FLAGS.csv_output_file, "w+") as output_file: | ||
88 | + output_file.write(get_csv_header()) | ||
89 | + | ||
90 | + for file_path in file_paths: | ||
91 | + logging.info("processing file: %s", file_path) | ||
92 | + | ||
93 | + with gfile.Open(file_path) as input_file: | ||
94 | + | ||
95 | + for line in input_file: | ||
96 | + json_data = json.loads(line) | ||
97 | + output_file.write(to_csv_row(json_data)) | ||
98 | + | ||
99 | + output_file.flush() | ||
100 | + logging.info("done") | ||
101 | + | ||
102 | + | ||
103 | +if __name__ == "__main__": | ||
104 | + app.run() |
yt8m/esot3ria/features.pb
0 → 100644
No preview for this file type
yt8m/esot3ria/inference_pb.py
0 → 100644
1 | +import numpy as np | ||
2 | +import tensorflow as tf | ||
3 | +from tensorflow import logging | ||
4 | +from tensorflow import gfile | ||
5 | +import esot3ria.pbutil as pbutil | ||
6 | + | ||
7 | + | ||
8 | +def get_segments(batch_video_mtx, batch_num_frames, segment_size): | ||
9 | + """Get segment-level inputs from frame-level features.""" | ||
10 | + video_batch_size = batch_video_mtx.shape[0] | ||
11 | + max_frame = batch_video_mtx.shape[1] | ||
12 | + feature_dim = batch_video_mtx.shape[-1] | ||
13 | + padded_segment_sizes = (batch_num_frames + segment_size - 1) // segment_size | ||
14 | + padded_segment_sizes *= segment_size | ||
15 | + segment_mask = ( | ||
16 | + 0 < (padded_segment_sizes[:, np.newaxis] - np.arange(0, max_frame))) | ||
17 | + | ||
18 | + # Segment bags. | ||
19 | + frame_bags = batch_video_mtx.reshape((-1, feature_dim)) | ||
20 | + segment_frames = frame_bags[segment_mask.reshape(-1)].reshape( | ||
21 | + (-1, segment_size, feature_dim)) | ||
22 | + | ||
23 | + # Segment num frames. | ||
24 | + segment_start_times = np.arange(0, max_frame, segment_size) | ||
25 | + num_segments = batch_num_frames[:, np.newaxis] - segment_start_times | ||
26 | + num_segment_bags = num_segments.reshape((-1)) | ||
27 | + valid_segment_mask = num_segment_bags > 0 | ||
28 | + segment_num_frames = num_segment_bags[valid_segment_mask] | ||
29 | + segment_num_frames[segment_num_frames > segment_size] = segment_size | ||
30 | + | ||
31 | + max_segment_num = (max_frame + segment_size - 1) // segment_size | ||
32 | + video_idxs = np.tile( | ||
33 | + np.arange(0, video_batch_size)[:, np.newaxis], [1, max_segment_num]) | ||
34 | + segment_idxs = np.tile(segment_start_times, [video_batch_size, 1]) | ||
35 | + idx_bags = np.stack([video_idxs, segment_idxs], axis=-1).reshape((-1, 2)) | ||
36 | + video_segment_ids = idx_bags[valid_segment_mask] | ||
37 | + | ||
38 | + return { | ||
39 | + "video_batch": segment_frames, | ||
40 | + "num_frames_batch": segment_num_frames, | ||
41 | + "video_segment_ids": video_segment_ids | ||
42 | + } | ||
43 | + | ||
44 | + | ||
45 | +def format_prediction(video_ids, predictions, top_k, whitelisted_cls_mask=None): | ||
46 | + batch_size = len(video_ids) | ||
47 | + for video_index in range(batch_size): | ||
48 | + video_prediction = predictions[video_index] | ||
49 | + if whitelisted_cls_mask is not None: | ||
50 | + # Whitelist classes. | ||
51 | + video_prediction *= whitelisted_cls_mask | ||
52 | + top_indices = np.argpartition(video_prediction, -top_k)[-top_k:] | ||
53 | + line = [(class_index, predictions[video_index][class_index]) | ||
54 | + for class_index in top_indices] | ||
55 | + line = sorted(line, key=lambda p: -p[1]) | ||
56 | + return (video_ids[video_index] + "," + | ||
57 | + " ".join("%i %g" % (label, score) for (label, score) in line) + | ||
58 | + "\n").encode("utf8") | ||
59 | + | ||
60 | + | ||
61 | +def inference_pb(filename): | ||
62 | + with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: | ||
63 | + | ||
64 | + # 200527 Esot3riA | ||
65 | + # 0. Import SequenceExample type target from pb. | ||
66 | + target_video = pbutil.convert_pb(filename) | ||
67 | + | ||
68 | + # 1. Load video features from pb. | ||
69 | + video_id_batch_val = np.array([b'video']) | ||
70 | + n_frames = len(target_video.feature_lists.feature_list['rgb'].feature) | ||
71 | + # Restrict frame size to 300 | ||
72 | + if n_frames > 300: | ||
73 | + n_frames = 300 | ||
74 | + video_batch_val = np.zeros((300, 1152)) | ||
75 | + for i in range(n_frames): | ||
76 | + video_batch_rgb_raw = target_video.feature_lists.feature_list['rgb'].feature[i].bytes_list.value[0] | ||
77 | + video_batch_rgb = np.array(tf.cast(tf.decode_raw(video_batch_rgb_raw, tf.float32), tf.float32).eval()) | ||
78 | + video_batch_audio_raw = target_video.feature_lists.feature_list['audio'].feature[i].bytes_list.value[0] | ||
79 | + video_batch_audio = np.array(tf.cast(tf.decode_raw(video_batch_audio_raw, tf.float32), tf.float32).eval()) | ||
80 | + video_batch_val[i] = np.concatenate([video_batch_rgb, video_batch_audio], axis=0) | ||
81 | + video_batch_val = np.array([video_batch_val]) | ||
82 | + num_frames_batch_val = np.array([n_frames]) | ||
83 | + # 200527 Esot3riA End | ||
84 | + | ||
85 | + # Restore checkpoint and meta-graph file | ||
86 | + checkpoint_file = '/Users/esot3ria/PycharmProjects/yt8m/models/frame' \ | ||
87 | + '/sample_model/inference_model/segment_inference_model' | ||
88 | + if not gfile.Exists(checkpoint_file + ".meta"): | ||
89 | + raise IOError("Cannot find %s. Did you run eval.py?" % checkpoint_file) | ||
90 | + meta_graph_location = checkpoint_file + ".meta" | ||
91 | + logging.info("loading meta-graph: " + meta_graph_location) | ||
92 | + | ||
93 | + with tf.device("/cpu:0"): | ||
94 | + saver = tf.train.import_meta_graph(meta_graph_location, | ||
95 | + clear_devices=True) | ||
96 | + logging.info("restoring variables from " + checkpoint_file) | ||
97 | + saver.restore(sess, checkpoint_file) | ||
98 | + input_tensor = tf.get_collection("input_batch_raw")[0] | ||
99 | + num_frames_tensor = tf.get_collection("num_frames")[0] | ||
100 | + predictions_tensor = tf.get_collection("predictions")[0] | ||
101 | + | ||
102 | + # Workaround for num_epochs issue. | ||
103 | + def set_up_init_ops(variables): | ||
104 | + init_op_list = [] | ||
105 | + for variable in list(variables): | ||
106 | + if "train_input" in variable.name: | ||
107 | + init_op_list.append(tf.assign(variable, 1)) | ||
108 | + variables.remove(variable) | ||
109 | + init_op_list.append(tf.variables_initializer(variables)) | ||
110 | + return init_op_list | ||
111 | + | ||
112 | + sess.run( | ||
113 | + set_up_init_ops(tf.get_collection_ref(tf.GraphKeys.LOCAL_VARIABLES))) | ||
114 | + | ||
115 | + coord = tf.train.Coordinator() | ||
116 | + threads = tf.train.start_queue_runners(sess=sess, coord=coord) | ||
117 | + whitelisted_cls_mask = np.zeros((predictions_tensor.get_shape()[-1],), | ||
118 | + dtype=np.float32) | ||
119 | + segment_label_ids_file = '../segment_label_ids.csv' | ||
120 | + with tf.io.gfile.GFile(segment_label_ids_file) as fobj: | ||
121 | + for line in fobj: | ||
122 | + try: | ||
123 | + cls_id = int(line) | ||
124 | + whitelisted_cls_mask[cls_id] = 1. | ||
125 | + except ValueError: | ||
126 | + # Simply skip the non-integer line. | ||
127 | + continue | ||
128 | + | ||
129 | + # 200527 Esot3riA | ||
130 | + # 2. Make segment features. | ||
131 | + results = get_segments(video_batch_val, num_frames_batch_val, 5) | ||
132 | + video_segment_ids = results["video_segment_ids"] | ||
133 | + video_id_batch_val = video_id_batch_val[video_segment_ids[:, 0]] | ||
134 | + video_id_batch_val = np.array([ | ||
135 | + "%s:%d" % (x.decode("utf8"), y) | ||
136 | + for x, y in zip(video_id_batch_val, video_segment_ids[:, 1]) | ||
137 | + ]) | ||
138 | + video_batch_val = results["video_batch"] | ||
139 | + num_frames_batch_val = results["num_frames_batch"] | ||
140 | + if input_tensor.get_shape()[1] != video_batch_val.shape[1]: | ||
141 | + raise ValueError("max_frames mismatch. Please re-run the eval.py " | ||
142 | + "with correct segment_labels settings.") | ||
143 | + | ||
144 | + predictions_val, = sess.run([predictions_tensor], | ||
145 | + feed_dict={ | ||
146 | + input_tensor: video_batch_val, | ||
147 | + num_frames_tensor: num_frames_batch_val | ||
148 | + }) | ||
149 | + logging.info(predictions_val) | ||
150 | + logging.info("profit :D") | ||
151 | + | ||
152 | + # result = format_prediction(video_id_batch_val, predictions_val, 10, whitelisted_cls_mask) | ||
153 | + | ||
154 | + | ||
155 | +if __name__ == '__main__': | ||
156 | + logging.set_verbosity(tf.logging.INFO) | ||
157 | + | ||
158 | + filename = 'features.pb' | ||
159 | + inference_pb(filename) |
yt8m/esot3ria/pbutil.py
0 → 100644
1 | +import tensorflow as tf | ||
2 | +import numpy | ||
3 | + | ||
4 | + | ||
5 | +def _make_bytes(int_array): | ||
6 | + if bytes == str: # Python2 | ||
7 | + return ''.join(map(chr, int_array)) | ||
8 | + else: | ||
9 | + return bytes(int_array) | ||
10 | + | ||
11 | + | ||
12 | +def quantize(features, min_quantized_value=-2.0, max_quantized_value=2.0): | ||
13 | + """Quantizes float32 `features` into string.""" | ||
14 | + assert features.dtype == 'float32' | ||
15 | + assert len(features.shape) == 1 # 1-D array | ||
16 | + features = numpy.clip(features, min_quantized_value, max_quantized_value) | ||
17 | + quantize_range = max_quantized_value - min_quantized_value | ||
18 | + features = (features - min_quantized_value) * (255.0 / quantize_range) | ||
19 | + features = [int(round(f)) for f in features] | ||
20 | + | ||
21 | + return _make_bytes(features) | ||
22 | + | ||
23 | + | ||
24 | +# for parse feature.pb | ||
25 | + | ||
26 | +contexts = { | ||
27 | + 'AUDIO/feature/dimensions': tf.io.FixedLenFeature([], tf.int64), | ||
28 | + 'AUDIO/feature/rate': tf.io.FixedLenFeature([], tf.float32), | ||
29 | + 'RGB/feature/dimensions': tf.io.FixedLenFeature([], tf.int64), | ||
30 | + 'RGB/feature/rate': tf.io.FixedLenFeature([], tf.float32), | ||
31 | + 'clip/data_path': tf.io.FixedLenFeature([], tf.string), | ||
32 | + 'clip/end/timestamp': tf.io.FixedLenFeature([], tf.int64), | ||
33 | + 'clip/start/timestamp': tf.io.FixedLenFeature([], tf.int64) | ||
34 | +} | ||
35 | + | ||
36 | +features = { | ||
37 | + 'AUDIO/feature/floats': tf.io.VarLenFeature(dtype=tf.float32), | ||
38 | + 'AUDIO/feature/timestamp': tf.io.VarLenFeature(tf.int64), | ||
39 | + 'RGB/feature/floats': tf.io.VarLenFeature(dtype=tf.float32), | ||
40 | + 'RGB/feature/timestamp': tf.io.VarLenFeature(tf.int64) | ||
41 | + | ||
42 | +} | ||
43 | + | ||
44 | + | ||
45 | +def parse_exmp(serial_exmp): | ||
46 | + _, sequence_parsed = tf.io.parse_single_sequence_example( | ||
47 | + serialized=serial_exmp, | ||
48 | + context_features=contexts, | ||
49 | + sequence_features=features) | ||
50 | + | ||
51 | + sequence_parsed = tf.contrib.learn.run_n(sequence_parsed)[0] | ||
52 | + | ||
53 | + audio = sequence_parsed['AUDIO/feature/floats'].values | ||
54 | + rgb = sequence_parsed['RGB/feature/floats'].values | ||
55 | + | ||
56 | + # print(audio.values) | ||
57 | + # print(type(audio.values)) | ||
58 | + | ||
59 | + # audio is 128 8bit, rgb is 1024 8bit for every second | ||
60 | + audio_slices = [audio[128 * i: 128 * (i + 1)] for i in range(len(audio) // 128)] | ||
61 | + rgb_slices = [rgb[1024 * i: 1024 * (i + 1)] for i in range(len(rgb) // 1024)] | ||
62 | + | ||
63 | + byte_audio = [] | ||
64 | + byte_rgb = [] | ||
65 | + | ||
66 | + for seg in audio_slices: | ||
67 | + # audio_seg = quantize(seg) | ||
68 | + audio_seg = _make_bytes(seg) | ||
69 | + byte_audio.append(audio_seg) | ||
70 | + | ||
71 | + for seg in rgb_slices: | ||
72 | + # rgb_seg = quantize(seg) | ||
73 | + rgb_seg = _make_bytes(seg) | ||
74 | + byte_rgb.append(rgb_seg) | ||
75 | + | ||
76 | + return byte_audio, byte_rgb | ||
77 | + | ||
78 | + | ||
79 | +def make_exmp(id, audio, rgb): | ||
80 | + audio_features = [] | ||
81 | + rgb_features = [] | ||
82 | + | ||
83 | + for embedding in audio: | ||
84 | + embedding_feature = tf.train.Feature( | ||
85 | + bytes_list=tf.train.BytesList(value=[embedding])) | ||
86 | + audio_features.append(embedding_feature) | ||
87 | + | ||
88 | + for embedding in rgb: | ||
89 | + embedding_feature = tf.train.Feature( | ||
90 | + bytes_list=tf.train.BytesList(value=[embedding])) | ||
91 | + rgb_features.append(embedding_feature) | ||
92 | + | ||
93 | + # for construct yt8m data | ||
94 | + seq_exmp = tf.train.SequenceExample( | ||
95 | + context=tf.train.Features( | ||
96 | + feature={ | ||
97 | + 'id': tf.train.Feature(bytes_list=tf.train.BytesList( | ||
98 | + value=[id.encode('utf-8')])) | ||
99 | + }), | ||
100 | + feature_lists=tf.train.FeatureLists( | ||
101 | + feature_list={ | ||
102 | + 'audio': tf.train.FeatureList( | ||
103 | + feature=audio_features | ||
104 | + ), | ||
105 | + 'rgb': tf.train.FeatureList( | ||
106 | + feature=rgb_features | ||
107 | + ) | ||
108 | + }) | ||
109 | + ) | ||
110 | + serialized = seq_exmp.SerializeToString() | ||
111 | + return serialized | ||
112 | + | ||
113 | + | ||
114 | +def convert_pb(filename): | ||
115 | + sequence_example = open(filename, 'rb').read() | ||
116 | + | ||
117 | + audio, rgb = parse_exmp(sequence_example) | ||
118 | + tmp_example = make_exmp('video', audio, rgb) | ||
119 | + | ||
120 | + decoded = tf.train.SequenceExample.FromString(tmp_example) | ||
121 | + return decoded |
yt8m/esot3ria/readpb.py
0 → 100644
1 | +import tensorflow as tf | ||
2 | +import numpy as np | ||
3 | + | ||
4 | +frame_lvl_record = "test0000.tfrecord" | ||
5 | + | ||
6 | +feat_rgb = [] | ||
7 | +feat_audio = [] | ||
8 | + | ||
9 | +for example in tf.python_io.tf_record_iterator(frame_lvl_record): | ||
10 | + tf_seq_example = tf.train.SequenceExample.FromString(example) | ||
11 | + test = tf_seq_example.SerializeToString() | ||
12 | + n_frames = len(tf_seq_example.feature_lists.feature_list['audio'].feature) | ||
13 | + sess = tf.InteractiveSession() | ||
14 | + rgb_frame = [] | ||
15 | + audio_frame = [] | ||
16 | + # iterate through frames | ||
17 | + for i in range(n_frames): | ||
18 | + rgb_frame.append(tf.cast(tf.decode_raw( | ||
19 | + tf_seq_example.feature_lists.feature_list['rgb'] | ||
20 | + .feature[i].bytes_list.value[0], tf.uint8) | ||
21 | + , tf.float32).eval()) | ||
22 | + audio_frame.append(tf.cast(tf.decode_raw( | ||
23 | + tf_seq_example.feature_lists.feature_list['audio'] | ||
24 | + .feature[i].bytes_list.value[0], tf.uint8) | ||
25 | + , tf.float32).eval()) | ||
26 | + | ||
27 | + sess.close() | ||
28 | + | ||
29 | + feat_audio.append(audio_frame) | ||
30 | + feat_rgb.append(rgb_frame) | ||
31 | + break | ||
32 | + | ||
33 | +print('The first video has %d frames' %len(feat_rgb[0])) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
yt8m/esot3ria/test0000.tfrecord
0 → 100644
No preview for this file type
yt8m/eval.py
0 → 100644
This diff is collapsed. Click to expand it.
yt8m/eval_util.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Provides functions to help with evaluating models.""" | ||
15 | +import average_precision_calculator as ap_calculator | ||
16 | +import mean_average_precision_calculator as map_calculator | ||
17 | +import numpy | ||
18 | +from tensorflow.python.platform import gfile | ||
19 | + | ||
20 | + | ||
21 | +def flatten(l): | ||
22 | + """Merges a list of lists into a single list. """ | ||
23 | + return [item for sublist in l for item in sublist] | ||
24 | + | ||
25 | + | ||
26 | +def calculate_hit_at_one(predictions, actuals): | ||
27 | + """Performs a local (numpy) calculation of the hit at one. | ||
28 | + | ||
29 | + Args: | ||
30 | + predictions: Matrix containing the outputs of the model. Dimensions are | ||
31 | + 'batch' x 'num_classes'. | ||
32 | + actuals: Matrix containing the ground truth labels. Dimensions are 'batch' x | ||
33 | + 'num_classes'. | ||
34 | + | ||
35 | + Returns: | ||
36 | + float: The average hit at one across the entire batch. | ||
37 | + """ | ||
38 | + top_prediction = numpy.argmax(predictions, 1) | ||
39 | + hits = actuals[numpy.arange(actuals.shape[0]), top_prediction] | ||
40 | + return numpy.average(hits) | ||
41 | + | ||
42 | + | ||
43 | +def calculate_precision_at_equal_recall_rate(predictions, actuals): | ||
44 | + """Performs a local (numpy) calculation of the PERR. | ||
45 | + | ||
46 | + Args: | ||
47 | + predictions: Matrix containing the outputs of the model. Dimensions are | ||
48 | + 'batch' x 'num_classes'. | ||
49 | + actuals: Matrix containing the ground truth labels. Dimensions are 'batch' x | ||
50 | + 'num_classes'. | ||
51 | + | ||
52 | + Returns: | ||
53 | + float: The average precision at equal recall rate across the entire batch. | ||
54 | + """ | ||
55 | + aggregated_precision = 0.0 | ||
56 | + num_videos = actuals.shape[0] | ||
57 | + for row in numpy.arange(num_videos): | ||
58 | + num_labels = int(numpy.sum(actuals[row])) | ||
59 | + top_indices = numpy.argpartition(predictions[row], | ||
60 | + -num_labels)[-num_labels:] | ||
61 | + item_precision = 0.0 | ||
62 | + for label_index in top_indices: | ||
63 | + if predictions[row][label_index] > 0: | ||
64 | + item_precision += actuals[row][label_index] | ||
65 | + item_precision /= top_indices.size | ||
66 | + aggregated_precision += item_precision | ||
67 | + aggregated_precision /= num_videos | ||
68 | + return aggregated_precision | ||
69 | + | ||
70 | + | ||
71 | +def calculate_gap(predictions, actuals, top_k=20): | ||
72 | + """Performs a local (numpy) calculation of the global average precision. | ||
73 | + | ||
74 | + Only the top_k predictions are taken for each of the videos. | ||
75 | + | ||
76 | + Args: | ||
77 | + predictions: Matrix containing the outputs of the model. Dimensions are | ||
78 | + 'batch' x 'num_classes'. | ||
79 | + actuals: Matrix containing the ground truth labels. Dimensions are 'batch' x | ||
80 | + 'num_classes'. | ||
81 | + top_k: How many predictions to use per video. | ||
82 | + | ||
83 | + Returns: | ||
84 | + float: The global average precision. | ||
85 | + """ | ||
86 | + gap_calculator = ap_calculator.AveragePrecisionCalculator() | ||
87 | + sparse_predictions, sparse_labels, num_positives = top_k_by_class( | ||
88 | + predictions, actuals, top_k) | ||
89 | + gap_calculator.accumulate(flatten(sparse_predictions), flatten(sparse_labels), | ||
90 | + sum(num_positives)) | ||
91 | + return gap_calculator.peek_ap_at_n() | ||
92 | + | ||
93 | + | ||
94 | +def top_k_by_class(predictions, labels, k=20): | ||
95 | + """Extracts the top k predictions for each video, sorted by class. | ||
96 | + | ||
97 | + Args: | ||
98 | + predictions: A numpy matrix containing the outputs of the model. Dimensions | ||
99 | + are 'batch' x 'num_classes'. | ||
100 | + k: the top k non-zero entries to preserve in each prediction. | ||
101 | + | ||
102 | + Returns: | ||
103 | + A tuple (predictions,labels, true_positives). 'predictions' and 'labels' | ||
104 | + are lists of lists of floats. 'true_positives' is a list of scalars. The | ||
105 | + length of the lists are equal to the number of classes. The entries in the | ||
106 | + predictions variable are probability predictions, and | ||
107 | + the corresponding entries in the labels variable are the ground truth for | ||
108 | + those predictions. The entries in 'true_positives' are the number of true | ||
109 | + positives for each class in the ground truth. | ||
110 | + | ||
111 | + Raises: | ||
112 | + ValueError: An error occurred when the k is not a positive integer. | ||
113 | + """ | ||
114 | + if k <= 0: | ||
115 | + raise ValueError("k must be a positive integer.") | ||
116 | + k = min(k, predictions.shape[1]) | ||
117 | + num_classes = predictions.shape[1] | ||
118 | + prediction_triplets = [] | ||
119 | + for video_index in range(predictions.shape[0]): | ||
120 | + prediction_triplets.extend( | ||
121 | + top_k_triplets(predictions[video_index], labels[video_index], k)) | ||
122 | + out_predictions = [[] for _ in range(num_classes)] | ||
123 | + out_labels = [[] for _ in range(num_classes)] | ||
124 | + for triplet in prediction_triplets: | ||
125 | + out_predictions[triplet[0]].append(triplet[1]) | ||
126 | + out_labels[triplet[0]].append(triplet[2]) | ||
127 | + out_true_positives = [numpy.sum(labels[:, i]) for i in range(num_classes)] | ||
128 | + | ||
129 | + return out_predictions, out_labels, out_true_positives | ||
130 | + | ||
131 | + | ||
132 | +def top_k_triplets(predictions, labels, k=20): | ||
133 | + """Get the top_k for a 1-d numpy array. | ||
134 | + | ||
135 | + Returns a sparse list of tuples in | ||
136 | + (prediction, class) format | ||
137 | + """ | ||
138 | + m = len(predictions) | ||
139 | + k = min(k, m) | ||
140 | + indices = numpy.argpartition(predictions, -k)[-k:] | ||
141 | + return [(index, predictions[index], labels[index]) for index in indices] | ||
142 | + | ||
143 | + | ||
144 | +class EvaluationMetrics(object): | ||
145 | + """A class to store the evaluation metrics.""" | ||
146 | + | ||
147 | + def __init__(self, num_class, top_k, top_n): | ||
148 | + """Construct an EvaluationMetrics object to store the evaluation metrics. | ||
149 | + | ||
150 | + Args: | ||
151 | + num_class: A positive integer specifying the number of classes. | ||
152 | + top_k: A positive integer specifying how many predictions are considered | ||
153 | + per video. | ||
154 | + top_n: A positive Integer specifying the average precision at n, or None | ||
155 | + to use all provided data points. | ||
156 | + | ||
157 | + Raises: | ||
158 | + ValueError: An error occurred when MeanAveragePrecisionCalculator cannot | ||
159 | + not be constructed. | ||
160 | + """ | ||
161 | + self.sum_hit_at_one = 0.0 | ||
162 | + self.sum_perr = 0.0 | ||
163 | + self.sum_loss = 0.0 | ||
164 | + self.map_calculator = map_calculator.MeanAveragePrecisionCalculator( | ||
165 | + num_class, top_n=top_n) | ||
166 | + self.global_ap_calculator = ap_calculator.AveragePrecisionCalculator() | ||
167 | + self.top_k = top_k | ||
168 | + self.num_examples = 0 | ||
169 | + | ||
170 | + def accumulate(self, predictions, labels, loss): | ||
171 | + """Accumulate the metrics calculated locally for this mini-batch. | ||
172 | + | ||
173 | + Args: | ||
174 | + predictions: A numpy matrix containing the outputs of the model. | ||
175 | + Dimensions are 'batch' x 'num_classes'. | ||
176 | + labels: A numpy matrix containing the ground truth labels. Dimensions are | ||
177 | + 'batch' x 'num_classes'. | ||
178 | + loss: A numpy array containing the loss for each sample. | ||
179 | + | ||
180 | + Returns: | ||
181 | + dictionary: A dictionary storing the metrics for the mini-batch. | ||
182 | + | ||
183 | + Raises: | ||
184 | + ValueError: An error occurred when the shape of predictions and actuals | ||
185 | + does not match. | ||
186 | + """ | ||
187 | + batch_size = labels.shape[0] | ||
188 | + mean_hit_at_one = calculate_hit_at_one(predictions, labels) | ||
189 | + mean_perr = calculate_precision_at_equal_recall_rate(predictions, labels) | ||
190 | + mean_loss = numpy.mean(loss) | ||
191 | + | ||
192 | + # Take the top 20 predictions. | ||
193 | + sparse_predictions, sparse_labels, num_positives = top_k_by_class( | ||
194 | + predictions, labels, self.top_k) | ||
195 | + self.map_calculator.accumulate(sparse_predictions, sparse_labels, | ||
196 | + num_positives) | ||
197 | + self.global_ap_calculator.accumulate(flatten(sparse_predictions), | ||
198 | + flatten(sparse_labels), | ||
199 | + sum(num_positives)) | ||
200 | + | ||
201 | + self.num_examples += batch_size | ||
202 | + self.sum_hit_at_one += mean_hit_at_one * batch_size | ||
203 | + self.sum_perr += mean_perr * batch_size | ||
204 | + self.sum_loss += mean_loss * batch_size | ||
205 | + | ||
206 | + return {"hit_at_one": mean_hit_at_one, "perr": mean_perr, "loss": mean_loss} | ||
207 | + | ||
208 | + def get(self): | ||
209 | + """Calculate the evaluation metrics for the whole epoch. | ||
210 | + | ||
211 | + Raises: | ||
212 | + ValueError: If no examples were accumulated. | ||
213 | + | ||
214 | + Returns: | ||
215 | + dictionary: a dictionary storing the evaluation metrics for the epoch. The | ||
216 | + dictionary has the fields: avg_hit_at_one, avg_perr, avg_loss, and | ||
217 | + aps (default nan). | ||
218 | + """ | ||
219 | + if self.num_examples <= 0: | ||
220 | + raise ValueError("total_sample must be positive.") | ||
221 | + avg_hit_at_one = self.sum_hit_at_one / self.num_examples | ||
222 | + avg_perr = self.sum_perr / self.num_examples | ||
223 | + avg_loss = self.sum_loss / self.num_examples | ||
224 | + | ||
225 | + aps = self.map_calculator.peek_map_at_n() | ||
226 | + gap = self.global_ap_calculator.peek_ap_at_n() | ||
227 | + | ||
228 | + epoch_info_dict = { | ||
229 | + "avg_hit_at_one": avg_hit_at_one, | ||
230 | + "avg_perr": avg_perr, | ||
231 | + "avg_loss": avg_loss, | ||
232 | + "aps": aps, | ||
233 | + "gap": gap | ||
234 | + } | ||
235 | + return epoch_info_dict | ||
236 | + | ||
237 | + def clear(self): | ||
238 | + """Clear the evaluation metrics and reset the EvaluationMetrics object.""" | ||
239 | + self.sum_hit_at_one = 0.0 | ||
240 | + self.sum_perr = 0.0 | ||
241 | + self.sum_loss = 0.0 | ||
242 | + self.map_calculator.clear() | ||
243 | + self.global_ap_calculator.clear() | ||
244 | + self.num_examples = 0 |
yt8m/export_model.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Utilities to export a model for batch prediction.""" | ||
15 | + | ||
16 | +import tensorflow as tf | ||
17 | +import tensorflow.contrib.slim as slim | ||
18 | + | ||
19 | +from tensorflow.python.saved_model import builder as saved_model_builder | ||
20 | +from tensorflow.python.saved_model import signature_constants | ||
21 | +from tensorflow.python.saved_model import signature_def_utils | ||
22 | +from tensorflow.python.saved_model import tag_constants | ||
23 | +from tensorflow.python.saved_model import utils as saved_model_utils | ||
24 | + | ||
25 | +_TOP_PREDICTIONS_IN_OUTPUT = 20 | ||
26 | + | ||
27 | + | ||
28 | +class ModelExporter(object): | ||
29 | + | ||
30 | + def __init__(self, frame_features, model, reader): | ||
31 | + self.frame_features = frame_features | ||
32 | + self.model = model | ||
33 | + self.reader = reader | ||
34 | + | ||
35 | + with tf.Graph().as_default() as graph: | ||
36 | + self.inputs, self.outputs = self.build_inputs_and_outputs() | ||
37 | + self.graph = graph | ||
38 | + self.saver = tf.train.Saver(tf.trainable_variables(), sharded=True) | ||
39 | + | ||
40 | + def export_model(self, model_dir, global_step_val, last_checkpoint): | ||
41 | + """Exports the model so that it can used for batch predictions.""" | ||
42 | + | ||
43 | + with self.graph.as_default(): | ||
44 | + with tf.Session() as session: | ||
45 | + session.run(tf.global_variables_initializer()) | ||
46 | + self.saver.restore(session, last_checkpoint) | ||
47 | + | ||
48 | + signature = signature_def_utils.build_signature_def( | ||
49 | + inputs=self.inputs, | ||
50 | + outputs=self.outputs, | ||
51 | + method_name=signature_constants.PREDICT_METHOD_NAME) | ||
52 | + | ||
53 | + signature_map = { | ||
54 | + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature | ||
55 | + } | ||
56 | + | ||
57 | + model_builder = saved_model_builder.SavedModelBuilder(model_dir) | ||
58 | + model_builder.add_meta_graph_and_variables( | ||
59 | + session, | ||
60 | + tags=[tag_constants.SERVING], | ||
61 | + signature_def_map=signature_map, | ||
62 | + clear_devices=True) | ||
63 | + model_builder.save() | ||
64 | + | ||
65 | + def build_inputs_and_outputs(self): | ||
66 | + if self.frame_features: | ||
67 | + serialized_examples = tf.placeholder(tf.string, shape=(None,)) | ||
68 | + | ||
69 | + fn = lambda x: self.build_prediction_graph(x) | ||
70 | + video_id_output, top_indices_output, top_predictions_output = (tf.map_fn( | ||
71 | + fn, serialized_examples, dtype=(tf.string, tf.int32, tf.float32))) | ||
72 | + | ||
73 | + else: | ||
74 | + serialized_examples = tf.placeholder(tf.string, shape=(None,)) | ||
75 | + | ||
76 | + video_id_output, top_indices_output, top_predictions_output = ( | ||
77 | + self.build_prediction_graph(serialized_examples)) | ||
78 | + | ||
79 | + inputs = { | ||
80 | + "example_bytes": | ||
81 | + saved_model_utils.build_tensor_info(serialized_examples) | ||
82 | + } | ||
83 | + | ||
84 | + outputs = { | ||
85 | + "video_id": | ||
86 | + saved_model_utils.build_tensor_info(video_id_output), | ||
87 | + "class_indexes": | ||
88 | + saved_model_utils.build_tensor_info(top_indices_output), | ||
89 | + "predictions": | ||
90 | + saved_model_utils.build_tensor_info(top_predictions_output) | ||
91 | + } | ||
92 | + | ||
93 | + return inputs, outputs | ||
94 | + | ||
95 | + def build_prediction_graph(self, serialized_examples): | ||
96 | + input_data_dict = ( | ||
97 | + self.reader.prepare_serialized_examples(serialized_examples)) | ||
98 | + video_id = input_data_dict["video_ids"] | ||
99 | + model_input_raw = input_data_dict["video_matrix"] | ||
100 | + labels_batch = input_data_dict["labels"] | ||
101 | + num_frames = input_data_dict["num_frames"] | ||
102 | + | ||
103 | + feature_dim = len(model_input_raw.get_shape()) - 1 | ||
104 | + model_input = tf.nn.l2_normalize(model_input_raw, feature_dim) | ||
105 | + | ||
106 | + with tf.variable_scope("tower"): | ||
107 | + result = self.model.create_model(model_input, | ||
108 | + num_frames=num_frames, | ||
109 | + vocab_size=self.reader.num_classes, | ||
110 | + labels=labels_batch, | ||
111 | + is_training=False) | ||
112 | + | ||
113 | + for variable in slim.get_model_variables(): | ||
114 | + tf.summary.histogram(variable.op.name, variable) | ||
115 | + | ||
116 | + predictions = result["predictions"] | ||
117 | + | ||
118 | + top_predictions, top_indices = tf.nn.top_k(predictions, | ||
119 | + _TOP_PREDICTIONS_IN_OUTPUT) | ||
120 | + return video_id, top_indices, top_predictions |
yt8m/export_model_mediapipe.py
0 → 100644
1 | +# Lint as: python3 | ||
2 | +import numpy as np | ||
3 | +import tensorflow as tf | ||
4 | +from tensorflow import app | ||
5 | +from tensorflow import flags | ||
6 | + | ||
7 | +FLAGS = flags.FLAGS | ||
8 | + | ||
9 | + | ||
10 | +def main(unused_argv): | ||
11 | + # Get the input tensor names to be replaced. | ||
12 | + tf.reset_default_graph() | ||
13 | + meta_graph_location = FLAGS.checkpoint_file + ".meta" | ||
14 | + tf.train.import_meta_graph(meta_graph_location, clear_devices=True) | ||
15 | + | ||
16 | + input_tensor_name = tf.get_collection("input_batch_raw")[0].name | ||
17 | + num_frames_tensor_name = tf.get_collection("num_frames")[0].name | ||
18 | + | ||
19 | + # Create output graph. | ||
20 | + saver = tf.train.Saver() | ||
21 | + tf.reset_default_graph() | ||
22 | + | ||
23 | + input_feature_placeholder = tf.placeholder( | ||
24 | + tf.float32, shape=(None, None, 1152)) | ||
25 | + num_frames_placeholder = tf.placeholder(tf.int32, shape=(None, 1)) | ||
26 | + | ||
27 | + saver = tf.train.import_meta_graph( | ||
28 | + meta_graph_location, | ||
29 | + input_map={ | ||
30 | + input_tensor_name: input_feature_placeholder, | ||
31 | + num_frames_tensor_name: tf.squeeze(num_frames_placeholder, axis=1) | ||
32 | + }, | ||
33 | + clear_devices=True) | ||
34 | + predictions_tensor = tf.get_collection("predictions")[0] | ||
35 | + | ||
36 | + with tf.Session() as sess: | ||
37 | + print("restoring variables from " + FLAGS.checkpoint_file) | ||
38 | + saver.restore(sess, FLAGS.checkpoint_file) | ||
39 | + tf.saved_model.simple_save( | ||
40 | + sess, | ||
41 | + FLAGS.output_dir, | ||
42 | + inputs={'rgb_and_audio': input_feature_placeholder, | ||
43 | + 'num_frames': num_frames_placeholder}, | ||
44 | + outputs={'predictions': predictions_tensor}) | ||
45 | + | ||
46 | + # Try running inference. | ||
47 | + predictions = sess.run( | ||
48 | + [predictions_tensor], | ||
49 | + feed_dict={ | ||
50 | + input_feature_placeholder: np.zeros((3, 7, 1152), dtype=np.float32), | ||
51 | + num_frames_placeholder: np.array([[7]], dtype=np.int32)}) | ||
52 | + print('Test inference:', predictions) | ||
53 | + | ||
54 | + print('Model saved to ', FLAGS.output_dir) | ||
55 | + | ||
56 | + | ||
57 | +if __name__ == '__main__': | ||
58 | + flags.DEFINE_string('checkpoint_file', None, 'Path to the checkpoint file.') | ||
59 | + flags.DEFINE_string('output_dir', None, 'SavedModel output directory.') | ||
60 | + app.run(main) |
yt8m/frame_level_models.py
0 → 100644
This diff is collapsed. Click to expand it.
yt8m/inference.py
0 → 100644
This diff is collapsed. Click to expand it.
yt8m/inference_per_segment.py
0 → 100644
This diff is collapsed. Click to expand it.
yt8m/losses.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Provides definitions for non-regularized training or test losses.""" | ||
15 | + | ||
16 | +import tensorflow as tf | ||
17 | + | ||
18 | + | ||
19 | +class BaseLoss(object): | ||
20 | + """Inherit from this class when implementing new losses.""" | ||
21 | + | ||
22 | + def calculate_loss(self, unused_predictions, unused_labels, **unused_params): | ||
23 | + """Calculates the average loss of the examples in a mini-batch. | ||
24 | + | ||
25 | + Args: | ||
26 | + unused_predictions: a 2-d tensor storing the prediction scores, in which | ||
27 | + each row represents a sample in the mini-batch and each column | ||
28 | + represents a class. | ||
29 | + unused_labels: a 2-d tensor storing the labels, which has the same shape | ||
30 | + as the unused_predictions. The labels must be in the range of 0 and 1. | ||
31 | + unused_params: loss specific parameters. | ||
32 | + | ||
33 | + Returns: | ||
34 | + A scalar loss tensor. | ||
35 | + """ | ||
36 | + raise NotImplementedError() | ||
37 | + | ||
38 | + | ||
39 | +class CrossEntropyLoss(BaseLoss): | ||
40 | + """Calculate the cross entropy loss between the predictions and labels.""" | ||
41 | + | ||
42 | + def calculate_loss(self, | ||
43 | + predictions, | ||
44 | + labels, | ||
45 | + label_weights=None, | ||
46 | + **unused_params): | ||
47 | + with tf.name_scope("loss_xent"): | ||
48 | + epsilon = 1e-5 | ||
49 | + float_labels = tf.cast(labels, tf.float32) | ||
50 | + cross_entropy_loss = float_labels * tf.math.log(predictions + epsilon) + ( | ||
51 | + 1 - float_labels) * tf.math.log(1 - predictions + epsilon) | ||
52 | + cross_entropy_loss = tf.negative(cross_entropy_loss) | ||
53 | + if label_weights is not None: | ||
54 | + cross_entropy_loss *= label_weights | ||
55 | + return tf.reduce_mean(tf.reduce_sum(cross_entropy_loss, 1)) | ||
56 | + | ||
57 | + | ||
58 | +class HingeLoss(BaseLoss): | ||
59 | + """Calculate the hinge loss between the predictions and labels. | ||
60 | + | ||
61 | + Note the subgradient is used in the backpropagation, and thus the optimization | ||
62 | + may converge slower. The predictions trained by the hinge loss are between -1 | ||
63 | + and +1. | ||
64 | + """ | ||
65 | + | ||
66 | + def calculate_loss(self, predictions, labels, b=1.0, **unused_params): | ||
67 | + with tf.name_scope("loss_hinge"): | ||
68 | + float_labels = tf.cast(labels, tf.float32) | ||
69 | + all_zeros = tf.zeros(tf.shape(float_labels), dtype=tf.float32) | ||
70 | + all_ones = tf.ones(tf.shape(float_labels), dtype=tf.float32) | ||
71 | + sign_labels = tf.subtract(tf.scalar_mul(2, float_labels), all_ones) | ||
72 | + hinge_loss = tf.maximum( | ||
73 | + all_zeros, | ||
74 | + tf.scalar_mul(b, all_ones) - sign_labels * predictions) | ||
75 | + return tf.reduce_mean(tf.reduce_sum(hinge_loss, 1)) | ||
76 | + | ||
77 | + | ||
78 | +class SoftmaxLoss(BaseLoss): | ||
79 | + """Calculate the softmax loss between the predictions and labels. | ||
80 | + | ||
81 | + The function calculates the loss in the following way: first we feed the | ||
82 | + predictions to the softmax activation function and then we calculate | ||
83 | + the minus linear dot product between the logged softmax activations and the | ||
84 | + normalized ground truth label. | ||
85 | + | ||
86 | + It is an extension to the one-hot label. It allows for more than one positive | ||
87 | + labels for each sample. | ||
88 | + """ | ||
89 | + | ||
90 | + def calculate_loss(self, predictions, labels, **unused_params): | ||
91 | + with tf.name_scope("loss_softmax"): | ||
92 | + epsilon = 10e-8 | ||
93 | + float_labels = tf.cast(labels, tf.float32) | ||
94 | + # l1 normalization (labels are no less than 0) | ||
95 | + label_rowsum = tf.maximum(tf.reduce_sum(float_labels, 1, keep_dims=True), | ||
96 | + epsilon) | ||
97 | + norm_float_labels = tf.div(float_labels, label_rowsum) | ||
98 | + softmax_outputs = tf.nn.softmax(predictions) | ||
99 | + softmax_loss = tf.negative( | ||
100 | + tf.reduce_sum(tf.multiply(norm_float_labels, tf.log(softmax_outputs)), | ||
101 | + 1)) | ||
102 | + return tf.reduce_mean(softmax_loss) |
yt8m/mean_average_precision_calculator.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Calculate the mean average precision. | ||
15 | + | ||
16 | +It provides an interface for calculating mean average precision | ||
17 | +for an entire list or the top-n ranked items. | ||
18 | + | ||
19 | +Example usages: | ||
20 | +We first call the function accumulate many times to process parts of the ranked | ||
21 | +list. After processing all the parts, we call peek_map_at_n | ||
22 | +to calculate the mean average precision. | ||
23 | + | ||
24 | +``` | ||
25 | +import random | ||
26 | + | ||
27 | +p = np.array([[random.random() for _ in xrange(50)] for _ in xrange(1000)]) | ||
28 | +a = np.array([[random.choice([0, 1]) for _ in xrange(50)] | ||
29 | + for _ in xrange(1000)]) | ||
30 | + | ||
31 | +# mean average precision for 50 classes. | ||
32 | +calculator = mean_average_precision_calculator.MeanAveragePrecisionCalculator( | ||
33 | + num_class=50) | ||
34 | +calculator.accumulate(p, a) | ||
35 | +aps = calculator.peek_map_at_n() | ||
36 | +``` | ||
37 | +""" | ||
38 | + | ||
39 | +import average_precision_calculator | ||
40 | + | ||
41 | + | ||
42 | +class MeanAveragePrecisionCalculator(object): | ||
43 | + """This class is to calculate mean average precision.""" | ||
44 | + | ||
45 | + def __init__(self, num_class, filter_empty_classes=True, top_n=None): | ||
46 | + """Construct a calculator to calculate the (macro) average precision. | ||
47 | + | ||
48 | + Args: | ||
49 | + num_class: A positive Integer specifying the number of classes. | ||
50 | + filter_empty_classes: whether to filter classes without any positives. | ||
51 | + top_n: A positive Integer specifying the average precision at n, or None | ||
52 | + to use all provided data points. | ||
53 | + | ||
54 | + Raises: | ||
55 | + ValueError: An error occurred when num_class is not a positive integer; | ||
56 | + or the top_n_array is not a list of positive integers. | ||
57 | + """ | ||
58 | + if not isinstance(num_class, int) or num_class <= 1: | ||
59 | + raise ValueError("num_class must be a positive integer.") | ||
60 | + | ||
61 | + self._ap_calculators = [] # member of AveragePrecisionCalculator | ||
62 | + self._num_class = num_class # total number of classes | ||
63 | + self._filter_empty_classes = filter_empty_classes | ||
64 | + for _ in range(num_class): | ||
65 | + self._ap_calculators.append( | ||
66 | + average_precision_calculator.AveragePrecisionCalculator(top_n=top_n)) | ||
67 | + | ||
68 | + def accumulate(self, predictions, actuals, num_positives=None): | ||
69 | + """Accumulate the predictions and their ground truth labels. | ||
70 | + | ||
71 | + Args: | ||
72 | + predictions: A list of lists storing the prediction scores. The outer | ||
73 | + dimension corresponds to classes. | ||
74 | + actuals: A list of lists storing the ground truth labels. The dimensions | ||
75 | + should correspond to the predictions input. Any value larger than 0 will | ||
76 | + be treated as positives, otherwise as negatives. | ||
77 | + num_positives: If provided, it is a list of numbers representing the | ||
78 | + number of true positives for each class. If not provided, the number of | ||
79 | + true positives will be inferred from the 'actuals' array. | ||
80 | + | ||
81 | + Raises: | ||
82 | + ValueError: An error occurred when the shape of predictions and actuals | ||
83 | + does not match. | ||
84 | + """ | ||
85 | + if not num_positives: | ||
86 | + num_positives = [None for i in range(self._num_class)] | ||
87 | + | ||
88 | + calculators = self._ap_calculators | ||
89 | + for i in range(self._num_class): | ||
90 | + calculators[i].accumulate(predictions[i], actuals[i], num_positives[i]) | ||
91 | + | ||
92 | + def clear(self): | ||
93 | + for calculator in self._ap_calculators: | ||
94 | + calculator.clear() | ||
95 | + | ||
96 | + def is_empty(self): | ||
97 | + return ([calculator.heap_size for calculator in self._ap_calculators | ||
98 | + ] == [0 for _ in range(self._num_class)]) | ||
99 | + | ||
100 | + def peek_map_at_n(self): | ||
101 | + """Peek the non-interpolated mean average precision at n. | ||
102 | + | ||
103 | + Returns: | ||
104 | + An array of non-interpolated average precision at n (default 0) for each | ||
105 | + class. | ||
106 | + """ | ||
107 | + aps = [] | ||
108 | + for i in range(self._num_class): | ||
109 | + if (not self._filter_empty_classes or | ||
110 | + self._ap_calculators[i].num_accumulated_positives > 0): | ||
111 | + ap = self._ap_calculators[i].peek_ap_at_n() | ||
112 | + aps.append(ap) | ||
113 | + return aps |
yt8m/model_utils.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Contains a collection of util functions for model construction.""" | ||
15 | +import numpy | ||
16 | +import tensorflow as tf | ||
17 | +from tensorflow import logging | ||
18 | +from tensorflow import flags | ||
19 | +import tensorflow.contrib.slim as slim | ||
20 | + | ||
21 | + | ||
22 | +def SampleRandomSequence(model_input, num_frames, num_samples): | ||
23 | + """Samples a random sequence of frames of size num_samples. | ||
24 | + | ||
25 | + Args: | ||
26 | + model_input: A tensor of size batch_size x max_frames x feature_size | ||
27 | + num_frames: A tensor of size batch_size x 1 | ||
28 | + num_samples: A scalar | ||
29 | + | ||
30 | + Returns: | ||
31 | + `model_input`: A tensor of size batch_size x num_samples x feature_size | ||
32 | + """ | ||
33 | + | ||
34 | + batch_size = tf.shape(model_input)[0] | ||
35 | + frame_index_offset = tf.tile(tf.expand_dims(tf.range(num_samples), 0), | ||
36 | + [batch_size, 1]) | ||
37 | + max_start_frame_index = tf.maximum(num_frames - num_samples, 0) | ||
38 | + start_frame_index = tf.cast( | ||
39 | + tf.multiply(tf.random_uniform([batch_size, 1]), | ||
40 | + tf.cast(max_start_frame_index + 1, tf.float32)), tf.int32) | ||
41 | + frame_index = tf.minimum(start_frame_index + frame_index_offset, | ||
42 | + tf.cast(num_frames - 1, tf.int32)) | ||
43 | + batch_index = tf.tile(tf.expand_dims(tf.range(batch_size), 1), | ||
44 | + [1, num_samples]) | ||
45 | + index = tf.stack([batch_index, frame_index], 2) | ||
46 | + return tf.gather_nd(model_input, index) | ||
47 | + | ||
48 | + | ||
49 | +def SampleRandomFrames(model_input, num_frames, num_samples): | ||
50 | + """Samples a random set of frames of size num_samples. | ||
51 | + | ||
52 | + Args: | ||
53 | + model_input: A tensor of size batch_size x max_frames x feature_size | ||
54 | + num_frames: A tensor of size batch_size x 1 | ||
55 | + num_samples: A scalar | ||
56 | + | ||
57 | + Returns: | ||
58 | + `model_input`: A tensor of size batch_size x num_samples x feature_size | ||
59 | + """ | ||
60 | + batch_size = tf.shape(model_input)[0] | ||
61 | + frame_index = tf.cast( | ||
62 | + tf.multiply(tf.random_uniform([batch_size, num_samples]), | ||
63 | + tf.tile(tf.cast(num_frames, tf.float32), [1, num_samples])), | ||
64 | + tf.int32) | ||
65 | + batch_index = tf.tile(tf.expand_dims(tf.range(batch_size), 1), | ||
66 | + [1, num_samples]) | ||
67 | + index = tf.stack([batch_index, frame_index], 2) | ||
68 | + return tf.gather_nd(model_input, index) | ||
69 | + | ||
70 | + | ||
71 | +def FramePooling(frames, method, **unused_params): | ||
72 | + """Pools over the frames of a video. | ||
73 | + | ||
74 | + Args: | ||
75 | + frames: A tensor with shape [batch_size, num_frames, feature_size]. | ||
76 | + method: "average", "max", "attention", or "none". | ||
77 | + | ||
78 | + Returns: | ||
79 | + A tensor with shape [batch_size, feature_size] for average, max, or | ||
80 | + attention pooling. A tensor with shape [batch_size*num_frames, feature_size] | ||
81 | + for none pooling. | ||
82 | + | ||
83 | + Raises: | ||
84 | + ValueError: if method is other than "average", "max", "attention", or | ||
85 | + "none". | ||
86 | + """ | ||
87 | + if method == "average": | ||
88 | + return tf.reduce_mean(frames, 1) | ||
89 | + elif method == "max": | ||
90 | + return tf.reduce_max(frames, 1) | ||
91 | + elif method == "none": | ||
92 | + feature_size = frames.shape_as_list()[2] | ||
93 | + return tf.reshape(frames, [-1, feature_size]) | ||
94 | + else: | ||
95 | + raise ValueError("Unrecognized pooling method: %s" % method) |
yt8m/models.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Contains the base class for models.""" | ||
15 | + | ||
16 | + | ||
17 | +class BaseModel(object): | ||
18 | + """Inherit from this class when implementing new models.""" | ||
19 | + | ||
20 | + def create_model(self, unused_model_input, **unused_params): | ||
21 | + raise NotImplementedError() |
yt8m/readers.py
0 → 100644
This diff is collapsed. Click to expand it.
yt8m/segment_eval_inference.py
0 → 100644
1 | +"""Eval mAP@N metric from inference file.""" | ||
2 | + | ||
3 | +from __future__ import absolute_import | ||
4 | +from __future__ import division | ||
5 | +from __future__ import print_function | ||
6 | + | ||
7 | +from absl import app | ||
8 | +from absl import flags | ||
9 | + | ||
10 | +import mean_average_precision_calculator as map_calculator | ||
11 | +import numpy as np | ||
12 | +import tensorflow as tf | ||
13 | + | ||
14 | +flags.DEFINE_string( | ||
15 | + "eval_data_pattern", "", | ||
16 | + "File glob defining the evaluation dataset in tensorflow.SequenceExample " | ||
17 | + "format. The SequenceExamples are expected to have an 'rgb' byte array " | ||
18 | + "sequence feature as well as a 'labels' int64 context feature.") | ||
19 | +flags.DEFINE_string( | ||
20 | + "label_cache", "", | ||
21 | + "The path for the label cache file. Leave blank for not to cache.") | ||
22 | +flags.DEFINE_string("submission_file", "", | ||
23 | + "The segment submission file generated by inference.py.") | ||
24 | +flags.DEFINE_integer( | ||
25 | + "top_n", 0, | ||
26 | + "The cap per-class predictions by a maximum of N. Use 0 for not capping.") | ||
27 | + | ||
28 | +FLAGS = flags.FLAGS | ||
29 | + | ||
30 | + | ||
31 | +class Labels(object): | ||
32 | + """Contains the class to hold label objects. | ||
33 | + | ||
34 | + This class can serialize and de-serialize the groundtruths. | ||
35 | + The ground truth is in a mapping from (segment_id, class_id) -> label_score. | ||
36 | + """ | ||
37 | + | ||
38 | + def __init__(self, labels): | ||
39 | + """__init__ method.""" | ||
40 | + self._labels = labels | ||
41 | + | ||
42 | + @property | ||
43 | + def labels(self): | ||
44 | + """Return the ground truth mapping. See class docstring for details.""" | ||
45 | + return self._labels | ||
46 | + | ||
47 | + def to_file(self, file_name): | ||
48 | + """Materialize the GT mapping to file.""" | ||
49 | + with tf.gfile.Open(file_name, "w") as fobj: | ||
50 | + for k, v in self._labels.items(): | ||
51 | + seg_id, label = k | ||
52 | + line = "%s,%s,%s\n" % (seg_id, label, v) | ||
53 | + fobj.write(line) | ||
54 | + | ||
55 | + @classmethod | ||
56 | + def from_file(cls, file_name): | ||
57 | + """Read the GT mapping from cached file.""" | ||
58 | + labels = {} | ||
59 | + with tf.gfile.Open(file_name) as fobj: | ||
60 | + for line in fobj: | ||
61 | + line = line.strip().strip("\n") | ||
62 | + seg_id, label, score = line.split(",") | ||
63 | + labels[(seg_id, int(label))] = float(score) | ||
64 | + return cls(labels) | ||
65 | + | ||
66 | + | ||
67 | +def read_labels(data_pattern, cache_path=""): | ||
68 | + """Read labels from TFRecords. | ||
69 | + | ||
70 | + Args: | ||
71 | + data_pattern: the data pattern to the TFRecords. | ||
72 | + cache_path: the cache path for the label file. | ||
73 | + | ||
74 | + Returns: | ||
75 | + a Labels object. | ||
76 | + """ | ||
77 | + if cache_path: | ||
78 | + if tf.gfile.Exists(cache_path): | ||
79 | + tf.logging.info("Reading cached labels from %s..." % cache_path) | ||
80 | + return Labels.from_file(cache_path) | ||
81 | + tf.enable_eager_execution() | ||
82 | + data_paths = tf.gfile.Glob(data_pattern) | ||
83 | + ds = tf.data.TFRecordDataset(data_paths, num_parallel_reads=50) | ||
84 | + context_features = { | ||
85 | + "id": tf.FixedLenFeature([], tf.string), | ||
86 | + "segment_labels": tf.VarLenFeature(tf.int64), | ||
87 | + "segment_start_times": tf.VarLenFeature(tf.int64), | ||
88 | + "segment_scores": tf.VarLenFeature(tf.float32) | ||
89 | + } | ||
90 | + | ||
91 | + def _parse_se_func(sequence_example): | ||
92 | + return tf.parse_single_sequence_example(sequence_example, | ||
93 | + context_features=context_features) | ||
94 | + | ||
95 | + ds = ds.map(_parse_se_func) | ||
96 | + rated_labels = {} | ||
97 | + tf.logging.info("Reading labels from TFRecords...") | ||
98 | + last_batch = 0 | ||
99 | + batch_size = 5000 | ||
100 | + for cxt_feature_val, _ in ds: | ||
101 | + video_id = cxt_feature_val["id"].numpy() | ||
102 | + segment_labels = cxt_feature_val["segment_labels"].values.numpy() | ||
103 | + segment_start_times = cxt_feature_val["segment_start_times"].values.numpy() | ||
104 | + segment_scores = cxt_feature_val["segment_scores"].values.numpy() | ||
105 | + for label, start_time, score in zip(segment_labels, segment_start_times, | ||
106 | + segment_scores): | ||
107 | + rated_labels[("%s:%d" % (video_id, start_time), label)] = score | ||
108 | + batch_id = len(rated_labels) // batch_size | ||
109 | + if batch_id != last_batch: | ||
110 | + tf.logging.info("%d examples processed.", len(rated_labels)) | ||
111 | + last_batch = batch_id | ||
112 | + tf.logging.info("Finish reading labels from TFRecords...") | ||
113 | + labels_obj = Labels(rated_labels) | ||
114 | + if cache_path: | ||
115 | + tf.logging.info("Caching labels to %s..." % cache_path) | ||
116 | + labels_obj.to_file(cache_path) | ||
117 | + return labels_obj | ||
118 | + | ||
119 | + | ||
120 | +def read_segment_predictions(file_path, labels, top_n=None): | ||
121 | + """Read segement predictions. | ||
122 | + | ||
123 | + Args: | ||
124 | + file_path: the submission file path. | ||
125 | + labels: a Labels object containing the eval labels. | ||
126 | + top_n: the per-class class capping. | ||
127 | + | ||
128 | + Returns: | ||
129 | + a segment prediction list for each classes. | ||
130 | + """ | ||
131 | + cls_preds = {} # A label_id to pred list mapping. | ||
132 | + with tf.gfile.Open(file_path) as fobj: | ||
133 | + tf.logging.info("Reading predictions from %s..." % file_path) | ||
134 | + for line in fobj: | ||
135 | + label_id, pred_ids_val = line.split(",") | ||
136 | + pred_ids = pred_ids_val.split(" ") | ||
137 | + if top_n: | ||
138 | + pred_ids = pred_ids[:top_n] | ||
139 | + pred_ids = [ | ||
140 | + pred_id for pred_id in pred_ids | ||
141 | + if (pred_id, int(label_id)) in labels.labels | ||
142 | + ] | ||
143 | + cls_preds[int(label_id)] = pred_ids | ||
144 | + if len(cls_preds) % 50 == 0: | ||
145 | + tf.logging.info("Processed %d classes..." % len(cls_preds)) | ||
146 | + tf.logging.info("Finish reading predictions.") | ||
147 | + return cls_preds | ||
148 | + | ||
149 | + | ||
150 | +def main(unused_argv): | ||
151 | + """Entry function of the script.""" | ||
152 | + if not FLAGS.submission_file: | ||
153 | + raise ValueError("You must input submission file.") | ||
154 | + eval_labels = read_labels(FLAGS.eval_data_pattern, | ||
155 | + cache_path=FLAGS.label_cache) | ||
156 | + tf.logging.info("Total rated segments: %d." % len(eval_labels.labels)) | ||
157 | + positive_counter = {} | ||
158 | + for k, v in eval_labels.labels.items(): | ||
159 | + _, label_id = k | ||
160 | + if v > 0: | ||
161 | + positive_counter[label_id] = positive_counter.get(label_id, 0) + 1 | ||
162 | + | ||
163 | + seg_preds = read_segment_predictions(FLAGS.submission_file, | ||
164 | + eval_labels, | ||
165 | + top_n=FLAGS.top_n) | ||
166 | + map_cal = map_calculator.MeanAveragePrecisionCalculator(len(seg_preds)) | ||
167 | + seg_labels = [] | ||
168 | + seg_scored_preds = [] | ||
169 | + num_positives = [] | ||
170 | + for label_id in sorted(seg_preds): | ||
171 | + class_preds = seg_preds[label_id] | ||
172 | + seg_label = [eval_labels.labels[(pred, label_id)] for pred in class_preds] | ||
173 | + seg_labels.append(seg_label) | ||
174 | + seg_scored_pred = [] | ||
175 | + if class_preds: | ||
176 | + seg_scored_pred = [ | ||
177 | + float(x) / len(class_preds) for x in range(len(class_preds), 0, -1) | ||
178 | + ] | ||
179 | + seg_scored_preds.append(seg_scored_pred) | ||
180 | + num_positives.append(positive_counter[label_id]) | ||
181 | + map_cal.accumulate(seg_scored_preds, seg_labels, num_positives) | ||
182 | + map_at_n = np.mean(map_cal.peek_map_at_n()) | ||
183 | + tf.logging.info("Num classes: %d | mAP@%d: %.6f" % | ||
184 | + (len(seg_preds), FLAGS.top_n, map_at_n)) | ||
185 | + | ||
186 | + | ||
187 | +if __name__ == "__main__": | ||
188 | + app.run(main) |
yt8m/segment_label_ids.csv
0 → 100644
1 | +Index | ||
2 | +3 | ||
3 | +7 | ||
4 | +8 | ||
5 | +11 | ||
6 | +12 | ||
7 | +17 | ||
8 | +18 | ||
9 | +19 | ||
10 | +21 | ||
11 | +22 | ||
12 | +23 | ||
13 | +28 | ||
14 | +31 | ||
15 | +30 | ||
16 | +32 | ||
17 | +33 | ||
18 | +34 | ||
19 | +41 | ||
20 | +43 | ||
21 | +45 | ||
22 | +46 | ||
23 | +48 | ||
24 | +53 | ||
25 | +54 | ||
26 | +52 | ||
27 | +55 | ||
28 | +58 | ||
29 | +59 | ||
30 | +60 | ||
31 | +61 | ||
32 | +65 | ||
33 | +68 | ||
34 | +73 | ||
35 | +71 | ||
36 | +74 | ||
37 | +75 | ||
38 | +76 | ||
39 | +77 | ||
40 | +80 | ||
41 | +83 | ||
42 | +90 | ||
43 | +88 | ||
44 | +89 | ||
45 | +92 | ||
46 | +95 | ||
47 | +100 | ||
48 | +101 | ||
49 | +99 | ||
50 | +104 | ||
51 | +105 | ||
52 | +109 | ||
53 | +113 | ||
54 | +112 | ||
55 | +115 | ||
56 | +116 | ||
57 | +118 | ||
58 | +120 | ||
59 | +121 | ||
60 | +123 | ||
61 | +125 | ||
62 | +127 | ||
63 | +131 | ||
64 | +128 | ||
65 | +129 | ||
66 | +130 | ||
67 | +137 | ||
68 | +141 | ||
69 | +143 | ||
70 | +145 | ||
71 | +148 | ||
72 | +152 | ||
73 | +151 | ||
74 | +156 | ||
75 | +155 | ||
76 | +158 | ||
77 | +160 | ||
78 | +164 | ||
79 | +163 | ||
80 | +169 | ||
81 | +170 | ||
82 | +172 | ||
83 | +171 | ||
84 | +173 | ||
85 | +174 | ||
86 | +175 | ||
87 | +176 | ||
88 | +178 | ||
89 | +182 | ||
90 | +184 | ||
91 | +186 | ||
92 | +188 | ||
93 | +187 | ||
94 | +192 | ||
95 | +191 | ||
96 | +190 | ||
97 | +194 | ||
98 | +197 | ||
99 | +196 | ||
100 | +198 | ||
101 | +201 | ||
102 | +202 | ||
103 | +200 | ||
104 | +199 | ||
105 | +205 | ||
106 | +204 | ||
107 | +209 | ||
108 | +207 | ||
109 | +206 | ||
110 | +210 | ||
111 | +213 | ||
112 | +214 | ||
113 | +220 | ||
114 | +218 | ||
115 | +217 | ||
116 | +226 | ||
117 | +227 | ||
118 | +231 | ||
119 | +232 | ||
120 | +229 | ||
121 | +233 | ||
122 | +235 | ||
123 | +237 | ||
124 | +244 | ||
125 | +240 | ||
126 | +249 | ||
127 | +246 | ||
128 | +248 | ||
129 | +239 | ||
130 | +250 | ||
131 | +245 | ||
132 | +255 | ||
133 | +253 | ||
134 | +256 | ||
135 | +261 | ||
136 | +259 | ||
137 | +263 | ||
138 | +262 | ||
139 | +266 | ||
140 | +267 | ||
141 | +268 | ||
142 | +269 | ||
143 | +271 | ||
144 | +276 | ||
145 | +273 | ||
146 | +277 | ||
147 | +274 | ||
148 | +278 | ||
149 | +279 | ||
150 | +280 | ||
151 | +288 | ||
152 | +291 | ||
153 | +295 | ||
154 | +294 | ||
155 | +293 | ||
156 | +297 | ||
157 | +296 | ||
158 | +300 | ||
159 | +299 | ||
160 | +303 | ||
161 | +302 | ||
162 | +304 | ||
163 | +305 | ||
164 | +313 | ||
165 | +307 | ||
166 | +311 | ||
167 | +310 | ||
168 | +312 | ||
169 | +316 | ||
170 | +318 | ||
171 | +321 | ||
172 | +322 | ||
173 | +331 | ||
174 | +333 | ||
175 | +329 | ||
176 | +330 | ||
177 | +334 | ||
178 | +343 | ||
179 | +349 | ||
180 | +340 | ||
181 | +344 | ||
182 | +348 | ||
183 | +358 | ||
184 | +347 | ||
185 | +359 | ||
186 | +355 | ||
187 | +361 | ||
188 | +360 | ||
189 | +364 | ||
190 | +365 | ||
191 | +368 | ||
192 | +369 | ||
193 | +366 | ||
194 | +370 | ||
195 | +374 | ||
196 | +380 | ||
197 | +373 | ||
198 | +385 | ||
199 | +384 | ||
200 | +388 | ||
201 | +389 | ||
202 | +382 | ||
203 | +393 | ||
204 | +381 | ||
205 | +390 | ||
206 | +394 | ||
207 | +399 | ||
208 | +397 | ||
209 | +396 | ||
210 | +402 | ||
211 | +400 | ||
212 | +398 | ||
213 | +401 | ||
214 | +405 | ||
215 | +406 | ||
216 | +410 | ||
217 | +408 | ||
218 | +416 | ||
219 | +415 | ||
220 | +419 | ||
221 | +422 | ||
222 | +414 | ||
223 | +421 | ||
224 | +424 | ||
225 | +429 | ||
226 | +418 | ||
227 | +427 | ||
228 | +434 | ||
229 | +428 | ||
230 | +435 | ||
231 | +430 | ||
232 | +441 | ||
233 | +439 | ||
234 | +437 | ||
235 | +443 | ||
236 | +440 | ||
237 | +442 | ||
238 | +445 | ||
239 | +446 | ||
240 | +448 | ||
241 | +454 | ||
242 | +444 | ||
243 | +453 | ||
244 | +455 | ||
245 | +451 | ||
246 | +452 | ||
247 | +458 | ||
248 | +460 | ||
249 | +465 | ||
250 | +457 | ||
251 | +463 | ||
252 | +462 | ||
253 | +461 | ||
254 | +464 | ||
255 | +469 | ||
256 | +468 | ||
257 | +472 | ||
258 | +473 | ||
259 | +471 | ||
260 | +475 | ||
261 | +474 | ||
262 | +477 | ||
263 | +485 | ||
264 | +491 | ||
265 | +488 | ||
266 | +482 | ||
267 | +490 | ||
268 | +496 | ||
269 | +494 | ||
270 | +483 | ||
271 | +495 | ||
272 | +493 | ||
273 | +507 | ||
274 | +501 | ||
275 | +499 | ||
276 | +503 | ||
277 | +498 | ||
278 | +514 | ||
279 | +504 | ||
280 | +502 | ||
281 | +506 | ||
282 | +508 | ||
283 | +511 | ||
284 | +527 | ||
285 | +526 | ||
286 | +532 | ||
287 | +513 | ||
288 | +519 | ||
289 | +525 | ||
290 | +518 | ||
291 | +528 | ||
292 | +522 | ||
293 | +523 | ||
294 | +535 | ||
295 | +539 | ||
296 | +540 | ||
297 | +533 | ||
298 | +521 | ||
299 | +541 | ||
300 | +547 | ||
301 | +550 | ||
302 | +544 | ||
303 | +549 | ||
304 | +551 | ||
305 | +554 | ||
306 | +543 | ||
307 | +548 | ||
308 | +557 | ||
309 | +560 | ||
310 | +552 | ||
311 | +559 | ||
312 | +563 | ||
313 | +565 | ||
314 | +567 | ||
315 | +555 | ||
316 | +576 | ||
317 | +568 | ||
318 | +564 | ||
319 | +573 | ||
320 | +581 | ||
321 | +580 | ||
322 | +572 | ||
323 | +571 | ||
324 | +584 | ||
325 | +590 | ||
326 | +585 | ||
327 | +587 | ||
328 | +588 | ||
329 | +592 | ||
330 | +598 | ||
331 | +597 | ||
332 | +599 | ||
333 | +603 | ||
334 | +600 | ||
335 | +604 | ||
336 | +605 | ||
337 | +614 | ||
338 | +602 | ||
339 | +610 | ||
340 | +608 | ||
341 | +611 | ||
342 | +612 | ||
343 | +613 | ||
344 | +617 | ||
345 | +620 | ||
346 | +607 | ||
347 | +624 | ||
348 | +627 | ||
349 | +625 | ||
350 | +631 | ||
351 | +629 | ||
352 | +638 | ||
353 | +632 | ||
354 | +634 | ||
355 | +644 | ||
356 | +641 | ||
357 | +642 | ||
358 | +646 | ||
359 | +652 | ||
360 | +647 | ||
361 | +637 | ||
362 | +661 | ||
363 | +635 | ||
364 | +658 | ||
365 | +648 | ||
366 | +663 | ||
367 | +668 | ||
368 | +664 | ||
369 | +656 | ||
370 | +666 | ||
371 | +671 | ||
372 | +683 | ||
373 | +675 | ||
374 | +669 | ||
375 | +676 | ||
376 | +667 | ||
377 | +691 | ||
378 | +685 | ||
379 | +673 | ||
380 | +688 | ||
381 | +702 | ||
382 | +684 | ||
383 | +679 | ||
384 | +694 | ||
385 | +686 | ||
386 | +689 | ||
387 | +680 | ||
388 | +693 | ||
389 | +703 | ||
390 | +697 | ||
391 | +698 | ||
392 | +692 | ||
393 | +705 | ||
394 | +706 | ||
395 | +712 | ||
396 | +711 | ||
397 | +709 | ||
398 | +710 | ||
399 | +726 | ||
400 | +713 | ||
401 | +721 | ||
402 | +720 | ||
403 | +715 | ||
404 | +717 | ||
405 | +730 | ||
406 | +728 | ||
407 | +723 | ||
408 | +716 | ||
409 | +722 | ||
410 | +718 | ||
411 | +732 | ||
412 | +724 | ||
413 | +736 | ||
414 | +725 | ||
415 | +742 | ||
416 | +727 | ||
417 | +735 | ||
418 | +740 | ||
419 | +748 | ||
420 | +738 | ||
421 | +746 | ||
422 | +751 | ||
423 | +749 | ||
424 | +752 | ||
425 | +754 | ||
426 | +760 | ||
427 | +763 | ||
428 | +756 | ||
429 | +758 | ||
430 | +766 | ||
431 | +764 | ||
432 | +757 | ||
433 | +780 | ||
434 | +767 | ||
435 | +769 | ||
436 | +771 | ||
437 | +786 | ||
438 | +785 | ||
439 | +781 | ||
440 | +787 | ||
441 | +778 | ||
442 | +783 | ||
443 | +792 | ||
444 | +791 | ||
445 | +795 | ||
446 | +788 | ||
447 | +805 | ||
448 | +802 | ||
449 | +801 | ||
450 | +793 | ||
451 | +796 | ||
452 | +804 | ||
453 | +803 | ||
454 | +797 | ||
455 | +814 | ||
456 | +813 | ||
457 | +789 | ||
458 | +808 | ||
459 | +818 | ||
460 | +816 | ||
461 | +817 | ||
462 | +811 | ||
463 | +820 | ||
464 | +826 | ||
465 | +829 | ||
466 | +824 | ||
467 | +821 | ||
468 | +825 | ||
469 | +822 | ||
470 | +835 | ||
471 | +833 | ||
472 | +843 | ||
473 | +823 | ||
474 | +827 | ||
475 | +830 | ||
476 | +832 | ||
477 | +837 | ||
478 | +852 | ||
479 | +844 | ||
480 | +841 | ||
481 | +812 | ||
482 | +847 | ||
483 | +862 | ||
484 | +869 | ||
485 | +860 | ||
486 | +838 | ||
487 | +870 | ||
488 | +846 | ||
489 | +858 | ||
490 | +854 | ||
491 | +880 | ||
492 | +876 | ||
493 | +857 | ||
494 | +859 | ||
495 | +877 | ||
496 | +871 | ||
497 | +855 | ||
498 | +875 | ||
499 | +861 | ||
500 | +867 | ||
501 | +892 | ||
502 | +898 | ||
503 | +888 | ||
504 | +884 | ||
505 | +887 | ||
506 | +891 | ||
507 | +906 | ||
508 | +900 | ||
509 | +878 | ||
510 | +885 | ||
511 | +883 | ||
512 | +901 | ||
513 | +903 | ||
514 | +907 | ||
515 | +930 | ||
516 | +897 | ||
517 | +914 | ||
518 | +917 | ||
519 | +910 | ||
520 | +905 | ||
521 | +909 | ||
522 | +933 | ||
523 | +932 | ||
524 | +922 | ||
525 | +913 | ||
526 | +923 | ||
527 | +931 | ||
528 | +911 | ||
529 | +937 | ||
530 | +918 | ||
531 | +955 | ||
532 | +915 | ||
533 | +944 | ||
534 | +952 | ||
535 | +945 | ||
536 | +948 | ||
537 | +946 | ||
538 | +970 | ||
539 | +974 | ||
540 | +958 | ||
541 | +925 | ||
542 | +979 | ||
543 | +942 | ||
544 | +965 | ||
545 | +975 | ||
546 | +950 | ||
547 | +982 | ||
548 | +940 | ||
549 | +973 | ||
550 | +962 | ||
551 | +972 | ||
552 | +957 | ||
553 | +984 | ||
554 | +983 | ||
555 | +964 | ||
556 | +1007 | ||
557 | +971 | ||
558 | +981 | ||
559 | +954 | ||
560 | +993 | ||
561 | +991 | ||
562 | +996 | ||
563 | +1005 | ||
564 | +1015 | ||
565 | +1009 | ||
566 | +995 | ||
567 | +986 | ||
568 | +1000 | ||
569 | +985 | ||
570 | +980 | ||
571 | +1016 | ||
572 | +1011 | ||
573 | +999 | ||
574 | +1002 | ||
575 | +994 | ||
576 | +1013 | ||
577 | +1010 | ||
578 | +992 | ||
579 | +1008 | ||
580 | +1036 | ||
581 | +1025 | ||
582 | +1012 | ||
583 | +990 | ||
584 | +1037 | ||
585 | +1040 | ||
586 | +1031 | ||
587 | +1019 | ||
588 | +1052 | ||
589 | +1001 | ||
590 | +1055 | ||
591 | +1032 | ||
592 | +1069 | ||
593 | +1058 | ||
594 | +1014 | ||
595 | +1023 | ||
596 | +1030 | ||
597 | +1061 | ||
598 | +1035 | ||
599 | +1034 | ||
600 | +1053 | ||
601 | +1045 | ||
602 | +1046 | ||
603 | +1067 | ||
604 | +1060 | ||
605 | +1049 | ||
606 | +1056 | ||
607 | +1074 | ||
608 | +1066 | ||
609 | +1044 | ||
610 | +1038 | ||
611 | +1073 | ||
612 | +1077 | ||
613 | +1068 | ||
614 | +1057 | ||
615 | +1072 | ||
616 | +1104 | ||
617 | +1083 | ||
618 | +1089 | ||
619 | +1087 | ||
620 | +1099 | ||
621 | +1076 | ||
622 | +1086 | ||
623 | +1098 | ||
624 | +1094 | ||
625 | +1095 | ||
626 | +1096 | ||
627 | +1101 | ||
628 | +1107 | ||
629 | +1105 | ||
630 | +1117 | ||
631 | +1093 | ||
632 | +1106 | ||
633 | +1122 | ||
634 | +1119 | ||
635 | +1103 | ||
636 | +1128 | ||
637 | +1120 | ||
638 | +1126 | ||
639 | +1102 | ||
640 | +1115 | ||
641 | +1124 | ||
642 | +1123 | ||
643 | +1131 | ||
644 | +1136 | ||
645 | +1144 | ||
646 | +1121 | ||
647 | +1137 | ||
648 | +1132 | ||
649 | +1133 | ||
650 | +1157 | ||
651 | +1134 | ||
652 | +1143 | ||
653 | +1159 | ||
654 | +1164 | ||
655 | +1155 | ||
656 | +1142 | ||
657 | +1150 | ||
658 | +1148 | ||
659 | +1161 | ||
660 | +1165 | ||
661 | +1147 | ||
662 | +1162 | ||
663 | +1152 | ||
664 | +1174 | ||
665 | +1160 | ||
666 | +1166 | ||
667 | +1190 | ||
668 | +1175 | ||
669 | +1167 | ||
670 | +1156 | ||
671 | +1180 | ||
672 | +1171 | ||
673 | +1179 | ||
674 | +1172 | ||
675 | +1186 | ||
676 | +1188 | ||
677 | +1201 | ||
678 | +1177 | ||
679 | +1208 | ||
680 | +1183 | ||
681 | +1189 | ||
682 | +1192 | ||
683 | +1209 | ||
684 | +1214 | ||
685 | +1197 | ||
686 | +1168 | ||
687 | +1202 | ||
688 | +1205 | ||
689 | +1203 | ||
690 | +1199 | ||
691 | +1219 | ||
692 | +1217 | ||
693 | +1187 | ||
694 | +1206 | ||
695 | +1210 | ||
696 | +1241 | ||
697 | +1221 | ||
698 | +1218 | ||
699 | +1223 | ||
700 | +1236 | ||
701 | +1212 | ||
702 | +1237 | ||
703 | +1195 | ||
704 | +1216 | ||
705 | +1247 | ||
706 | +1234 | ||
707 | +1240 | ||
708 | +1257 | ||
709 | +1224 | ||
710 | +1243 | ||
711 | +1259 | ||
712 | +1242 | ||
713 | +1282 | ||
714 | +1222 | ||
715 | +1254 | ||
716 | +1227 | ||
717 | +1235 | ||
718 | +1269 | ||
719 | +1258 | ||
720 | +1290 | ||
721 | +1275 | ||
722 | +1262 | ||
723 | +1252 | ||
724 | +1248 | ||
725 | +1272 | ||
726 | +1246 | ||
727 | +1225 | ||
728 | +1245 | ||
729 | +1277 | ||
730 | +1298 | ||
731 | +1288 | ||
732 | +1271 | ||
733 | +1265 | ||
734 | +1286 | ||
735 | +1260 | ||
736 | +1266 | ||
737 | +1296 | ||
738 | +1280 | ||
739 | +1285 | ||
740 | +1293 | ||
741 | +1276 | ||
742 | +1287 | ||
743 | +1289 | ||
744 | +1261 | ||
745 | +1264 | ||
746 | +1295 | ||
747 | +1291 | ||
748 | +1283 | ||
749 | +1311 | ||
750 | +1303 | ||
751 | +1330 | ||
752 | +1315 | ||
753 | +1300 | ||
754 | +1333 | ||
755 | +1307 | ||
756 | +1325 | ||
757 | +1334 | ||
758 | +1316 | ||
759 | +1314 | ||
760 | +1317 | ||
761 | +1310 | ||
762 | +1329 | ||
763 | +1324 | ||
764 | +1339 | ||
765 | +1346 | ||
766 | +1342 | ||
767 | +1352 | ||
768 | +1321 | ||
769 | +1376 | ||
770 | +1366 | ||
771 | +1308 | ||
772 | +1345 | ||
773 | +1348 | ||
774 | +1386 | ||
775 | +1383 | ||
776 | +1372 | ||
777 | +1367 | ||
778 | +1400 | ||
779 | +1382 | ||
780 | +1375 | ||
781 | +1392 | ||
782 | +1380 | ||
783 | +1371 | ||
784 | +1393 | ||
785 | +1389 | ||
786 | +1353 | ||
787 | +1387 | ||
788 | +1374 | ||
789 | +1379 | ||
790 | +1381 | ||
791 | +1359 | ||
792 | +1360 | ||
793 | +1396 | ||
794 | +1399 | ||
795 | +1365 | ||
796 | +1424 | ||
797 | +1373 | ||
798 | +1411 | ||
799 | +1401 | ||
800 | +1397 | ||
801 | +1395 | ||
802 | +1412 | ||
803 | +1394 | ||
804 | +1368 | ||
805 | +1423 | ||
806 | +1391 | ||
807 | +1435 | ||
808 | +1409 | ||
809 | +1443 | ||
810 | +1402 | ||
811 | +1425 | ||
812 | +1415 | ||
813 | +1421 | ||
814 | +1426 | ||
815 | +1433 | ||
816 | +1420 | ||
817 | +1452 | ||
818 | +1436 | ||
819 | +1430 | ||
820 | +1408 | ||
821 | +1458 | ||
822 | +1429 | ||
823 | +1453 | ||
824 | +1454 | ||
825 | +1447 | ||
826 | +1472 | ||
827 | +1486 | ||
828 | +1468 | ||
829 | +1461 | ||
830 | +1467 | ||
831 | +1484 | ||
832 | +1457 | ||
833 | +1444 | ||
834 | +1450 | ||
835 | +1451 | ||
836 | +1459 | ||
837 | +1462 | ||
838 | +1449 | ||
839 | +1476 | ||
840 | +1470 | ||
841 | +1471 | ||
842 | +1498 | ||
843 | +1488 | ||
844 | +1442 | ||
845 | +1480 | ||
846 | +1456 | ||
847 | +1466 | ||
848 | +1505 | ||
849 | +1517 | ||
850 | +1464 | ||
851 | +1503 | ||
852 | +1490 | ||
853 | +1519 | ||
854 | +1481 | ||
855 | +1493 | ||
856 | +1463 | ||
857 | +1532 | ||
858 | +1487 | ||
859 | +1501 | ||
860 | +1500 | ||
861 | +1495 | ||
862 | +1509 | ||
863 | +1535 | ||
864 | +1506 | ||
865 | +1521 | ||
866 | +1580 | ||
867 | +1540 | ||
868 | +1502 | ||
869 | +1520 | ||
870 | +1496 | ||
871 | +1569 | ||
872 | +1515 | ||
873 | +1489 | ||
874 | +1507 | ||
875 | +1527 | ||
876 | +1545 | ||
877 | +1560 | ||
878 | +1510 | ||
879 | +1514 | ||
880 | +1526 | ||
881 | +1594 | ||
882 | +1511 | ||
883 | +1572 | ||
884 | +1548 | ||
885 | +1584 | ||
886 | +1556 | ||
887 | +1588 | ||
888 | +1628 | ||
889 | +1555 | ||
890 | +1568 | ||
891 | +1550 | ||
892 | +1622 | ||
893 | +1563 | ||
894 | +1603 | ||
895 | +1616 | ||
896 | +1576 | ||
897 | +1549 | ||
898 | +1537 | ||
899 | +1593 | ||
900 | +1618 | ||
901 | +1645 | ||
902 | +1624 | ||
903 | +1617 | ||
904 | +1634 | ||
905 | +1595 | ||
906 | +1597 | ||
907 | +1590 | ||
908 | +1632 | ||
909 | +1575 | ||
910 | +1559 | ||
911 | +1625 | ||
912 | +1615 | ||
913 | +1591 | ||
914 | +1630 | ||
915 | +1608 | ||
916 | +1621 | ||
917 | +1589 | ||
918 | +1646 | ||
919 | +1643 | ||
920 | +1652 | ||
921 | +1627 | ||
922 | +1611 | ||
923 | +1626 | ||
924 | +1613 | ||
925 | +1639 | ||
926 | +1655 | ||
927 | +1620 | ||
928 | +1602 | ||
929 | +1651 | ||
930 | +1653 | ||
931 | +1669 | ||
932 | +1638 | ||
933 | +1696 | ||
934 | +1649 | ||
935 | +1675 | ||
936 | +1660 | ||
937 | +1683 | ||
938 | +1666 | ||
939 | +1671 | ||
940 | +1703 | ||
941 | +1716 | ||
942 | +1637 | ||
943 | +1672 | ||
944 | +1676 | ||
945 | +1692 | ||
946 | +1711 | ||
947 | +1680 | ||
948 | +1641 | ||
949 | +1688 | ||
950 | +1708 | ||
951 | +1704 | ||
952 | +1690 | ||
953 | +1674 | ||
954 | +1718 | ||
955 | +1699 | ||
956 | +1723 | ||
957 | +1756 | ||
958 | +1700 | ||
959 | +1662 | ||
960 | +1715 | ||
961 | +1657 | ||
962 | +1733 | ||
963 | +1728 | ||
964 | +1670 | ||
965 | +1712 | ||
966 | +1685 | ||
967 | +1724 | ||
968 | +1735 | ||
969 | +1714 | ||
970 | +1730 | ||
971 | +1747 | ||
972 | +1656 | ||
973 | +1737 | ||
974 | +1705 | ||
975 | +1693 | ||
976 | +1713 | ||
977 | +1689 | ||
978 | +1753 | ||
979 | +1739 | ||
980 | +1721 | ||
981 | +1725 | ||
982 | +1749 | ||
983 | +1732 | ||
984 | +1743 | ||
985 | +1731 | ||
986 | +1767 | ||
987 | +1738 | ||
988 | +1831 | ||
989 | +1771 | ||
990 | +1726 | ||
991 | +1746 | ||
992 | +1776 | ||
993 | +1775 | ||
994 | +1799 | ||
995 | +1774 | ||
996 | +1780 | ||
997 | +1781 | ||
998 | +1769 | ||
999 | +1805 | ||
1000 | +1788 | ||
1001 | +1801 |
yt8m/train.py
0 → 100644
This diff is collapsed. Click to expand it.
yt8m/utils.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Contains a collection of util functions for training and evaluating.""" | ||
15 | + | ||
16 | +import numpy | ||
17 | +import tensorflow as tf | ||
18 | +from tensorflow import logging | ||
19 | + | ||
20 | +try: | ||
21 | + xrange # Python 2 | ||
22 | +except NameError: | ||
23 | + xrange = range # Python 3 | ||
24 | + | ||
25 | + | ||
26 | +def Dequantize(feat_vector, max_quantized_value=2, min_quantized_value=-2): | ||
27 | + """Dequantize the feature from the byte format to the float format. | ||
28 | + | ||
29 | + Args: | ||
30 | + feat_vector: the input 1-d vector. | ||
31 | + max_quantized_value: the maximum of the quantized value. | ||
32 | + min_quantized_value: the minimum of the quantized value. | ||
33 | + | ||
34 | + Returns: | ||
35 | + A float vector which has the same shape as feat_vector. | ||
36 | + """ | ||
37 | + assert max_quantized_value > min_quantized_value | ||
38 | + quantized_range = max_quantized_value - min_quantized_value | ||
39 | + scalar = quantized_range / 255.0 | ||
40 | + bias = (quantized_range / 512.0) + min_quantized_value | ||
41 | + return feat_vector * scalar + bias | ||
42 | + | ||
43 | + | ||
44 | +def MakeSummary(name, value): | ||
45 | + """Creates a tf.Summary proto with the given name and value.""" | ||
46 | + summary = tf.Summary() | ||
47 | + val = summary.value.add() | ||
48 | + val.tag = str(name) | ||
49 | + val.simple_value = float(value) | ||
50 | + return summary | ||
51 | + | ||
52 | + | ||
53 | +def AddGlobalStepSummary(summary_writer, | ||
54 | + global_step_val, | ||
55 | + global_step_info_dict, | ||
56 | + summary_scope="Eval"): | ||
57 | + """Add the global_step summary to the Tensorboard. | ||
58 | + | ||
59 | + Args: | ||
60 | + summary_writer: Tensorflow summary_writer. | ||
61 | + global_step_val: a int value of the global step. | ||
62 | + global_step_info_dict: a dictionary of the evaluation metrics calculated for | ||
63 | + a mini-batch. | ||
64 | + summary_scope: Train or Eval. | ||
65 | + | ||
66 | + Returns: | ||
67 | + A string of this global_step summary | ||
68 | + """ | ||
69 | + this_hit_at_one = global_step_info_dict["hit_at_one"] | ||
70 | + this_perr = global_step_info_dict["perr"] | ||
71 | + this_loss = global_step_info_dict["loss"] | ||
72 | + examples_per_second = global_step_info_dict.get("examples_per_second", -1) | ||
73 | + | ||
74 | + summary_writer.add_summary( | ||
75 | + MakeSummary("GlobalStep/" + summary_scope + "_Hit@1", this_hit_at_one), | ||
76 | + global_step_val) | ||
77 | + summary_writer.add_summary( | ||
78 | + MakeSummary("GlobalStep/" + summary_scope + "_Perr", this_perr), | ||
79 | + global_step_val) | ||
80 | + summary_writer.add_summary( | ||
81 | + MakeSummary("GlobalStep/" + summary_scope + "_Loss", this_loss), | ||
82 | + global_step_val) | ||
83 | + | ||
84 | + if examples_per_second != -1: | ||
85 | + summary_writer.add_summary( | ||
86 | + MakeSummary("GlobalStep/" + summary_scope + "_Example_Second", | ||
87 | + examples_per_second), global_step_val) | ||
88 | + | ||
89 | + summary_writer.flush() | ||
90 | + info = ( | ||
91 | + "global_step {0} | Batch Hit@1: {1:.3f} | Batch PERR: {2:.3f} | Batch " | ||
92 | + "Loss: {3:.3f} | Examples_per_sec: {4:.3f}").format( | ||
93 | + global_step_val, this_hit_at_one, this_perr, this_loss, | ||
94 | + examples_per_second) | ||
95 | + return info | ||
96 | + | ||
97 | + | ||
98 | +def AddEpochSummary(summary_writer, | ||
99 | + global_step_val, | ||
100 | + epoch_info_dict, | ||
101 | + summary_scope="Eval"): | ||
102 | + """Add the epoch summary to the Tensorboard. | ||
103 | + | ||
104 | + Args: | ||
105 | + summary_writer: Tensorflow summary_writer. | ||
106 | + global_step_val: a int value of the global step. | ||
107 | + epoch_info_dict: a dictionary of the evaluation metrics calculated for the | ||
108 | + whole epoch. | ||
109 | + summary_scope: Train or Eval. | ||
110 | + | ||
111 | + Returns: | ||
112 | + A string of this global_step summary | ||
113 | + """ | ||
114 | + epoch_id = epoch_info_dict["epoch_id"] | ||
115 | + avg_hit_at_one = epoch_info_dict["avg_hit_at_one"] | ||
116 | + avg_perr = epoch_info_dict["avg_perr"] | ||
117 | + avg_loss = epoch_info_dict["avg_loss"] | ||
118 | + aps = epoch_info_dict["aps"] | ||
119 | + gap = epoch_info_dict["gap"] | ||
120 | + mean_ap = numpy.mean(aps) | ||
121 | + | ||
122 | + summary_writer.add_summary( | ||
123 | + MakeSummary("Epoch/" + summary_scope + "_Avg_Hit@1", avg_hit_at_one), | ||
124 | + global_step_val) | ||
125 | + summary_writer.add_summary( | ||
126 | + MakeSummary("Epoch/" + summary_scope + "_Avg_Perr", avg_perr), | ||
127 | + global_step_val) | ||
128 | + summary_writer.add_summary( | ||
129 | + MakeSummary("Epoch/" + summary_scope + "_Avg_Loss", avg_loss), | ||
130 | + global_step_val) | ||
131 | + summary_writer.add_summary( | ||
132 | + MakeSummary("Epoch/" + summary_scope + "_MAP", mean_ap), global_step_val) | ||
133 | + summary_writer.add_summary( | ||
134 | + MakeSummary("Epoch/" + summary_scope + "_GAP", gap), global_step_val) | ||
135 | + summary_writer.flush() | ||
136 | + | ||
137 | + info = ("epoch/eval number {0} | Avg_Hit@1: {1:.3f} | Avg_PERR: {2:.3f} " | ||
138 | + "| MAP: {3:.3f} | GAP: {4:.3f} | Avg_Loss: {5:3f} | num_classes: {6}" | ||
139 | + ).format(epoch_id, avg_hit_at_one, avg_perr, mean_ap, gap, avg_loss, | ||
140 | + len(aps)) | ||
141 | + return info | ||
142 | + | ||
143 | + | ||
144 | +def GetListOfFeatureNamesAndSizes(feature_names, feature_sizes): | ||
145 | + """Extract the list of feature names and the dimensionality of each feature | ||
146 | + | ||
147 | + from string of comma separated values. | ||
148 | + | ||
149 | + Args: | ||
150 | + feature_names: string containing comma separated list of feature names | ||
151 | + feature_sizes: string containing comma separated list of feature sizes | ||
152 | + | ||
153 | + Returns: | ||
154 | + List of the feature names and list of the dimensionality of each feature. | ||
155 | + Elements in the first/second list are strings/integers. | ||
156 | + """ | ||
157 | + list_of_feature_names = [ | ||
158 | + feature_names.strip() for feature_names in feature_names.split(",") | ||
159 | + ] | ||
160 | + list_of_feature_sizes = [ | ||
161 | + int(feature_sizes) for feature_sizes in feature_sizes.split(",") | ||
162 | + ] | ||
163 | + if len(list_of_feature_names) != len(list_of_feature_sizes): | ||
164 | + logging.error("length of the feature names (=" + | ||
165 | + str(len(list_of_feature_names)) + ") != length of feature " | ||
166 | + "sizes (=" + str(len(list_of_feature_sizes)) + ")") | ||
167 | + | ||
168 | + return list_of_feature_names, list_of_feature_sizes | ||
169 | + | ||
170 | + | ||
171 | +def clip_gradient_norms(gradients_to_variables, max_norm): | ||
172 | + """Clips the gradients by the given value. | ||
173 | + | ||
174 | + Args: | ||
175 | + gradients_to_variables: A list of gradient to variable pairs (tuples). | ||
176 | + max_norm: the maximum norm value. | ||
177 | + | ||
178 | + Returns: | ||
179 | + A list of clipped gradient to variable pairs. | ||
180 | + """ | ||
181 | + clipped_grads_and_vars = [] | ||
182 | + for grad, var in gradients_to_variables: | ||
183 | + if grad is not None: | ||
184 | + if isinstance(grad, tf.IndexedSlices): | ||
185 | + tmp = tf.clip_by_norm(grad.values, max_norm) | ||
186 | + grad = tf.IndexedSlices(tmp, grad.indices, grad.dense_shape) | ||
187 | + else: | ||
188 | + grad = tf.clip_by_norm(grad, max_norm) | ||
189 | + clipped_grads_and_vars.append((grad, var)) | ||
190 | + return clipped_grads_and_vars | ||
191 | + | ||
192 | + | ||
193 | +def combine_gradients(tower_grads): | ||
194 | + """Calculate the combined gradient for each shared variable across all towers. | ||
195 | + | ||
196 | + Note that this function provides a synchronization point across all towers. | ||
197 | + | ||
198 | + Args: | ||
199 | + tower_grads: List of lists of (gradient, variable) tuples. The outer list is | ||
200 | + over individual gradients. The inner list is over the gradient calculation | ||
201 | + for each tower. | ||
202 | + | ||
203 | + Returns: | ||
204 | + List of pairs of (gradient, variable) where the gradient has been summed | ||
205 | + across all towers. | ||
206 | + """ | ||
207 | + filtered_grads = [ | ||
208 | + [x for x in grad_list if x[0] is not None] for grad_list in tower_grads | ||
209 | + ] | ||
210 | + final_grads = [] | ||
211 | + for i in xrange(len(filtered_grads[0])): | ||
212 | + grads = [filtered_grads[t][i] for t in xrange(len(filtered_grads))] | ||
213 | + grad = tf.stack([x[0] for x in grads], 0) | ||
214 | + grad = tf.reduce_sum(grad, 0) | ||
215 | + final_grads.append(( | ||
216 | + grad, | ||
217 | + filtered_grads[0][i][1], | ||
218 | + )) | ||
219 | + | ||
220 | + return final_grads |
yt8m/video_level_models.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Contains model definitions.""" | ||
15 | +import math | ||
16 | + | ||
17 | +import models | ||
18 | +import tensorflow as tf | ||
19 | +import utils | ||
20 | + | ||
21 | +from tensorflow import flags | ||
22 | +import tensorflow.contrib.slim as slim | ||
23 | + | ||
24 | +FLAGS = flags.FLAGS | ||
25 | +flags.DEFINE_integer( | ||
26 | + "moe_num_mixtures", 2, | ||
27 | + "The number of mixtures (excluding the dummy 'expert') used for MoeModel.") | ||
28 | + | ||
29 | + | ||
30 | +class LogisticModel(models.BaseModel): | ||
31 | + """Logistic model with L2 regularization.""" | ||
32 | + | ||
33 | + def create_model(self, | ||
34 | + model_input, | ||
35 | + vocab_size, | ||
36 | + l2_penalty=1e-8, | ||
37 | + **unused_params): | ||
38 | + """Creates a logistic model. | ||
39 | + | ||
40 | + Args: | ||
41 | + model_input: 'batch' x 'num_features' matrix of input features. | ||
42 | + vocab_size: The number of classes in the dataset. | ||
43 | + | ||
44 | + Returns: | ||
45 | + A dictionary with a tensor containing the probability predictions of the | ||
46 | + model in the 'predictions' key. The dimensions of the tensor are | ||
47 | + batch_size x num_classes. | ||
48 | + """ | ||
49 | + output = slim.fully_connected( | ||
50 | + model_input, | ||
51 | + vocab_size, | ||
52 | + activation_fn=tf.nn.sigmoid, | ||
53 | + weights_regularizer=slim.l2_regularizer(l2_penalty)) | ||
54 | + return {"predictions": output} | ||
55 | + | ||
56 | + | ||
57 | +class MoeModel(models.BaseModel): | ||
58 | + """A softmax over a mixture of logistic models (with L2 regularization).""" | ||
59 | + | ||
60 | + def create_model(self, | ||
61 | + model_input, | ||
62 | + vocab_size, | ||
63 | + num_mixtures=None, | ||
64 | + l2_penalty=1e-8, | ||
65 | + **unused_params): | ||
66 | + """Creates a Mixture of (Logistic) Experts model. | ||
67 | + | ||
68 | + The model consists of a per-class softmax distribution over a | ||
69 | + configurable number of logistic classifiers. One of the classifiers in the | ||
70 | + mixture is not trained, and always predicts 0. | ||
71 | + | ||
72 | + Args: | ||
73 | + model_input: 'batch_size' x 'num_features' matrix of input features. | ||
74 | + vocab_size: The number of classes in the dataset. | ||
75 | + num_mixtures: The number of mixtures (excluding a dummy 'expert' that | ||
76 | + always predicts the non-existence of an entity). | ||
77 | + l2_penalty: How much to penalize the squared magnitudes of parameter | ||
78 | + values. | ||
79 | + | ||
80 | + Returns: | ||
81 | + A dictionary with a tensor containing the probability predictions of the | ||
82 | + model in the 'predictions' key. The dimensions of the tensor are | ||
83 | + batch_size x num_classes. | ||
84 | + """ | ||
85 | + num_mixtures = num_mixtures or FLAGS.moe_num_mixtures | ||
86 | + | ||
87 | + gate_activations = slim.fully_connected( | ||
88 | + model_input, | ||
89 | + vocab_size * (num_mixtures + 1), | ||
90 | + activation_fn=None, | ||
91 | + biases_initializer=None, | ||
92 | + weights_regularizer=slim.l2_regularizer(l2_penalty), | ||
93 | + scope="gates") | ||
94 | + expert_activations = slim.fully_connected( | ||
95 | + model_input, | ||
96 | + vocab_size * num_mixtures, | ||
97 | + activation_fn=None, | ||
98 | + weights_regularizer=slim.l2_regularizer(l2_penalty), | ||
99 | + scope="experts") | ||
100 | + | ||
101 | + gating_distribution = tf.nn.softmax( | ||
102 | + tf.reshape( | ||
103 | + gate_activations, | ||
104 | + [-1, num_mixtures + 1])) # (Batch * #Labels) x (num_mixtures + 1) | ||
105 | + expert_distribution = tf.nn.sigmoid( | ||
106 | + tf.reshape(expert_activations, | ||
107 | + [-1, num_mixtures])) # (Batch * #Labels) x num_mixtures | ||
108 | + | ||
109 | + final_probabilities_by_class_and_batch = tf.reduce_sum( | ||
110 | + gating_distribution[:, :num_mixtures] * expert_distribution, 1) | ||
111 | + final_probabilities = tf.reshape(final_probabilities_by_class_and_batch, | ||
112 | + [-1, vocab_size]) | ||
113 | + return {"predictions": final_probabilities} |
-
Please register or login to post a comment