Showing
2 changed files
with
163 additions
and
0 deletions
Code/bird_classficate_example.py
0 → 100644
1 | +#***************************************************** | ||
2 | +# * | ||
3 | +# Copyright 2018 Amazon.com, Inc. or its affiliates. * | ||
4 | +# All Rights Reserved. * | ||
5 | +# * | ||
6 | +#***************************************************** | ||
7 | +""" A sample lambda for bird detection""" | ||
8 | +from threading import Thread, Event | ||
9 | +import os | ||
10 | +import json | ||
11 | +import numpy as np | ||
12 | +import awscam | ||
13 | +import cv2 | ||
14 | +import mo | ||
15 | +import greengrasssdk | ||
16 | + | ||
17 | +class LocalDisplay(Thread): | ||
18 | + """ Class for facilitating the local display of inference results | ||
19 | + (as images). The class is designed to run on its own thread. In | ||
20 | + particular the class dumps the inference results into a FIFO | ||
21 | + located in the tmp directory (which lambda has access to). The | ||
22 | + results can be rendered using mplayer by typing: | ||
23 | + mplayer -demuxer lavf -lavfdopts format=mjpeg:probesize=32 /tmp/results.mjpeg | ||
24 | + """ | ||
25 | + def __init__(self, resolution): | ||
26 | + """ resolution - Desired resolution of the project stream""" | ||
27 | + super(LocalDisplay, self).__init__() | ||
28 | + # List of valid resolutions | ||
29 | + RESOLUTION = {'1080p' : (1920, 1080), '720p' : (1280, 720), '480p' : (858, 480)} | ||
30 | + if resolution not in RESOLUTION: | ||
31 | + raise Exception("Invalid resolution") | ||
32 | + self.resolution = RESOLUTION[resolution] | ||
33 | + # Initialize the default image to be a white canvas. Clients | ||
34 | + # will update the image when ready. | ||
35 | + self.frame = cv2.imencode('.jpg', 255*np.ones([640, 480, 3]))[1] | ||
36 | + self.stop_request = Event() | ||
37 | + | ||
38 | + def run(self): | ||
39 | + """ Overridden method that continually dumps images to the desired | ||
40 | + FIFO file. | ||
41 | + """ | ||
42 | + # Path to the FIFO file. The lambda only has permissions to the tmp | ||
43 | + # directory. Pointing to a FIFO file in another directory | ||
44 | + # will cause the lambda to crash. | ||
45 | + result_path = '/tmp/results.mjpeg' | ||
46 | + # Create the FIFO file if it doesn't exist. | ||
47 | + if not os.path.exists(result_path): | ||
48 | + os.mkfifo(result_path) | ||
49 | + # This call will block until a consumer is available | ||
50 | + with open(result_path, 'w') as fifo_file: | ||
51 | + while not self.stop_request.isSet(): | ||
52 | + try: | ||
53 | + # Write the data to the FIFO file. This call will block | ||
54 | + # meaning the code will come to a halt here until a consumer | ||
55 | + # is available. | ||
56 | + fifo_file.write(self.frame.tobytes()) | ||
57 | + except IOError: | ||
58 | + continue | ||
59 | + | ||
60 | + def set_frame_data(self, frame): | ||
61 | + """ Method updates the image data. This currently encodes the | ||
62 | + numpy array to jpg but can be modified to support other encodings. | ||
63 | + frame - Numpy array containing the image data of the next frame | ||
64 | + in the project stream. | ||
65 | + """ | ||
66 | + ret, jpeg = cv2.imencode('.jpg', cv2.resize(frame, self.resolution)) | ||
67 | + if not ret: | ||
68 | + raise Exception('Failed to set frame data') | ||
69 | + self.frame = jpeg | ||
70 | + | ||
71 | + def join(self): | ||
72 | + self.stop_request.set() | ||
73 | + | ||
74 | +def infinite_infer_run(): | ||
75 | + """ Entry point of the lambda function""" | ||
76 | + try: | ||
77 | + # This bird detection model is implemented as multi classifier. The number of labels | ||
78 | + # is quite large so we upload them to a list to map the machine labels to human readable | ||
79 | + # labels. | ||
80 | + model_type = 'classification' | ||
81 | + with open('labels.txt', 'r') as labels_file: | ||
82 | + output_map = [class_label.rstrip() for class_label in labels_file] | ||
83 | + # Create an IoT client for sending to messages to the cloud. | ||
84 | + client = greengrasssdk.client('iot-data') | ||
85 | + iot_topic = '$aws/things/{}/infer'.format(os.environ['AWS_IOT_THING_NAME']) | ||
86 | + # Create a local display instance that will dump the image bytes to a FIFO | ||
87 | + # file that the image can be rendered locally. | ||
88 | + local_display = LocalDisplay('480p') | ||
89 | + local_display.start() | ||
90 | + # The height and width of the training set images | ||
91 | + input_height = 224 | ||
92 | + input_width = 224 | ||
93 | + # The sample projects come with optimized artifacts, hence only the artifact | ||
94 | + # path is required. | ||
95 | + ret, model_path = mo.optimize('bird_classification_resnet-18', input_width, | ||
96 | + input_height, 'mx') | ||
97 | + # Load the model onto the GPU. | ||
98 | + client.publish(topic=iot_topic, payload='Loading bird detection model') | ||
99 | + model = awscam.Model(model_path, {'GPU': 1}) | ||
100 | + client.publish(topic=iot_topic, payload='Bird detection loaded') | ||
101 | + # The number of top results to stream to IoT. | ||
102 | + num_top_k = 5 | ||
103 | + # Define the detection region size. | ||
104 | + region_size = 800 | ||
105 | + # Define the inference display region size. This size was decided based on the longest label. | ||
106 | + label_region_width = 940 | ||
107 | + label_region_height = 600 | ||
108 | + # Heading for the inference display. | ||
109 | + prediction_label = 'Top 5 bird predictions' | ||
110 | + # Do inference until the lambda is killed. | ||
111 | + while True: | ||
112 | + # Get a frame from the video stream | ||
113 | + ret, frame = awscam.getLastFrame() | ||
114 | + if not ret: | ||
115 | + raise Exception('Failed to get frame from the stream') | ||
116 | + # Crop the detection region for inference. | ||
117 | + frame_crop = frame[int(frame.shape[0]/2-region_size/2):int(frame.shape[0]/2+region_size/2), \ | ||
118 | + int(frame.shape[1]/2-region_size/2):int(frame.shape[1]/2+region_size/2), :] | ||
119 | + # Resize frame to the same size as the training set. | ||
120 | + frame_resize = cv2.resize(frame_crop, (input_height, input_width)) | ||
121 | + # Model was trained in RGB format but getLastFrame returns image | ||
122 | + # in BGR format so need to switch. | ||
123 | + frame_resize = cv2.cvtColor(frame_resize, cv2.COLOR_BGR2RGB) | ||
124 | + # Run the images through the inference engine and parse the results using | ||
125 | + # the parser API, note it is possible to get the output of doInference | ||
126 | + # and do the parsing manually, but since it is a classification model, | ||
127 | + # a simple API is provided. | ||
128 | + parsed_inference_results = model.parseResult(model_type, | ||
129 | + model.doInference(frame_resize)) | ||
130 | + # Get top k results with highest probabilities | ||
131 | + top_k = parsed_inference_results[model_type][0:num_top_k] | ||
132 | + # Create a copy of the original frame. | ||
133 | + overlay = frame.copy() | ||
134 | + # Create the rectangle that shows the inference results. | ||
135 | + cv2.rectangle(overlay, (0, 0), \ | ||
136 | + (int(label_region_width), int(label_region_height)), (211,211,211), -1) | ||
137 | + # Blend with the original frame. | ||
138 | + opacity = 0.7 | ||
139 | + cv2.addWeighted(overlay, opacity, frame, 1 - opacity, 0, frame) | ||
140 | + # Add the header for the inference results. | ||
141 | + cv2.putText(frame, prediction_label, (0, 50), | ||
142 | + cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), 4) | ||
143 | + # Add the label along with the probability of the top result to the frame used by local display. | ||
144 | + # See https://docs.opencv.org/3.4.1/d6/d6e/group__imgproc__draw.html | ||
145 | + # for more information about the cv2.putText method. | ||
146 | + # Method signature: image, text, origin, font face, font scale, color, and tickness | ||
147 | + for i in range(num_top_k): | ||
148 | + cv2.putText(frame, output_map[top_k[i]['label']] + ' ' + str(round(top_k[i]['prob'], 3) * 100) + '%', \ | ||
149 | + (0, 100*i+150), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (255, 0, 0), 3) | ||
150 | + # Display the detection region. | ||
151 | + cv2.rectangle(frame, (int(frame.shape[1]/2-region_size/2), int(frame.shape[0]/2-region_size/2)), \ | ||
152 | + (int(frame.shape[1]/2+region_size/2), int(frame.shape[0]/2+region_size/2)), (255,0,0), 5) | ||
153 | + # Set the next frame in the local display stream. | ||
154 | + local_display.set_frame_data(frame) | ||
155 | + # Send the top k results to the IoT console via MQTT | ||
156 | + cloud_output = {} | ||
157 | + for obj in top_k: | ||
158 | + cloud_output[output_map[obj['label']]] = obj['prob'] | ||
159 | + client.publish(topic=iot_topic, payload=json.dumps(cloud_output)) | ||
160 | + except Exception as ex: | ||
161 | + client.publish(topic=iot_topic, payload='Error in bird detection lambda: {}'.format(ex)) | ||
162 | + | ||
163 | +infinite_infer_run() | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
Report/wk12 주간보고서.docx
0 → 100644
No preview for this file type
-
Please register or login to post a comment