Showing
1 changed file
with
0 additions
and
53 deletions
final/flask/server.py
deleted
100644 → 0
1 | -import numpy as np | ||
2 | -import pickle | ||
3 | -import tensorflow as tf | ||
4 | -from flask import Flask, jsonify, render_template, request | ||
5 | -import model | ||
6 | - | ||
7 | -# Load in data structures | ||
8 | -with open("data/wordList.txt", "rb") as fp: | ||
9 | - wordList = pickle.load(fp) | ||
10 | -wordList.append('<pad>') | ||
11 | -wordList.append('<EOS>') | ||
12 | - | ||
13 | -# Load in hyperparamters | ||
14 | -vocabSize = len(wordList) | ||
15 | -batchSize = 24 | ||
16 | -maxEncoderLength = 15 | ||
17 | -maxDecoderLength = 15 | ||
18 | -lstmUnits = 112 | ||
19 | -numLayersLSTM = 3 | ||
20 | - | ||
21 | -# Create placeholders | ||
22 | -encoderInputs = [tf.placeholder(tf.int32, shape=(None,)) for i in range(maxEncoderLength)] | ||
23 | -decoderLabels = [tf.placeholder(tf.int32, shape=(None,)) for i in range(maxDecoderLength)] | ||
24 | -decoderInputs = [tf.placeholder(tf.int32, shape=(None,)) for i in range(maxDecoderLength)] | ||
25 | -feedPrevious = tf.placeholder(tf.bool) | ||
26 | - | ||
27 | -encoderLSTM = tf.nn.rnn_cell.BasicLSTMCell(lstmUnits, state_is_tuple=True) | ||
28 | -#encoderLSTM = tf.nn.rnn_cell.MultiRNNCell([singleCell]*numLayersLSTM, state_is_tuple=True) | ||
29 | -decoderOutputs, decoderFinalState = tf.contrib.legacy_seq2seq.embedding_rnn_seq2seq(encoderInputs, decoderInputs, encoderLSTM, | ||
30 | - vocabSize, vocabSize, lstmUnits, feed_previous=feedPrevious) | ||
31 | - | ||
32 | -decoderPrediction = tf.argmax(decoderOutputs, 2) | ||
33 | - | ||
34 | -# Start session and get graph | ||
35 | -sess = tf.Session() | ||
36 | -#y, variables = model.getModel(encoderInputs, decoderLabels, decoderInputs, feedPrevious) | ||
37 | - | ||
38 | -# Load in pretrained model | ||
39 | -saver = tf.train.Saver() | ||
40 | -saver.restore(sess, tf.train.latest_checkpoint('models')) | ||
41 | -zeroVector = np.zeros((1), dtype='int32') | ||
42 | - | ||
43 | -def pred(inputString): | ||
44 | - inputVector = model.getTestInput(inputString, wordList, maxEncoderLength) | ||
45 | - feedDict = {encoderInputs[t]: inputVector[t] for t in range(maxEncoderLength)} | ||
46 | - feedDict.update({decoderLabels[t]: zeroVector for t in range(maxDecoderLength)}) | ||
47 | - feedDict.update({decoderInputs[t]: zeroVector for t in range(maxDecoderLength)}) | ||
48 | - feedDict.update({feedPrevious: True}) | ||
49 | - ids = (sess.run(decoderPrediction, feed_dict=feedDict)) | ||
50 | - return model.idsToSentence(ids, wordList) | ||
51 | - | ||
52 | -# webapp | ||
53 | -app = Flask(__name__, template_folder='./') | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment