server.py 9.92 KB
import asyncio
import os
import websockets
import json
import cv2
import numpy as np
import tensorflow as tf
import base64
import pymysql
import configparser
import logging
from datetime import datetime

clients = set()

config = configparser.ConfigParser()
config.read('./server.cnf')

api = configparser.ConfigParser()
api.read('./API_form.cnf')

host = config['verification_server']['host']
port = config['verification_server']['port']
model_dir = config['verification_server']['model']
image_size = int(config['verification_server']['image_size'])
threshold = float(config['verification_server']['threshold'])

print('connect to DB')
attendance_db = pymysql.connect(
read_default_file="./DB.cnf"
)

async def register(websocket):
    clients.add(websocket)
    remote_ip = websocket.remote_address[0]
    msg='[{ip}] connected'.format(ip=remote_ip)
    print(msg)

async def unregister(websocket):
    clients.remove(websocket)
    remote_ip = websocket.remote_address[0]
    msg='[{ip}] disconnected'.format(ip=remote_ip)
    print(msg)

def resize(image):
    resized = cv2.resize(image, (image_size, image_size), interpolation=cv2.INTER_CUBIC)
    return resized

def prewhiten(x):
    mean = np.mean(x)
    std = np.std(x)
    std_adj = np.maximum(std, 1.0/np.sqrt(x.size))
    y = np.multiply(np.subtract(x, mean), 1/std_adj)
    return y

def get_distance(arr1, arr2):
    # Euclidian distance
    distance = np.linalg.norm(arr1 - arr2)
    return distance

async def thread(websocket, path):
    await register(websocket)
    cursor = attendance_db.cursor(pymysql.cursors.DictCursor)
    remote_ip = websocket.remote_address[0]
    try:
        async for message in websocket:
            data = json.loads(message)
            if data['action'] == 'register':
                # log
                msg='[{ip}] register face'.format(ip=remote_ip)
                print(msg)
                student_id = data['student_id']
                student_name = data['student_name']
                # 학생을 찾음
                sql = "SELECT student_id FROM student WHERE student_id = %s;"
                rows_count = cursor.execute(sql, (student_id))
                # DB에 학생이 없으면 등록
                if rows_count == 0:
                    sql = "INSERT INTO student(student_id, student_name) VALUES (%s, %s)"
                    cursor.execute(sql, (student_id, student_name))
                    sql = "INSERT INTO lecture_students(lecture_id, student_id) VALUES (%s, %s)"
                    cursor.execute(sql, ('0', student_id))
                    msg='[{ip}] {id} is registered'.format(ip=remote_ip, id=student_id)
                    print(msg)

                # image to input tensor
                image = base64.b64decode(data['image'])
                image_np = np.frombuffer(image, dtype=np.uint8)
                image_np = cv2.imdecode(image_np, flags=1)
                image_np = resize(image_np)
                image_np = prewhiten(image_np)
                image_np = image_np.reshape(-1, image_size, image_size, 3)
                # get embedding
                feed_dict = {input_placeholder:image_np, phase_train_placeholder:False }
                embedding = sess.run(embeddings_placeholder, feed_dict=feed_dict)
                # blob으로 입력 가능하도록 bytes 화
                embedding = embedding.tobytes()
                embedding_date = datetime.now().strftime('%Y-%m-%d')
                sql = "insert into student_embedding(student_id, embedding_date, embedding) values (%s, %s, _binary %s)"
                cursor.execute(sql, (student_id, embedding_date, embedding))
                attendance_db.commit()
                send = json.dumps({'status': 'success', 'student_id': student_id})
                await websocket.send(send)

            elif data['action'] == 'verify':
                # log
                msg='[{ip}] verify face'.format(ip=remote_ip)
                remote_ip = websocket.remote_address[0]
                print(msg)

                # image to input tensor
                image = base64.b64decode(data['image'])
                image_np = np.frombuffer(image, dtype=np.uint8)
                image_np = cv2.imdecode(image_np, flags=1)
                image_np = resize(image_np)
                image_np = prewhiten(image_np)
                image_np = image_np.reshape(-1, image_size, image_size, 3)
                # get embedding
                feed_dict = {input_placeholder:image_np, phase_train_placeholder:False }
                embedding = sess.run(embeddings_placeholder, feed_dict=feed_dict)

                # compare received embedding to database embedding
                verified_id = None
                sql = "SELECT student_id, embedding FROM student_embedding;"
                cursor.execute(sql)
                result = cursor.fetchall()
                for row_data in result:
                    db_embedding = np.frombuffer(row_data['embedding'], dtype=np.float32)
                    db_embedding = db_embedding.reshape((1,512))
                    distance = get_distance(embedding, db_embedding)
                    print(distance)
                    if (distance < threshold):
                        verified_id = row_data['student_id']
                        break

                # 출석 데이터 전송
                send = ''
                if verified_id != None:
                    # 인증 성공
                    # 오늘 이미 출석 됐는지 확인
                    sql = "SELECT DATE(timestamp) FROM student_attendance WHERE (lecture_id=%s) AND (student_id=%s) AND (DATE(timestamp) = CURDATE());"
                    cursor.execute(sql, ('0', verified_id))
                    # 출석 기록이 없는 경우에만
                    if cursor.rowcount == 0:
                        # 테이블 맨 뒤에 datetime attribute가 있음. 서버 시간 가져오게 default로 설정해둠.
                        sql = "INSERT INTO student_attendance(lecture_id, student_id, status) VALUES (%s, %s, %s)"
                        # TODO: attend / late 처리
                        cursor.execute(sql, ('0', verified_id, 'attend'))
                        attendance_db.commit()
                        sql = "SELECT student_name FROM student WHERE student_id = %s"
                        cursor.execute(sql, (verified_id))
                        row_data = cursor.fetchone()
                        verified_name = row_data['student_name']
                        # log 작성
                        msg='[{ip}] verification success {id}'.format(ip=remote_ip, id=verified_id)
                        print(msg)
                        send = json.dumps({'status': 'attend', 'student_id': verified_id, 'student_name': verified_name})
                    else:
                        msg='[{ip}] verification failed: {id} is already verified'.format(ip=remote_ip, id=verified_id)
                        print(msg)
                        send = json.dumps({'status': 'already', 'student_id': verified_id})
                else:
                    # 인증 실패
                    msg='[{ip}] verification failed'.format(ip=remote_ip)
                    print(msg)
                    send = json.dumps({'status': 'fail'})                    
                await websocket.send(send)
            else:
                raise Exception('Undefined Action', data)
    except Exception as e:
        print(e)
    finally:
        await unregister(websocket)

