app.py 8.04 KB
from flask import Flask, render_template, Response, request, jsonify
import os
import io
import json
import cv2
import numpy as np
import time
from datetime import datetime
import sys
import tensorflow as tf
import base64
import pymysql
import configparser
from PIL import Image

config = configparser.ConfigParser()
config.read('./config.cnf')

api = configparser.ConfigParser()
api.read('./API_form.cnf')

model_dir = config['verification_server']['model']
image_size = int(config['verification_server']['image_size'])
threshold = float(config['verification_server']['threshold'])

app = Flask(__name__)
sess = tf.compat.v1.Session()


def resize(image):
    resized = cv2.resize(image, (image_size, image_size), interpolation=cv2.INTER_CUBIC)
    return resized

def prewhiten(x):
    mean = np.mean(x)
    std = np.std(x)
    std_adj = np.maximum(std, 1.0/np.sqrt(x.size))
    y = np.multiply(np.subtract(x, mean), 1/std_adj)
    return y

def get_distance(arr1, arr2):
    # Euclidian distance
    distance = np.linalg.norm(arr1 - arr2)
    return distance

@app.route('/')
def index():
    """Video streaming page"""
    return render_template('index.html')

@app.route('/register', methods=['GET', 'POST'])
def register():
    if request.method == 'GET':
        return render_template('register.html')

    attendance_db = pymysql.connect(read_default_file="./DB.cnf")
    cursor = attendance_db.cursor(pymysql.cursors.DictCursor)
    send = {'form':'json'}
    student_id = request.form['student_id']
    student_name = request.form['student_name']
    msg='[{id}] register face'.format(id=student_id)
    print(msg)
    sql = "SELECT student_id FROM student WHERE student_id = %s;"
    cursor.execute(sql, (student_id))
    if cursor.rowcount == 0:
        sql = "INSERT INTO student(student_id, student_name) VALUES (%s, %s)"
        cursor.execute(sql, (student_id, student_name))
        sql = "INSERT INTO lecture_students(lecture_id, student_id) VALUES (%s, %s)"
        # temp: student in lecture 0
        cursor.execute(sql, ('0', student_id))
        msg='[{id}] is registered'.format(id=student_id)
        print(msg)
    # image to input tensor
    image = base64.b64decode(request.form['image'])
    image_np = np.frombuffer(image, dtype=np.uint8)
    image_np = cv2.imdecode(image_np, cv2.IMREAD_UNCHANGED)
    image_np = resize(image_np)
    image_np = prewhiten(image_np)
    image_np = image_np.reshape(-1, image_size, image_size, 3)
    # get embedding
    feed_dict = {input_placeholder:image_np, phase_train_placeholder:False }
    embedding = sess.run(embeddings_placeholder, feed_dict=feed_dict)
    # embedding to blob
    embedding = embedding.tobytes()
    embedding_date = datetime.now().strftime('%Y-%m-%d')
    sql = "insert into student_embedding(student_id, embedding_date, embedding) values (%s, %s, _binary %s)"
    cursor.execute(sql, (student_id, embedding_date, embedding))
    attendance_db.commit()
    send = jsonify({'status': 'success', 'student_id': student_id})
    cursor.close()
    attendance_db.close()
    return send
                

