Showing
1 changed file
with
21 additions
and
35 deletions
1 | # -*- coding: utf-8 -*- | 1 | # -*- coding: utf-8 -*- |
2 | 2 | ||
3 | -""" | ||
4 | -Copyright 2018 NAVER Corp. | ||
5 | -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and | ||
6 | -associated documentation files (the "Software"), to deal in the Software without restriction, including | ||
7 | -without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
8 | -copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to | ||
9 | -the following conditions: | ||
10 | -The above copyright notice and this permission notice shall be included in all copies or substantial | ||
11 | -portions of the Software. | ||
12 | -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, | ||
13 | -INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A | ||
14 | -PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT | ||
15 | -HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF | ||
16 | -CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE | ||
17 | -OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | ||
18 | -""" | ||
19 | - | ||
20 | - | ||
21 | import argparse | 3 | import argparse |
22 | import os | 4 | import os |
23 | 5 | ||
... | @@ -103,9 +85,9 @@ if __name__ == '__main__': | ... | @@ -103,9 +85,9 @@ if __name__ == '__main__': |
103 | # User options | 85 | # User options |
104 | args.add_argument('--output', type=int, default=1) | 86 | args.add_argument('--output', type=int, default=1) |
105 | args.add_argument('--epochs', type=int, default=10) | 87 | args.add_argument('--epochs', type=int, default=10) |
106 | - args.add_argument('--batch', type=int, default=2000) | 88 | + args.add_argument('--batch', type=int, default=3000) |
107 | args.add_argument('--strmaxlen', type=int, default=400) | 89 | args.add_argument('--strmaxlen', type=int, default=400) |
108 | - args.add_argument('--embedding', type=int, default=8) | 90 | + args.add_argument('--embedding', type=int, default=20) |
109 | args.add_argument('--threshold', type=float, default=0.5) | 91 | args.add_argument('--threshold', type=float, default=0.5) |
110 | config = args.parse_args() | 92 | config = args.parse_args() |
111 | 93 | ||
... | @@ -115,27 +97,31 @@ if __name__ == '__main__': | ... | @@ -115,27 +97,31 @@ if __name__ == '__main__': |
115 | # 모델의 specification | 97 | # 모델의 specification |
116 | input_size = config.embedding*config.strmaxlen | 98 | input_size = config.embedding*config.strmaxlen |
117 | output_size = 1 | 99 | output_size = 1 |
118 | - hidden_layer_size = 200 | 100 | + learning_rate = 0.01 |
119 | - learning_rate = 0.001 | ||
120 | character_size = 251 | 101 | character_size = 251 |
121 | 102 | ||
122 | x = tf.placeholder(tf.int32, [None, config.strmaxlen]) | 103 | x = tf.placeholder(tf.int32, [None, config.strmaxlen]) |
123 | y_ = tf.placeholder(tf.float32, [None, output_size]) | 104 | y_ = tf.placeholder(tf.float32, [None, output_size]) |
105 | + keep_probs = tf.placeholder(tf.float32) | ||
124 | # 임베딩 | 106 | # 임베딩 |
125 | char_embedding = tf.get_variable('char_embedding', [character_size, config.embedding]) | 107 | char_embedding = tf.get_variable('char_embedding', [character_size, config.embedding]) |
126 | - embedded = tf.nn.embedding_lookup(char_embedding, x) | 108 | + embedded_chars_base = tf.nn.embedding_lookup(char_embedding, x) |
127 | - | 109 | + embedded = tf.expand_dims(embedded_chars_base, -1) |
128 | - # 첫 번째 레이어 | 110 | + print("emb", embedded.shape) |
129 | - first_layer_weight = weight_variable([input_size, hidden_layer_size]) | 111 | + ## MODEL |
130 | - first_layer_bias = bias_variable([hidden_layer_size]) | 112 | + l3_1 = tf.layers.conv2d(embedded, 512, [3, config.embedding], activation=tf.nn.relu) |
131 | - hidden_layer = tf.matmul(tf.reshape(embedded, (-1, input_size)), | 113 | + print("l3-1", l3_1.shape) |
132 | - first_layer_weight) + first_layer_bias | 114 | + l3_1 = tf.layers.max_pooling2d(l3_1, [character_size-3+1, 1]) |
115 | + print("l3-1 pool", l3_1.shape) | ||
116 | + l3_2 = tf.layers.conv2d(l3_1, 1024, [3, config.embedding], activation=tf.nn.relu) | ||
117 | + l3_2 = tf.layers.max_pooling2d(l3_2, [character_size-3+1, 1]) | ||
118 | + l3_3 = tf.layers.conv2d(l3_2, 512, [3, config.embedding], activation=tf.nn.relu) | ||
119 | + l3_3 = tf.layers.max_pooling2d(l3_3, [character_size-3+1, 1]) | ||
120 | + flatten = tf.fontrib.layers.flatten(l3_3) | ||
121 | + | ||
122 | + drop = tf.layers.dropout(l3_2, keep_probs) | ||
123 | + output_sigmoid = tf.layers.dense(flatten, output_size, activation=tf.nn.sigmoid) | ||
133 | 124 | ||
134 | - # 두 번째 (아웃풋) 레이어 | ||
135 | - second_layer_weight = weight_variable([hidden_layer_size, output_size]) | ||
136 | - second_layer_bias = bias_variable([output_size]) | ||
137 | - output = tf.matmul(hidden_layer, second_layer_weight) + second_layer_bias | ||
138 | - output_sigmoid = tf.sigmoid(output) | ||
139 | 125 | ||
140 | # loss와 optimizer | 126 | # loss와 optimizer |
141 | binary_cross_entropy = tf.reduce_mean(-(y_ * tf.log(output_sigmoid)) - (1-y_) * tf.log(1-output_sigmoid)) | 127 | binary_cross_entropy = tf.reduce_mean(-(y_ * tf.log(output_sigmoid)) - (1-y_) * tf.log(1-output_sigmoid)) |
... | @@ -163,7 +149,7 @@ if __name__ == '__main__': | ... | @@ -163,7 +149,7 @@ if __name__ == '__main__': |
163 | avg_loss = 0.0 | 149 | avg_loss = 0.0 |
164 | for i, (data, labels) in enumerate(_batch_loader(dataset, config.batch)): | 150 | for i, (data, labels) in enumerate(_batch_loader(dataset, config.batch)): |
165 | _, loss = sess.run([train_step, binary_cross_entropy], | 151 | _, loss = sess.run([train_step, binary_cross_entropy], |
166 | - feed_dict={x: data, y_: labels}) | 152 | + feed_dict={x: data, y_: labels, keep_probs: 0.9}) |
167 | print('Batch : ', i + 1, '/', one_batch_size, | 153 | print('Batch : ', i + 1, '/', one_batch_size, |
168 | ', BCE in this minibatch: ', float(loss)) | 154 | ', BCE in this minibatch: ', float(loss)) |
169 | avg_loss += float(loss) | 155 | avg_loss += float(loss) | ... | ... |
-
Please register or login to post a comment