최인훈

fix

1 +from __future__ import division
2 +
3 +from roipool2 import *
4 +from models import *
5 +from utils.utils import *
6 +from utils.datasets import *
7 +from video_capture import BufferlessVideoCapture
8 +
9 +import serial
10 +import os
11 +import sys
12 +import time
13 +import datetime
14 +import argparse
15 +import cv2
16 +
17 +from PIL import Image
18 +
19 +import torch
20 +from torch.utils.data import DataLoader
21 +from torchvision import datasets
22 +from torch.autograd import Variable
23 +
24 +import matplotlib.pyplot as plt
25 +import matplotlib.patches as patches
26 +from matplotlib.ticker import NullLocator
27 +
28 +def changeRGB2BGR(img):
29 + r = img[:, :, 0].copy()
30 + g = img[:, :, 1].copy()
31 + b = img[:, :, 2].copy()
32 +
33 + # RGB > BGR
34 + img[:, :, 0] = b
35 + img[:, :, 1] = g
36 + img[:, :, 2] = r
37 +
38 + return img
39 +
40 +def changeBGR2RGB(img):
41 + b = img[:, :, 0].copy()
42 + g = img[:, :, 1].copy()
43 + r = img[:, :, 2].copy()
44 +
45 + img[:, :, 0] = r
46 + img[:, :, 1] = g
47 + img[:, :, 2] = b
48 +
49 + return img
50 +
51 +
52 +if __name__ == "__main__":
53 + parser = argparse.ArgumentParser()
54 + parser.add_argument("--image_folder", type=str, default="data/cafe_distance/1.jpg", help="path to dataset")
55 + parser.add_argument("--video_file", type=str, default="0", help="path to dataset")
56 + parser.add_argument("--model_def", type=str, default="config/yolov3-tiny.cfg", help="path to model definition file")
57 + # parser.add_argument("--weights_path", type=str, default="weights/yolov3-tiny.weights", help="path to weights file")
58 + parser.add_argument("--weights_path", type=str, default="checkpoints_yolo/tiny1_2500.pth", help="path to weights file")
59 + parser.add_argument("--class_path", type=str, default="data/cafe_distance/classes.names", help="path to class label file")
60 + parser.add_argument("--conf_thres", type=float, default=0.8, help="object confidence threshold")
61 + parser.add_argument("--nms_thres", type=float, default=0.4, help="iou thresshold for non-maximum suppression")
62 + parser.add_argument("--batch_size", type=int, default=1, help="size of the batches")
63 + parser.add_argument("--n_cpu", type=int, default=0, help="number of cpu threads to use during batch generation")
64 + parser.add_argument("--img_size", type=int, default=416, help="size of each image dimension")
65 + parser.add_argument("--checkpoint_model", type=str, help="path to checkpoint model")
66 + parser.add_argument("--target_object", type=int, default=0)
67 + opt = parser.parse_args()
68 + print(opt)
69 +
70 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71 +
72 + os.makedirs("output", exist_ok=True)
73 +
74 + sclient = serial.Serial(port='/dev/ttyAMA0', baudrate=115200, timeout=0.1)
75 + if sclient.isOpen():
76 + print('Serial is Open')
77 +
78 + # Set up model
79 + model = Darknet(opt.model_def, img_size=opt.img_size).to(device)
80 + model_parameters = filter(lambda p: p.requires_grad, model.parameters())
81 + params = sum([np.prod(p.size()) for p in model_parameters])
82 + print('Params: ', params)
83 +
84 + if opt.weights_path.endswith(".weights"):
85 + # Load darknet weights
86 + model.load_darknet_weights(opt.weights_path)
87 + else:
88 + # Load checkpoint weights
89 + model.load_state_dict(torch.load(opt.weights_path, map_location=device))
90 +
91 + model.eval() # Set in evaluation mode
92 +
93 + model_distance = ROIPool((3, 3)).to(device)
94 + model_distance.load_state_dict(torch.load('checkpoints_distance/tiny1_340.pth', map_location=device))
95 + model_distance.eval()
96 +
97 + dataloader = DataLoader(
98 + ImageFolder(opt.image_folder, img_size=opt.img_size),
99 + batch_size=opt.batch_size,
100 + shuffle=False,
101 + num_workers=opt.n_cpu,
102 + )
103 +
104 + classes = load_classes(opt.class_path) # Extracts class labels from file
105 +
106 + Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
107 +
108 + cap = BufferlessVideoCapture(0)
109 + # cap = cv2.VideoCapture('data/cafe_distance/videos/output17.avi')
110 + colors = np.random.randint(0, 255, size=(len(classes), 3), dtype="uint8")
111 + a=[]
112 + time_begin = time.time()
113 + NUM = cap.get(cv2.CAP_PROP_FRAME_COUNT)
114 +
115 + fourcc = cv2.VideoWriter_fourcc('D', 'I', 'V', 'X')
116 + out = cv2.VideoWriter('output/distance3.avi', fourcc, 30, (640,480))
117 +
118 + mode = 0
119 + while cap.isOpened():
120 + ret, img = cap.read()
121 + if ret is False:
122 + break
123 + # img = cv2.resize(img, (1280, 960), interpolation=cv2.INTER_CUBIC)
124 +
125 + RGBimg=changeBGR2RGB(img)
126 + imgTensor = transforms.ToTensor()(RGBimg)
127 + imgTensor, _ = pad_to_square(imgTensor, 0)
128 + imgTensor = resize(imgTensor, 416)
129 +
130 + imgTensor = imgTensor.unsqueeze(0)
131 + imgTensor = Variable(imgTensor.type(Tensor))
132 +
133 +
134 + with torch.no_grad():
135 + # prev_time = time.time()
136 +
137 + featuremap, detections = model(imgTensor)
138 +
139 +
140 + # print(featuremap)
141 +
142 + # current_time = time.time()
143 + # sec = current_time - prev_time
144 + # fps = 1/sec
145 + # frame_per_sec = "FPS: %0.1f" % fps
146 + # print(frame_per_sec)
147 +
148 +
149 + detections = non_max_suppression(detections, opt.conf_thres, opt.nms_thres)
150 + # print(f'none test = {detections}')
151 +
152 +
153 +
154 + a.clear()
155 + if detections is not None and detections[0] is not None:
156 + # print(detections)
157 + featuremap = Variable(featuremap.to(device))
158 + detects = Variable(detections[0], requires_grad=False)
159 + # print(f'detects = {detects}')
160 + # print(f'featuremap = {featuremap.shape}')
161 + outputs = model_distance(featuremap, detects)
162 + print(f'distance = {outputs}')
163 +
164 + a.extend(detections)
165 + if len(a):
166 + for detections in a:
167 +
168 + if detections is not None:
169 + # print(detections)
170 +
171 + detections = rescale_boxes(detections, opt.img_size, RGBimg.shape[:2])
172 + # print(detections)
173 + unique_labels = detections[:, -1].cpu().unique()
174 + n_cls_preds = len(unique_labels)
175 + for i, (x1, y1, x2, y2, conf, cls_conf, cls_pred) in enumerate(detections):
176 + if(classes[int(cls_pred)] == opt.target_object):
177 +
178 + target_distance = float(outputs[i])
179 + if(mode == 0):
180 + if target_distance > 8:
181 + sclient.write(serial.to_bytes([int('1', 16)]))
182 + break
183 + else:
184 + mode = 1
185 + break
186 + elif(mode == 1):
187 + box_w = x2 - x1
188 + target_location = int(x1+box_w/2)
189 + if target_location < 300:
190 + sclient.write(serial.to_bytes([int('2', 16)]))
191 + break
192 + elif target_location > 340:
193 + sclient.write(serial.to_bytes([int('3', 16)]))
194 + break
195 + else:
196 + sclient.write(serial.to_bytes([int('4', 16)]))
197 + break
198 +
199 +
200 +
201 +
202 + #box_w = x2 - x1
203 + # print(box_w)
204 + #box_h = y2 - y1
205 + # print(y2, y1)
206 + # color = [int(c) for c in colors[int(cls_pred)]]
207 + #print(cls_conf)
208 + # img = cv2.rectangle(img, (x1, y1 + box_h), (x2, y1), color, 2)
209 +
210 + # cv2.putText(img, classes[int(cls_pred)], (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
211 + # cv2.putText(img, str("%.2f" % float(outputs[i])), (x2, y2 - box_h), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
212 + # color, 2)
213 +
214 + # print(classes[int(cls_pred)], int(x1+box_w/2), int(480-(y1+box_h/2)))
215 +
216 + #print()
217 + #print()
218 + #cv2.putText(img,"Hello World!",(400,50),cv2.FONT_HERSHEY_PLAIN,2.0,(0,0,255),2)
219 +
220 + # cv2.imshow('frame', changeRGB2BGR(RGBimg))
221 + # out.write(changeRGB2BGR(RGBimg))
222 + #cv2.waitKey(0)
223 +
224 + if cv2.waitKey(1) & 0xFF == ord('q'):
225 + break
226 + time_end = time.time()
227 + time_total = time_end - time_begin
228 + print(NUM // time_total)
229 +
230 + sclient.close()
231 + cap.release()
232 + out.release()
233 + cv2.destroyAllWindows()
234 +
235 +
236 +
237 +
238 +
239 +
240 +
241 +
242 +
243 +
244 +'''
245 + capture = cv2.VideoCapture("data/cafe/9.mp4")
246 + capture.set(cv2.CAP_PROP_FRAME_WIDTH, 416)
247 + capture.set(cv2.CAP_PROP_FRAME_HEIGHT, 416)
248 + capture.set(cv2.CAP_PROP_FPS, 3)
249 +
250 + colors = np.random.randint(0, 255, size=(len(classes), 3), dtype="uint8")
251 + capture.set(5, 5)
252 + print(capture.get(cv2.CAP_PROP_FRAME_WIDTH), capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
253 + print("FPS: ", capture.get(5))
254 + startTime = time.time()
255 + a=[]
256 + while capture.isOpened():
257 + ret, frame = capture.read()
258 + # print()
259 + nowTime = time.time()
260 +
261 + PILimg = np.array(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
262 + # RGBimg = changeBGR2RGB(frame)
263 +
264 +
265 +
266 + imgTensor = transforms.ToTensor()(PILimg)
267 + imgTensor, _ = pad_to_square(imgTensor, 0)
268 + imgTensor = resize(imgTensor, 416)
269 + imgTensor = imgTensor.unsqueeze(0)
270 + imgTensor = Variable(imgTensor.type(Tensor))
271 +
272 + with torch.no_grad():
273 + prev_time = time.time()
274 + detections = model(imgTensor)
275 + current_time = time.time()
276 +
277 + sec = current_time - prev_time
278 + fps = 1/sec
279 + frame_per_sec = "FPS: %0.1f" % fps
280 + # inference_time = datetime.timedelta(seconds=current_time - prev_time)
281 + prev_time = current_time
282 +
283 + red = (0, 0, 255)
284 + cv2.putText(frame, frame_per_sec, (25, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.5, red, 2)
285 + detections = non_max_suppression(detections, opt.conf_thres, opt.nms_thres)
286 +
287 + a.clear()
288 + if detections is not None:
289 + a.extend(detections)
290 + b=len(a)
291 + if len(a):
292 + for detections in a:
293 + if detections is not None:
294 + detections = rescale_boxes(detections, opt.img_size, PILimg.shape[:2])
295 + unique_labels = detections[:, -1].cpu().unique()
296 + n_cls_preds = len(unique_labels)
297 + for x1, y1, x2, y2, conf, cls_conf, cls_pred in detections:
298 + if classes[int(cls_pred)] == 'shrimp cracker':
299 +
300 + box_w = x2 - x1
301 + box_h = y2 - y1
302 + color = [int(c) for c in colors[int(cls_pred)]]
303 + # print(cls_conf)
304 + frame = cv2.rectangle(frame, (x1, y1 + box_h), (x2, y1), color, 2)
305 + cv2.putText(frame, classes[int(cls_pred)], (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
306 + cv2.putText(frame, str("%.2f" % float(conf)), (x2, y2 - box_h), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
307 + color, 2)
308 + print(classes[int(cls_pred)], int(x1+box_w/2), int(224-(y1+box_h/2)))
309 +
310 + print()
311 + #cv2.putText(img,"Hello World!",(400,50),cv2.FONT_HERSHEY_PLAIN,2.0,(0,0,255),2)
312 + #cv2.namedWindow('frame', cv2.WINDOW_NORMAL)
313 + cv2.imshow('frame', frame)
314 +
315 + #cv2.waitKey(0)
316 +
317 + if cv2.waitKey(25) & 0xFF == ord('q'):
318 + break
319 + capture.release()
320 + cv2.destroyAllWindows()
321 +
322 +'''
323 +
324 +'''
325 + imgs = [] # Stores image paths
326 + img_detections = [] # Stores detections for each image index
327 + print('parameter count: ', count_parameters(model))
328 + print("\nPerforming object detection:")
329 + prev_time = time.time()
330 + for batch_i, (img_paths, input_imgs) in enumerate(dataloader):
331 + # Configure input
332 + input_imgs = Variable(input_imgs.type(Tensor))
333 +
334 + # Get detections
335 + with torch.no_grad():
336 + detections = model(input_imgs)
337 + detections = non_max_suppression(detections, opt.conf_thres, opt.nms_thres)
338 +
339 + # Log progress
340 + current_time = time.time()
341 + inference_time = datetime.timedelta(seconds=current_time - prev_time)
342 + prev_time = current_time
343 + print("\t+ Batch %d, Inference Time: %s" % (batch_i, inference_time))
344 +
345 + # Save image and detections
346 + imgs.extend(img_paths)
347 + img_detections.extend(detections)
348 +
349 + # Bounding-box colors
350 + cmap = plt.get_cmap("tab20b")
351 + colors = [cmap(i) for i in np.linspace(0, 1, 20)]
352 +
353 + print("\nSaving images:")
354 + # Iterate through images and save plot of detections
355 + for img_i, (path, detections) in enumerate(zip(imgs, img_detections)):
356 +
357 + print("(%d) Image: '%s'" % (img_i, path))
358 +
359 + # Create plot
360 + img = np.array(Image.open(path))
361 + plt.figure()
362 + fig, ax = plt.subplots(1)
363 + ax.imshow(img)
364 +
365 + # Draw bounding boxes and labels of detections
366 + if detections is not None:
367 + # Rescale boxes to original image
368 + detections = rescale_boxes(detections, opt.img_size, img.shape[:2])
369 + unique_labels = detections[:, -1].cpu().unique()
370 + n_cls_preds = len(unique_labels)
371 + bbox_colors = random.sample(colors, n_cls_preds)
372 + for x1, y1, x2, y2, conf, cls_conf, cls_pred in detections:
373 +
374 + print("\t+ Label: %s, Conf: %.5f" % (classes[int(cls_pred)], cls_conf.item()))
375 +
376 + box_w = x2 - x1
377 + box_h = y2 - y1
378 +
379 +
380 + color = bbox_colors[int(np.where(unique_labels == int(cls_pred))[0])]
381 + # Create a Rectangle patch
382 + bbox = patches.Rectangle((x1, y1), box_w, box_h, linewidth=2, edgecolor=color, facecolor="none")
383 + # Add the bbox to the plot
384 + ax.add_patch(bbox)
385 + # Add label
386 + plt.text(
387 + x1,
388 + y1,
389 + s=str(classes[int(cls_pred)])+' '+str(int(x1+box_w/2))+ ', '+str(int(y1+box_h/2)),
390 + color="white",
391 + verticalalignment="top",
392 + bbox={"color": color, "pad": 0},
393 + )
394 +
395 + # Save generated image with detections
396 + plt.axis("off")
397 + plt.gca().xaxis.set_major_locator(NullLocator())
398 + plt.gca().yaxis.set_major_locator(NullLocator())
399 + filename = path.split("/")[-1].split("\\")[-1].split(".")[0]
400 + plt.savefig(f"output/{filename}.png", bbox_inches="tight", pad_inches=0.0)
401 + plt.close()
402 +
403 +'''
1 +from __future__ import division
2 +
3 +from roipool2 import *
4 +from models import *
5 +from utils.utils import *
6 +from utils.datasets import *
7 +from utils.parse_config import *
8 +# from test import evaluate
9 +
10 +from terminaltables import AsciiTable
11 +
12 +import os
13 +import sys
14 +import time
15 +import datetime
16 +import argparse
17 +import warnings
18 +
19 +import torch
20 +from torch.utils.data import DataLoader
21 +from torchvision import datasets
22 +from torchvision import transforms
23 +from torch.autograd import Variable
24 +import torch.optim as optim
25 +import warnings
26 +warnings.filterwarnings("ignore", category=UserWarning)
27 +
28 +
29 +
30 +
31 +
32 +if __name__ == '__main__':
33 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34 + print('device: ', device)
35 +
36 + data_config = parse_data_config('config/cafe_distance.data')
37 + train_path = data_config["train"]
38 + valid_path = data_config["valid"]
39 + class_names = load_classes(data_config["names"])
40 +
41 + model = Darknet('config/yolov3-tiny.cfg', 416).to(device)
42 +
43 + model.load_state_dict(torch.load('checkpoints_cafe_distance/tiny1_2500.pth', map_location=device))
44 + model.eval()
45 +
46 + dataset = ListDataset(train_path, augment=True, multiscale=True)
47 + dataloader = torch.utils.data.DataLoader(
48 + dataset,
49 + batch_size=1,
50 + shuffle=True,
51 + num_workers=4,
52 + pin_memory=True,
53 + collate_fn=dataset.collate_fn,
54 + )
55 +
56 + model_distance = ROIPool((3, 3)).to(device)
57 + model_parameters = filter(lambda p: p.requires_grad, model_distance.parameters())
58 + params = sum([np.prod(p.size()) for p in model_parameters])
59 + print('Params: ', params)
60 +
61 + optimizer = torch.optim.Adam(model_distance.parameters())
62 +
63 +
64 + a = []
65 + for epoch in range(2000):
66 +
67 + warnings.filterwarnings('ignore', category=UserWarning)
68 + for batch_i, (img_path, imgs, targets, targets_distance) in enumerate(dataloader):
69 +
70 +
71 + imgs = Variable(imgs.to(device))
72 + with torch.no_grad():
73 +
74 + featuremap, detections = model(imgs)
75 + # print(featuremap.shape)
76 + featuremap = Variable(featuremap.to(device))
77 +
78 + detections = non_max_suppression(detections, 0.8, 0.4)
79 + targets_distance = torch.tensor(targets_distance[0])
80 + targets_distance = Variable(targets_distance, requires_grad=True)
81 +
82 +
83 +
84 + if detections is not None:
85 + detections[0] = Variable(detections[0], requires_grad=True)
86 +
87 +
88 + loss, outputs = model_distance(featuremap, detections[0], targets=targets_distance)
89 + # loss = torch.tensor([loss]).to(device)
90 + # loss.requires_grad = True
91 + # print(model_distance.fc1.bias)
92 + optimizer.zero_grad()
93 + loss.backward()
94 + optimizer.step()
95 + # print(model_distance.fc1.bias)
96 +
97 + # print(batch_i)
98 + print(epoch)
99 +
100 + # print(featuremap)
101 + if epoch % 10 == 0:
102 + optimizer.param_groups[0]['lr'] /= 2
103 +
104 + if epoch % 10 == 0:
105 + torch.save(model_distance.state_dict(), f'checkpoints_distance11/tiny1_{epoch}.pth')
1 +from __future__ import division
2 +
3 +from models import *
4 +from roipool import *
5 +# from utils.logger import *
6 +from utils.utils import *
7 +from utils.datasets import *
8 +from utils.parse_config import *
9 +# from test import evaluate
10 +
11 +from terminaltables import AsciiTable
12 +
13 +import os
14 +import sys
15 +import time
16 +import datetime
17 +import argparse
18 +import warnings
19 +
20 +import torch
21 +from torch.utils.data import DataLoader
22 +from torchvision import datasets
23 +from torchvision import transforms
24 +from torch.autograd import Variable
25 +import torch.optim as optim
26 +import warnings
27 +warnings.filterwarnings("ignore", category=UserWarning)
28 +
29 +if __name__ == "__main__":
30 + warnings.filterwarnings("ignore", category=UserWarning)
31 + parser = argparse.ArgumentParser()
32 + parser.add_argument("--epochs", type=int, default=8001, help="number of epochs")
33 + parser.add_argument("--batch_size", type=int, default=1, help="size of each image batch")
34 + parser.add_argument("--gradient_accumulations", type=int, default=2, help="number of gradient accums before step")
35 + parser.add_argument("--model_def", type=str, default="config/yolov3-tiny.cfg", help="path to model definition file")
36 + parser.add_argument("--data_config", type=str, default="config/testdata.data", help="path to data config file")
37 + parser.add_argument("--pretrained_weights", type=str, help="if specified starts from checkpoint model")
38 + parser.add_argument("--n_cpu", type=int, default=4, help="number of cpu threads to use during batch generation")
39 + parser.add_argument("--img_size", type=int, default=416, help="size of each image dimension")
40 + parser.add_argument("--checkpoint_interval", type=int, default=50, help="interval between saving model weights")
41 + parser.add_argument("--evaluation_interval", type=int, default=10000, help="interval evaluations on validation set")
42 + parser.add_argument("--compute_map", default=False, help="if True computes mAP every tenth batch")
43 + parser.add_argument("--multiscale_training", default=True, help="allow for multi-scale training")
44 + opt = parser.parse_args()
45 + print(opt)
46 +
47 + # logger = Logger("logs")
48 +
49 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50 + print('device: ', device)
51 +
52 + os.makedirs("output", exist_ok=True)
53 + os.makedirs("checkpoints", exist_ok=True)
54 +
55 + # Get data configuration
56 + data_config = parse_data_config(opt.data_config)
57 + train_path = data_config["train"]
58 + valid_path = data_config["valid"]
59 + class_names = load_classes(data_config["names"])
60 +
61 + # Initiate model
62 + model = Darknet(opt.model_def).to(device)
63 + model.apply(weights_init_normal)
64 +
65 + model_distance = ROIPool((7, 7)).to(device)
66 +
67 + # If specified we start from checkpoint
68 + if opt.pretrained_weights:
69 + if opt.pretrained_weights.endswith(".pth"):
70 + model.load_state_dict(torch.load(opt.pretrained_weights))
71 + else:
72 + model.load_darknet_weights(opt.pretrained_weights)
73 +
74 + model_parameters = filter(lambda p: p.requires_grad, model.parameters())
75 + params = sum([np.prod(p.size()) for p in model_parameters])
76 + print('Params: ', params)
77 + # Get dataloader
78 + dataset = ListDataset(train_path, augment=True, multiscale=opt.multiscale_training)
79 + dataloader = torch.utils.data.DataLoader(
80 + dataset,
81 + batch_size=opt.batch_size,
82 + shuffle=False,
83 + num_workers=opt.n_cpu,
84 + pin_memory=True,
85 + collate_fn=dataset.collate_fn,
86 + )
87 +
88 + optimizer = torch.optim.Adam(model.parameters())
89 +
90 + metrics = [
91 + "grid_size",
92 + "loss",
93 + "x",
94 + "y",
95 + "w",
96 + "h",
97 + "conf",
98 + "cls",
99 + "cls_acc",
100 + "recall50",
101 + "recall75",
102 + "precision",
103 + "conf_obj",
104 + "conf_noobj",
105 + ]
106 +
107 + for epoch in range(opt.epochs):
108 + model.train()
109 + warnings.filterwarnings('ignore', category=UserWarning)
110 + start_time = time.time()
111 + for batch_i, (_, imgs, targets) in enumerate(dataloader):
112 + batches_done = len(dataloader) * epoch + batch_i
113 +
114 + imgs = Variable(imgs.to(device))
115 + targets = Variable(targets.to(device), requires_grad=False)
116 +
117 + loss, outputs = model(imgs, targets)
118 + print(f'targets = {targets}')
119 + loss.backward()
120 +
121 + if batches_done % opt.gradient_accumulations:
122 + # Accumulates gradient before each step
123 + optimizer.step()
124 + optimizer.zero_grad()
125 +
126 + # ----------------
127 + # Log progress
128 + # ----------------
129 +
130 + log_str = "\n---- [Epoch %d/%d, Batch %d/%d] ----\n" % (epoch, opt.epochs, batch_i, len(dataloader))
131 +
132 + metric_table = [["Metrics", *[f"YOLO Layer {i}" for i in range(len(model.yolo_layers))]]]
133 +
134 + # Log metrics at each YOLO layer
135 + for i, metric in enumerate(metrics):
136 + formats = {m: "%.6f" for m in metrics}
137 + formats["grid_size"] = "%2d"
138 + formats["cls_acc"] = "%.2f%%"
139 + row_metrics = [formats[metric] % yolo.metrics.get(metric, 0) for yolo in model.yolo_layers]
140 + metric_table += [[metric, *row_metrics]]
141 +
142 + # Tensorboard logging
143 + tensorboard_log = []
144 + for j, yolo in enumerate(model.yolo_layers):
145 + for name, metric in yolo.metrics.items():
146 + if name != "grid_size":
147 + tensorboard_log += [(f"{name}_{j+1}", metric)]
148 + tensorboard_log += [("loss", loss.item())]
149 + # logger.list_of_scalars_summary(tensorboard_log, batches_done)
150 +
151 + log_str += AsciiTable(metric_table).table
152 + log_str += f"\nTotal loss {loss.item()}"
153 +
154 + # Determine approximate time left for epoch
155 + epoch_batches_left = len(dataloader) - (batch_i + 1)
156 + time_left = datetime.timedelta(seconds=epoch_batches_left * (time.time() - start_time) / (batch_i + 1))
157 + log_str += f"\n---- ETA {time_left}"
158 +
159 + print(log_str)
160 +
161 + model.seen += imgs.size(0)
162 +
163 + if epoch % opt.evaluation_interval == 0 and epoch != 0:
164 + print("\n---- Evaluating Model ----")
165 + # Evaluate the model on the validation set
166 + precision, recall, AP, f1, ap_class = evaluate(
167 + model,
168 + path=valid_path,
169 + iou_thres=0.5,
170 + conf_thres=0.5,
171 + nms_thres=0.5,
172 + img_size=opt.img_size,
173 + batch_size=1,
174 + )
175 + evaluation_metrics = [
176 + ("val_precision", precision.mean()),
177 + ("val_recall", recall.mean()),
178 + ("val_mAP", AP.mean()),
179 + ("val_f1", f1.mean()),
180 + ]
181 + # logger.list_of_scalars_summary(evaluation_metrics, epoch)
182 +
183 + # Print class APs and mAP
184 + ap_table = [["Index", "Class name", "AP"]]
185 + for i, c in enumerate(ap_class):
186 + ap_table += [[c, class_names[c], "%.5f" % AP[i]]]
187 + print(AsciiTable(ap_table).table)
188 + print(f"---- mAP {AP.mean()}")
189 +
190 + if epoch % opt.checkpoint_interval == 0:
191 + torch.save(model.state_dict(), f"checkpoints_fire/tiny1_%d.pth" % (epoch))
1 +import cv2
2 +import queue
3 +import threading
4 +
5 +class BufferlessVideoCapture:
6 + '''
7 + BufferlessVideoCapture is a wrapper for cv2.VideoCapture,
8 + which doesn't have frame buffer.
9 + @param name: videocapture name
10 + '''
11 + def __init__(self, name):
12 + self.cap = cv2.VideoCapture(name)
13 + self.q = queue.Queue()
14 + self.thr = threading.Thread(target=self._reader)
15 + self.thr.daemon = True
16 + self.thr.start()
17 +
18 + def _reader(self):
19 + '''
20 + Main loop for thread.
21 + '''
22 + while True:
23 + ret, frame = self.cap.read()
24 + if not ret:
25 + break
26 + if not self.q.empty():
27 + try:
28 + self.q.get_nowait() # discard previous (unprocessed) frame
29 + except queue.Empty:
30 + pass
31 + if self.q.qsize() > 2:
32 + print(self.q.qsize())
33 + self.q.put(frame)
34 +
35 + def isOpened(self):
36 + return self.cap.isOpened()
37 +
38 + def release(self):
39 + self.cap.release()
40 +
41 + def read(self):
42 + '''
43 + Read current frame.
44 + '''
45 + return True, self.q.get()
46 +
47 + def close(self):
48 + pass
...\ No newline at end of file ...\ No newline at end of file