@app.route('/verify', methods=['POST'])
def verify():
    attendance_db = pymysql.connect(read_default_file="./DB.cnf")
    cursor = attendance_db.cursor(pymysql.cursors.DictCursor)
    send = {'form':'json'}
    image = base64.b64decode(request.form['image'])
    image_np = np.frombuffer(image, dtype=np.uint8)
    image_np = cv2.imdecode(image_np, cv2.IMREAD_UNCHANGED)
    image_np = resize(image_np)
    image_np = prewhiten(image_np)
    image_np = image_np.reshape(-1, image_size, image_size, 3)
    # get embedding
    feed_dict = {input_placeholder:image_np, phase_train_placeholder:False }
    embedding = sess.run(embeddings_placeholder, feed_dict=feed_dict)
    # compare received embedding to database embedding
    verified_id = None
    sql = "SELECT student_id, embedding FROM student_embedding;"
    cursor.execute(sql)
    result = cursor.fetchall()
    for row_data in result:
        db_embedding = np.frombuffer(row_data['embedding'], dtype=np.float32)
        db_embedding = db_embedding.reshape((1,512))
        distance = get_distance(embedding, db_embedding)
        print(distance)
        if (distance < threshold):
            verified_id = row_data['student_id']
            new_embedding = db_embedding * 0.8 + embedding * 0.2
            new_embedding = new_embedding.tobytes()
            sql = "UPDATE student_embedding SET embedding=_binary %s WHERE student_id = %s"
            cursor.execute(sql, (new_embedding, verified_id))
            attendance_db.commit()
            break

    if verified_id != None:
        sql = "SELECT student_name FROM student WHERE student_id = %s"
        cursor.execute(sql, (verified_id))
        row_data = cursor.fetchone()
        verified_name = row_data['student_name']
        sql = "SELECT DATE(attendance_time) FROM student_attendance WHERE (lecture_id=%s) AND (student_id=%s) AND (DATE(attendance_time) = CURDATE());"
        cursor.execute(sql, ('0', verified_id))
        if cursor.rowcount == 0:
            sql = "INSERT INTO student_attendance(lecture_id, student_id, status) VALUES (%s, %s, %s)"
            # TODO: attend / late 처리
            cursor.execute(sql, ('0', verified_id, 'attend'))
            attendance_db.commit()
            # log 작성
            msg='[{id}] verification success'.format(id=verified_id)
            print(msg)
            send = jsonify({'status': 'attend', 'student_id': verified_id, 'student_name': verified_name})
        else:
            msg='[{id}] verification failed: already verified'.format(id=verified_id)
            print(msg)
            send = jsonify({'status': 'already', 'student_id': verified_id, 'student_name': verified_name})
    else:
        # 인증 실패
        msg='[0000000000] verification failed'
        print(msg)
        send = jsonify({'status': 'fail'})
    cursor.close()
    attendance_db.close()
    return send

def load_model(model, input_map=None):
    # Check if the model is a model directory (containing a metagraph and a checkpoint file)
    #  or if it is a protobuf file with a frozen graph
    model_exp = os.path.expanduser(model)
    if (os.path.isfile(model_exp)):
        print('Model filename: %s' % model_exp)
        with gfile.FastGFile(model_exp,'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            tf.import_graph_def(graph_def, input_map=input_map, name='')
    else:
        print('Model directory: %s' % model_exp)
        meta_file, ckpt_file = get_model_filenames(model_exp)

        print('Metagraph file: %s' % meta_file)
        print('Checkpoint file: %s' % ckpt_file)

        saver = tf.compat.v1.train.import_meta_graph(os.path.join(model_exp, meta_file), input_map=input_map)
        saver.restore(sess, os.path.join(model_exp, ckpt_file))

def get_model_filenames(model_dir):
    files = os.listdir(model_dir)
    meta_files = [s for s in files if s.endswith('.meta')]
    if len(meta_files)==0:
        raise ValueError('No meta file found in the model directory (%s)' % model_dir)
    elif len(meta_files)>1:
        raise ValueError('There should not be more than one meta file in the model directory (%s)' % model_dir)
    meta_file = meta_files[0]
    ckpt = tf.train.get_checkpoint_state(model_dir)
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_file = os.path.basename(ckpt.model_checkpoint_path)
        return meta_file, ckpt_file

    meta_files = [s for s in files if '.ckpt' in s]
    max_step = -1
    for f in files:
        step_str = re.match(r'(^model-[\w\- ]+.ckpt-(\d+))', f)
        if step_str is not None and len(step_str.groups())>=2:
            step = int(step_str.groups()[1])
            if step > max_step:
                max_step = step
                ckpt_file = step_str.groups()[0]
    return meta_file, ckpt_file

print('load tensorflow model')
load_model(model_dir)
input_placeholder = tf.compat.v1.get_default_graph().get_tensor_by_name("input:0")
embeddings_placeholder = tf.compat.v1.get_default_graph().get_tensor_by_name("embeddings:0")
phase_train_placeholder = tf.compat.v1.get_default_graph().get_tensor_by_name("phase_train:0")