Showing
4 changed files
with
747 additions
and
0 deletions
source_code/detect.py
0 → 100644
1 | +from __future__ import division | ||
2 | + | ||
3 | +from roipool2 import * | ||
4 | +from models import * | ||
5 | +from utils.utils import * | ||
6 | +from utils.datasets import * | ||
7 | +from video_capture import BufferlessVideoCapture | ||
8 | + | ||
9 | +import serial | ||
10 | +import os | ||
11 | +import sys | ||
12 | +import time | ||
13 | +import datetime | ||
14 | +import argparse | ||
15 | +import cv2 | ||
16 | + | ||
17 | +from PIL import Image | ||
18 | + | ||
19 | +import torch | ||
20 | +from torch.utils.data import DataLoader | ||
21 | +from torchvision import datasets | ||
22 | +from torch.autograd import Variable | ||
23 | + | ||
24 | +import matplotlib.pyplot as plt | ||
25 | +import matplotlib.patches as patches | ||
26 | +from matplotlib.ticker import NullLocator | ||
27 | + | ||
28 | +def changeRGB2BGR(img): | ||
29 | + r = img[:, :, 0].copy() | ||
30 | + g = img[:, :, 1].copy() | ||
31 | + b = img[:, :, 2].copy() | ||
32 | + | ||
33 | + # RGB > BGR | ||
34 | + img[:, :, 0] = b | ||
35 | + img[:, :, 1] = g | ||
36 | + img[:, :, 2] = r | ||
37 | + | ||
38 | + return img | ||
39 | + | ||
40 | +def changeBGR2RGB(img): | ||
41 | + b = img[:, :, 0].copy() | ||
42 | + g = img[:, :, 1].copy() | ||
43 | + r = img[:, :, 2].copy() | ||
44 | + | ||
45 | + img[:, :, 0] = r | ||
46 | + img[:, :, 1] = g | ||
47 | + img[:, :, 2] = b | ||
48 | + | ||
49 | + return img | ||
50 | + | ||
51 | + | ||
52 | +if __name__ == "__main__": | ||
53 | + parser = argparse.ArgumentParser() | ||
54 | + parser.add_argument("--image_folder", type=str, default="data/cafe_distance/1.jpg", help="path to dataset") | ||
55 | + parser.add_argument("--video_file", type=str, default="0", help="path to dataset") | ||
56 | + parser.add_argument("--model_def", type=str, default="config/yolov3-tiny.cfg", help="path to model definition file") | ||
57 | + # parser.add_argument("--weights_path", type=str, default="weights/yolov3-tiny.weights", help="path to weights file") | ||
58 | + parser.add_argument("--weights_path", type=str, default="checkpoints_yolo/tiny1_2500.pth", help="path to weights file") | ||
59 | + parser.add_argument("--class_path", type=str, default="data/cafe_distance/classes.names", help="path to class label file") | ||
60 | + parser.add_argument("--conf_thres", type=float, default=0.8, help="object confidence threshold") | ||
61 | + parser.add_argument("--nms_thres", type=float, default=0.4, help="iou thresshold for non-maximum suppression") | ||
62 | + parser.add_argument("--batch_size", type=int, default=1, help="size of the batches") | ||
63 | + parser.add_argument("--n_cpu", type=int, default=0, help="number of cpu threads to use during batch generation") | ||
64 | + parser.add_argument("--img_size", type=int, default=416, help="size of each image dimension") | ||
65 | + parser.add_argument("--checkpoint_model", type=str, help="path to checkpoint model") | ||
66 | + parser.add_argument("--target_object", type=int, default=0) | ||
67 | + opt = parser.parse_args() | ||
68 | + print(opt) | ||
69 | + | ||
70 | + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
71 | + | ||
72 | + os.makedirs("output", exist_ok=True) | ||
73 | + | ||
74 | + sclient = serial.Serial(port='/dev/ttyAMA0', baudrate=115200, timeout=0.1) | ||
75 | + if sclient.isOpen(): | ||
76 | + print('Serial is Open') | ||
77 | + | ||
78 | + # Set up model | ||
79 | + model = Darknet(opt.model_def, img_size=opt.img_size).to(device) | ||
80 | + model_parameters = filter(lambda p: p.requires_grad, model.parameters()) | ||
81 | + params = sum([np.prod(p.size()) for p in model_parameters]) | ||
82 | + print('Params: ', params) | ||
83 | + | ||
84 | + if opt.weights_path.endswith(".weights"): | ||
85 | + # Load darknet weights | ||
86 | + model.load_darknet_weights(opt.weights_path) | ||
87 | + else: | ||
88 | + # Load checkpoint weights | ||
89 | + model.load_state_dict(torch.load(opt.weights_path, map_location=device)) | ||
90 | + | ||
91 | + model.eval() # Set in evaluation mode | ||
92 | + | ||
93 | + model_distance = ROIPool((3, 3)).to(device) | ||
94 | + model_distance.load_state_dict(torch.load('checkpoints_distance/tiny1_340.pth', map_location=device)) | ||
95 | + model_distance.eval() | ||
96 | + | ||
97 | + dataloader = DataLoader( | ||
98 | + ImageFolder(opt.image_folder, img_size=opt.img_size), | ||
99 | + batch_size=opt.batch_size, | ||
100 | + shuffle=False, | ||
101 | + num_workers=opt.n_cpu, | ||
102 | + ) | ||
103 | + | ||
104 | + classes = load_classes(opt.class_path) # Extracts class labels from file | ||
105 | + | ||
106 | + Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor | ||
107 | + | ||
108 | + cap = BufferlessVideoCapture(0) | ||
109 | + # cap = cv2.VideoCapture('data/cafe_distance/videos/output17.avi') | ||
110 | + colors = np.random.randint(0, 255, size=(len(classes), 3), dtype="uint8") | ||
111 | + a=[] | ||
112 | + time_begin = time.time() | ||
113 | + NUM = cap.get(cv2.CAP_PROP_FRAME_COUNT) | ||
114 | + | ||
115 | + fourcc = cv2.VideoWriter_fourcc('D', 'I', 'V', 'X') | ||
116 | + out = cv2.VideoWriter('output/distance3.avi', fourcc, 30, (640,480)) | ||
117 | + | ||
118 | + mode = 0 | ||
119 | + while cap.isOpened(): | ||
120 | + ret, img = cap.read() | ||
121 | + if ret is False: | ||
122 | + break | ||
123 | + # img = cv2.resize(img, (1280, 960), interpolation=cv2.INTER_CUBIC) | ||
124 | + | ||
125 | + RGBimg=changeBGR2RGB(img) | ||
126 | + imgTensor = transforms.ToTensor()(RGBimg) | ||
127 | + imgTensor, _ = pad_to_square(imgTensor, 0) | ||
128 | + imgTensor = resize(imgTensor, 416) | ||
129 | + | ||
130 | + imgTensor = imgTensor.unsqueeze(0) | ||
131 | + imgTensor = Variable(imgTensor.type(Tensor)) | ||
132 | + | ||
133 | + | ||
134 | + with torch.no_grad(): | ||
135 | + # prev_time = time.time() | ||
136 | + | ||
137 | + featuremap, detections = model(imgTensor) | ||
138 | + | ||
139 | + | ||
140 | + # print(featuremap) | ||
141 | + | ||
142 | + # current_time = time.time() | ||
143 | + # sec = current_time - prev_time | ||
144 | + # fps = 1/sec | ||
145 | + # frame_per_sec = "FPS: %0.1f" % fps | ||
146 | + # print(frame_per_sec) | ||
147 | + | ||
148 | + | ||
149 | + detections = non_max_suppression(detections, opt.conf_thres, opt.nms_thres) | ||
150 | + # print(f'none test = {detections}') | ||
151 | + | ||
152 | + | ||
153 | + | ||
154 | + a.clear() | ||
155 | + if detections is not None and detections[0] is not None: | ||
156 | + # print(detections) | ||
157 | + featuremap = Variable(featuremap.to(device)) | ||
158 | + detects = Variable(detections[0], requires_grad=False) | ||
159 | + # print(f'detects = {detects}') | ||
160 | + # print(f'featuremap = {featuremap.shape}') | ||
161 | + outputs = model_distance(featuremap, detects) | ||
162 | + print(f'distance = {outputs}') | ||
163 | + | ||
164 | + a.extend(detections) | ||
165 | + if len(a): | ||
166 | + for detections in a: | ||
167 | + | ||
168 | + if detections is not None: | ||
169 | + # print(detections) | ||
170 | + | ||
171 | + detections = rescale_boxes(detections, opt.img_size, RGBimg.shape[:2]) | ||
172 | + # print(detections) | ||
173 | + unique_labels = detections[:, -1].cpu().unique() | ||
174 | + n_cls_preds = len(unique_labels) | ||
175 | + for i, (x1, y1, x2, y2, conf, cls_conf, cls_pred) in enumerate(detections): | ||
176 | + if(classes[int(cls_pred)] == opt.target_object): | ||
177 | + | ||
178 | + target_distance = float(outputs[i]) | ||
179 | + if(mode == 0): | ||
180 | + if target_distance > 8: | ||
181 | + sclient.write(serial.to_bytes([int('1', 16)])) | ||
182 | + break | ||
183 | + else: | ||
184 | + mode = 1 | ||
185 | + break | ||
186 | + elif(mode == 1): | ||
187 | + box_w = x2 - x1 | ||
188 | + target_location = int(x1+box_w/2) | ||
189 | + if target_location < 300: | ||
190 | + sclient.write(serial.to_bytes([int('2', 16)])) | ||
191 | + break | ||
192 | + elif target_location > 340: | ||
193 | + sclient.write(serial.to_bytes([int('3', 16)])) | ||
194 | + break | ||
195 | + else: | ||
196 | + sclient.write(serial.to_bytes([int('4', 16)])) | ||
197 | + break | ||
198 | + | ||
199 | + | ||
200 | + | ||
201 | + | ||
202 | + #box_w = x2 - x1 | ||
203 | + # print(box_w) | ||
204 | + #box_h = y2 - y1 | ||
205 | + # print(y2, y1) | ||
206 | + # color = [int(c) for c in colors[int(cls_pred)]] | ||
207 | + #print(cls_conf) | ||
208 | + # img = cv2.rectangle(img, (x1, y1 + box_h), (x2, y1), color, 2) | ||
209 | + | ||
210 | + # cv2.putText(img, classes[int(cls_pred)], (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) | ||
211 | + # cv2.putText(img, str("%.2f" % float(outputs[i])), (x2, y2 - box_h), cv2.FONT_HERSHEY_SIMPLEX, 0.5, | ||
212 | + # color, 2) | ||
213 | + | ||
214 | + # print(classes[int(cls_pred)], int(x1+box_w/2), int(480-(y1+box_h/2))) | ||
215 | + | ||
216 | + #print() | ||
217 | + #print() | ||
218 | + #cv2.putText(img,"Hello World!",(400,50),cv2.FONT_HERSHEY_PLAIN,2.0,(0,0,255),2) | ||
219 | + | ||
220 | + # cv2.imshow('frame', changeRGB2BGR(RGBimg)) | ||
221 | + # out.write(changeRGB2BGR(RGBimg)) | ||
222 | + #cv2.waitKey(0) | ||
223 | + | ||
224 | + if cv2.waitKey(1) & 0xFF == ord('q'): | ||
225 | + break | ||
226 | + time_end = time.time() | ||
227 | + time_total = time_end - time_begin | ||
228 | + print(NUM // time_total) | ||
229 | + | ||
230 | + sclient.close() | ||
231 | + cap.release() | ||
232 | + out.release() | ||
233 | + cv2.destroyAllWindows() | ||
234 | + | ||
235 | + | ||
236 | + | ||
237 | + | ||
238 | + | ||
239 | + | ||
240 | + | ||
241 | + | ||
242 | + | ||
243 | + | ||
244 | +''' | ||
245 | + capture = cv2.VideoCapture("data/cafe/9.mp4") | ||
246 | + capture.set(cv2.CAP_PROP_FRAME_WIDTH, 416) | ||
247 | + capture.set(cv2.CAP_PROP_FRAME_HEIGHT, 416) | ||
248 | + capture.set(cv2.CAP_PROP_FPS, 3) | ||
249 | + | ||
250 | + colors = np.random.randint(0, 255, size=(len(classes), 3), dtype="uint8") | ||
251 | + capture.set(5, 5) | ||
252 | + print(capture.get(cv2.CAP_PROP_FRAME_WIDTH), capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) | ||
253 | + print("FPS: ", capture.get(5)) | ||
254 | + startTime = time.time() | ||
255 | + a=[] | ||
256 | + while capture.isOpened(): | ||
257 | + ret, frame = capture.read() | ||
258 | + # print() | ||
259 | + nowTime = time.time() | ||
260 | + | ||
261 | + PILimg = np.array(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))) | ||
262 | + # RGBimg = changeBGR2RGB(frame) | ||
263 | + | ||
264 | + | ||
265 | + | ||
266 | + imgTensor = transforms.ToTensor()(PILimg) | ||
267 | + imgTensor, _ = pad_to_square(imgTensor, 0) | ||
268 | + imgTensor = resize(imgTensor, 416) | ||
269 | + imgTensor = imgTensor.unsqueeze(0) | ||
270 | + imgTensor = Variable(imgTensor.type(Tensor)) | ||
271 | + | ||
272 | + with torch.no_grad(): | ||
273 | + prev_time = time.time() | ||
274 | + detections = model(imgTensor) | ||
275 | + current_time = time.time() | ||
276 | + | ||
277 | + sec = current_time - prev_time | ||
278 | + fps = 1/sec | ||
279 | + frame_per_sec = "FPS: %0.1f" % fps | ||
280 | + # inference_time = datetime.timedelta(seconds=current_time - prev_time) | ||
281 | + prev_time = current_time | ||
282 | + | ||
283 | + red = (0, 0, 255) | ||
284 | + cv2.putText(frame, frame_per_sec, (25, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.5, red, 2) | ||
285 | + detections = non_max_suppression(detections, opt.conf_thres, opt.nms_thres) | ||
286 | + | ||
287 | + a.clear() | ||
288 | + if detections is not None: | ||
289 | + a.extend(detections) | ||
290 | + b=len(a) | ||
291 | + if len(a): | ||
292 | + for detections in a: | ||
293 | + if detections is not None: | ||
294 | + detections = rescale_boxes(detections, opt.img_size, PILimg.shape[:2]) | ||
295 | + unique_labels = detections[:, -1].cpu().unique() | ||
296 | + n_cls_preds = len(unique_labels) | ||
297 | + for x1, y1, x2, y2, conf, cls_conf, cls_pred in detections: | ||
298 | + if classes[int(cls_pred)] == 'shrimp cracker': | ||
299 | + | ||
300 | + box_w = x2 - x1 | ||
301 | + box_h = y2 - y1 | ||
302 | + color = [int(c) for c in colors[int(cls_pred)]] | ||
303 | + # print(cls_conf) | ||
304 | + frame = cv2.rectangle(frame, (x1, y1 + box_h), (x2, y1), color, 2) | ||
305 | + cv2.putText(frame, classes[int(cls_pred)], (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) | ||
306 | + cv2.putText(frame, str("%.2f" % float(conf)), (x2, y2 - box_h), cv2.FONT_HERSHEY_SIMPLEX, 0.5, | ||
307 | + color, 2) | ||
308 | + print(classes[int(cls_pred)], int(x1+box_w/2), int(224-(y1+box_h/2))) | ||
309 | + | ||
310 | + print() | ||
311 | + #cv2.putText(img,"Hello World!",(400,50),cv2.FONT_HERSHEY_PLAIN,2.0,(0,0,255),2) | ||
312 | + #cv2.namedWindow('frame', cv2.WINDOW_NORMAL) | ||
313 | + cv2.imshow('frame', frame) | ||
314 | + | ||
315 | + #cv2.waitKey(0) | ||
316 | + | ||
317 | + if cv2.waitKey(25) & 0xFF == ord('q'): | ||
318 | + break | ||
319 | + capture.release() | ||
320 | + cv2.destroyAllWindows() | ||
321 | + | ||
322 | +''' | ||
323 | + | ||
324 | +''' | ||
325 | + imgs = [] # Stores image paths | ||
326 | + img_detections = [] # Stores detections for each image index | ||
327 | + print('parameter count: ', count_parameters(model)) | ||
328 | + print("\nPerforming object detection:") | ||
329 | + prev_time = time.time() | ||
330 | + for batch_i, (img_paths, input_imgs) in enumerate(dataloader): | ||
331 | + # Configure input | ||
332 | + input_imgs = Variable(input_imgs.type(Tensor)) | ||
333 | + | ||
334 | + # Get detections | ||
335 | + with torch.no_grad(): | ||
336 | + detections = model(input_imgs) | ||
337 | + detections = non_max_suppression(detections, opt.conf_thres, opt.nms_thres) | ||
338 | + | ||
339 | + # Log progress | ||
340 | + current_time = time.time() | ||
341 | + inference_time = datetime.timedelta(seconds=current_time - prev_time) | ||
342 | + prev_time = current_time | ||
343 | + print("\t+ Batch %d, Inference Time: %s" % (batch_i, inference_time)) | ||
344 | + | ||
345 | + # Save image and detections | ||
346 | + imgs.extend(img_paths) | ||
347 | + img_detections.extend(detections) | ||
348 | + | ||
349 | + # Bounding-box colors | ||
350 | + cmap = plt.get_cmap("tab20b") | ||
351 | + colors = [cmap(i) for i in np.linspace(0, 1, 20)] | ||
352 | + | ||
353 | + print("\nSaving images:") | ||
354 | + # Iterate through images and save plot of detections | ||
355 | + for img_i, (path, detections) in enumerate(zip(imgs, img_detections)): | ||
356 | + | ||
357 | + print("(%d) Image: '%s'" % (img_i, path)) | ||
358 | + | ||
359 | + # Create plot | ||
360 | + img = np.array(Image.open(path)) | ||
361 | + plt.figure() | ||
362 | + fig, ax = plt.subplots(1) | ||
363 | + ax.imshow(img) | ||
364 | + | ||
365 | + # Draw bounding boxes and labels of detections | ||
366 | + if detections is not None: | ||
367 | + # Rescale boxes to original image | ||
368 | + detections = rescale_boxes(detections, opt.img_size, img.shape[:2]) | ||
369 | + unique_labels = detections[:, -1].cpu().unique() | ||
370 | + n_cls_preds = len(unique_labels) | ||
371 | + bbox_colors = random.sample(colors, n_cls_preds) | ||
372 | + for x1, y1, x2, y2, conf, cls_conf, cls_pred in detections: | ||
373 | + | ||
374 | + print("\t+ Label: %s, Conf: %.5f" % (classes[int(cls_pred)], cls_conf.item())) | ||
375 | + | ||
376 | + box_w = x2 - x1 | ||
377 | + box_h = y2 - y1 | ||
378 | + | ||
379 | + | ||
380 | + color = bbox_colors[int(np.where(unique_labels == int(cls_pred))[0])] | ||
381 | + # Create a Rectangle patch | ||
382 | + bbox = patches.Rectangle((x1, y1), box_w, box_h, linewidth=2, edgecolor=color, facecolor="none") | ||
383 | + # Add the bbox to the plot | ||
384 | + ax.add_patch(bbox) | ||
385 | + # Add label | ||
386 | + plt.text( | ||
387 | + x1, | ||
388 | + y1, | ||
389 | + s=str(classes[int(cls_pred)])+' '+str(int(x1+box_w/2))+ ', '+str(int(y1+box_h/2)), | ||
390 | + color="white", | ||
391 | + verticalalignment="top", | ||
392 | + bbox={"color": color, "pad": 0}, | ||
393 | + ) | ||
394 | + | ||
395 | + # Save generated image with detections | ||
396 | + plt.axis("off") | ||
397 | + plt.gca().xaxis.set_major_locator(NullLocator()) | ||
398 | + plt.gca().yaxis.set_major_locator(NullLocator()) | ||
399 | + filename = path.split("/")[-1].split("\\")[-1].split(".")[0] | ||
400 | + plt.savefig(f"output/{filename}.png", bbox_inches="tight", pad_inches=0.0) | ||
401 | + plt.close() | ||
402 | + | ||
403 | +''' |
source_code/roitrain.py
0 → 100644
1 | +from __future__ import division | ||
2 | + | ||
3 | +from roipool2 import * | ||
4 | +from models import * | ||
5 | +from utils.utils import * | ||
6 | +from utils.datasets import * | ||
7 | +from utils.parse_config import * | ||
8 | +# from test import evaluate | ||
9 | + | ||
10 | +from terminaltables import AsciiTable | ||
11 | + | ||
12 | +import os | ||
13 | +import sys | ||
14 | +import time | ||
15 | +import datetime | ||
16 | +import argparse | ||
17 | +import warnings | ||
18 | + | ||
19 | +import torch | ||
20 | +from torch.utils.data import DataLoader | ||
21 | +from torchvision import datasets | ||
22 | +from torchvision import transforms | ||
23 | +from torch.autograd import Variable | ||
24 | +import torch.optim as optim | ||
25 | +import warnings | ||
26 | +warnings.filterwarnings("ignore", category=UserWarning) | ||
27 | + | ||
28 | + | ||
29 | + | ||
30 | + | ||
31 | + | ||
32 | +if __name__ == '__main__': | ||
33 | + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
34 | + print('device: ', device) | ||
35 | + | ||
36 | + data_config = parse_data_config('config/cafe_distance.data') | ||
37 | + train_path = data_config["train"] | ||
38 | + valid_path = data_config["valid"] | ||
39 | + class_names = load_classes(data_config["names"]) | ||
40 | + | ||
41 | + model = Darknet('config/yolov3-tiny.cfg', 416).to(device) | ||
42 | + | ||
43 | + model.load_state_dict(torch.load('checkpoints_cafe_distance/tiny1_2500.pth', map_location=device)) | ||
44 | + model.eval() | ||
45 | + | ||
46 | + dataset = ListDataset(train_path, augment=True, multiscale=True) | ||
47 | + dataloader = torch.utils.data.DataLoader( | ||
48 | + dataset, | ||
49 | + batch_size=1, | ||
50 | + shuffle=True, | ||
51 | + num_workers=4, | ||
52 | + pin_memory=True, | ||
53 | + collate_fn=dataset.collate_fn, | ||
54 | + ) | ||
55 | + | ||
56 | + model_distance = ROIPool((3, 3)).to(device) | ||
57 | + model_parameters = filter(lambda p: p.requires_grad, model_distance.parameters()) | ||
58 | + params = sum([np.prod(p.size()) for p in model_parameters]) | ||
59 | + print('Params: ', params) | ||
60 | + | ||
61 | + optimizer = torch.optim.Adam(model_distance.parameters()) | ||
62 | + | ||
63 | + | ||
64 | + a = [] | ||
65 | + for epoch in range(2000): | ||
66 | + | ||
67 | + warnings.filterwarnings('ignore', category=UserWarning) | ||
68 | + for batch_i, (img_path, imgs, targets, targets_distance) in enumerate(dataloader): | ||
69 | + | ||
70 | + | ||
71 | + imgs = Variable(imgs.to(device)) | ||
72 | + with torch.no_grad(): | ||
73 | + | ||
74 | + featuremap, detections = model(imgs) | ||
75 | + # print(featuremap.shape) | ||
76 | + featuremap = Variable(featuremap.to(device)) | ||
77 | + | ||
78 | + detections = non_max_suppression(detections, 0.8, 0.4) | ||
79 | + targets_distance = torch.tensor(targets_distance[0]) | ||
80 | + targets_distance = Variable(targets_distance, requires_grad=True) | ||
81 | + | ||
82 | + | ||
83 | + | ||
84 | + if detections is not None: | ||
85 | + detections[0] = Variable(detections[0], requires_grad=True) | ||
86 | + | ||
87 | + | ||
88 | + loss, outputs = model_distance(featuremap, detections[0], targets=targets_distance) | ||
89 | + # loss = torch.tensor([loss]).to(device) | ||
90 | + # loss.requires_grad = True | ||
91 | + # print(model_distance.fc1.bias) | ||
92 | + optimizer.zero_grad() | ||
93 | + loss.backward() | ||
94 | + optimizer.step() | ||
95 | + # print(model_distance.fc1.bias) | ||
96 | + | ||
97 | + # print(batch_i) | ||
98 | + print(epoch) | ||
99 | + | ||
100 | + # print(featuremap) | ||
101 | + if epoch % 10 == 0: | ||
102 | + optimizer.param_groups[0]['lr'] /= 2 | ||
103 | + | ||
104 | + if epoch % 10 == 0: | ||
105 | + torch.save(model_distance.state_dict(), f'checkpoints_distance11/tiny1_{epoch}.pth') |
source_code/train.py
0 → 100644
1 | +from __future__ import division | ||
2 | + | ||
3 | +from models import * | ||
4 | +from roipool import * | ||
5 | +# from utils.logger import * | ||
6 | +from utils.utils import * | ||
7 | +from utils.datasets import * | ||
8 | +from utils.parse_config import * | ||
9 | +# from test import evaluate | ||
10 | + | ||
11 | +from terminaltables import AsciiTable | ||
12 | + | ||
13 | +import os | ||
14 | +import sys | ||
15 | +import time | ||
16 | +import datetime | ||
17 | +import argparse | ||
18 | +import warnings | ||
19 | + | ||
20 | +import torch | ||
21 | +from torch.utils.data import DataLoader | ||
22 | +from torchvision import datasets | ||
23 | +from torchvision import transforms | ||
24 | +from torch.autograd import Variable | ||
25 | +import torch.optim as optim | ||
26 | +import warnings | ||
27 | +warnings.filterwarnings("ignore", category=UserWarning) | ||
28 | + | ||
29 | +if __name__ == "__main__": | ||
30 | + warnings.filterwarnings("ignore", category=UserWarning) | ||
31 | + parser = argparse.ArgumentParser() | ||
32 | + parser.add_argument("--epochs", type=int, default=8001, help="number of epochs") | ||
33 | + parser.add_argument("--batch_size", type=int, default=1, help="size of each image batch") | ||
34 | + parser.add_argument("--gradient_accumulations", type=int, default=2, help="number of gradient accums before step") | ||
35 | + parser.add_argument("--model_def", type=str, default="config/yolov3-tiny.cfg", help="path to model definition file") | ||
36 | + parser.add_argument("--data_config", type=str, default="config/testdata.data", help="path to data config file") | ||
37 | + parser.add_argument("--pretrained_weights", type=str, help="if specified starts from checkpoint model") | ||
38 | + parser.add_argument("--n_cpu", type=int, default=4, help="number of cpu threads to use during batch generation") | ||
39 | + parser.add_argument("--img_size", type=int, default=416, help="size of each image dimension") | ||
40 | + parser.add_argument("--checkpoint_interval", type=int, default=50, help="interval between saving model weights") | ||
41 | + parser.add_argument("--evaluation_interval", type=int, default=10000, help="interval evaluations on validation set") | ||
42 | + parser.add_argument("--compute_map", default=False, help="if True computes mAP every tenth batch") | ||
43 | + parser.add_argument("--multiscale_training", default=True, help="allow for multi-scale training") | ||
44 | + opt = parser.parse_args() | ||
45 | + print(opt) | ||
46 | + | ||
47 | + # logger = Logger("logs") | ||
48 | + | ||
49 | + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
50 | + print('device: ', device) | ||
51 | + | ||
52 | + os.makedirs("output", exist_ok=True) | ||
53 | + os.makedirs("checkpoints", exist_ok=True) | ||
54 | + | ||
55 | + # Get data configuration | ||
56 | + data_config = parse_data_config(opt.data_config) | ||
57 | + train_path = data_config["train"] | ||
58 | + valid_path = data_config["valid"] | ||
59 | + class_names = load_classes(data_config["names"]) | ||
60 | + | ||
61 | + # Initiate model | ||
62 | + model = Darknet(opt.model_def).to(device) | ||
63 | + model.apply(weights_init_normal) | ||
64 | + | ||
65 | + model_distance = ROIPool((7, 7)).to(device) | ||
66 | + | ||
67 | + # If specified we start from checkpoint | ||
68 | + if opt.pretrained_weights: | ||
69 | + if opt.pretrained_weights.endswith(".pth"): | ||
70 | + model.load_state_dict(torch.load(opt.pretrained_weights)) | ||
71 | + else: | ||
72 | + model.load_darknet_weights(opt.pretrained_weights) | ||
73 | + | ||
74 | + model_parameters = filter(lambda p: p.requires_grad, model.parameters()) | ||
75 | + params = sum([np.prod(p.size()) for p in model_parameters]) | ||
76 | + print('Params: ', params) | ||
77 | + # Get dataloader | ||
78 | + dataset = ListDataset(train_path, augment=True, multiscale=opt.multiscale_training) | ||
79 | + dataloader = torch.utils.data.DataLoader( | ||
80 | + dataset, | ||
81 | + batch_size=opt.batch_size, | ||
82 | + shuffle=False, | ||
83 | + num_workers=opt.n_cpu, | ||
84 | + pin_memory=True, | ||
85 | + collate_fn=dataset.collate_fn, | ||
86 | + ) | ||
87 | + | ||
88 | + optimizer = torch.optim.Adam(model.parameters()) | ||
89 | + | ||
90 | + metrics = [ | ||
91 | + "grid_size", | ||
92 | + "loss", | ||
93 | + "x", | ||
94 | + "y", | ||
95 | + "w", | ||
96 | + "h", | ||
97 | + "conf", | ||
98 | + "cls", | ||
99 | + "cls_acc", | ||
100 | + "recall50", | ||
101 | + "recall75", | ||
102 | + "precision", | ||
103 | + "conf_obj", | ||
104 | + "conf_noobj", | ||
105 | + ] | ||
106 | + | ||
107 | + for epoch in range(opt.epochs): | ||
108 | + model.train() | ||
109 | + warnings.filterwarnings('ignore', category=UserWarning) | ||
110 | + start_time = time.time() | ||
111 | + for batch_i, (_, imgs, targets) in enumerate(dataloader): | ||
112 | + batches_done = len(dataloader) * epoch + batch_i | ||
113 | + | ||
114 | + imgs = Variable(imgs.to(device)) | ||
115 | + targets = Variable(targets.to(device), requires_grad=False) | ||
116 | + | ||
117 | + loss, outputs = model(imgs, targets) | ||
118 | + print(f'targets = {targets}') | ||
119 | + loss.backward() | ||
120 | + | ||
121 | + if batches_done % opt.gradient_accumulations: | ||
122 | + # Accumulates gradient before each step | ||
123 | + optimizer.step() | ||
124 | + optimizer.zero_grad() | ||
125 | + | ||
126 | + # ---------------- | ||
127 | + # Log progress | ||
128 | + # ---------------- | ||
129 | + | ||
130 | + log_str = "\n---- [Epoch %d/%d, Batch %d/%d] ----\n" % (epoch, opt.epochs, batch_i, len(dataloader)) | ||
131 | + | ||
132 | + metric_table = [["Metrics", *[f"YOLO Layer {i}" for i in range(len(model.yolo_layers))]]] | ||
133 | + | ||
134 | + # Log metrics at each YOLO layer | ||
135 | + for i, metric in enumerate(metrics): | ||
136 | + formats = {m: "%.6f" for m in metrics} | ||
137 | + formats["grid_size"] = "%2d" | ||
138 | + formats["cls_acc"] = "%.2f%%" | ||
139 | + row_metrics = [formats[metric] % yolo.metrics.get(metric, 0) for yolo in model.yolo_layers] | ||
140 | + metric_table += [[metric, *row_metrics]] | ||
141 | + | ||
142 | + # Tensorboard logging | ||
143 | + tensorboard_log = [] | ||
144 | + for j, yolo in enumerate(model.yolo_layers): | ||
145 | + for name, metric in yolo.metrics.items(): | ||
146 | + if name != "grid_size": | ||
147 | + tensorboard_log += [(f"{name}_{j+1}", metric)] | ||
148 | + tensorboard_log += [("loss", loss.item())] | ||
149 | + # logger.list_of_scalars_summary(tensorboard_log, batches_done) | ||
150 | + | ||
151 | + log_str += AsciiTable(metric_table).table | ||
152 | + log_str += f"\nTotal loss {loss.item()}" | ||
153 | + | ||
154 | + # Determine approximate time left for epoch | ||
155 | + epoch_batches_left = len(dataloader) - (batch_i + 1) | ||
156 | + time_left = datetime.timedelta(seconds=epoch_batches_left * (time.time() - start_time) / (batch_i + 1)) | ||
157 | + log_str += f"\n---- ETA {time_left}" | ||
158 | + | ||
159 | + print(log_str) | ||
160 | + | ||
161 | + model.seen += imgs.size(0) | ||
162 | + | ||
163 | + if epoch % opt.evaluation_interval == 0 and epoch != 0: | ||
164 | + print("\n---- Evaluating Model ----") | ||
165 | + # Evaluate the model on the validation set | ||
166 | + precision, recall, AP, f1, ap_class = evaluate( | ||
167 | + model, | ||
168 | + path=valid_path, | ||
169 | + iou_thres=0.5, | ||
170 | + conf_thres=0.5, | ||
171 | + nms_thres=0.5, | ||
172 | + img_size=opt.img_size, | ||
173 | + batch_size=1, | ||
174 | + ) | ||
175 | + evaluation_metrics = [ | ||
176 | + ("val_precision", precision.mean()), | ||
177 | + ("val_recall", recall.mean()), | ||
178 | + ("val_mAP", AP.mean()), | ||
179 | + ("val_f1", f1.mean()), | ||
180 | + ] | ||
181 | + # logger.list_of_scalars_summary(evaluation_metrics, epoch) | ||
182 | + | ||
183 | + # Print class APs and mAP | ||
184 | + ap_table = [["Index", "Class name", "AP"]] | ||
185 | + for i, c in enumerate(ap_class): | ||
186 | + ap_table += [[c, class_names[c], "%.5f" % AP[i]]] | ||
187 | + print(AsciiTable(ap_table).table) | ||
188 | + print(f"---- mAP {AP.mean()}") | ||
189 | + | ||
190 | + if epoch % opt.checkpoint_interval == 0: | ||
191 | + torch.save(model.state_dict(), f"checkpoints_fire/tiny1_%d.pth" % (epoch)) |
source_code/video_capture.py
0 → 100644
1 | +import cv2 | ||
2 | +import queue | ||
3 | +import threading | ||
4 | + | ||
5 | +class BufferlessVideoCapture: | ||
6 | + ''' | ||
7 | + BufferlessVideoCapture is a wrapper for cv2.VideoCapture, | ||
8 | + which doesn't have frame buffer. | ||
9 | + @param name: videocapture name | ||
10 | + ''' | ||
11 | + def __init__(self, name): | ||
12 | + self.cap = cv2.VideoCapture(name) | ||
13 | + self.q = queue.Queue() | ||
14 | + self.thr = threading.Thread(target=self._reader) | ||
15 | + self.thr.daemon = True | ||
16 | + self.thr.start() | ||
17 | + | ||
18 | + def _reader(self): | ||
19 | + ''' | ||
20 | + Main loop for thread. | ||
21 | + ''' | ||
22 | + while True: | ||
23 | + ret, frame = self.cap.read() | ||
24 | + if not ret: | ||
25 | + break | ||
26 | + if not self.q.empty(): | ||
27 | + try: | ||
28 | + self.q.get_nowait() # discard previous (unprocessed) frame | ||
29 | + except queue.Empty: | ||
30 | + pass | ||
31 | + if self.q.qsize() > 2: | ||
32 | + print(self.q.qsize()) | ||
33 | + self.q.put(frame) | ||
34 | + | ||
35 | + def isOpened(self): | ||
36 | + return self.cap.isOpened() | ||
37 | + | ||
38 | + def release(self): | ||
39 | + self.cap.release() | ||
40 | + | ||
41 | + def read(self): | ||
42 | + ''' | ||
43 | + Read current frame. | ||
44 | + ''' | ||
45 | + return True, self.q.get() | ||
46 | + | ||
47 | + def close(self): | ||
48 | + pass | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment