이현규

Convert submodule to files

youtube-8m @ e6f6bf68
1 -Subproject commit e6f6bf682d20bb21904ea9c081c15e070809d914
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Calculate or keep track of the interpolated average precision.
15 +
16 +It provides an interface for calculating interpolated average precision for an
17 +entire list or the top-n ranked items. For the definition of the
18 +(non-)interpolated average precision:
19 +http://trec.nist.gov/pubs/trec15/appendices/CE.MEASURES06.pdf
20 +
21 +Example usages:
22 +1) Use it as a static function call to directly calculate average precision for
23 +a short ranked list in the memory.
24 +
25 +```
26 +import random
27 +
28 +p = np.array([random.random() for _ in xrange(10)])
29 +a = np.array([random.choice([0, 1]) for _ in xrange(10)])
30 +
31 +ap = average_precision_calculator.AveragePrecisionCalculator.ap(p, a)
32 +```
33 +
34 +2) Use it as an object for long ranked list that cannot be stored in memory or
35 +the case where partial predictions can be observed at a time (Tensorflow
36 +predictions). In this case, we first call the function accumulate many times
37 +to process parts of the ranked list. After processing all the parts, we call
38 +peek_interpolated_ap_at_n.
39 +```
40 +p1 = np.array([random.random() for _ in xrange(5)])
41 +a1 = np.array([random.choice([0, 1]) for _ in xrange(5)])
42 +p2 = np.array([random.random() for _ in xrange(5)])
43 +a2 = np.array([random.choice([0, 1]) for _ in xrange(5)])
44 +
45 +# interpolated average precision at 10 using 1000 break points
46 +calculator = average_precision_calculator.AveragePrecisionCalculator(10)
47 +calculator.accumulate(p1, a1)
48 +calculator.accumulate(p2, a2)
49 +ap3 = calculator.peek_ap_at_n()
50 +```
51 +"""
52 +
53 +import heapq
54 +import random
55 +import numbers
56 +
57 +import numpy
58 +
59 +
60 +class AveragePrecisionCalculator(object):
61 + """Calculate the average precision and average precision at n."""
62 +
63 + def __init__(self, top_n=None):
64 + """Construct an AveragePrecisionCalculator to calculate average precision.
65 +
66 + This class is used to calculate the average precision for a single label.
67 +
68 + Args:
69 + top_n: A positive Integer specifying the average precision at n, or None
70 + to use all provided data points.
71 +
72 + Raises:
73 + ValueError: An error occurred when the top_n is not a positive integer.
74 + """
75 + if not ((isinstance(top_n, int) and top_n >= 0) or top_n is None):
76 + raise ValueError("top_n must be a positive integer or None.")
77 +
78 + self._top_n = top_n # average precision at n
79 + self._total_positives = 0 # total number of positives have seen
80 + self._heap = [] # max heap of (prediction, actual)
81 +
82 + @property
83 + def heap_size(self):
84 + """Gets the heap size maintained in the class."""
85 + return len(self._heap)
86 +
87 + @property
88 + def num_accumulated_positives(self):
89 + """Gets the number of positive samples that have been accumulated."""
90 + return self._total_positives
91 +
92 + def accumulate(self, predictions, actuals, num_positives=None):
93 + """Accumulate the predictions and their ground truth labels.
94 +
95 + After the function call, we may call peek_ap_at_n to actually calculate
96 + the average precision.
97 + Note predictions and actuals must have the same shape.
98 +
99 + Args:
100 + predictions: a list storing the prediction scores.
101 + actuals: a list storing the ground truth labels. Any value larger than 0
102 + will be treated as positives, otherwise as negatives. num_positives = If
103 + the 'predictions' and 'actuals' inputs aren't complete, then it's
104 + possible some true positives were missed in them. In that case, you can
105 + provide 'num_positives' in order to accurately track recall.
106 +
107 + Raises:
108 + ValueError: An error occurred when the format of the input is not the
109 + numpy 1-D array or the shape of predictions and actuals does not match.
110 + """
111 + if len(predictions) != len(actuals):
112 + raise ValueError("the shape of predictions and actuals does not match.")
113 +
114 + if num_positives is not None:
115 + if not isinstance(num_positives, numbers.Number) or num_positives < 0:
116 + raise ValueError(
117 + "'num_positives' was provided but it was a negative number.")
118 +
119 + if num_positives is not None:
120 + self._total_positives += num_positives
121 + else:
122 + self._total_positives += numpy.size(
123 + numpy.where(numpy.array(actuals) > 1e-5))
124 + topk = self._top_n
125 + heap = self._heap
126 +
127 + for i in range(numpy.size(predictions)):
128 + if topk is None or len(heap) < topk:
129 + heapq.heappush(heap, (predictions[i], actuals[i]))
130 + else:
131 + if predictions[i] > heap[0][0]: # heap[0] is the smallest
132 + heapq.heappop(heap)
133 + heapq.heappush(heap, (predictions[i], actuals[i]))
134 +
135 + def clear(self):
136 + """Clear the accumulated predictions."""
137 + self._heap = []
138 + self._total_positives = 0
139 +
140 + def peek_ap_at_n(self):
141 + """Peek the non-interpolated average precision at n.
142 +
143 + Returns:
144 + The non-interpolated average precision at n (default 0).
145 + If n is larger than the length of the ranked list,
146 + the average precision will be returned.
147 + """
148 + if self.heap_size <= 0:
149 + return 0
150 + predlists = numpy.array(list(zip(*self._heap)))
151 +
152 + ap = self.ap_at_n(predlists[0],
153 + predlists[1],
154 + n=self._top_n,
155 + total_num_positives=self._total_positives)
156 + return ap
157 +
158 + @staticmethod
159 + def ap(predictions, actuals):
160 + """Calculate the non-interpolated average precision.
161 +
162 + Args:
163 + predictions: a numpy 1-D array storing the sparse prediction scores.
164 + actuals: a numpy 1-D array storing the ground truth labels. Any value
165 + larger than 0 will be treated as positives, otherwise as negatives.
166 +
167 + Returns:
168 + The non-interpolated average precision at n.
169 + If n is larger than the length of the ranked list,
170 + the average precision will be returned.
171 +
172 + Raises:
173 + ValueError: An error occurred when the format of the input is not the
174 + numpy 1-D array or the shape of predictions and actuals does not match.
175 + """
176 + return AveragePrecisionCalculator.ap_at_n(predictions, actuals, n=None)
177 +
178 + @staticmethod
179 + def ap_at_n(predictions, actuals, n=20, total_num_positives=None):
180 + """Calculate the non-interpolated average precision.
181 +
182 + Args:
183 + predictions: a numpy 1-D array storing the sparse prediction scores.
184 + actuals: a numpy 1-D array storing the ground truth labels. Any value
185 + larger than 0 will be treated as positives, otherwise as negatives.
186 + n: the top n items to be considered in ap@n.
187 + total_num_positives : (optionally) you can specify the number of total
188 + positive in the list. If specified, it will be used in calculation.
189 +
190 + Returns:
191 + The non-interpolated average precision at n.
192 + If n is larger than the length of the ranked list,
193 + the average precision will be returned.
194 +
195 + Raises:
196 + ValueError: An error occurred when
197 + 1) the format of the input is not the numpy 1-D array;
198 + 2) the shape of predictions and actuals does not match;
199 + 3) the input n is not a positive integer.
200 + """
201 + if len(predictions) != len(actuals):
202 + raise ValueError("the shape of predictions and actuals does not match.")
203 +
204 + if n is not None:
205 + if not isinstance(n, int) or n <= 0:
206 + raise ValueError("n must be 'None' or a positive integer."
207 + " It was '%s'." % n)
208 +
209 + ap = 0.0
210 +
211 + predictions = numpy.array(predictions)
212 + actuals = numpy.array(actuals)
213 +
214 + # add a shuffler to avoid overestimating the ap
215 + predictions, actuals = AveragePrecisionCalculator._shuffle(
216 + predictions, actuals)
217 + sortidx = sorted(range(len(predictions)),
218 + key=lambda k: predictions[k],
219 + reverse=True)
220 +
221 + if total_num_positives is None:
222 + numpos = numpy.size(numpy.where(actuals > 0))
223 + else:
224 + numpos = total_num_positives
225 +
226 + if numpos == 0:
227 + return 0
228 +
229 + if n is not None:
230 + numpos = min(numpos, n)
231 + delta_recall = 1.0 / numpos
232 + poscount = 0.0
233 +
234 + # calculate the ap
235 + r = len(sortidx)
236 + if n is not None:
237 + r = min(r, n)
238 + for i in range(r):
239 + if actuals[sortidx[i]] > 0:
240 + poscount += 1
241 + ap += poscount / (i + 1) * delta_recall
242 + return ap
243 +
244 + @staticmethod
245 + def _shuffle(predictions, actuals):
246 + random.seed(0)
247 + suffidx = random.sample(range(len(predictions)), len(predictions))
248 + predictions = predictions[suffidx]
249 + actuals = actuals[suffidx]
250 + return predictions, actuals
251 +
252 + @staticmethod
253 + def _zero_one_normalize(predictions, epsilon=1e-7):
254 + """Normalize the predictions to the range between 0.0 and 1.0.
255 +
256 + For some predictions like SVM predictions, we need to normalize them before
257 + calculate the interpolated average precision. The normalization will not
258 + change the rank in the original list and thus won't change the average
259 + precision.
260 +
261 + Args:
262 + predictions: a numpy 1-D array storing the sparse prediction scores.
263 + epsilon: a small constant to avoid denominator being zero.
264 +
265 + Returns:
266 + The normalized prediction.
267 + """
268 + denominator = numpy.max(predictions) - numpy.min(predictions)
269 + ret = (predictions - numpy.min(predictions)) / numpy.max(
270 + denominator, epsilon)
271 + return ret
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Utility to convert the output of batch prediction into a CSV submission.
15 +
16 +It converts the JSON files created by the command
17 +'gcloud beta ml jobs submit prediction' into a CSV file ready for submission.
18 +"""
19 +
20 +import json
21 +import tensorflow as tf
22 +
23 +from builtins import range
24 +from tensorflow import app
25 +from tensorflow import flags
26 +from tensorflow import gfile
27 +from tensorflow import logging
28 +
29 +FLAGS = flags.FLAGS
30 +
31 +if __name__ == "__main__":
32 +
33 + flags.DEFINE_string(
34 + "json_prediction_files_pattern", None,
35 + "Pattern specifying the list of JSON files that the command "
36 + "'gcloud beta ml jobs submit prediction' outputs. These files are "
37 + "located in the output path of the prediction command and are prefixed "
38 + "with 'prediction.results'.")
39 + flags.DEFINE_string(
40 + "csv_output_file", None,
41 + "The file to save the predictions converted to the CSV format.")
42 +
43 +
44 +def get_csv_header():
45 + return "VideoId,LabelConfidencePairs\n"
46 +
47 +
48 +def to_csv_row(json_data):
49 +
50 + video_id = json_data["video_id"]
51 +
52 + class_indexes = json_data["class_indexes"]
53 + predictions = json_data["predictions"]
54 +
55 + if isinstance(video_id, list):
56 + video_id = video_id[0]
57 + class_indexes = class_indexes[0]
58 + predictions = predictions[0]
59 +
60 + if len(class_indexes) != len(predictions):
61 + raise ValueError(
62 + "The number of indexes (%s) and predictions (%s) must be equal." %
63 + (len(class_indexes), len(predictions)))
64 +
65 + return (video_id.decode("utf-8") + "," +
66 + " ".join("%i %f" % (class_indexes[i], predictions[i])
67 + for i in range(len(class_indexes))) + "\n")
68 +
69 +
70 +def main(unused_argv):
71 + logging.set_verbosity(tf.logging.INFO)
72 +
73 + if not FLAGS.json_prediction_files_pattern:
74 + raise ValueError(
75 + "The flag --json_prediction_files_pattern must be specified.")
76 +
77 + if not FLAGS.csv_output_file:
78 + raise ValueError("The flag --csv_output_file must be specified.")
79 +
80 + logging.info("Looking for prediction files with pattern: %s",
81 + FLAGS.json_prediction_files_pattern)
82 +
83 + file_paths = gfile.Glob(FLAGS.json_prediction_files_pattern)
84 + logging.info("Found files: %s", file_paths)
85 +
86 + logging.info("Writing submission file to: %s", FLAGS.csv_output_file)
87 + with gfile.Open(FLAGS.csv_output_file, "w+") as output_file:
88 + output_file.write(get_csv_header())
89 +
90 + for file_path in file_paths:
91 + logging.info("processing file: %s", file_path)
92 +
93 + with gfile.Open(file_path) as input_file:
94 +
95 + for line in input_file:
96 + json_data = json.loads(line)
97 + output_file.write(to_csv_row(json_data))
98 +
99 + output_file.flush()
100 + logging.info("done")
101 +
102 +
103 +if __name__ == "__main__":
104 + app.run()
No preview for this file type
1 +import numpy as np
2 +import tensorflow as tf
3 +from tensorflow import logging
4 +from tensorflow import gfile
5 +import esot3ria.pbutil as pbutil
6 +
7 +
8 +def get_segments(batch_video_mtx, batch_num_frames, segment_size):
9 + """Get segment-level inputs from frame-level features."""
10 + video_batch_size = batch_video_mtx.shape[0]
11 + max_frame = batch_video_mtx.shape[1]
12 + feature_dim = batch_video_mtx.shape[-1]
13 + padded_segment_sizes = (batch_num_frames + segment_size - 1) // segment_size
14 + padded_segment_sizes *= segment_size
15 + segment_mask = (
16 + 0 < (padded_segment_sizes[:, np.newaxis] - np.arange(0, max_frame)))
17 +
18 + # Segment bags.
19 + frame_bags = batch_video_mtx.reshape((-1, feature_dim))
20 + segment_frames = frame_bags[segment_mask.reshape(-1)].reshape(
21 + (-1, segment_size, feature_dim))
22 +
23 + # Segment num frames.
24 + segment_start_times = np.arange(0, max_frame, segment_size)
25 + num_segments = batch_num_frames[:, np.newaxis] - segment_start_times
26 + num_segment_bags = num_segments.reshape((-1))
27 + valid_segment_mask = num_segment_bags > 0
28 + segment_num_frames = num_segment_bags[valid_segment_mask]
29 + segment_num_frames[segment_num_frames > segment_size] = segment_size
30 +
31 + max_segment_num = (max_frame + segment_size - 1) // segment_size
32 + video_idxs = np.tile(
33 + np.arange(0, video_batch_size)[:, np.newaxis], [1, max_segment_num])
34 + segment_idxs = np.tile(segment_start_times, [video_batch_size, 1])
35 + idx_bags = np.stack([video_idxs, segment_idxs], axis=-1).reshape((-1, 2))
36 + video_segment_ids = idx_bags[valid_segment_mask]
37 +
38 + return {
39 + "video_batch": segment_frames,
40 + "num_frames_batch": segment_num_frames,
41 + "video_segment_ids": video_segment_ids
42 + }
43 +
44 +
45 +def format_prediction(video_ids, predictions, top_k, whitelisted_cls_mask=None):
46 + batch_size = len(video_ids)
47 + for video_index in range(batch_size):
48 + video_prediction = predictions[video_index]
49 + if whitelisted_cls_mask is not None:
50 + # Whitelist classes.
51 + video_prediction *= whitelisted_cls_mask
52 + top_indices = np.argpartition(video_prediction, -top_k)[-top_k:]
53 + line = [(class_index, predictions[video_index][class_index])
54 + for class_index in top_indices]
55 + line = sorted(line, key=lambda p: -p[1])
56 + return (video_ids[video_index] + "," +
57 + " ".join("%i %g" % (label, score) for (label, score) in line) +
58 + "\n").encode("utf8")
59 +
60 +
61 +def inference_pb(filename):
62 + with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
63 +
64 + # 200527 Esot3riA
65 + # 0. Import SequenceExample type target from pb.
66 + target_video = pbutil.convert_pb(filename)
67 +
68 + # 1. Load video features from pb.
69 + video_id_batch_val = np.array([b'video'])
70 + n_frames = len(target_video.feature_lists.feature_list['rgb'].feature)
71 + # Restrict frame size to 300
72 + if n_frames > 300:
73 + n_frames = 300
74 + video_batch_val = np.zeros((300, 1152))
75 + for i in range(n_frames):
76 + video_batch_rgb_raw = target_video.feature_lists.feature_list['rgb'].feature[i].bytes_list.value[0]
77 + video_batch_rgb = np.array(tf.cast(tf.decode_raw(video_batch_rgb_raw, tf.float32), tf.float32).eval())
78 + video_batch_audio_raw = target_video.feature_lists.feature_list['audio'].feature[i].bytes_list.value[0]
79 + video_batch_audio = np.array(tf.cast(tf.decode_raw(video_batch_audio_raw, tf.float32), tf.float32).eval())
80 + video_batch_val[i] = np.concatenate([video_batch_rgb, video_batch_audio], axis=0)
81 + video_batch_val = np.array([video_batch_val])
82 + num_frames_batch_val = np.array([n_frames])
83 + # 200527 Esot3riA End
84 +
85 + # Restore checkpoint and meta-graph file
86 + checkpoint_file = '/Users/esot3ria/PycharmProjects/yt8m/models/frame' \
87 + '/sample_model/inference_model/segment_inference_model'
88 + if not gfile.Exists(checkpoint_file + ".meta"):
89 + raise IOError("Cannot find %s. Did you run eval.py?" % checkpoint_file)
90 + meta_graph_location = checkpoint_file + ".meta"
91 + logging.info("loading meta-graph: " + meta_graph_location)
92 +
93 + with tf.device("/cpu:0"):
94 + saver = tf.train.import_meta_graph(meta_graph_location,
95 + clear_devices=True)
96 + logging.info("restoring variables from " + checkpoint_file)
97 + saver.restore(sess, checkpoint_file)
98 + input_tensor = tf.get_collection("input_batch_raw")[0]
99 + num_frames_tensor = tf.get_collection("num_frames")[0]
100 + predictions_tensor = tf.get_collection("predictions")[0]
101 +
102 + # Workaround for num_epochs issue.
103 + def set_up_init_ops(variables):
104 + init_op_list = []
105 + for variable in list(variables):
106 + if "train_input" in variable.name:
107 + init_op_list.append(tf.assign(variable, 1))
108 + variables.remove(variable)
109 + init_op_list.append(tf.variables_initializer(variables))
110 + return init_op_list
111 +
112 + sess.run(
113 + set_up_init_ops(tf.get_collection_ref(tf.GraphKeys.LOCAL_VARIABLES)))
114 +
115 + coord = tf.train.Coordinator()
116 + threads = tf.train.start_queue_runners(sess=sess, coord=coord)
117 + whitelisted_cls_mask = np.zeros((predictions_tensor.get_shape()[-1],),
118 + dtype=np.float32)
119 + segment_label_ids_file = '../segment_label_ids.csv'
120 + with tf.io.gfile.GFile(segment_label_ids_file) as fobj:
121 + for line in fobj:
122 + try:
123 + cls_id = int(line)
124 + whitelisted_cls_mask[cls_id] = 1.
125 + except ValueError:
126 + # Simply skip the non-integer line.
127 + continue
128 +
129 + # 200527 Esot3riA
130 + # 2. Make segment features.
131 + results = get_segments(video_batch_val, num_frames_batch_val, 5)
132 + video_segment_ids = results["video_segment_ids"]
133 + video_id_batch_val = video_id_batch_val[video_segment_ids[:, 0]]
134 + video_id_batch_val = np.array([
135 + "%s:%d" % (x.decode("utf8"), y)
136 + for x, y in zip(video_id_batch_val, video_segment_ids[:, 1])
137 + ])
138 + video_batch_val = results["video_batch"]
139 + num_frames_batch_val = results["num_frames_batch"]
140 + if input_tensor.get_shape()[1] != video_batch_val.shape[1]:
141 + raise ValueError("max_frames mismatch. Please re-run the eval.py "
142 + "with correct segment_labels settings.")
143 +
144 + predictions_val, = sess.run([predictions_tensor],
145 + feed_dict={
146 + input_tensor: video_batch_val,
147 + num_frames_tensor: num_frames_batch_val
148 + })
149 + logging.info(predictions_val)
150 + logging.info("profit :D")
151 +
152 + # result = format_prediction(video_id_batch_val, predictions_val, 10, whitelisted_cls_mask)
153 +
154 +
155 +if __name__ == '__main__':
156 + logging.set_verbosity(tf.logging.INFO)
157 +
158 + filename = 'features.pb'
159 + inference_pb(filename)
1 +import tensorflow as tf
2 +import numpy
3 +
4 +
5 +def _make_bytes(int_array):
6 + if bytes == str: # Python2
7 + return ''.join(map(chr, int_array))
8 + else:
9 + return bytes(int_array)
10 +
11 +
12 +def quantize(features, min_quantized_value=-2.0, max_quantized_value=2.0):
13 + """Quantizes float32 `features` into string."""
14 + assert features.dtype == 'float32'
15 + assert len(features.shape) == 1 # 1-D array
16 + features = numpy.clip(features, min_quantized_value, max_quantized_value)
17 + quantize_range = max_quantized_value - min_quantized_value
18 + features = (features - min_quantized_value) * (255.0 / quantize_range)
19 + features = [int(round(f)) for f in features]
20 +
21 + return _make_bytes(features)
22 +
23 +
24 +# for parse feature.pb
25 +
26 +contexts = {
27 + 'AUDIO/feature/dimensions': tf.io.FixedLenFeature([], tf.int64),
28 + 'AUDIO/feature/rate': tf.io.FixedLenFeature([], tf.float32),
29 + 'RGB/feature/dimensions': tf.io.FixedLenFeature([], tf.int64),
30 + 'RGB/feature/rate': tf.io.FixedLenFeature([], tf.float32),
31 + 'clip/data_path': tf.io.FixedLenFeature([], tf.string),
32 + 'clip/end/timestamp': tf.io.FixedLenFeature([], tf.int64),
33 + 'clip/start/timestamp': tf.io.FixedLenFeature([], tf.int64)
34 +}
35 +
36 +features = {
37 + 'AUDIO/feature/floats': tf.io.VarLenFeature(dtype=tf.float32),
38 + 'AUDIO/feature/timestamp': tf.io.VarLenFeature(tf.int64),
39 + 'RGB/feature/floats': tf.io.VarLenFeature(dtype=tf.float32),
40 + 'RGB/feature/timestamp': tf.io.VarLenFeature(tf.int64)
41 +
42 +}
43 +
44 +
45 +def parse_exmp(serial_exmp):
46 + _, sequence_parsed = tf.io.parse_single_sequence_example(
47 + serialized=serial_exmp,
48 + context_features=contexts,
49 + sequence_features=features)
50 +
51 + sequence_parsed = tf.contrib.learn.run_n(sequence_parsed)[0]
52 +
53 + audio = sequence_parsed['AUDIO/feature/floats'].values
54 + rgb = sequence_parsed['RGB/feature/floats'].values
55 +
56 + # print(audio.values)
57 + # print(type(audio.values))
58 +
59 + # audio is 128 8bit, rgb is 1024 8bit for every second
60 + audio_slices = [audio[128 * i: 128 * (i + 1)] for i in range(len(audio) // 128)]
61 + rgb_slices = [rgb[1024 * i: 1024 * (i + 1)] for i in range(len(rgb) // 1024)]
62 +
63 + byte_audio = []
64 + byte_rgb = []
65 +
66 + for seg in audio_slices:
67 + # audio_seg = quantize(seg)
68 + audio_seg = _make_bytes(seg)
69 + byte_audio.append(audio_seg)
70 +
71 + for seg in rgb_slices:
72 + # rgb_seg = quantize(seg)
73 + rgb_seg = _make_bytes(seg)
74 + byte_rgb.append(rgb_seg)
75 +
76 + return byte_audio, byte_rgb
77 +
78 +
79 +def make_exmp(id, audio, rgb):
80 + audio_features = []
81 + rgb_features = []
82 +
83 + for embedding in audio:
84 + embedding_feature = tf.train.Feature(
85 + bytes_list=tf.train.BytesList(value=[embedding]))
86 + audio_features.append(embedding_feature)
87 +
88 + for embedding in rgb:
89 + embedding_feature = tf.train.Feature(
90 + bytes_list=tf.train.BytesList(value=[embedding]))
91 + rgb_features.append(embedding_feature)
92 +
93 + # for construct yt8m data
94 + seq_exmp = tf.train.SequenceExample(
95 + context=tf.train.Features(
96 + feature={
97 + 'id': tf.train.Feature(bytes_list=tf.train.BytesList(
98 + value=[id.encode('utf-8')]))
99 + }),
100 + feature_lists=tf.train.FeatureLists(
101 + feature_list={
102 + 'audio': tf.train.FeatureList(
103 + feature=audio_features
104 + ),
105 + 'rgb': tf.train.FeatureList(
106 + feature=rgb_features
107 + )
108 + })
109 + )
110 + serialized = seq_exmp.SerializeToString()
111 + return serialized
112 +
113 +
114 +def convert_pb(filename):
115 + sequence_example = open(filename, 'rb').read()
116 +
117 + audio, rgb = parse_exmp(sequence_example)
118 + tmp_example = make_exmp('video', audio, rgb)
119 +
120 + decoded = tf.train.SequenceExample.FromString(tmp_example)
121 + return decoded
1 +import tensorflow as tf
2 +import numpy as np
3 +
4 +frame_lvl_record = "test0000.tfrecord"
5 +
6 +feat_rgb = []
7 +feat_audio = []
8 +
9 +for example in tf.python_io.tf_record_iterator(frame_lvl_record):
10 + tf_seq_example = tf.train.SequenceExample.FromString(example)
11 + test = tf_seq_example.SerializeToString()
12 + n_frames = len(tf_seq_example.feature_lists.feature_list['audio'].feature)
13 + sess = tf.InteractiveSession()
14 + rgb_frame = []
15 + audio_frame = []
16 + # iterate through frames
17 + for i in range(n_frames):
18 + rgb_frame.append(tf.cast(tf.decode_raw(
19 + tf_seq_example.feature_lists.feature_list['rgb']
20 + .feature[i].bytes_list.value[0], tf.uint8)
21 + , tf.float32).eval())
22 + audio_frame.append(tf.cast(tf.decode_raw(
23 + tf_seq_example.feature_lists.feature_list['audio']
24 + .feature[i].bytes_list.value[0], tf.uint8)
25 + , tf.float32).eval())
26 +
27 + sess.close()
28 +
29 + feat_audio.append(audio_frame)
30 + feat_rgb.append(rgb_frame)
31 + break
32 +
33 +print('The first video has %d frames' %len(feat_rgb[0]))
...\ No newline at end of file ...\ No newline at end of file
No preview for this file type
This diff is collapsed. Click to expand it.
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Provides functions to help with evaluating models."""
15 +import average_precision_calculator as ap_calculator
16 +import mean_average_precision_calculator as map_calculator
17 +import numpy
18 +from tensorflow.python.platform import gfile
19 +
20 +
21 +def flatten(l):
22 + """Merges a list of lists into a single list. """
23 + return [item for sublist in l for item in sublist]
24 +
25 +
26 +def calculate_hit_at_one(predictions, actuals):
27 + """Performs a local (numpy) calculation of the hit at one.
28 +
29 + Args:
30 + predictions: Matrix containing the outputs of the model. Dimensions are
31 + 'batch' x 'num_classes'.
32 + actuals: Matrix containing the ground truth labels. Dimensions are 'batch' x
33 + 'num_classes'.
34 +
35 + Returns:
36 + float: The average hit at one across the entire batch.
37 + """
38 + top_prediction = numpy.argmax(predictions, 1)
39 + hits = actuals[numpy.arange(actuals.shape[0]), top_prediction]
40 + return numpy.average(hits)
41 +
42 +
43 +def calculate_precision_at_equal_recall_rate(predictions, actuals):
44 + """Performs a local (numpy) calculation of the PERR.
45 +
46 + Args:
47 + predictions: Matrix containing the outputs of the model. Dimensions are
48 + 'batch' x 'num_classes'.
49 + actuals: Matrix containing the ground truth labels. Dimensions are 'batch' x
50 + 'num_classes'.
51 +
52 + Returns:
53 + float: The average precision at equal recall rate across the entire batch.
54 + """
55 + aggregated_precision = 0.0
56 + num_videos = actuals.shape[0]
57 + for row in numpy.arange(num_videos):
58 + num_labels = int(numpy.sum(actuals[row]))
59 + top_indices = numpy.argpartition(predictions[row],
60 + -num_labels)[-num_labels:]
61 + item_precision = 0.0
62 + for label_index in top_indices:
63 + if predictions[row][label_index] > 0:
64 + item_precision += actuals[row][label_index]
65 + item_precision /= top_indices.size
66 + aggregated_precision += item_precision
67 + aggregated_precision /= num_videos
68 + return aggregated_precision
69 +
70 +
71 +def calculate_gap(predictions, actuals, top_k=20):
72 + """Performs a local (numpy) calculation of the global average precision.
73 +
74 + Only the top_k predictions are taken for each of the videos.
75 +
76 + Args:
77 + predictions: Matrix containing the outputs of the model. Dimensions are
78 + 'batch' x 'num_classes'.
79 + actuals: Matrix containing the ground truth labels. Dimensions are 'batch' x
80 + 'num_classes'.
81 + top_k: How many predictions to use per video.
82 +
83 + Returns:
84 + float: The global average precision.
85 + """
86 + gap_calculator = ap_calculator.AveragePrecisionCalculator()
87 + sparse_predictions, sparse_labels, num_positives = top_k_by_class(
88 + predictions, actuals, top_k)
89 + gap_calculator.accumulate(flatten(sparse_predictions), flatten(sparse_labels),
90 + sum(num_positives))
91 + return gap_calculator.peek_ap_at_n()
92 +
93 +
94 +def top_k_by_class(predictions, labels, k=20):
95 + """Extracts the top k predictions for each video, sorted by class.
96 +
97 + Args:
98 + predictions: A numpy matrix containing the outputs of the model. Dimensions
99 + are 'batch' x 'num_classes'.
100 + k: the top k non-zero entries to preserve in each prediction.
101 +
102 + Returns:
103 + A tuple (predictions,labels, true_positives). 'predictions' and 'labels'
104 + are lists of lists of floats. 'true_positives' is a list of scalars. The
105 + length of the lists are equal to the number of classes. The entries in the
106 + predictions variable are probability predictions, and
107 + the corresponding entries in the labels variable are the ground truth for
108 + those predictions. The entries in 'true_positives' are the number of true
109 + positives for each class in the ground truth.
110 +
111 + Raises:
112 + ValueError: An error occurred when the k is not a positive integer.
113 + """
114 + if k <= 0:
115 + raise ValueError("k must be a positive integer.")
116 + k = min(k, predictions.shape[1])
117 + num_classes = predictions.shape[1]
118 + prediction_triplets = []
119 + for video_index in range(predictions.shape[0]):
120 + prediction_triplets.extend(
121 + top_k_triplets(predictions[video_index], labels[video_index], k))
122 + out_predictions = [[] for _ in range(num_classes)]
123 + out_labels = [[] for _ in range(num_classes)]
124 + for triplet in prediction_triplets:
125 + out_predictions[triplet[0]].append(triplet[1])
126 + out_labels[triplet[0]].append(triplet[2])
127 + out_true_positives = [numpy.sum(labels[:, i]) for i in range(num_classes)]
128 +
129 + return out_predictions, out_labels, out_true_positives
130 +
131 +
132 +def top_k_triplets(predictions, labels, k=20):
133 + """Get the top_k for a 1-d numpy array.
134 +
135 + Returns a sparse list of tuples in
136 + (prediction, class) format
137 + """
138 + m = len(predictions)
139 + k = min(k, m)
140 + indices = numpy.argpartition(predictions, -k)[-k:]
141 + return [(index, predictions[index], labels[index]) for index in indices]
142 +
143 +
144 +class EvaluationMetrics(object):
145 + """A class to store the evaluation metrics."""
146 +
147 + def __init__(self, num_class, top_k, top_n):
148 + """Construct an EvaluationMetrics object to store the evaluation metrics.
149 +
150 + Args:
151 + num_class: A positive integer specifying the number of classes.
152 + top_k: A positive integer specifying how many predictions are considered
153 + per video.
154 + top_n: A positive Integer specifying the average precision at n, or None
155 + to use all provided data points.
156 +
157 + Raises:
158 + ValueError: An error occurred when MeanAveragePrecisionCalculator cannot
159 + not be constructed.
160 + """
161 + self.sum_hit_at_one = 0.0
162 + self.sum_perr = 0.0
163 + self.sum_loss = 0.0
164 + self.map_calculator = map_calculator.MeanAveragePrecisionCalculator(
165 + num_class, top_n=top_n)
166 + self.global_ap_calculator = ap_calculator.AveragePrecisionCalculator()
167 + self.top_k = top_k
168 + self.num_examples = 0
169 +
170 + def accumulate(self, predictions, labels, loss):
171 + """Accumulate the metrics calculated locally for this mini-batch.
172 +
173 + Args:
174 + predictions: A numpy matrix containing the outputs of the model.
175 + Dimensions are 'batch' x 'num_classes'.
176 + labels: A numpy matrix containing the ground truth labels. Dimensions are
177 + 'batch' x 'num_classes'.
178 + loss: A numpy array containing the loss for each sample.
179 +
180 + Returns:
181 + dictionary: A dictionary storing the metrics for the mini-batch.
182 +
183 + Raises:
184 + ValueError: An error occurred when the shape of predictions and actuals
185 + does not match.
186 + """
187 + batch_size = labels.shape[0]
188 + mean_hit_at_one = calculate_hit_at_one(predictions, labels)
189 + mean_perr = calculate_precision_at_equal_recall_rate(predictions, labels)
190 + mean_loss = numpy.mean(loss)
191 +
192 + # Take the top 20 predictions.
193 + sparse_predictions, sparse_labels, num_positives = top_k_by_class(
194 + predictions, labels, self.top_k)
195 + self.map_calculator.accumulate(sparse_predictions, sparse_labels,
196 + num_positives)
197 + self.global_ap_calculator.accumulate(flatten(sparse_predictions),
198 + flatten(sparse_labels),
199 + sum(num_positives))
200 +
201 + self.num_examples += batch_size
202 + self.sum_hit_at_one += mean_hit_at_one * batch_size
203 + self.sum_perr += mean_perr * batch_size
204 + self.sum_loss += mean_loss * batch_size
205 +
206 + return {"hit_at_one": mean_hit_at_one, "perr": mean_perr, "loss": mean_loss}
207 +
208 + def get(self):
209 + """Calculate the evaluation metrics for the whole epoch.
210 +
211 + Raises:
212 + ValueError: If no examples were accumulated.
213 +
214 + Returns:
215 + dictionary: a dictionary storing the evaluation metrics for the epoch. The
216 + dictionary has the fields: avg_hit_at_one, avg_perr, avg_loss, and
217 + aps (default nan).
218 + """
219 + if self.num_examples <= 0:
220 + raise ValueError("total_sample must be positive.")
221 + avg_hit_at_one = self.sum_hit_at_one / self.num_examples
222 + avg_perr = self.sum_perr / self.num_examples
223 + avg_loss = self.sum_loss / self.num_examples
224 +
225 + aps = self.map_calculator.peek_map_at_n()
226 + gap = self.global_ap_calculator.peek_ap_at_n()
227 +
228 + epoch_info_dict = {
229 + "avg_hit_at_one": avg_hit_at_one,
230 + "avg_perr": avg_perr,
231 + "avg_loss": avg_loss,
232 + "aps": aps,
233 + "gap": gap
234 + }
235 + return epoch_info_dict
236 +
237 + def clear(self):
238 + """Clear the evaluation metrics and reset the EvaluationMetrics object."""
239 + self.sum_hit_at_one = 0.0
240 + self.sum_perr = 0.0
241 + self.sum_loss = 0.0
242 + self.map_calculator.clear()
243 + self.global_ap_calculator.clear()
244 + self.num_examples = 0
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Utilities to export a model for batch prediction."""
15 +
16 +import tensorflow as tf
17 +import tensorflow.contrib.slim as slim
18 +
19 +from tensorflow.python.saved_model import builder as saved_model_builder
20 +from tensorflow.python.saved_model import signature_constants
21 +from tensorflow.python.saved_model import signature_def_utils
22 +from tensorflow.python.saved_model import tag_constants
23 +from tensorflow.python.saved_model import utils as saved_model_utils
24 +
25 +_TOP_PREDICTIONS_IN_OUTPUT = 20
26 +
27 +
28 +class ModelExporter(object):
29 +
30 + def __init__(self, frame_features, model, reader):
31 + self.frame_features = frame_features
32 + self.model = model
33 + self.reader = reader
34 +
35 + with tf.Graph().as_default() as graph:
36 + self.inputs, self.outputs = self.build_inputs_and_outputs()
37 + self.graph = graph
38 + self.saver = tf.train.Saver(tf.trainable_variables(), sharded=True)
39 +
40 + def export_model(self, model_dir, global_step_val, last_checkpoint):
41 + """Exports the model so that it can used for batch predictions."""
42 +
43 + with self.graph.as_default():
44 + with tf.Session() as session:
45 + session.run(tf.global_variables_initializer())
46 + self.saver.restore(session, last_checkpoint)
47 +
48 + signature = signature_def_utils.build_signature_def(
49 + inputs=self.inputs,
50 + outputs=self.outputs,
51 + method_name=signature_constants.PREDICT_METHOD_NAME)
52 +
53 + signature_map = {
54 + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature
55 + }
56 +
57 + model_builder = saved_model_builder.SavedModelBuilder(model_dir)
58 + model_builder.add_meta_graph_and_variables(
59 + session,
60 + tags=[tag_constants.SERVING],
61 + signature_def_map=signature_map,
62 + clear_devices=True)
63 + model_builder.save()
64 +
65 + def build_inputs_and_outputs(self):
66 + if self.frame_features:
67 + serialized_examples = tf.placeholder(tf.string, shape=(None,))
68 +
69 + fn = lambda x: self.build_prediction_graph(x)
70 + video_id_output, top_indices_output, top_predictions_output = (tf.map_fn(
71 + fn, serialized_examples, dtype=(tf.string, tf.int32, tf.float32)))
72 +
73 + else:
74 + serialized_examples = tf.placeholder(tf.string, shape=(None,))
75 +
76 + video_id_output, top_indices_output, top_predictions_output = (
77 + self.build_prediction_graph(serialized_examples))
78 +
79 + inputs = {
80 + "example_bytes":
81 + saved_model_utils.build_tensor_info(serialized_examples)
82 + }
83 +
84 + outputs = {
85 + "video_id":
86 + saved_model_utils.build_tensor_info(video_id_output),
87 + "class_indexes":
88 + saved_model_utils.build_tensor_info(top_indices_output),
89 + "predictions":
90 + saved_model_utils.build_tensor_info(top_predictions_output)
91 + }
92 +
93 + return inputs, outputs
94 +
95 + def build_prediction_graph(self, serialized_examples):
96 + input_data_dict = (
97 + self.reader.prepare_serialized_examples(serialized_examples))
98 + video_id = input_data_dict["video_ids"]
99 + model_input_raw = input_data_dict["video_matrix"]
100 + labels_batch = input_data_dict["labels"]
101 + num_frames = input_data_dict["num_frames"]
102 +
103 + feature_dim = len(model_input_raw.get_shape()) - 1
104 + model_input = tf.nn.l2_normalize(model_input_raw, feature_dim)
105 +
106 + with tf.variable_scope("tower"):
107 + result = self.model.create_model(model_input,
108 + num_frames=num_frames,
109 + vocab_size=self.reader.num_classes,
110 + labels=labels_batch,
111 + is_training=False)
112 +
113 + for variable in slim.get_model_variables():
114 + tf.summary.histogram(variable.op.name, variable)
115 +
116 + predictions = result["predictions"]
117 +
118 + top_predictions, top_indices = tf.nn.top_k(predictions,
119 + _TOP_PREDICTIONS_IN_OUTPUT)
120 + return video_id, top_indices, top_predictions
1 +# Lint as: python3
2 +import numpy as np
3 +import tensorflow as tf
4 +from tensorflow import app
5 +from tensorflow import flags
6 +
7 +FLAGS = flags.FLAGS
8 +
9 +
10 +def main(unused_argv):
11 + # Get the input tensor names to be replaced.
12 + tf.reset_default_graph()
13 + meta_graph_location = FLAGS.checkpoint_file + ".meta"
14 + tf.train.import_meta_graph(meta_graph_location, clear_devices=True)
15 +
16 + input_tensor_name = tf.get_collection("input_batch_raw")[0].name
17 + num_frames_tensor_name = tf.get_collection("num_frames")[0].name
18 +
19 + # Create output graph.
20 + saver = tf.train.Saver()
21 + tf.reset_default_graph()
22 +
23 + input_feature_placeholder = tf.placeholder(
24 + tf.float32, shape=(None, None, 1152))
25 + num_frames_placeholder = tf.placeholder(tf.int32, shape=(None, 1))
26 +
27 + saver = tf.train.import_meta_graph(
28 + meta_graph_location,
29 + input_map={
30 + input_tensor_name: input_feature_placeholder,
31 + num_frames_tensor_name: tf.squeeze(num_frames_placeholder, axis=1)
32 + },
33 + clear_devices=True)
34 + predictions_tensor = tf.get_collection("predictions")[0]
35 +
36 + with tf.Session() as sess:
37 + print("restoring variables from " + FLAGS.checkpoint_file)
38 + saver.restore(sess, FLAGS.checkpoint_file)
39 + tf.saved_model.simple_save(
40 + sess,
41 + FLAGS.output_dir,
42 + inputs={'rgb_and_audio': input_feature_placeholder,
43 + 'num_frames': num_frames_placeholder},
44 + outputs={'predictions': predictions_tensor})
45 +
46 + # Try running inference.
47 + predictions = sess.run(
48 + [predictions_tensor],
49 + feed_dict={
50 + input_feature_placeholder: np.zeros((3, 7, 1152), dtype=np.float32),
51 + num_frames_placeholder: np.array([[7]], dtype=np.int32)})
52 + print('Test inference:', predictions)
53 +
54 + print('Model saved to ', FLAGS.output_dir)
55 +
56 +
57 +if __name__ == '__main__':
58 + flags.DEFINE_string('checkpoint_file', None, 'Path to the checkpoint file.')
59 + flags.DEFINE_string('output_dir', None, 'SavedModel output directory.')
60 + app.run(main)
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Provides definitions for non-regularized training or test losses."""
15 +
16 +import tensorflow as tf
17 +
18 +
19 +class BaseLoss(object):
20 + """Inherit from this class when implementing new losses."""
21 +
22 + def calculate_loss(self, unused_predictions, unused_labels, **unused_params):
23 + """Calculates the average loss of the examples in a mini-batch.
24 +
25 + Args:
26 + unused_predictions: a 2-d tensor storing the prediction scores, in which
27 + each row represents a sample in the mini-batch and each column
28 + represents a class.
29 + unused_labels: a 2-d tensor storing the labels, which has the same shape
30 + as the unused_predictions. The labels must be in the range of 0 and 1.
31 + unused_params: loss specific parameters.
32 +
33 + Returns:
34 + A scalar loss tensor.
35 + """
36 + raise NotImplementedError()
37 +
38 +
39 +class CrossEntropyLoss(BaseLoss):
40 + """Calculate the cross entropy loss between the predictions and labels."""
41 +
42 + def calculate_loss(self,
43 + predictions,
44 + labels,
45 + label_weights=None,
46 + **unused_params):
47 + with tf.name_scope("loss_xent"):
48 + epsilon = 1e-5
49 + float_labels = tf.cast(labels, tf.float32)
50 + cross_entropy_loss = float_labels * tf.math.log(predictions + epsilon) + (
51 + 1 - float_labels) * tf.math.log(1 - predictions + epsilon)
52 + cross_entropy_loss = tf.negative(cross_entropy_loss)
53 + if label_weights is not None:
54 + cross_entropy_loss *= label_weights
55 + return tf.reduce_mean(tf.reduce_sum(cross_entropy_loss, 1))
56 +
57 +
58 +class HingeLoss(BaseLoss):
59 + """Calculate the hinge loss between the predictions and labels.
60 +
61 + Note the subgradient is used in the backpropagation, and thus the optimization
62 + may converge slower. The predictions trained by the hinge loss are between -1
63 + and +1.
64 + """
65 +
66 + def calculate_loss(self, predictions, labels, b=1.0, **unused_params):
67 + with tf.name_scope("loss_hinge"):
68 + float_labels = tf.cast(labels, tf.float32)
69 + all_zeros = tf.zeros(tf.shape(float_labels), dtype=tf.float32)
70 + all_ones = tf.ones(tf.shape(float_labels), dtype=tf.float32)
71 + sign_labels = tf.subtract(tf.scalar_mul(2, float_labels), all_ones)
72 + hinge_loss = tf.maximum(
73 + all_zeros,
74 + tf.scalar_mul(b, all_ones) - sign_labels * predictions)
75 + return tf.reduce_mean(tf.reduce_sum(hinge_loss, 1))
76 +
77 +
78 +class SoftmaxLoss(BaseLoss):
79 + """Calculate the softmax loss between the predictions and labels.
80 +
81 + The function calculates the loss in the following way: first we feed the
82 + predictions to the softmax activation function and then we calculate
83 + the minus linear dot product between the logged softmax activations and the
84 + normalized ground truth label.
85 +
86 + It is an extension to the one-hot label. It allows for more than one positive
87 + labels for each sample.
88 + """
89 +
90 + def calculate_loss(self, predictions, labels, **unused_params):
91 + with tf.name_scope("loss_softmax"):
92 + epsilon = 10e-8
93 + float_labels = tf.cast(labels, tf.float32)
94 + # l1 normalization (labels are no less than 0)
95 + label_rowsum = tf.maximum(tf.reduce_sum(float_labels, 1, keep_dims=True),
96 + epsilon)
97 + norm_float_labels = tf.div(float_labels, label_rowsum)
98 + softmax_outputs = tf.nn.softmax(predictions)
99 + softmax_loss = tf.negative(
100 + tf.reduce_sum(tf.multiply(norm_float_labels, tf.log(softmax_outputs)),
101 + 1))
102 + return tf.reduce_mean(softmax_loss)
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Calculate the mean average precision.
15 +
16 +It provides an interface for calculating mean average precision
17 +for an entire list or the top-n ranked items.
18 +
19 +Example usages:
20 +We first call the function accumulate many times to process parts of the ranked
21 +list. After processing all the parts, we call peek_map_at_n
22 +to calculate the mean average precision.
23 +
24 +```
25 +import random
26 +
27 +p = np.array([[random.random() for _ in xrange(50)] for _ in xrange(1000)])
28 +a = np.array([[random.choice([0, 1]) for _ in xrange(50)]
29 + for _ in xrange(1000)])
30 +
31 +# mean average precision for 50 classes.
32 +calculator = mean_average_precision_calculator.MeanAveragePrecisionCalculator(
33 + num_class=50)
34 +calculator.accumulate(p, a)
35 +aps = calculator.peek_map_at_n()
36 +```
37 +"""
38 +
39 +import average_precision_calculator
40 +
41 +
42 +class MeanAveragePrecisionCalculator(object):
43 + """This class is to calculate mean average precision."""
44 +
45 + def __init__(self, num_class, filter_empty_classes=True, top_n=None):
46 + """Construct a calculator to calculate the (macro) average precision.
47 +
48 + Args:
49 + num_class: A positive Integer specifying the number of classes.
50 + filter_empty_classes: whether to filter classes without any positives.
51 + top_n: A positive Integer specifying the average precision at n, or None
52 + to use all provided data points.
53 +
54 + Raises:
55 + ValueError: An error occurred when num_class is not a positive integer;
56 + or the top_n_array is not a list of positive integers.
57 + """
58 + if not isinstance(num_class, int) or num_class <= 1:
59 + raise ValueError("num_class must be a positive integer.")
60 +
61 + self._ap_calculators = [] # member of AveragePrecisionCalculator
62 + self._num_class = num_class # total number of classes
63 + self._filter_empty_classes = filter_empty_classes
64 + for _ in range(num_class):
65 + self._ap_calculators.append(
66 + average_precision_calculator.AveragePrecisionCalculator(top_n=top_n))
67 +
68 + def accumulate(self, predictions, actuals, num_positives=None):
69 + """Accumulate the predictions and their ground truth labels.
70 +
71 + Args:
72 + predictions: A list of lists storing the prediction scores. The outer
73 + dimension corresponds to classes.
74 + actuals: A list of lists storing the ground truth labels. The dimensions
75 + should correspond to the predictions input. Any value larger than 0 will
76 + be treated as positives, otherwise as negatives.
77 + num_positives: If provided, it is a list of numbers representing the
78 + number of true positives for each class. If not provided, the number of
79 + true positives will be inferred from the 'actuals' array.
80 +
81 + Raises:
82 + ValueError: An error occurred when the shape of predictions and actuals
83 + does not match.
84 + """
85 + if not num_positives:
86 + num_positives = [None for i in range(self._num_class)]
87 +
88 + calculators = self._ap_calculators
89 + for i in range(self._num_class):
90 + calculators[i].accumulate(predictions[i], actuals[i], num_positives[i])
91 +
92 + def clear(self):
93 + for calculator in self._ap_calculators:
94 + calculator.clear()
95 +
96 + def is_empty(self):
97 + return ([calculator.heap_size for calculator in self._ap_calculators
98 + ] == [0 for _ in range(self._num_class)])
99 +
100 + def peek_map_at_n(self):
101 + """Peek the non-interpolated mean average precision at n.
102 +
103 + Returns:
104 + An array of non-interpolated average precision at n (default 0) for each
105 + class.
106 + """
107 + aps = []
108 + for i in range(self._num_class):
109 + if (not self._filter_empty_classes or
110 + self._ap_calculators[i].num_accumulated_positives > 0):
111 + ap = self._ap_calculators[i].peek_ap_at_n()
112 + aps.append(ap)
113 + return aps
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Contains a collection of util functions for model construction."""
15 +import numpy
16 +import tensorflow as tf
17 +from tensorflow import logging
18 +from tensorflow import flags
19 +import tensorflow.contrib.slim as slim
20 +
21 +
22 +def SampleRandomSequence(model_input, num_frames, num_samples):
23 + """Samples a random sequence of frames of size num_samples.
24 +
25 + Args:
26 + model_input: A tensor of size batch_size x max_frames x feature_size
27 + num_frames: A tensor of size batch_size x 1
28 + num_samples: A scalar
29 +
30 + Returns:
31 + `model_input`: A tensor of size batch_size x num_samples x feature_size
32 + """
33 +
34 + batch_size = tf.shape(model_input)[0]
35 + frame_index_offset = tf.tile(tf.expand_dims(tf.range(num_samples), 0),
36 + [batch_size, 1])
37 + max_start_frame_index = tf.maximum(num_frames - num_samples, 0)
38 + start_frame_index = tf.cast(
39 + tf.multiply(tf.random_uniform([batch_size, 1]),
40 + tf.cast(max_start_frame_index + 1, tf.float32)), tf.int32)
41 + frame_index = tf.minimum(start_frame_index + frame_index_offset,
42 + tf.cast(num_frames - 1, tf.int32))
43 + batch_index = tf.tile(tf.expand_dims(tf.range(batch_size), 1),
44 + [1, num_samples])
45 + index = tf.stack([batch_index, frame_index], 2)
46 + return tf.gather_nd(model_input, index)
47 +
48 +
49 +def SampleRandomFrames(model_input, num_frames, num_samples):
50 + """Samples a random set of frames of size num_samples.
51 +
52 + Args:
53 + model_input: A tensor of size batch_size x max_frames x feature_size
54 + num_frames: A tensor of size batch_size x 1
55 + num_samples: A scalar
56 +
57 + Returns:
58 + `model_input`: A tensor of size batch_size x num_samples x feature_size
59 + """
60 + batch_size = tf.shape(model_input)[0]
61 + frame_index = tf.cast(
62 + tf.multiply(tf.random_uniform([batch_size, num_samples]),
63 + tf.tile(tf.cast(num_frames, tf.float32), [1, num_samples])),
64 + tf.int32)
65 + batch_index = tf.tile(tf.expand_dims(tf.range(batch_size), 1),
66 + [1, num_samples])
67 + index = tf.stack([batch_index, frame_index], 2)
68 + return tf.gather_nd(model_input, index)
69 +
70 +
71 +def FramePooling(frames, method, **unused_params):
72 + """Pools over the frames of a video.
73 +
74 + Args:
75 + frames: A tensor with shape [batch_size, num_frames, feature_size].
76 + method: "average", "max", "attention", or "none".
77 +
78 + Returns:
79 + A tensor with shape [batch_size, feature_size] for average, max, or
80 + attention pooling. A tensor with shape [batch_size*num_frames, feature_size]
81 + for none pooling.
82 +
83 + Raises:
84 + ValueError: if method is other than "average", "max", "attention", or
85 + "none".
86 + """
87 + if method == "average":
88 + return tf.reduce_mean(frames, 1)
89 + elif method == "max":
90 + return tf.reduce_max(frames, 1)
91 + elif method == "none":
92 + feature_size = frames.shape_as_list()[2]
93 + return tf.reshape(frames, [-1, feature_size])
94 + else:
95 + raise ValueError("Unrecognized pooling method: %s" % method)
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Contains the base class for models."""
15 +
16 +
17 +class BaseModel(object):
18 + """Inherit from this class when implementing new models."""
19 +
20 + def create_model(self, unused_model_input, **unused_params):
21 + raise NotImplementedError()
This diff is collapsed. Click to expand it.
1 +"""Eval mAP@N metric from inference file."""
2 +
3 +from __future__ import absolute_import
4 +from __future__ import division
5 +from __future__ import print_function
6 +
7 +from absl import app
8 +from absl import flags
9 +
10 +import mean_average_precision_calculator as map_calculator
11 +import numpy as np
12 +import tensorflow as tf
13 +
14 +flags.DEFINE_string(
15 + "eval_data_pattern", "",
16 + "File glob defining the evaluation dataset in tensorflow.SequenceExample "
17 + "format. The SequenceExamples are expected to have an 'rgb' byte array "
18 + "sequence feature as well as a 'labels' int64 context feature.")
19 +flags.DEFINE_string(
20 + "label_cache", "",
21 + "The path for the label cache file. Leave blank for not to cache.")
22 +flags.DEFINE_string("submission_file", "",
23 + "The segment submission file generated by inference.py.")
24 +flags.DEFINE_integer(
25 + "top_n", 0,
26 + "The cap per-class predictions by a maximum of N. Use 0 for not capping.")
27 +
28 +FLAGS = flags.FLAGS
29 +
30 +
31 +class Labels(object):
32 + """Contains the class to hold label objects.
33 +
34 + This class can serialize and de-serialize the groundtruths.
35 + The ground truth is in a mapping from (segment_id, class_id) -> label_score.
36 + """
37 +
38 + def __init__(self, labels):
39 + """__init__ method."""
40 + self._labels = labels
41 +
42 + @property
43 + def labels(self):
44 + """Return the ground truth mapping. See class docstring for details."""
45 + return self._labels
46 +
47 + def to_file(self, file_name):
48 + """Materialize the GT mapping to file."""
49 + with tf.gfile.Open(file_name, "w") as fobj:
50 + for k, v in self._labels.items():
51 + seg_id, label = k
52 + line = "%s,%s,%s\n" % (seg_id, label, v)
53 + fobj.write(line)
54 +
55 + @classmethod
56 + def from_file(cls, file_name):
57 + """Read the GT mapping from cached file."""
58 + labels = {}
59 + with tf.gfile.Open(file_name) as fobj:
60 + for line in fobj:
61 + line = line.strip().strip("\n")
62 + seg_id, label, score = line.split(",")
63 + labels[(seg_id, int(label))] = float(score)
64 + return cls(labels)
65 +
66 +
67 +def read_labels(data_pattern, cache_path=""):
68 + """Read labels from TFRecords.
69 +
70 + Args:
71 + data_pattern: the data pattern to the TFRecords.
72 + cache_path: the cache path for the label file.
73 +
74 + Returns:
75 + a Labels object.
76 + """
77 + if cache_path:
78 + if tf.gfile.Exists(cache_path):
79 + tf.logging.info("Reading cached labels from %s..." % cache_path)
80 + return Labels.from_file(cache_path)
81 + tf.enable_eager_execution()
82 + data_paths = tf.gfile.Glob(data_pattern)
83 + ds = tf.data.TFRecordDataset(data_paths, num_parallel_reads=50)
84 + context_features = {
85 + "id": tf.FixedLenFeature([], tf.string),
86 + "segment_labels": tf.VarLenFeature(tf.int64),
87 + "segment_start_times": tf.VarLenFeature(tf.int64),
88 + "segment_scores": tf.VarLenFeature(tf.float32)
89 + }
90 +
91 + def _parse_se_func(sequence_example):
92 + return tf.parse_single_sequence_example(sequence_example,
93 + context_features=context_features)
94 +
95 + ds = ds.map(_parse_se_func)
96 + rated_labels = {}
97 + tf.logging.info("Reading labels from TFRecords...")
98 + last_batch = 0
99 + batch_size = 5000
100 + for cxt_feature_val, _ in ds:
101 + video_id = cxt_feature_val["id"].numpy()
102 + segment_labels = cxt_feature_val["segment_labels"].values.numpy()
103 + segment_start_times = cxt_feature_val["segment_start_times"].values.numpy()
104 + segment_scores = cxt_feature_val["segment_scores"].values.numpy()
105 + for label, start_time, score in zip(segment_labels, segment_start_times,
106 + segment_scores):
107 + rated_labels[("%s:%d" % (video_id, start_time), label)] = score
108 + batch_id = len(rated_labels) // batch_size
109 + if batch_id != last_batch:
110 + tf.logging.info("%d examples processed.", len(rated_labels))
111 + last_batch = batch_id
112 + tf.logging.info("Finish reading labels from TFRecords...")
113 + labels_obj = Labels(rated_labels)
114 + if cache_path:
115 + tf.logging.info("Caching labels to %s..." % cache_path)
116 + labels_obj.to_file(cache_path)
117 + return labels_obj
118 +
119 +
120 +def read_segment_predictions(file_path, labels, top_n=None):
121 + """Read segement predictions.
122 +
123 + Args:
124 + file_path: the submission file path.
125 + labels: a Labels object containing the eval labels.
126 + top_n: the per-class class capping.
127 +
128 + Returns:
129 + a segment prediction list for each classes.
130 + """
131 + cls_preds = {} # A label_id to pred list mapping.
132 + with tf.gfile.Open(file_path) as fobj:
133 + tf.logging.info("Reading predictions from %s..." % file_path)
134 + for line in fobj:
135 + label_id, pred_ids_val = line.split(",")
136 + pred_ids = pred_ids_val.split(" ")
137 + if top_n:
138 + pred_ids = pred_ids[:top_n]
139 + pred_ids = [
140 + pred_id for pred_id in pred_ids
141 + if (pred_id, int(label_id)) in labels.labels
142 + ]
143 + cls_preds[int(label_id)] = pred_ids
144 + if len(cls_preds) % 50 == 0:
145 + tf.logging.info("Processed %d classes..." % len(cls_preds))
146 + tf.logging.info("Finish reading predictions.")
147 + return cls_preds
148 +
149 +
150 +def main(unused_argv):
151 + """Entry function of the script."""
152 + if not FLAGS.submission_file:
153 + raise ValueError("You must input submission file.")
154 + eval_labels = read_labels(FLAGS.eval_data_pattern,
155 + cache_path=FLAGS.label_cache)
156 + tf.logging.info("Total rated segments: %d." % len(eval_labels.labels))
157 + positive_counter = {}
158 + for k, v in eval_labels.labels.items():
159 + _, label_id = k
160 + if v > 0:
161 + positive_counter[label_id] = positive_counter.get(label_id, 0) + 1
162 +
163 + seg_preds = read_segment_predictions(FLAGS.submission_file,
164 + eval_labels,
165 + top_n=FLAGS.top_n)
166 + map_cal = map_calculator.MeanAveragePrecisionCalculator(len(seg_preds))
167 + seg_labels = []
168 + seg_scored_preds = []
169 + num_positives = []
170 + for label_id in sorted(seg_preds):
171 + class_preds = seg_preds[label_id]
172 + seg_label = [eval_labels.labels[(pred, label_id)] for pred in class_preds]
173 + seg_labels.append(seg_label)
174 + seg_scored_pred = []
175 + if class_preds:
176 + seg_scored_pred = [
177 + float(x) / len(class_preds) for x in range(len(class_preds), 0, -1)
178 + ]
179 + seg_scored_preds.append(seg_scored_pred)
180 + num_positives.append(positive_counter[label_id])
181 + map_cal.accumulate(seg_scored_preds, seg_labels, num_positives)
182 + map_at_n = np.mean(map_cal.peek_map_at_n())
183 + tf.logging.info("Num classes: %d | mAP@%d: %.6f" %
184 + (len(seg_preds), FLAGS.top_n, map_at_n))
185 +
186 +
187 +if __name__ == "__main__":
188 + app.run(main)
1 +Index
2 +3
3 +7
4 +8
5 +11
6 +12
7 +17
8 +18
9 +19
10 +21
11 +22
12 +23
13 +28
14 +31
15 +30
16 +32
17 +33
18 +34
19 +41
20 +43
21 +45
22 +46
23 +48
24 +53
25 +54
26 +52
27 +55
28 +58
29 +59
30 +60
31 +61
32 +65
33 +68
34 +73
35 +71
36 +74
37 +75
38 +76
39 +77
40 +80
41 +83
42 +90
43 +88
44 +89
45 +92
46 +95
47 +100
48 +101
49 +99
50 +104
51 +105
52 +109
53 +113
54 +112
55 +115
56 +116
57 +118
58 +120
59 +121
60 +123
61 +125
62 +127
63 +131
64 +128
65 +129
66 +130
67 +137
68 +141
69 +143
70 +145
71 +148
72 +152
73 +151
74 +156
75 +155
76 +158
77 +160
78 +164
79 +163
80 +169
81 +170
82 +172
83 +171
84 +173
85 +174
86 +175
87 +176
88 +178
89 +182
90 +184
91 +186
92 +188
93 +187
94 +192
95 +191
96 +190
97 +194
98 +197
99 +196
100 +198
101 +201
102 +202
103 +200
104 +199
105 +205
106 +204
107 +209
108 +207
109 +206
110 +210
111 +213
112 +214
113 +220
114 +218
115 +217
116 +226
117 +227
118 +231
119 +232
120 +229
121 +233
122 +235
123 +237
124 +244
125 +240
126 +249
127 +246
128 +248
129 +239
130 +250
131 +245
132 +255
133 +253
134 +256
135 +261
136 +259
137 +263
138 +262
139 +266
140 +267
141 +268
142 +269
143 +271
144 +276
145 +273
146 +277
147 +274
148 +278
149 +279
150 +280
151 +288
152 +291
153 +295
154 +294
155 +293
156 +297
157 +296
158 +300
159 +299
160 +303
161 +302
162 +304
163 +305
164 +313
165 +307
166 +311
167 +310
168 +312
169 +316
170 +318
171 +321
172 +322
173 +331
174 +333
175 +329
176 +330
177 +334
178 +343
179 +349
180 +340
181 +344
182 +348
183 +358
184 +347
185 +359
186 +355
187 +361
188 +360
189 +364
190 +365
191 +368
192 +369
193 +366
194 +370
195 +374
196 +380
197 +373
198 +385
199 +384
200 +388
201 +389
202 +382
203 +393
204 +381
205 +390
206 +394
207 +399
208 +397
209 +396
210 +402
211 +400
212 +398
213 +401
214 +405
215 +406
216 +410
217 +408
218 +416
219 +415
220 +419
221 +422
222 +414
223 +421
224 +424
225 +429
226 +418
227 +427
228 +434
229 +428
230 +435
231 +430
232 +441
233 +439
234 +437
235 +443
236 +440
237 +442
238 +445
239 +446
240 +448
241 +454
242 +444
243 +453
244 +455
245 +451
246 +452
247 +458
248 +460
249 +465
250 +457
251 +463
252 +462
253 +461
254 +464
255 +469
256 +468
257 +472
258 +473
259 +471
260 +475
261 +474
262 +477
263 +485
264 +491
265 +488
266 +482
267 +490
268 +496
269 +494
270 +483
271 +495
272 +493
273 +507
274 +501
275 +499
276 +503
277 +498
278 +514
279 +504
280 +502
281 +506
282 +508
283 +511
284 +527
285 +526
286 +532
287 +513
288 +519
289 +525
290 +518
291 +528
292 +522
293 +523
294 +535
295 +539
296 +540
297 +533
298 +521
299 +541
300 +547
301 +550
302 +544
303 +549
304 +551
305 +554
306 +543
307 +548
308 +557
309 +560
310 +552
311 +559
312 +563
313 +565
314 +567
315 +555
316 +576
317 +568
318 +564
319 +573
320 +581
321 +580
322 +572
323 +571
324 +584
325 +590
326 +585
327 +587
328 +588
329 +592
330 +598
331 +597
332 +599
333 +603
334 +600
335 +604
336 +605
337 +614
338 +602
339 +610
340 +608
341 +611
342 +612
343 +613
344 +617
345 +620
346 +607
347 +624
348 +627
349 +625
350 +631
351 +629
352 +638
353 +632
354 +634
355 +644
356 +641
357 +642
358 +646
359 +652
360 +647
361 +637
362 +661
363 +635
364 +658
365 +648
366 +663
367 +668
368 +664
369 +656
370 +666
371 +671
372 +683
373 +675
374 +669
375 +676
376 +667
377 +691
378 +685
379 +673
380 +688
381 +702
382 +684
383 +679
384 +694
385 +686
386 +689
387 +680
388 +693
389 +703
390 +697
391 +698
392 +692
393 +705
394 +706
395 +712
396 +711
397 +709
398 +710
399 +726
400 +713
401 +721
402 +720
403 +715
404 +717
405 +730
406 +728
407 +723
408 +716
409 +722
410 +718
411 +732
412 +724
413 +736
414 +725
415 +742
416 +727
417 +735
418 +740
419 +748
420 +738
421 +746
422 +751
423 +749
424 +752
425 +754
426 +760
427 +763
428 +756
429 +758
430 +766
431 +764
432 +757
433 +780
434 +767
435 +769
436 +771
437 +786
438 +785
439 +781
440 +787
441 +778
442 +783
443 +792
444 +791
445 +795
446 +788
447 +805
448 +802
449 +801
450 +793
451 +796
452 +804
453 +803
454 +797
455 +814
456 +813
457 +789
458 +808
459 +818
460 +816
461 +817
462 +811
463 +820
464 +826
465 +829
466 +824
467 +821
468 +825
469 +822
470 +835
471 +833
472 +843
473 +823
474 +827
475 +830
476 +832
477 +837
478 +852
479 +844
480 +841
481 +812
482 +847
483 +862
484 +869
485 +860
486 +838
487 +870
488 +846
489 +858
490 +854
491 +880
492 +876
493 +857
494 +859
495 +877
496 +871
497 +855
498 +875
499 +861
500 +867
501 +892
502 +898
503 +888
504 +884
505 +887
506 +891
507 +906
508 +900
509 +878
510 +885
511 +883
512 +901
513 +903
514 +907
515 +930
516 +897
517 +914
518 +917
519 +910
520 +905
521 +909
522 +933
523 +932
524 +922
525 +913
526 +923
527 +931
528 +911
529 +937
530 +918
531 +955
532 +915
533 +944
534 +952
535 +945
536 +948
537 +946
538 +970
539 +974
540 +958
541 +925
542 +979
543 +942
544 +965
545 +975
546 +950
547 +982
548 +940
549 +973
550 +962
551 +972
552 +957
553 +984
554 +983
555 +964
556 +1007
557 +971
558 +981
559 +954
560 +993
561 +991
562 +996
563 +1005
564 +1015
565 +1009
566 +995
567 +986
568 +1000
569 +985
570 +980
571 +1016
572 +1011
573 +999
574 +1002
575 +994
576 +1013
577 +1010
578 +992
579 +1008
580 +1036
581 +1025
582 +1012
583 +990
584 +1037
585 +1040
586 +1031
587 +1019
588 +1052
589 +1001
590 +1055
591 +1032
592 +1069
593 +1058
594 +1014
595 +1023
596 +1030
597 +1061
598 +1035
599 +1034
600 +1053
601 +1045
602 +1046
603 +1067
604 +1060
605 +1049
606 +1056
607 +1074
608 +1066
609 +1044
610 +1038
611 +1073
612 +1077
613 +1068
614 +1057
615 +1072
616 +1104
617 +1083
618 +1089
619 +1087
620 +1099
621 +1076
622 +1086
623 +1098
624 +1094
625 +1095
626 +1096
627 +1101
628 +1107
629 +1105
630 +1117
631 +1093
632 +1106
633 +1122
634 +1119
635 +1103
636 +1128
637 +1120
638 +1126
639 +1102
640 +1115
641 +1124
642 +1123
643 +1131
644 +1136
645 +1144
646 +1121
647 +1137
648 +1132
649 +1133
650 +1157
651 +1134
652 +1143
653 +1159
654 +1164
655 +1155
656 +1142
657 +1150
658 +1148
659 +1161
660 +1165
661 +1147
662 +1162
663 +1152
664 +1174
665 +1160
666 +1166
667 +1190
668 +1175
669 +1167
670 +1156
671 +1180
672 +1171
673 +1179
674 +1172
675 +1186
676 +1188
677 +1201
678 +1177
679 +1208
680 +1183
681 +1189
682 +1192
683 +1209
684 +1214
685 +1197
686 +1168
687 +1202
688 +1205
689 +1203
690 +1199
691 +1219
692 +1217
693 +1187
694 +1206
695 +1210
696 +1241
697 +1221
698 +1218
699 +1223
700 +1236
701 +1212
702 +1237
703 +1195
704 +1216
705 +1247
706 +1234
707 +1240
708 +1257
709 +1224
710 +1243
711 +1259
712 +1242
713 +1282
714 +1222
715 +1254
716 +1227
717 +1235
718 +1269
719 +1258
720 +1290
721 +1275
722 +1262
723 +1252
724 +1248
725 +1272
726 +1246
727 +1225
728 +1245
729 +1277
730 +1298
731 +1288
732 +1271
733 +1265
734 +1286
735 +1260
736 +1266
737 +1296
738 +1280
739 +1285
740 +1293
741 +1276
742 +1287
743 +1289
744 +1261
745 +1264
746 +1295
747 +1291
748 +1283
749 +1311
750 +1303
751 +1330
752 +1315
753 +1300
754 +1333
755 +1307
756 +1325
757 +1334
758 +1316
759 +1314
760 +1317
761 +1310
762 +1329
763 +1324
764 +1339
765 +1346
766 +1342
767 +1352
768 +1321
769 +1376
770 +1366
771 +1308
772 +1345
773 +1348
774 +1386
775 +1383
776 +1372
777 +1367
778 +1400
779 +1382
780 +1375
781 +1392
782 +1380
783 +1371
784 +1393
785 +1389
786 +1353
787 +1387
788 +1374
789 +1379
790 +1381
791 +1359
792 +1360
793 +1396
794 +1399
795 +1365
796 +1424
797 +1373
798 +1411
799 +1401
800 +1397
801 +1395
802 +1412
803 +1394
804 +1368
805 +1423
806 +1391
807 +1435
808 +1409
809 +1443
810 +1402
811 +1425
812 +1415
813 +1421
814 +1426
815 +1433
816 +1420
817 +1452
818 +1436
819 +1430
820 +1408
821 +1458
822 +1429
823 +1453
824 +1454
825 +1447
826 +1472
827 +1486
828 +1468
829 +1461
830 +1467
831 +1484
832 +1457
833 +1444
834 +1450
835 +1451
836 +1459
837 +1462
838 +1449
839 +1476
840 +1470
841 +1471
842 +1498
843 +1488
844 +1442
845 +1480
846 +1456
847 +1466
848 +1505
849 +1517
850 +1464
851 +1503
852 +1490
853 +1519
854 +1481
855 +1493
856 +1463
857 +1532
858 +1487
859 +1501
860 +1500
861 +1495
862 +1509
863 +1535
864 +1506
865 +1521
866 +1580
867 +1540
868 +1502
869 +1520
870 +1496
871 +1569
872 +1515
873 +1489
874 +1507
875 +1527
876 +1545
877 +1560
878 +1510
879 +1514
880 +1526
881 +1594
882 +1511
883 +1572
884 +1548
885 +1584
886 +1556
887 +1588
888 +1628
889 +1555
890 +1568
891 +1550
892 +1622
893 +1563
894 +1603
895 +1616
896 +1576
897 +1549
898 +1537
899 +1593
900 +1618
901 +1645
902 +1624
903 +1617
904 +1634
905 +1595
906 +1597
907 +1590
908 +1632
909 +1575
910 +1559
911 +1625
912 +1615
913 +1591
914 +1630
915 +1608
916 +1621
917 +1589
918 +1646
919 +1643
920 +1652
921 +1627
922 +1611
923 +1626
924 +1613
925 +1639
926 +1655
927 +1620
928 +1602
929 +1651
930 +1653
931 +1669
932 +1638
933 +1696
934 +1649
935 +1675
936 +1660
937 +1683
938 +1666
939 +1671
940 +1703
941 +1716
942 +1637
943 +1672
944 +1676
945 +1692
946 +1711
947 +1680
948 +1641
949 +1688
950 +1708
951 +1704
952 +1690
953 +1674
954 +1718
955 +1699
956 +1723
957 +1756
958 +1700
959 +1662
960 +1715
961 +1657
962 +1733
963 +1728
964 +1670
965 +1712
966 +1685
967 +1724
968 +1735
969 +1714
970 +1730
971 +1747
972 +1656
973 +1737
974 +1705
975 +1693
976 +1713
977 +1689
978 +1753
979 +1739
980 +1721
981 +1725
982 +1749
983 +1732
984 +1743
985 +1731
986 +1767
987 +1738
988 +1831
989 +1771
990 +1726
991 +1746
992 +1776
993 +1775
994 +1799
995 +1774
996 +1780
997 +1781
998 +1769
999 +1805
1000 +1788
1001 +1801
This diff is collapsed. Click to expand it.
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Contains a collection of util functions for training and evaluating."""
15 +
16 +import numpy
17 +import tensorflow as tf
18 +from tensorflow import logging
19 +
20 +try:
21 + xrange # Python 2
22 +except NameError:
23 + xrange = range # Python 3
24 +
25 +
26 +def Dequantize(feat_vector, max_quantized_value=2, min_quantized_value=-2):
27 + """Dequantize the feature from the byte format to the float format.
28 +
29 + Args:
30 + feat_vector: the input 1-d vector.
31 + max_quantized_value: the maximum of the quantized value.
32 + min_quantized_value: the minimum of the quantized value.
33 +
34 + Returns:
35 + A float vector which has the same shape as feat_vector.
36 + """
37 + assert max_quantized_value > min_quantized_value
38 + quantized_range = max_quantized_value - min_quantized_value
39 + scalar = quantized_range / 255.0
40 + bias = (quantized_range / 512.0) + min_quantized_value
41 + return feat_vector * scalar + bias
42 +
43 +
44 +def MakeSummary(name, value):
45 + """Creates a tf.Summary proto with the given name and value."""
46 + summary = tf.Summary()
47 + val = summary.value.add()
48 + val.tag = str(name)
49 + val.simple_value = float(value)
50 + return summary
51 +
52 +
53 +def AddGlobalStepSummary(summary_writer,
54 + global_step_val,
55 + global_step_info_dict,
56 + summary_scope="Eval"):
57 + """Add the global_step summary to the Tensorboard.
58 +
59 + Args:
60 + summary_writer: Tensorflow summary_writer.
61 + global_step_val: a int value of the global step.
62 + global_step_info_dict: a dictionary of the evaluation metrics calculated for
63 + a mini-batch.
64 + summary_scope: Train or Eval.
65 +
66 + Returns:
67 + A string of this global_step summary
68 + """
69 + this_hit_at_one = global_step_info_dict["hit_at_one"]
70 + this_perr = global_step_info_dict["perr"]
71 + this_loss = global_step_info_dict["loss"]
72 + examples_per_second = global_step_info_dict.get("examples_per_second", -1)
73 +
74 + summary_writer.add_summary(
75 + MakeSummary("GlobalStep/" + summary_scope + "_Hit@1", this_hit_at_one),
76 + global_step_val)
77 + summary_writer.add_summary(
78 + MakeSummary("GlobalStep/" + summary_scope + "_Perr", this_perr),
79 + global_step_val)
80 + summary_writer.add_summary(
81 + MakeSummary("GlobalStep/" + summary_scope + "_Loss", this_loss),
82 + global_step_val)
83 +
84 + if examples_per_second != -1:
85 + summary_writer.add_summary(
86 + MakeSummary("GlobalStep/" + summary_scope + "_Example_Second",
87 + examples_per_second), global_step_val)
88 +
89 + summary_writer.flush()
90 + info = (
91 + "global_step {0} | Batch Hit@1: {1:.3f} | Batch PERR: {2:.3f} | Batch "
92 + "Loss: {3:.3f} | Examples_per_sec: {4:.3f}").format(
93 + global_step_val, this_hit_at_one, this_perr, this_loss,
94 + examples_per_second)
95 + return info
96 +
97 +
98 +def AddEpochSummary(summary_writer,
99 + global_step_val,
100 + epoch_info_dict,
101 + summary_scope="Eval"):
102 + """Add the epoch summary to the Tensorboard.
103 +
104 + Args:
105 + summary_writer: Tensorflow summary_writer.
106 + global_step_val: a int value of the global step.
107 + epoch_info_dict: a dictionary of the evaluation metrics calculated for the
108 + whole epoch.
109 + summary_scope: Train or Eval.
110 +
111 + Returns:
112 + A string of this global_step summary
113 + """
114 + epoch_id = epoch_info_dict["epoch_id"]
115 + avg_hit_at_one = epoch_info_dict["avg_hit_at_one"]
116 + avg_perr = epoch_info_dict["avg_perr"]
117 + avg_loss = epoch_info_dict["avg_loss"]
118 + aps = epoch_info_dict["aps"]
119 + gap = epoch_info_dict["gap"]
120 + mean_ap = numpy.mean(aps)
121 +
122 + summary_writer.add_summary(
123 + MakeSummary("Epoch/" + summary_scope + "_Avg_Hit@1", avg_hit_at_one),
124 + global_step_val)
125 + summary_writer.add_summary(
126 + MakeSummary("Epoch/" + summary_scope + "_Avg_Perr", avg_perr),
127 + global_step_val)
128 + summary_writer.add_summary(
129 + MakeSummary("Epoch/" + summary_scope + "_Avg_Loss", avg_loss),
130 + global_step_val)
131 + summary_writer.add_summary(
132 + MakeSummary("Epoch/" + summary_scope + "_MAP", mean_ap), global_step_val)
133 + summary_writer.add_summary(
134 + MakeSummary("Epoch/" + summary_scope + "_GAP", gap), global_step_val)
135 + summary_writer.flush()
136 +
137 + info = ("epoch/eval number {0} | Avg_Hit@1: {1:.3f} | Avg_PERR: {2:.3f} "
138 + "| MAP: {3:.3f} | GAP: {4:.3f} | Avg_Loss: {5:3f} | num_classes: {6}"
139 + ).format(epoch_id, avg_hit_at_one, avg_perr, mean_ap, gap, avg_loss,
140 + len(aps))
141 + return info
142 +
143 +
144 +def GetListOfFeatureNamesAndSizes(feature_names, feature_sizes):
145 + """Extract the list of feature names and the dimensionality of each feature
146 +
147 + from string of comma separated values.
148 +
149 + Args:
150 + feature_names: string containing comma separated list of feature names
151 + feature_sizes: string containing comma separated list of feature sizes
152 +
153 + Returns:
154 + List of the feature names and list of the dimensionality of each feature.
155 + Elements in the first/second list are strings/integers.
156 + """
157 + list_of_feature_names = [
158 + feature_names.strip() for feature_names in feature_names.split(",")
159 + ]
160 + list_of_feature_sizes = [
161 + int(feature_sizes) for feature_sizes in feature_sizes.split(",")
162 + ]
163 + if len(list_of_feature_names) != len(list_of_feature_sizes):
164 + logging.error("length of the feature names (=" +
165 + str(len(list_of_feature_names)) + ") != length of feature "
166 + "sizes (=" + str(len(list_of_feature_sizes)) + ")")
167 +
168 + return list_of_feature_names, list_of_feature_sizes
169 +
170 +
171 +def clip_gradient_norms(gradients_to_variables, max_norm):
172 + """Clips the gradients by the given value.
173 +
174 + Args:
175 + gradients_to_variables: A list of gradient to variable pairs (tuples).
176 + max_norm: the maximum norm value.
177 +
178 + Returns:
179 + A list of clipped gradient to variable pairs.
180 + """
181 + clipped_grads_and_vars = []
182 + for grad, var in gradients_to_variables:
183 + if grad is not None:
184 + if isinstance(grad, tf.IndexedSlices):
185 + tmp = tf.clip_by_norm(grad.values, max_norm)
186 + grad = tf.IndexedSlices(tmp, grad.indices, grad.dense_shape)
187 + else:
188 + grad = tf.clip_by_norm(grad, max_norm)
189 + clipped_grads_and_vars.append((grad, var))
190 + return clipped_grads_and_vars
191 +
192 +
193 +def combine_gradients(tower_grads):
194 + """Calculate the combined gradient for each shared variable across all towers.
195 +
196 + Note that this function provides a synchronization point across all towers.
197 +
198 + Args:
199 + tower_grads: List of lists of (gradient, variable) tuples. The outer list is
200 + over individual gradients. The inner list is over the gradient calculation
201 + for each tower.
202 +
203 + Returns:
204 + List of pairs of (gradient, variable) where the gradient has been summed
205 + across all towers.
206 + """
207 + filtered_grads = [
208 + [x for x in grad_list if x[0] is not None] for grad_list in tower_grads
209 + ]
210 + final_grads = []
211 + for i in xrange(len(filtered_grads[0])):
212 + grads = [filtered_grads[t][i] for t in xrange(len(filtered_grads))]
213 + grad = tf.stack([x[0] for x in grads], 0)
214 + grad = tf.reduce_sum(grad, 0)
215 + final_grads.append((
216 + grad,
217 + filtered_grads[0][i][1],
218 + ))
219 +
220 + return final_grads
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Contains model definitions."""
15 +import math
16 +
17 +import models
18 +import tensorflow as tf
19 +import utils
20 +
21 +from tensorflow import flags
22 +import tensorflow.contrib.slim as slim
23 +
24 +FLAGS = flags.FLAGS
25 +flags.DEFINE_integer(
26 + "moe_num_mixtures", 2,
27 + "The number of mixtures (excluding the dummy 'expert') used for MoeModel.")
28 +
29 +
30 +class LogisticModel(models.BaseModel):
31 + """Logistic model with L2 regularization."""
32 +
33 + def create_model(self,
34 + model_input,
35 + vocab_size,
36 + l2_penalty=1e-8,
37 + **unused_params):
38 + """Creates a logistic model.
39 +
40 + Args:
41 + model_input: 'batch' x 'num_features' matrix of input features.
42 + vocab_size: The number of classes in the dataset.
43 +
44 + Returns:
45 + A dictionary with a tensor containing the probability predictions of the
46 + model in the 'predictions' key. The dimensions of the tensor are
47 + batch_size x num_classes.
48 + """
49 + output = slim.fully_connected(
50 + model_input,
51 + vocab_size,
52 + activation_fn=tf.nn.sigmoid,
53 + weights_regularizer=slim.l2_regularizer(l2_penalty))
54 + return {"predictions": output}
55 +
56 +
57 +class MoeModel(models.BaseModel):
58 + """A softmax over a mixture of logistic models (with L2 regularization)."""
59 +
60 + def create_model(self,
61 + model_input,
62 + vocab_size,
63 + num_mixtures=None,
64 + l2_penalty=1e-8,
65 + **unused_params):
66 + """Creates a Mixture of (Logistic) Experts model.
67 +
68 + The model consists of a per-class softmax distribution over a
69 + configurable number of logistic classifiers. One of the classifiers in the
70 + mixture is not trained, and always predicts 0.
71 +
72 + Args:
73 + model_input: 'batch_size' x 'num_features' matrix of input features.
74 + vocab_size: The number of classes in the dataset.
75 + num_mixtures: The number of mixtures (excluding a dummy 'expert' that
76 + always predicts the non-existence of an entity).
77 + l2_penalty: How much to penalize the squared magnitudes of parameter
78 + values.
79 +
80 + Returns:
81 + A dictionary with a tensor containing the probability predictions of the
82 + model in the 'predictions' key. The dimensions of the tensor are
83 + batch_size x num_classes.
84 + """
85 + num_mixtures = num_mixtures or FLAGS.moe_num_mixtures
86 +
87 + gate_activations = slim.fully_connected(
88 + model_input,
89 + vocab_size * (num_mixtures + 1),
90 + activation_fn=None,
91 + biases_initializer=None,
92 + weights_regularizer=slim.l2_regularizer(l2_penalty),
93 + scope="gates")
94 + expert_activations = slim.fully_connected(
95 + model_input,
96 + vocab_size * num_mixtures,
97 + activation_fn=None,
98 + weights_regularizer=slim.l2_regularizer(l2_penalty),
99 + scope="experts")
100 +
101 + gating_distribution = tf.nn.softmax(
102 + tf.reshape(
103 + gate_activations,
104 + [-1, num_mixtures + 1])) # (Batch * #Labels) x (num_mixtures + 1)
105 + expert_distribution = tf.nn.sigmoid(
106 + tf.reshape(expert_activations,
107 + [-1, num_mixtures])) # (Batch * #Labels) x num_mixtures
108 +
109 + final_probabilities_by_class_and_batch = tf.reduce_sum(
110 + gating_distribution[:, :num_mixtures] * expert_distribution, 1)
111 + final_probabilities = tf.reshape(final_probabilities_by_class_and_batch,
112 + [-1, vocab_size])
113 + return {"predictions": final_probabilities}