Showing
19 changed files
with
712 additions
and
0 deletions
docs/.gitkeep
0 → 100644
File mode changed
server/legacy/image server.py
0 → 100644
1 | +import numpy as np | ||
2 | +import cv2 | ||
3 | +import asyncio | ||
4 | +import websockets | ||
5 | +from io import BytesIO | ||
6 | + | ||
7 | +from PIL import Image, ImageDraw | ||
8 | +from IPython import display | ||
9 | + | ||
10 | +async def recv_image(websocket, path): | ||
11 | + buf = await websocket.recv() | ||
12 | + byte = BytesIO(buf) | ||
13 | + image = Image.open(byte) | ||
14 | + remote_ip = websocket.remote_address[0] | ||
15 | + msg='[{ip}] receive face properly, image size={size}'.format(ip=remote_ip, size=image.size) | ||
16 | + print(msg) | ||
17 | + await websocket.send('100') | ||
18 | + #for debug | ||
19 | + #frame = np.array(image) | ||
20 | + #frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) | ||
21 | + #cv2.imshow('recv', frame) | ||
22 | + #cv2.waitKey(2000) | ||
23 | + #cv2.destroyAllWindows() | ||
24 | + | ||
25 | +print('run image server') | ||
26 | +start_image_server = websockets.serve(recv_image, '0.0.0.0', 8766) | ||
27 | +asyncio.get_event_loop().run_until_complete(start_image_server) | ||
28 | +asyncio.get_event_loop().run_forever() | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
server/legacy/image/test1.jpg
0 → 100644
337 KB
server/legacy/image/test2.jpg
0 → 100644
167 KB
server/legacy/verification server.py
0 → 100644
1 | +import torch | ||
2 | +import numpy as np | ||
3 | +import os | ||
4 | +import asyncio | ||
5 | +import json | ||
6 | +import websockets | ||
7 | +from io import BytesIO | ||
8 | + | ||
9 | +from PIL import Image, ImageDraw | ||
10 | +from IPython import display | ||
11 | + | ||
12 | +from models.mtcnn import MTCNN | ||
13 | +from models.inception_resnet_v1 import InceptionResnetV1 | ||
14 | + | ||
15 | +device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | ||
16 | +print('Running on device: {}'.format(device)) | ||
17 | + | ||
18 | +model = InceptionResnetV1().eval().to(device) | ||
19 | + | ||
20 | +async def get_embeddings(face_list): | ||
21 | + global model | ||
22 | + x = torch.Tensor(face_list).to(device) | ||
23 | + yhat = model(x) | ||
24 | + return yhat | ||
25 | + | ||
26 | +def get_distance(someone, database): | ||
27 | + distance = [(someone - data).norm().item() for data in database] | ||
28 | + return distance | ||
29 | + | ||
30 | +def get_argmin(someone, database): | ||
31 | + distance = get_distance(someone, database) | ||
32 | + for i in range(len(distance)): | ||
33 | + return np.argmin(distance) | ||
34 | + return -1 | ||
35 | + | ||
36 | +async def recv_face(websocket, path): | ||
37 | + buf = await websocket.recv() | ||
38 | + face = np.frombuffer(buf, dtype = np.float32) | ||
39 | + face = face.reshape((1,3,160,160)) | ||
40 | + remote_ip = websocket.remote_address[0] | ||
41 | + msg='[{ip}] receive face properly, numpy shape={shape}'.format(ip=remote_ip, shape=face.shape) | ||
42 | + print(msg) | ||
43 | + embedding = await get_embeddings(face) | ||
44 | + await websocket.send('100') | ||
45 | + ##embedding DB서버에 넘기기## | ||
46 | + | ||
47 | +print('run verification server') | ||
48 | +start_server = websockets.serve(recv_face, '0.0.0.0', 8765) | ||
49 | +asyncio.get_event_loop().run_until_complete(start_server) | ||
50 | +asyncio.get_event_loop().run_forever() | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
No preview for this file type
No preview for this file type
server/models/data/onet.pt
0 → 100644
No preview for this file type
server/models/data/pnet.pt
0 → 100644
No preview for this file type
server/models/data/rnet.pt
0 → 100644
No preview for this file type
server/models/inception_resnet_v1.py
0 → 100644
1 | +import torch | ||
2 | +from torch import nn | ||
3 | +from torch.nn import functional as F | ||
4 | +import os | ||
5 | + | ||
6 | + | ||
7 | +class BasicConv2d(nn.Module): | ||
8 | + | ||
9 | + def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): | ||
10 | + super().__init__() | ||
11 | + self.conv = nn.Conv2d( | ||
12 | + in_planes, out_planes, | ||
13 | + kernel_size=kernel_size, stride=stride, | ||
14 | + padding=padding, bias=False | ||
15 | + ) # verify bias false | ||
16 | + self.bn = nn.BatchNorm2d( | ||
17 | + out_planes, | ||
18 | + eps=0.001, # value found in tensorflow | ||
19 | + momentum=0.1, # default pytorch value | ||
20 | + affine=True | ||
21 | + ) | ||
22 | + self.relu = nn.ReLU(inplace=False) | ||
23 | + | ||
24 | + def forward(self, x): | ||
25 | + x = self.conv(x) | ||
26 | + x = self.bn(x) | ||
27 | + x = self.relu(x) | ||
28 | + return x | ||
29 | + | ||
30 | + | ||
31 | +class Block35(nn.Module): | ||
32 | + | ||
33 | + def __init__(self, scale=1.0): | ||
34 | + super().__init__() | ||
35 | + | ||
36 | + self.scale = scale | ||
37 | + | ||
38 | + self.branch0 = BasicConv2d(256, 32, kernel_size=1, stride=1) | ||
39 | + | ||
40 | + self.branch1 = nn.Sequential( | ||
41 | + BasicConv2d(256, 32, kernel_size=1, stride=1), | ||
42 | + BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1) | ||
43 | + ) | ||
44 | + | ||
45 | + self.branch2 = nn.Sequential( | ||
46 | + BasicConv2d(256, 32, kernel_size=1, stride=1), | ||
47 | + BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1), | ||
48 | + BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1) | ||
49 | + ) | ||
50 | + | ||
51 | + self.conv2d = nn.Conv2d(96, 256, kernel_size=1, stride=1) | ||
52 | + self.relu = nn.ReLU(inplace=False) | ||
53 | + | ||
54 | + def forward(self, x): | ||
55 | + x0 = self.branch0(x) | ||
56 | + x1 = self.branch1(x) | ||
57 | + x2 = self.branch2(x) | ||
58 | + out = torch.cat((x0, x1, x2), 1) | ||
59 | + out = self.conv2d(out) | ||
60 | + out = out * self.scale + x | ||
61 | + out = self.relu(out) | ||
62 | + return out | ||
63 | + | ||
64 | + | ||
65 | +class Block17(nn.Module): | ||
66 | + | ||
67 | + def __init__(self, scale=1.0): | ||
68 | + super().__init__() | ||
69 | + | ||
70 | + self.scale = scale | ||
71 | + | ||
72 | + self.branch0 = BasicConv2d(896, 128, kernel_size=1, stride=1) | ||
73 | + | ||
74 | + self.branch1 = nn.Sequential( | ||
75 | + BasicConv2d(896, 128, kernel_size=1, stride=1), | ||
76 | + BasicConv2d(128, 128, kernel_size=(1,7), stride=1, padding=(0,3)), | ||
77 | + BasicConv2d(128, 128, kernel_size=(7,1), stride=1, padding=(3,0)) | ||
78 | + ) | ||
79 | + | ||
80 | + self.conv2d = nn.Conv2d(256, 896, kernel_size=1, stride=1) | ||
81 | + self.relu = nn.ReLU(inplace=False) | ||
82 | + | ||
83 | + def forward(self, x): | ||
84 | + x0 = self.branch0(x) | ||
85 | + x1 = self.branch1(x) | ||
86 | + out = torch.cat((x0, x1), 1) | ||
87 | + out = self.conv2d(out) | ||
88 | + out = out * self.scale + x | ||
89 | + out = self.relu(out) | ||
90 | + return out | ||
91 | + | ||
92 | + | ||
93 | +class Block8(nn.Module): | ||
94 | + | ||
95 | + def __init__(self, scale=1.0, noReLU=False): | ||
96 | + super().__init__() | ||
97 | + | ||
98 | + self.scale = scale | ||
99 | + self.noReLU = noReLU | ||
100 | + | ||
101 | + self.branch0 = BasicConv2d(1792, 192, kernel_size=1, stride=1) | ||
102 | + | ||
103 | + self.branch1 = nn.Sequential( | ||
104 | + BasicConv2d(1792, 192, kernel_size=1, stride=1), | ||
105 | + BasicConv2d(192, 192, kernel_size=(1,3), stride=1, padding=(0,1)), | ||
106 | + BasicConv2d(192, 192, kernel_size=(3,1), stride=1, padding=(1,0)) | ||
107 | + ) | ||
108 | + | ||
109 | + self.conv2d = nn.Conv2d(384, 1792, kernel_size=1, stride=1) | ||
110 | + if not self.noReLU: | ||
111 | + self.relu = nn.ReLU(inplace=False) | ||
112 | + | ||
113 | + def forward(self, x): | ||
114 | + x0 = self.branch0(x) | ||
115 | + x1 = self.branch1(x) | ||
116 | + out = torch.cat((x0, x1), 1) | ||
117 | + out = self.conv2d(out) | ||
118 | + out = out * self.scale + x | ||
119 | + if not self.noReLU: | ||
120 | + out = self.relu(out) | ||
121 | + return out | ||
122 | + | ||
123 | + | ||
124 | +class Mixed_6a(nn.Module): | ||
125 | + | ||
126 | + def __init__(self): | ||
127 | + super().__init__() | ||
128 | + | ||
129 | + self.branch0 = BasicConv2d(256, 384, kernel_size=3, stride=2) | ||
130 | + | ||
131 | + self.branch1 = nn.Sequential( | ||
132 | + BasicConv2d(256, 192, kernel_size=1, stride=1), | ||
133 | + BasicConv2d(192, 192, kernel_size=3, stride=1, padding=1), | ||
134 | + BasicConv2d(192, 256, kernel_size=3, stride=2) | ||
135 | + ) | ||
136 | + | ||
137 | + self.branch2 = nn.MaxPool2d(3, stride=2) | ||
138 | + | ||
139 | + def forward(self, x): | ||
140 | + x0 = self.branch0(x) | ||
141 | + x1 = self.branch1(x) | ||
142 | + x2 = self.branch2(x) | ||
143 | + out = torch.cat((x0, x1, x2), 1) | ||
144 | + return out | ||
145 | + | ||
146 | + | ||
147 | +class Mixed_7a(nn.Module): | ||
148 | + | ||
149 | + def __init__(self): | ||
150 | + super().__init__() | ||
151 | + | ||
152 | + self.branch0 = nn.Sequential( | ||
153 | + BasicConv2d(896, 256, kernel_size=1, stride=1), | ||
154 | + BasicConv2d(256, 384, kernel_size=3, stride=2) | ||
155 | + ) | ||
156 | + | ||
157 | + self.branch1 = nn.Sequential( | ||
158 | + BasicConv2d(896, 256, kernel_size=1, stride=1), | ||
159 | + BasicConv2d(256, 256, kernel_size=3, stride=2) | ||
160 | + ) | ||
161 | + | ||
162 | + self.branch2 = nn.Sequential( | ||
163 | + BasicConv2d(896, 256, kernel_size=1, stride=1), | ||
164 | + BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1), | ||
165 | + BasicConv2d(256, 256, kernel_size=3, stride=2) | ||
166 | + ) | ||
167 | + | ||
168 | + self.branch3 = nn.MaxPool2d(3, stride=2) | ||
169 | + | ||
170 | + def forward(self, x): | ||
171 | + x0 = self.branch0(x) | ||
172 | + x1 = self.branch1(x) | ||
173 | + x2 = self.branch2(x) | ||
174 | + x3 = self.branch3(x) | ||
175 | + out = torch.cat((x0, x1, x2, x3), 1) | ||
176 | + return out | ||
177 | + | ||
178 | + | ||
179 | +class InceptionResnetV1(nn.Module): | ||
180 | + """Inception Resnet V1 model with optional loading of pretrained weights. | ||
181 | + | ||
182 | + Model parameters can be loaded based on pretraining on the VGGFace2 or CASIA-Webface | ||
183 | + datasets. Pretrained state_dicts are automatically downloaded on model instantiation if | ||
184 | + requested and cached in the torch cache. Subsequent instantiations use the cache rather than | ||
185 | + redownloading. | ||
186 | + | ||
187 | + Keyword Arguments: | ||
188 | + pretrained {str} -- Optional pretraining dataset. Either 'vggface2' or 'casia-webface'. | ||
189 | + (default: {None}) | ||
190 | + classify {bool} -- Whether the model should output classification probabilities or feature | ||
191 | + embeddings. (default: {False}) | ||
192 | + num_classes {int} -- Number of output classes. If 'pretrained' is set and num_classes not | ||
193 | + equal to that used for the pretrained model, the final linear layer will be randomly | ||
194 | + initialized. (default: {None}) | ||
195 | + dropout_prob {float} -- Dropout probability. (default: {0.6}) | ||
196 | + """ | ||
197 | + def __init__(self, classify=False, dropout_prob=0.6, device=None): | ||
198 | + super().__init__() | ||
199 | + | ||
200 | + # Set simple attributes | ||
201 | + self.classify = classify | ||
202 | + self.num_classes = 8631 | ||
203 | + | ||
204 | + # Define layers | ||
205 | + self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2) | ||
206 | + self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1) | ||
207 | + self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1) | ||
208 | + self.maxpool_3a = nn.MaxPool2d(3, stride=2) | ||
209 | + self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1) | ||
210 | + self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1) | ||
211 | + self.conv2d_4b = BasicConv2d(192, 256, kernel_size=3, stride=2) | ||
212 | + self.repeat_1 = nn.Sequential( | ||
213 | + Block35(scale=0.17), | ||
214 | + Block35(scale=0.17), | ||
215 | + Block35(scale=0.17), | ||
216 | + Block35(scale=0.17), | ||
217 | + Block35(scale=0.17), | ||
218 | + ) | ||
219 | + self.mixed_6a = Mixed_6a() | ||
220 | + self.repeat_2 = nn.Sequential( | ||
221 | + Block17(scale=0.10), | ||
222 | + Block17(scale=0.10), | ||
223 | + Block17(scale=0.10), | ||
224 | + Block17(scale=0.10), | ||
225 | + Block17(scale=0.10), | ||
226 | + Block17(scale=0.10), | ||
227 | + Block17(scale=0.10), | ||
228 | + Block17(scale=0.10), | ||
229 | + Block17(scale=0.10), | ||
230 | + Block17(scale=0.10), | ||
231 | + ) | ||
232 | + self.mixed_7a = Mixed_7a() | ||
233 | + self.repeat_3 = nn.Sequential( | ||
234 | + Block8(scale=0.20), | ||
235 | + Block8(scale=0.20), | ||
236 | + Block8(scale=0.20), | ||
237 | + Block8(scale=0.20), | ||
238 | + Block8(scale=0.20), | ||
239 | + ) | ||
240 | + self.block8 = Block8(noReLU=True) | ||
241 | + self.avgpool_1a = nn.AdaptiveAvgPool2d(1) | ||
242 | + self.dropout = nn.Dropout(dropout_prob) | ||
243 | + self.last_linear = nn.Linear(1792, 512, bias=False) | ||
244 | + self.last_bn = nn.BatchNorm1d(512, eps=0.001, momentum=0.1, affine=True) | ||
245 | + self.logits = nn.Linear(512, self.num_classes) | ||
246 | + load_weights(self) | ||
247 | + | ||
248 | + self.device = torch.device('cpu') | ||
249 | + if device is not None: | ||
250 | + self.device = device | ||
251 | + self.to(device) | ||
252 | + | ||
253 | + def forward(self, x): | ||
254 | + """Calculate embeddings or logits given a batch of input image tensors. | ||
255 | + | ||
256 | + Arguments: | ||
257 | + x {torch.tensor} -- Batch of image tensors representing faces. | ||
258 | + | ||
259 | + Returns: | ||
260 | + torch.tensor -- Batch of embedding vectors or multinomial logits. | ||
261 | + """ | ||
262 | + x = self.conv2d_1a(x) | ||
263 | + x = self.conv2d_2a(x) | ||
264 | + x = self.conv2d_2b(x) | ||
265 | + x = self.maxpool_3a(x) | ||
266 | + x = self.conv2d_3b(x) | ||
267 | + x = self.conv2d_4a(x) | ||
268 | + x = self.conv2d_4b(x) | ||
269 | + x = self.repeat_1(x) | ||
270 | + x = self.mixed_6a(x) | ||
271 | + x = self.repeat_2(x) | ||
272 | + x = self.mixed_7a(x) | ||
273 | + x = self.repeat_3(x) | ||
274 | + x = self.block8(x) | ||
275 | + x = self.avgpool_1a(x) | ||
276 | + x = self.dropout(x) | ||
277 | + x = self.last_linear(x.view(x.shape[0], -1)) | ||
278 | + x = self.last_bn(x) | ||
279 | + if self.classify: | ||
280 | + x = self.logits(x) | ||
281 | + else: | ||
282 | + x = F.normalize(x, p=2, dim=1) | ||
283 | + return x | ||
284 | + | ||
285 | + | ||
286 | +def load_weights(mdl): | ||
287 | + features_path = state_dict_path = os.path.join(os.path.dirname(__file__), 'vggface2-dict/20180402-114759-vggface2-features.pt') | ||
288 | + logits_path = state_dict_path = os.path.join(os.path.dirname(__file__), 'vggface2-dict/20180402-114759-vggface2-logits.pt') | ||
289 | + state_dict = {} | ||
290 | + for i, path in enumerate([features_path, logits_path]): | ||
291 | + state_dict.update(torch.load(path)) | ||
292 | + mdl.load_state_dict(state_dict) | ||
293 | + | ||
294 | + | ||
295 | +def get_torch_home(): | ||
296 | + torch_home = os.path.expanduser( | ||
297 | + os.getenv( | ||
298 | + 'TORCH_HOME', | ||
299 | + os.path.join(os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch') | ||
300 | + ) | ||
301 | + ) | ||
302 | + return torch_home |
server/models/mtcnn.py
0 → 100644
This diff is collapsed. Click to expand it.
No preview for this file type
server/models/utils/detect_face.py
0 → 100644
This diff is collapsed. Click to expand it.
server/models/utils/tensorflow2pytorch.py
0 → 100644
This diff is collapsed. Click to expand it.
server/models/utils/training.py
0 → 100644
1 | +import torch | ||
2 | +import numpy as np | ||
3 | +import time | ||
4 | + | ||
5 | + | ||
6 | +class Logger(object): | ||
7 | + | ||
8 | + def __init__(self, mode, length, calculate_mean=False): | ||
9 | + self.mode = mode | ||
10 | + self.length = length | ||
11 | + self.calculate_mean = calculate_mean | ||
12 | + if self.calculate_mean: | ||
13 | + self.fn = lambda x, i: x / (i + 1) | ||
14 | + else: | ||
15 | + self.fn = lambda x, i: x | ||
16 | + | ||
17 | + def __call__(self, loss, metrics, i): | ||
18 | + track_str = '\r{} | {:5d}/{:<5d}| '.format(self.mode, i + 1, self.length) | ||
19 | + loss_str = 'loss: {:9.4f} | '.format(self.fn(loss, i)) | ||
20 | + metric_str = ' | '.join('{}: {:9.4f}'.format(k, self.fn(v, i)) for k, v in metrics.items()) | ||
21 | + print(track_str + loss_str + metric_str + ' ', end='') | ||
22 | + if i + 1 == self.length: | ||
23 | + print('') | ||
24 | + | ||
25 | + | ||
26 | +class BatchTimer(object): | ||
27 | + """Batch timing class. | ||
28 | + Use this class for tracking training and testing time/rate per batch or per sample. | ||
29 | + | ||
30 | + Keyword Arguments: | ||
31 | + rate {bool} -- Whether to report a rate (batches or samples per second) or a time (seconds | ||
32 | + per batch or sample). (default: {True}) | ||
33 | + per_sample {bool} -- Whether to report times or rates per sample or per batch. | ||
34 | + (default: {True}) | ||
35 | + """ | ||
36 | + | ||
37 | + def __init__(self, rate=True, per_sample=True): | ||
38 | + self.start = time.time() | ||
39 | + self.end = None | ||
40 | + self.rate = rate | ||
41 | + self.per_sample = per_sample | ||
42 | + | ||
43 | + def __call__(self, y_pred, y): | ||
44 | + self.end = time.time() | ||
45 | + elapsed = self.end - self.start | ||
46 | + self.start = self.end | ||
47 | + self.end = None | ||
48 | + | ||
49 | + if self.per_sample: | ||
50 | + elapsed /= len(y_pred) | ||
51 | + if self.rate: | ||
52 | + elapsed = 1 / elapsed | ||
53 | + | ||
54 | + return torch.tensor(elapsed) | ||
55 | + | ||
56 | + | ||
57 | +def accuracy(logits, y): | ||
58 | + _, preds = torch.max(logits, 1) | ||
59 | + return (preds == y).float().mean() | ||
60 | + | ||
61 | + | ||
62 | +def pass_epoch( | ||
63 | + model, loss_fn, loader, optimizer=None, scheduler=None, | ||
64 | + batch_metrics={'time': BatchTimer()}, show_running=True, | ||
65 | + device='cpu', writer=None | ||
66 | +): | ||
67 | + """Train or evaluate over a data epoch. | ||
68 | + | ||
69 | + Arguments: | ||
70 | + model {torch.nn.Module} -- Pytorch model. | ||
71 | + loss_fn {callable} -- A function to compute (scalar) loss. | ||
72 | + loader {torch.utils.data.DataLoader} -- A pytorch data loader. | ||
73 | + | ||
74 | + Keyword Arguments: | ||
75 | + optimizer {torch.optim.Optimizer} -- A pytorch optimizer. | ||
76 | + scheduler {torch.optim.lr_scheduler._LRScheduler} -- LR scheduler (default: {None}) | ||
77 | + batch_metrics {dict} -- Dictionary of metric functions to call on each batch. The default | ||
78 | + is a simple timer. A progressive average of these metrics, along with the average | ||
79 | + loss, is printed every batch. (default: {{'time': iter_timer()}}) | ||
80 | + show_running {bool} -- Whether or not to print losses and metrics for the current batch | ||
81 | + or rolling averages. (default: {False}) | ||
82 | + device {str or torch.device} -- Device for pytorch to use. (default: {'cpu'}) | ||
83 | + writer {torch.utils.tensorboard.SummaryWriter} -- Tensorboard SummaryWriter. (default: {None}) | ||
84 | + | ||
85 | + Returns: | ||
86 | + tuple(torch.Tensor, dict) -- A tuple of the average loss and a dictionary of average | ||
87 | + metric values across the epoch. | ||
88 | + """ | ||
89 | + | ||
90 | + mode = 'Train' if model.training else 'Valid' | ||
91 | + logger = Logger(mode, length=len(loader), calculate_mean=show_running) | ||
92 | + loss = 0 | ||
93 | + metrics = {} | ||
94 | + | ||
95 | + for i_batch, (x, y) in enumerate(loader): | ||
96 | + x = x.to(device) | ||
97 | + y = y.to(device) | ||
98 | + y_pred = model(x) | ||
99 | + loss_batch = loss_fn(y_pred, y) | ||
100 | + | ||
101 | + if model.training: | ||
102 | + loss_batch.backward() | ||
103 | + optimizer.step() | ||
104 | + optimizer.zero_grad() | ||
105 | + | ||
106 | + metrics_batch = {} | ||
107 | + for metric_name, metric_fn in batch_metrics.items(): | ||
108 | + metrics_batch[metric_name] = metric_fn(y_pred, y).detach().cpu() | ||
109 | + metrics[metric_name] = metrics.get(metric_name, 0) + metrics_batch[metric_name] | ||
110 | + | ||
111 | + if writer is not None and model.training: | ||
112 | + if writer.iteration % writer.interval == 0: | ||
113 | + writer.add_scalars('loss', {mode: loss_batch.detach().cpu()}, writer.iteration) | ||
114 | + for metric_name, metric_batch in metrics_batch.items(): | ||
115 | + writer.add_scalars(metric_name, {mode: metric_batch}, writer.iteration) | ||
116 | + writer.iteration += 1 | ||
117 | + | ||
118 | + loss_batch = loss_batch.detach().cpu() | ||
119 | + loss += loss_batch | ||
120 | + if show_running: | ||
121 | + logger(loss, metrics, i_batch) | ||
122 | + else: | ||
123 | + logger(loss_batch, metrics_batch, i_batch) | ||
124 | + | ||
125 | + if model.training and scheduler is not None: | ||
126 | + scheduler.step() | ||
127 | + | ||
128 | + loss = loss / (i_batch + 1) | ||
129 | + metrics = {k: v / (i_batch + 1) for k, v in metrics.items()} | ||
130 | + | ||
131 | + if writer is not None and not model.training: | ||
132 | + writer.add_scalars('loss', {mode: loss.detach()}, writer.iteration) | ||
133 | + for metric_name, metric in metrics.items(): | ||
134 | + writer.add_scalars(metric_name, {mode: metric}) | ||
135 | + | ||
136 | + return loss, metrics | ||
137 | + | ||
138 | + | ||
139 | +def collate_pil(x): | ||
140 | + out_x, out_y = [], [] | ||
141 | + for xx, yy in x: | ||
142 | + out_x.append(xx) | ||
143 | + out_y.append(yy) | ||
144 | + return out_x, out_y |
This file is too large to display.
This file is too large to display.
server/server.py
0 → 100644
1 | +import os | ||
2 | +import torch | ||
3 | +import numpy as np | ||
4 | +import asyncio | ||
5 | +import json | ||
6 | +import base64 | ||
7 | +import websockets | ||
8 | +from io import BytesIO | ||
9 | + | ||
10 | +import pymysql | ||
11 | +from datetime import datetime | ||
12 | + | ||
13 | +from PIL import Image, ImageDraw | ||
14 | +from IPython import display | ||
15 | + | ||
16 | +from models.mtcnn import MTCNN | ||
17 | +from models.inception_resnet_v1 import InceptionResnetV1 | ||
18 | + | ||
19 | +device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | ||
20 | +print('Running on device: {}'.format(device)) | ||
21 | + | ||
22 | +model = InceptionResnetV1().eval().to(device) | ||
23 | +attendance_db = pymysql.connect( | ||
24 | + user='root', | ||
25 | + passwd='1234', | ||
26 | + host='localhost', | ||
27 | + db='attendance', | ||
28 | + charset='utf8' | ||
29 | +) | ||
30 | + | ||
31 | +lock = asyncio.Lock() | ||
32 | +clients = set() | ||
33 | +#processes = [] | ||
34 | + | ||
35 | +async def get_embeddings(face_list): | ||
36 | + global model | ||
37 | + x = torch.Tensor(face_list).to(device) | ||
38 | + yhat = model(x) | ||
39 | + return yhat | ||
40 | + | ||
41 | +async def get_distance(arr1, arr2): | ||
42 | + distance = np.linalg.norm(arr1 - arr2) | ||
43 | + return distance | ||
44 | + | ||
45 | +async def get_cosine_similarity(arr1, arr2): | ||
46 | + similarity = np.inner(arr1, arr2) / (np.linalg.norm(arr1) * np.linalg.norm(arr2)) | ||
47 | + return similarity | ||
48 | + | ||
49 | +async def register(websocket): | ||
50 | + global lock | ||
51 | + global clients | ||
52 | + async with lock: | ||
53 | + clients.add(websocket) | ||
54 | + remote_ip = websocket.remote_address[0] | ||
55 | + msg='[{ip}] connected'.format(ip=remote_ip) | ||
56 | + print(msg) | ||
57 | + | ||
58 | +async def unregister(websocket): | ||
59 | + global lock | ||
60 | + global clients | ||
61 | + async with lock: | ||
62 | + clients.remove(websocket) | ||
63 | + remote_ip = websocket.remote_address[0] | ||
64 | + msg='[{ip}] disconnected'.format(ip=remote_ip) | ||
65 | + print(msg) | ||
66 | + | ||
67 | +async def thread(websocket, path): | ||
68 | + await register(websocket) | ||
69 | + try: | ||
70 | + async for message in websocket: | ||
71 | + data = json.loads(message) | ||
72 | + remote_ip = websocket.remote_address[0] | ||
73 | + if data['action'] == 'register': | ||
74 | + # log | ||
75 | + msg='[{ip}] register face'.format(ip=remote_ip) | ||
76 | + print(msg) | ||
77 | + | ||
78 | + # load json | ||
79 | + student_id = data['student_id'] | ||
80 | + student_name = data['student_name'] | ||
81 | + face = np.asarray(data['MTCNN'], dtype = np.float32) | ||
82 | + face = face.reshape((1,3,160,160)) | ||
83 | + | ||
84 | + # DB에 연결 | ||
85 | + cursor = attendance_db.cursor(pymysql.cursors.DictCursor) | ||
86 | + | ||
87 | + # 학생을 찾음 | ||
88 | + sql = "SELECT student_id FROM student WHERE student_id = %s;" | ||
89 | + cursor.execute(sql, (student_id)) | ||
90 | + | ||
91 | + # DB에 학생이 없으면 등록 | ||
92 | + if not cursor.fetchone(): | ||
93 | + sql = "INSERT INTO student(student_id, student_name) VALUES (%s, %s)" | ||
94 | + cursor.execute(sql, (student_id, student_name)) | ||
95 | + sql = "INSERT INTO lecture_students(lecture_id, student_id) VALUES (%s, %s)" | ||
96 | + cursor.execute(sql, ('0', student_id)) | ||
97 | + msg='[{ip}] {id} is registered'.format(ip=remote_ip, id=student_id) | ||
98 | + print(msg) | ||
99 | + | ||
100 | + # student_embedding Table에 등록 | ||
101 | + embedding = await get_embeddings(face) | ||
102 | + embedding = embedding.detach().numpy().tobytes() | ||
103 | + embedding_date = datetime.now().strftime('%Y-%m-%d') | ||
104 | + sql = "insert into student_embedding(student_id, embedding_date, embedding) values (%s, %s, _binary %s)" | ||
105 | + cursor.execute(sql, (student_id, embedding_date, embedding)) | ||
106 | + attendance_db.commit() | ||
107 | + send = json.dumps({'status': 'success', 'student_id': student_id}) | ||
108 | + await websocket.send(send) | ||
109 | + | ||
110 | + elif data['action'] == 'verify': | ||
111 | + # log | ||
112 | + msg='[{ip}] verify face'.format(ip=remote_ip) | ||
113 | + print(msg) | ||
114 | + | ||
115 | + # load json | ||
116 | + face = np.asarray(data['MTCNN'], dtype = np.float32) | ||
117 | + face = face.reshape((1,3,160,160)) | ||
118 | + | ||
119 | + embedding = await get_embeddings(face) | ||
120 | + embedding = embedding.detach().numpy() | ||
121 | + | ||
122 | + # 가장 비슷한 Embedding을 찾는 SQL | ||
123 | + cursor = attendance_db.cursor(pymysql.cursors.DictCursor) | ||
124 | + sql = "SELECT student_id, embedding FROM student_embedding;" | ||
125 | + cursor.execute(sql) | ||
126 | + result = cursor.fetchall() | ||
127 | + verified_id = '0' | ||
128 | + distance_min = 99 | ||
129 | + for row_data in result: | ||
130 | + db_embedding = np.frombuffer(row_data['embedding'], dtype=np.float32) | ||
131 | + db_embedding = db_embedding.reshape((1,512)) | ||
132 | + distance = await get_distance(embedding, db_embedding) | ||
133 | + if (distance < distance_min): | ||
134 | + verified_id = row_data['student_id'] | ||
135 | + distance_min = distance | ||
136 | + | ||
137 | + # 출석 데이터 전송 | ||
138 | + print('[debug] distance:', distance_min) | ||
139 | + send = '' | ||
140 | + if distance_min < 0.4: | ||
141 | + # 인증 성공 | ||
142 | + # 오늘 이미 출석 됐는지 확인 | ||
143 | + sql = "SELECT DATE(timestamp) FROM student_attendance WHERE (lecture_id=%s) AND (student_id=%s) AND (DATE(timestamp) = CURDATE());" | ||
144 | + cursor.execute(sql, ('0', verified_id)) | ||
145 | + | ||
146 | + # 출석 기록이 없는 경우에만 | ||
147 | + if not cursor.fetchone(): | ||
148 | + # 테이블 맨 뒤에 datetime attribute가 있음. 서버 시간 가져오게 default로 설정해둠. | ||
149 | + sql = "INSERT INTO student_attendance(lecture_id, student_id, status) VALUES (%s, %s, %s)" | ||
150 | + # TODO: attend / late 처리 | ||
151 | + cursor.execute(sql, ('0', verified_id, 'attend')) | ||
152 | + attendance_db.commit() | ||
153 | + # log 작성 | ||
154 | + msg='[{ip}] verification success {id}'.format(ip=remote_ip, id=verified_id) | ||
155 | + print(msg) | ||
156 | + send = json.dumps({'status': 'success', 'student_id': verified_id}) | ||
157 | + else: | ||
158 | + msg='[{ip}] verification failed: {id} is already verified'.format(ip=remote_ip, id=verified_id) | ||
159 | + print(msg) | ||
160 | + send = json.dumps({'status': 'already', 'student_id': verified_id}) | ||
161 | + else: | ||
162 | + # 인증 실패 | ||
163 | + msg='[{ip}] verification failed'.format(ip=remote_ip) | ||
164 | + print(msg) | ||
165 | + send = json.dumps({'status': 'fail'}) | ||
166 | + await websocket.send(send) | ||
167 | + elif data['action'] == "save_image": | ||
168 | + # 출석이 제대로 이뤄지지 않으면 이미지를 저장하여 | ||
169 | + # 나중에 교강사가 출석을 확인할 수 있도록 한다 | ||
170 | + msg='[{ip}] save image'.format(ip=remote_ip) | ||
171 | + print(msg) | ||
172 | + arr = np.asarray(data['image'], dtype = np.uint8) | ||
173 | + blob = arr.tobytes() | ||
174 | + # TODO: lecture DB에 tuple 삽입해야 아래 코드가 돌아감 | ||
175 | + # 테이블 맨 뒤에 datetime attribute가 있음. 서버 시간 가져오게 default로 설정해둠. | ||
176 | + cursor = attendance_db.cursor(pymysql.cursors.DictCursor) | ||
177 | + sql = "INSERT INTO undefined_image(lecture_id, image, width, height) VALUES (%s, _binary %s, %s, %s)" | ||
178 | + cursor.execute(sql, ('0', blob, arr.shape[0], arr.shape[1])) | ||
179 | + attendance_db.commit() | ||
180 | + else: | ||
181 | + print("unsupported event: {}", data) | ||
182 | + finally: | ||
183 | + await unregister(websocket) | ||
184 | + | ||
185 | +print('run verification server') | ||
186 | +start_server = websockets.serve(thread, '0.0.0.0', 8765) | ||
187 | +asyncio.get_event_loop().run_until_complete(start_server) | ||
188 | +asyncio.get_event_loop().run_forever() |
-
Please register or login to post a comment