김성주

added TFRecord writer

1 +import tensorflow as tf
2 +
3 +def bytes_feature(values):
4 + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
5 +
6 +def int64_feature(values):
7 + if not isinstance(values, (tuple, list)):
8 + values = [values]
9 +
10 + return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
11 +
12 +def readImage(path):
13 + file = open(path, 'rb')
14 + byte = file.read()
15 + return byte
16 +
17 +def main():
18 + ANNOTATION_PATH = 'train.txt' #annotation set (train/val/test) text file
19 + IMAGE_DIRECTORY = 'image_data/' #image directory
20 + SAVE_PATH = 'train.tfrecord' #save path for tfrecord
21 +
22 + print('Tensorflow version:', tf.__version__) #tensorflow version should be 1.x
23 +
24 + file = open(ANNOTATION_PATH, 'r')
25 + lines = file.readlines()
26 + file.close()
27 +
28 + options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP) #compress option
29 + writer = tf.python_io.TFRecordWriter(path=SAVE_PATH, options=options)
30 +
31 + for line in lines:
32 + parsed = line.split(' ')
33 + print('Current Doing...', parsed[1]) #debug messages
34 + image = readImage(IMAGE_DIRECTORY + '/' + parsed[1])
35 + boxes = []
36 +
37 + for i in range(4, len(parsed)):
38 + boxes.append(int(parsed[i]))
39 +
40 + data = tf.train.Example(features=tf.train.Features(feature={
41 + 'index': int64_feature(int(parsed[0])),
42 + 'image': bytes_feature(image),
43 + 'width': int64_feature(int(parsed[2])),
44 + 'height': int64_feature(int(parsed[3])),
45 + 'boxes': int64_feature(boxes) # boxes = [label1, xmin1, ymin1, xmax1, ymax1, label2, xmax2, ...]
46 + }))
47 +
48 + writer.write(data.SerializeToString())
49 +
50 + writer.close()
51 +
52 +if __name__ == '__main__':
53 + main()
...\ No newline at end of file ...\ No newline at end of file