def load_model(model, input_map=None):
    # Check if the model is a model directory (containing a metagraph and a checkpoint file)
    #  or if it is a protobuf file with a frozen graph
    model_exp = os.path.expanduser(model)
    if (os.path.isfile(model_exp)):
        print('Model filename: %s' % model_exp)
        with gfile.FastGFile(model_exp,'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            tf.import_graph_def(graph_def, input_map=input_map, name='')
    else:
        print('Model directory: %s' % model_exp)
        meta_file, ckpt_file = get_model_filenames(model_exp)
        
        print('Metagraph file: %s' % meta_file)
        print('Checkpoint file: %s' % ckpt_file)
      
        saver = tf.compat.v1.train.import_meta_graph(os.path.join(model_exp, meta_file), input_map=input_map)
        saver.restore(tf.get_default_session(), os.path.join(model_exp, ckpt_file))
    
def get_model_filenames(model_dir):
    files = os.listdir(model_dir)
    meta_files = [s for s in files if s.endswith('.meta')]
    if len(meta_files)==0:
        raise ValueError('No meta file found in the model directory (%s)' % model_dir)
    elif len(meta_files)>1:
        raise ValueError('There should not be more than one meta file in the model directory (%s)' % model_dir)
    meta_file = meta_files[0]
    ckpt = tf.train.get_checkpoint_state(model_dir)
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_file = os.path.basename(ckpt.model_checkpoint_path)
        return meta_file, ckpt_file

    meta_files = [s for s in files if '.ckpt' in s]
    max_step = -1
    for f in files:
        step_str = re.match(r'(^model-[\w\- ]+.ckpt-(\d+))', f)
        if step_str is not None and len(step_str.groups())>=2:
            step = int(step_str.groups()[1])
            if step > max_step:
                max_step = step
                ckpt_file = step_str.groups()[0]
    return meta_file, ckpt_file
  

with tf.Graph().as_default():
    with tf.compat.v1.Session() as sess:
        print('load tensorflow model')
        load_model(model_dir)
        input_placeholder = tf.compat.v1.get_default_graph().get_tensor_by_name("input:0")
        embeddings_placeholder = tf.compat.v1.get_default_graph().get_tensor_by_name("embeddings:0")
        phase_train_placeholder = tf.compat.v1.get_default_graph().get_tensor_by_name("phase_train:0")
        print('run verification server')
        start_server = websockets.serve(thread, host, port)
        asyncio.get_event_loop().run_until_complete(start_server)
        asyncio.get_event_loop().run_forever()