이동찬

add python code for making csv to tfrecord

1 +from __future__ import division
2 +from __future__ import print_function
3 +from __future__ import absolute_import
4 +
5 +import os
6 +import io
7 +import pandas as pd
8 +import tensorflow as tf
9 +
10 +from PIL import Image
11 +from object_detection.utils import dataset_util
12 +from collections import namedtuple, OrderedDict
13 +
14 +flags = tf.app.flags
15 +flags.DEFINE_string('csv_input', '', 'Path to the CSV input')
16 +flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
17 +flags.DEFINE_string('image_dir', '', 'Path to images')
18 +FLAGS = flags.FLAGS
19 +
20 +
21 +# TO-DO replace this with label map
22 +def class_text_to_int(row_label):
23 + if row_label == 'fire':
24 + return 1
25 + else:
26 + None
27 +
28 +
29 +def split(df, group):
30 + data = namedtuple('data', ['filename', 'object'])
31 + gb = df.groupby(group)
32 + return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]
33 +
34 +
35 +def create_tf_example(group, path):
36 + with tf.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:
37 + encoded_jpg = fid.read()
38 + encoded_jpg_io = io.BytesIO(encoded_jpg)
39 + image = Image.open(encoded_jpg_io)
40 + width, height = image.size
41 +
42 + filename = group.filename.encode('utf8')
43 + image_format = b'jpg'
44 + xmins = []
45 + xmaxs = []
46 + ymins = []
47 + ymaxs = []
48 + classes_text = []
49 + classes = []
50 +
51 + for index, row in group.object.iterrows():
52 + xmins.append(row['xmin'] / width)
53 + xmaxs.append(row['xmax'] / width)
54 + ymins.append(row['ymin'] / height)
55 + ymaxs.append(row['ymax'] / height)
56 + classes_text.append(row['class'].encode('utf8'))
57 + classes.append(class_text_to_int(row['class']))
58 +
59 + tf_example = tf.train.Example(features=tf.train.Features(feature={
60 + 'image/height': dataset_util.int64_feature(height),
61 + 'image/width': dataset_util.int64_feature(width),
62 + 'image/filename': dataset_util.bytes_feature(filename),
63 + 'image/source_id': dataset_util.bytes_feature(filename),
64 + 'image/encoded': dataset_util.bytes_feature(encoded_jpg),
65 + 'image/format': dataset_util.bytes_feature(image_format),
66 + 'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
67 + 'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
68 + 'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
69 + 'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
70 + 'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
71 + 'image/object/class/label': dataset_util.int64_list_feature(classes),
72 + }))
73 + return tf_example
74 +
75 +
76 +def main(_):
77 + writer = tf.python_io.TFRecordWriter(FLAGS.output_path)
78 + path = os.path.join(FLAGS.image_dir)
79 + examples = pd.read_csv(FLAGS.csv_input)
80 + grouped = split(examples, 'filename')
81 + for group in grouped:
82 + tf_example = create_tf_example(group, path)
83 + writer.write(tf_example.SerializeToString())
84 +
85 + writer.close()
86 + output_path = os.path.join(os.getcwd(), FLAGS.output_path)
87 + print('Successfully created the TFRecords: {}'.format(output_path))
88 +
89 +
90 +if __name__ == '__main__':
91 + tf.app.run()
...\ No newline at end of file ...\ No newline at end of file