Showing
1 changed file
with
91 additions
and
0 deletions
for_dataset/csv_to_tfrecord.py
0 → 100644
1 | +from __future__ import division | ||
2 | +from __future__ import print_function | ||
3 | +from __future__ import absolute_import | ||
4 | + | ||
5 | +import os | ||
6 | +import io | ||
7 | +import pandas as pd | ||
8 | +import tensorflow as tf | ||
9 | + | ||
10 | +from PIL import Image | ||
11 | +from object_detection.utils import dataset_util | ||
12 | +from collections import namedtuple, OrderedDict | ||
13 | + | ||
14 | +flags = tf.app.flags | ||
15 | +flags.DEFINE_string('csv_input', '', 'Path to the CSV input') | ||
16 | +flags.DEFINE_string('output_path', '', 'Path to output TFRecord') | ||
17 | +flags.DEFINE_string('image_dir', '', 'Path to images') | ||
18 | +FLAGS = flags.FLAGS | ||
19 | + | ||
20 | + | ||
21 | +# TO-DO replace this with label map | ||
22 | +def class_text_to_int(row_label): | ||
23 | + if row_label == 'fire': | ||
24 | + return 1 | ||
25 | + else: | ||
26 | + None | ||
27 | + | ||
28 | + | ||
29 | +def split(df, group): | ||
30 | + data = namedtuple('data', ['filename', 'object']) | ||
31 | + gb = df.groupby(group) | ||
32 | + return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)] | ||
33 | + | ||
34 | + | ||
35 | +def create_tf_example(group, path): | ||
36 | + with tf.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid: | ||
37 | + encoded_jpg = fid.read() | ||
38 | + encoded_jpg_io = io.BytesIO(encoded_jpg) | ||
39 | + image = Image.open(encoded_jpg_io) | ||
40 | + width, height = image.size | ||
41 | + | ||
42 | + filename = group.filename.encode('utf8') | ||
43 | + image_format = b'jpg' | ||
44 | + xmins = [] | ||
45 | + xmaxs = [] | ||
46 | + ymins = [] | ||
47 | + ymaxs = [] | ||
48 | + classes_text = [] | ||
49 | + classes = [] | ||
50 | + | ||
51 | + for index, row in group.object.iterrows(): | ||
52 | + xmins.append(row['xmin'] / width) | ||
53 | + xmaxs.append(row['xmax'] / width) | ||
54 | + ymins.append(row['ymin'] / height) | ||
55 | + ymaxs.append(row['ymax'] / height) | ||
56 | + classes_text.append(row['class'].encode('utf8')) | ||
57 | + classes.append(class_text_to_int(row['class'])) | ||
58 | + | ||
59 | + tf_example = tf.train.Example(features=tf.train.Features(feature={ | ||
60 | + 'image/height': dataset_util.int64_feature(height), | ||
61 | + 'image/width': dataset_util.int64_feature(width), | ||
62 | + 'image/filename': dataset_util.bytes_feature(filename), | ||
63 | + 'image/source_id': dataset_util.bytes_feature(filename), | ||
64 | + 'image/encoded': dataset_util.bytes_feature(encoded_jpg), | ||
65 | + 'image/format': dataset_util.bytes_feature(image_format), | ||
66 | + 'image/object/bbox/xmin': dataset_util.float_list_feature(xmins), | ||
67 | + 'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs), | ||
68 | + 'image/object/bbox/ymin': dataset_util.float_list_feature(ymins), | ||
69 | + 'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs), | ||
70 | + 'image/object/class/text': dataset_util.bytes_list_feature(classes_text), | ||
71 | + 'image/object/class/label': dataset_util.int64_list_feature(classes), | ||
72 | + })) | ||
73 | + return tf_example | ||
74 | + | ||
75 | + | ||
76 | +def main(_): | ||
77 | + writer = tf.python_io.TFRecordWriter(FLAGS.output_path) | ||
78 | + path = os.path.join(FLAGS.image_dir) | ||
79 | + examples = pd.read_csv(FLAGS.csv_input) | ||
80 | + grouped = split(examples, 'filename') | ||
81 | + for group in grouped: | ||
82 | + tf_example = create_tf_example(group, path) | ||
83 | + writer.write(tf_example.SerializeToString()) | ||
84 | + | ||
85 | + writer.close() | ||
86 | + output_path = os.path.join(os.getcwd(), FLAGS.output_path) | ||
87 | + print('Successfully created the TFRecords: {}'.format(output_path)) | ||
88 | + | ||
89 | + | ||
90 | +if __name__ == '__main__': | ||
91 | + tf.app.run() | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment