from flask import Flask, render_template, Response, request, jsonify
import os
import json
import cv2
import numpy as np
import time
import datetime
import sys
import tensorflow as tf
import base64
import pymysql
import configparser

config = configparser.ConfigParser()'./config.cnf')

api = configparser.ConfigParser()'./API_form.cnf')

model_dir = config['verification_server']['model']
image_size = int(config['verification_server']['image_size'])
threshold = float(config['verification_server']['threshold'])

app = Flask(__name__)
sess = tf.compat.v1.Session()

def resize(image):
    resized = cv2.resize(image, (image_size, image_size), interpolation=cv2.INTER_CUBIC)
    return resized

def prewhiten(x):
    mean = np.mean(x)
    std = np.std(x)
    std_adj = np.maximum(std, 1.0/np.sqrt(x.size))
    y = np.multiply(np.subtract(x, mean), 1/std_adj)
    return y

def get_distance(arr1, arr2):
    # Euclidian distance
    distance = np.linalg.norm(arr1 - arr2)
    return distance

def index():
    """Video streaming page"""
    return render_template('index.html')

@app.route('/register', methods=['POST'])
def register():
    attendance_db = pymysql.connect(read_default_file="./DB.cnf")
    cursor = attendance_db.cursor(pymysql.cursors.DictCursor)
    send = {'form':'json'}
    student_id = request.form['student_id']
    student_name = request.form['student_name']
    msg='[{id}] register face'.format(id=student_id)
    sql = "SELECT student_id FROM student WHERE student_id = %s;"
    cursor.execute(sql, (student_id))
    if cursor.rowcount == 0:
        sql = "INSERT INTO student(student_id, student_name) VALUES (%s, %s)"
        cursor.execute(sql, (student_id, student_name))
        sql = "INSERT INTO lecture_students(lecture_id, student_id) VALUES (%s, %s)"
        # temp: student in lecture 0
        cursor.execute(sql, ('0', student_id))
        msg='[{id}] is registered'.format(id=student_id)
    # image to input tensor
    image = base64.b64decode(request.form['image'])
    image_np = np.frombuffer(image, dtype=np.uint8)
    image_np = cv2.imdecode(image_np, flags=1)
    image_np = resize(image_np)
    image_np = prewhiten(image_np)
    image_np = image_np.reshape(-1, image_size, image_size, 3)
    # get embedding
    feed_dict = {input_placeholder:image_np, phase_train_placeholder:False }
    embedding =, feed_dict=feed_dict)
    # embedding to blob
    embedding = embedding.tobytes()
    embedding_date ='%Y-%m-%d')
    sql = "insert into student_embedding(student_id, embedding_date, embedding) values (%s, %s, _binary %s)"
    cursor.execute(sql, (student_id, embedding_date, embedding))
    send = jsonify({'status': 'success', 'student_id': student_id})
    return send

@app.route('/verify', methods=['POST'])
def verify():
    attendance_db = pymysql.connect(read_default_file="./DB.cnf")
    cursor = attendance_db.cursor(pymysql.cursors.DictCursor)
    send = {'form':'json'}
    image = base64.b64decode(request.form['image'])
    image_np = np.frombuffer(image, dtype=np.uint8)
    image_np = cv2.imdecode(image_np, flags=1)
    image_np = resize(image_np)
    image_np = prewhiten(image_np)
    image_np = image_np.reshape(-1, image_size, image_size, 3)
    # get embedding
    feed_dict = {input_placeholder:image_np, phase_train_placeholder:False }
    embedding =, feed_dict=feed_dict)
    # compare received embedding to database embedding
    verified_id = None
    sql = "SELECT student_id, embedding FROM student_embedding;"
    result = cursor.fetchall()
    for row_data in result:
        db_embedding = np.frombuffer(row_data['embedding'], dtype=np.float32)
        db_embedding = db_embedding.reshape((1,512))
        distance = get_distance(embedding, db_embedding)
        if (distance < threshold):
            verified_id = row_data['student_id']
            new_embedding = db_embedding * 0.8 + embedding * 0.2
            new_embedding = new_embedding.tobytes()
            sql = "UPDATE student_embedding SET embedding=_binary %s WHERE student_id = %s"
            cursor.execute(sql, (new_embedding, verified_id))

    if verified_id != None:
        sql = "SELECT DATE(attendance_time) FROM student_attendance WHERE (lecture_id=%s) AND (student_id=%s) AND (DATE(attendance_time) = CURDATE());"
        cursor.execute(sql, ('0', verified_id))
        if cursor.rowcount == 0:
            sql = "INSERT INTO student_attendance(lecture_id, student_id, status) VALUES (%s, %s, %s)"
            # TODO: attend / late 처리
            cursor.execute(sql, ('0', verified_id, 'attend'))
            sql = "SELECT student_name FROM student WHERE student_id = %s"
            cursor.execute(sql, (verified_id))
            row_data = cursor.fetchone()
            verified_name = row_data['student_name']
            # log 작성
            msg='[{id}] verification success'.format(id=verified_id)
            send = jsonify({'status': 'attend', 'student_id': verified_id, 'student_name': verified_name})
            msg='[{id}] verification failed: already verified'.format(id=verified_id)
            send = jsonify({'status': 'already', 'student_id': verified_id})
        # 인증 실패
        msg='[0000000000] verification failed'
        send = jsonify({'status': 'fail'})
    return send

def load_model(model, input_map=None):
    # Check if the model is a model directory (containing a metagraph and a checkpoint file)
    #  or if it is a protobuf file with a frozen graph
    model_exp = os.path.expanduser(model)
    if (os.path.isfile(model_exp)):
        print('Model filename: %s' % model_exp)
        with gfile.FastGFile(model_exp,'rb') as f:
            graph_def = tf.GraphDef()
            tf.import_graph_def(graph_def, input_map=input_map, name='')
        print('Model directory: %s' % model_exp)
        meta_file, ckpt_file = get_model_filenames(model_exp)

        print('Metagraph file: %s' % meta_file)
        print('Checkpoint file: %s' % ckpt_file)

        saver = tf.compat.v1.train.import_meta_graph(os.path.join(model_exp, meta_file), input_map=input_map)
        saver.restore(sess, os.path.join(model_exp, ckpt_file))

def get_model_filenames(model_dir):
    files = os.listdir(model_dir)
    meta_files = [s for s in files if s.endswith('.meta')]
    if len(meta_files)==0:
        raise ValueError('No meta file found in the model directory (%s)' % model_dir)
    elif len(meta_files)>1:
        raise ValueError('There should not be more than one meta file in the model directory (%s)' % model_dir)
    meta_file = meta_files[0]
    ckpt = tf.train.get_checkpoint_state(model_dir)
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_file = os.path.basename(ckpt.model_checkpoint_path)
        return meta_file, ckpt_file

    meta_files = [s for s in files if '.ckpt' in s]
    max_step = -1
    for f in files:
        step_str = re.match(r'(^model-[\w\- ]+.ckpt-(\d+))', f)
        if step_str is not None and len(step_str.groups())>=2:
            step = int(step_str.groups()[1])
            if step > max_step:
                max_step = step
                ckpt_file = step_str.groups()[0]
    return meta_file, ckpt_file

print('load tensorflow model')
input_placeholder = tf.compat.v1.get_default_graph().get_tensor_by_name("input:0")
embeddings_placeholder = tf.compat.v1.get_default_graph().get_tensor_by_name("embeddings:0")
phase_train_placeholder = tf.compat.v1.get_default_graph().get_tensor_by_name("phase_train:0")