Showing
8 changed files
with
536 additions
and
0 deletions
movie2/dataset.py
0 → 100644
1 | +""" | ||
2 | +kin dataset | ||
3 | +""" | ||
4 | + | ||
5 | +import os | ||
6 | +import numpy as np | ||
7 | +# from kor_char_parser import decompose_str_as_one_hot | ||
8 | + | ||
9 | +import text_helpers | ||
10 | +from konlpy.tag import Twitter | ||
11 | +pos_tagger = Twitter() | ||
12 | + | ||
13 | +class KinQueryDataset: | ||
14 | + """ | ||
15 | + 지식인 데이터를 읽어서, tuple (데이터, 레이블)의 형태로 리턴하는 파이썬 오브젝트 입니다. | ||
16 | + """ | ||
17 | + def __init__(self, dataset_path: str, max_length: int): | ||
18 | + """ | ||
19 | + :param dataset_path: 데이터셋 root path | ||
20 | + :param max_length: 문자열의 최대 길이 | ||
21 | + """ | ||
22 | + # 데이터, 레이블 각각의 경로 | ||
23 | + queries_path = os.path.join(dataset_path, 'train', 'train_data') | ||
24 | + labels_path = os.path.join(dataset_path, 'train', 'train_label') | ||
25 | + | ||
26 | + # 지식인 데이터를 읽고 preprocess까지 진행합니다 | ||
27 | + with open(queries_path, 'rt', encoding='utf8') as f: | ||
28 | + self.queries = preprocess(f.readlines(), max_length) | ||
29 | + # 지식인 레이블을 읽고 preprocess까지 진행합니다. | ||
30 | + with open(labels_path) as f: | ||
31 | + self.labels = np.array([[np.float32(x)] for x in f.readlines()]) | ||
32 | + | ||
33 | + def __len__(self): | ||
34 | + """ | ||
35 | + :return: 전체 데이터의 수를 리턴합니다 | ||
36 | + """ | ||
37 | + return len(self.queries) | ||
38 | + | ||
39 | + def __getitem__(self, idx): | ||
40 | + """ | ||
41 | + :param idx: 필요한 데이터의 인덱스 | ||
42 | + :return: 인덱스에 맞는 데이터, 레이블 pair를 리턴합니다 | ||
43 | + """ | ||
44 | + return self.queries[idx], self.labels[idx] | ||
45 | + | ||
46 | +def tokenize(doc): | ||
47 | + # norm, stem은 optional | ||
48 | + return ['/'.join(t) for t in pos_tagger.pos(doc, norm=True, stem=True)] | ||
49 | + | ||
50 | +def preprocess(data: list, max_length: int): | ||
51 | + train_docs = [(tokenize(row[0]), tokenize(row[1])) for row in data] | ||
52 | + |
movie2/doc2vec.model
0 → 100644
No preview for this file type
movie2/embadding.py
0 → 100644
1 | +# -*- coding: utf-8 -*- | ||
2 | +from konlpy.corpus import kolaw | ||
3 | +def read_data(filename): | ||
4 | + with open(filename, 'r') as f: | ||
5 | + data = [line.split('\t') for line in f.read().splitlines()] | ||
6 | + data = data[1:] # header 제외 | ||
7 | + return data | ||
8 | + | ||
9 | +train_data = kolaw.open('constitution.txt').read() | ||
10 | + | ||
11 | +print(len(train_data)) # nrows: 150000 | ||
12 | +print(len(train_data[0])) | ||
13 | + | ||
14 | +from konlpy.tag import Twitter | ||
15 | +pos_tagger = Twitter() | ||
16 | + | ||
17 | +def tokenize(doc): | ||
18 | + # norm, stem은 optional | ||
19 | + return ['/'.join(t) for t in pos_tagger.pos(doc, norm=True, stem=True)] | ||
20 | + | ||
21 | +train_docs = [] | ||
22 | +for row in train_data: | ||
23 | + train_docs.append((tokenize(row[0]), '0')) | ||
24 | + # train_docs.append((tokenize(row[1]), '0')) | ||
25 | + | ||
26 | +# 잘 들어갔는지 확인 | ||
27 | +from pprint import pprint | ||
28 | +pprint(train_docs[0]) | ||
29 | + | ||
30 | +from gensim.models.doc2vec import TaggedDocument | ||
31 | +tagged_train_docs = [TaggedDocument(d, [c]) for d, c in train_docs] | ||
32 | + | ||
33 | +from gensim.models import doc2vec | ||
34 | +import multiprocessing | ||
35 | +cores = multiprocessing.cpu_count() | ||
36 | + | ||
37 | +# 사전 구축 | ||
38 | +doc_vectorizer = doc2vec.Doc2Vec(vector_size=1000, alpha=0.025, min_alpha=0.025, seed=1234, epochs=100, workers=cores, hs=1) | ||
39 | +doc_vectorizer.build_vocab(tagged_train_docs) | ||
40 | +doc_vectorizer.train(tagged_train_docs, epochs=doc_vectorizer.epochs, total_examples=doc_vectorizer.corpus_count) | ||
41 | + | ||
42 | +# To save | ||
43 | +doc_vectorizer.save('doc2vec.model') | ||
44 | + | ||
45 | +doc_vectorizer = doc2vec.Doc2Vec.load('doc2vec.model') | ||
46 | +pprint(doc_vectorizer.wv.most_similar('한국/Noun')) |
movie2/kor_char_parser.py
0 → 100644
1 | +# -*- coding: utf-8 -*- | ||
2 | + | ||
3 | +""" | ||
4 | +Copyright 2018 NAVER Corp. | ||
5 | +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and | ||
6 | +associated documentation files (the "Software"), to deal in the Software without restriction, including | ||
7 | +without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
8 | +copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to | ||
9 | +the following conditions: | ||
10 | +The above copyright notice and this permission notice shall be included in all copies or substantial | ||
11 | +portions of the Software. | ||
12 | +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, | ||
13 | +INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A | ||
14 | +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT | ||
15 | +HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF | ||
16 | +CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE | ||
17 | +OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | ||
18 | +""" | ||
19 | + | ||
20 | +cho = "ㄱㄲㄴㄷㄸㄹㅁㅂㅃㅅㅆㅇㅈㅉㅊㅋㅌㅍㅎ" # len = 19 | ||
21 | +jung = "ㅏㅐㅑㅒㅓㅔㅕㅖㅗㅘㅙㅚㅛㅜㅝㅞㅟㅠㅡㅢㅣ" # len = 21 | ||
22 | +# len = 27 | ||
23 | +jong = "ㄱ/ㄲ/ㄱㅅ/ㄴ/ㄴㅈ/ㄴㅎ/ㄷ/ㄹ/ㄹㄱ/ㄹㅁ/ㄹㅂ/ㄹㅅ/ㄹㅌ/ㄹㅍ/ㄹㅎ/ㅁ/ㅂ/ㅂㅅ/ㅅ/ㅆ/ㅇ/ㅈ/ㅊ/ㅋ/ㅌ/ㅍ/ㅎ".split( | ||
24 | + '/') | ||
25 | +test = cho + jung + ''.join(jong) | ||
26 | + | ||
27 | +hangul_length = len(cho) + len(jung) + len(jong) # 67 | ||
28 | + | ||
29 | + | ||
30 | +def is_valid_decomposition_atom(x): | ||
31 | + return x in test | ||
32 | + | ||
33 | + | ||
34 | +def decompose(x): | ||
35 | + in_char = x | ||
36 | + if x < ord('가') or x > ord('힣'): | ||
37 | + return chr(x) | ||
38 | + x = x - ord('가') | ||
39 | + y = x // 28 | ||
40 | + z = x % 28 | ||
41 | + x = y // 21 | ||
42 | + y = y % 21 | ||
43 | + # if there is jong, then is z > 0. So z starts from 1 index. | ||
44 | + zz = jong[z - 1] if z > 0 else '' | ||
45 | + if x >= len(cho): | ||
46 | + print('Unknown Exception: ', in_char, chr(in_char), x, y, z, zz) | ||
47 | + return cho[x] + jung[y] + zz | ||
48 | + | ||
49 | + | ||
50 | +def decompose_as_one_hot(in_char, warning=True): | ||
51 | + one_hot = [] | ||
52 | + # print(ord('ㅣ'), chr(0xac00)) | ||
53 | + # [0,66]: hangul / [67,194]: ASCII / [195,245]: hangul danja,danmo / [246,249]: special characters | ||
54 | + # Total 250 dimensions. | ||
55 | + if ord('가') <= in_char <= ord('힣'): # 가:44032 , 힣: 55203 | ||
56 | + x = in_char - 44032 # in_char - ord('가') | ||
57 | + y = x // 28 | ||
58 | + z = x % 28 | ||
59 | + x = y // 21 | ||
60 | + y = y % 21 | ||
61 | + # if there is jong, then is z > 0. So z starts from 1 index. | ||
62 | + zz = jong[z - 1] if z > 0 else '' | ||
63 | + if x >= len(cho): | ||
64 | + if warning: | ||
65 | + print('Unknown Exception: ', in_char, | ||
66 | + chr(in_char), x, y, z, zz) | ||
67 | + | ||
68 | + one_hot.append(x) | ||
69 | + one_hot.append(len(cho) + y) | ||
70 | + if z > 0: | ||
71 | + one_hot.append(len(cho) + len(jung) + (z - 1)) | ||
72 | + return one_hot | ||
73 | + else: | ||
74 | + if in_char < 128: | ||
75 | + result = hangul_length + in_char # 67~ | ||
76 | + elif ord('ㄱ') <= in_char <= ord('ㅣ'): | ||
77 | + # 194~ # [ㄱ:12593]~[ㅣ:12643] (len = 51) | ||
78 | + result = hangul_length + 128 + (in_char - 12593) | ||
79 | + elif in_char == ord('♡'): | ||
80 | + result = hangul_length + 128 + 51 # 245~ # ♡ | ||
81 | + elif in_char == ord('♥'): | ||
82 | + result = hangul_length + 128 + 51 + 1 # ♥ | ||
83 | + elif in_char == ord('★'): | ||
84 | + result = hangul_length + 128 + 51 + 2 # ★ | ||
85 | + elif in_char == ord('☆'): | ||
86 | + result = hangul_length + 128 + 51 + 3 # ☆ | ||
87 | + else: | ||
88 | + if warning: | ||
89 | + print('Unhandled character:', chr(in_char), in_char) | ||
90 | + # unknown character | ||
91 | + result = hangul_length + 128 + 51 + 4 # for unknown character | ||
92 | + | ||
93 | + return [result] | ||
94 | + | ||
95 | + | ||
96 | +def decompose_str(string): | ||
97 | + return ''.join([decompose(ord(x)) for x in string]) | ||
98 | + | ||
99 | + | ||
100 | +def decompose_str_as_one_hot(string, warning=True): | ||
101 | + tmp_list = [] | ||
102 | + for x in string: | ||
103 | + da = decompose_as_one_hot(ord(x), warning=warning) | ||
104 | + tmp_list.extend(da) | ||
105 | + return tmp_list |
movie2/main.py
0 → 100644
1 | +# -*- coding: utf-8 -*- | ||
2 | + | ||
3 | +""" | ||
4 | +Copyright 2018 NAVER Corp. | ||
5 | +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and | ||
6 | +associated documentation files (the "Software"), to deal in the Software without restriction, including | ||
7 | +without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
8 | +copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to | ||
9 | +the following conditions: | ||
10 | +The above copyright notice and this permission notice shall be included in all copies or substantial | ||
11 | +portions of the Software. | ||
12 | +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, | ||
13 | +INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A | ||
14 | +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT | ||
15 | +HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF | ||
16 | +CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE | ||
17 | +OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | ||
18 | +""" | ||
19 | + | ||
20 | + | ||
21 | +import argparse | ||
22 | +import os | ||
23 | + | ||
24 | +import numpy as np | ||
25 | +import tensorflow as tf | ||
26 | + | ||
27 | +import nsml | ||
28 | +from nsml import DATASET_PATH, HAS_DATASET, IS_ON_NSML | ||
29 | +from dataset import KinQueryDataset, preprocess | ||
30 | + | ||
31 | + | ||
32 | +# DONOTCHANGE: They are reserved for nsml | ||
33 | +# This is for nsml leaderboard | ||
34 | +def bind_model(sess, config): | ||
35 | + # 학습한 모델을 저장하는 함수입니다. | ||
36 | + def save(dir_name, *args): | ||
37 | + # directory | ||
38 | + os.makedirs(dir_name, exist_ok=True) | ||
39 | + saver = tf.train.Saver() | ||
40 | + saver.save(sess, os.path.join(dir_name, 'model')) | ||
41 | + | ||
42 | + # 저장한 모델을 불러올 수 있는 함수입니다. | ||
43 | + def load(dir_name, *args): | ||
44 | + saver = tf.train.Saver() | ||
45 | + # find checkpoint | ||
46 | + ckpt = tf.train.get_checkpoint_state(dir_name) | ||
47 | + if ckpt and ckpt.model_checkpoint_path: | ||
48 | + checkpoint = os.path.basename(ckpt.model_checkpoint_path) | ||
49 | + saver.restore(sess, os.path.join(dir_name, checkpoint)) | ||
50 | + else: | ||
51 | + raise NotImplemented('No checkpoint!') | ||
52 | + print('Model loaded') | ||
53 | + | ||
54 | + def infer(raw_data, **kwargs): | ||
55 | + """ | ||
56 | + :param raw_data: raw input (여기서는 문자열)을 입력받습니다 | ||
57 | + :param kwargs: | ||
58 | + :return: | ||
59 | + """ | ||
60 | + # dataset.py에서 작성한 preprocess 함수를 호출하여, 문자열을 벡터로 변환합니다 | ||
61 | + preprocessed_data = preprocess(raw_data, config.strmaxlen) | ||
62 | + # 저장한 모델에 입력값을 넣고 prediction 결과를 리턴받습니다 | ||
63 | + pred = sess.run(output_sigmoid, feed_dict={x: preprocessed_data}) | ||
64 | + clipped = np.array(pred > config.threshold, dtype=np.int) | ||
65 | + # DONOTCHANGE: They are reserved for nsml | ||
66 | + # 리턴 결과는 [(확률, 0 or 1)] 의 형태로 보내야만 리더보드에 올릴 수 있습니다. 리더보드 결과에 확률의 값은 영향을 미치지 않습니다 | ||
67 | + return list(zip(pred.flatten(), clipped.flatten())) | ||
68 | + | ||
69 | + # DONOTCHANGE: They are reserved for nsml | ||
70 | + # nsml에서 지정한 함수에 접근할 수 있도록 하는 함수입니다. | ||
71 | + nsml.bind(save=save, load=load, infer=infer) | ||
72 | + | ||
73 | + | ||
74 | +def _batch_loader(iterable, n=1): | ||
75 | + """ | ||
76 | + 데이터를 배치 사이즈만큼 잘라서 보내주는 함수입니다. PyTorch의 DataLoader와 같은 역할을 합니다 | ||
77 | + :param iterable: 데이터 list, 혹은 다른 포맷 | ||
78 | + :param n: 배치 사이즈 | ||
79 | + :return: | ||
80 | + """ | ||
81 | + length = len(iterable) | ||
82 | + for n_idx in range(0, length, n): | ||
83 | + yield iterable[n_idx:min(n_idx + n, length)] | ||
84 | + | ||
85 | + | ||
86 | +def weight_variable(shape): | ||
87 | + initial = tf.truncated_normal(shape, stddev=0.1) | ||
88 | + return tf.Variable(initial) | ||
89 | + | ||
90 | + | ||
91 | +def bias_variable(shape): | ||
92 | + initial = tf.constant(0.1, shape=shape) | ||
93 | + return tf.Variable(initial) | ||
94 | + | ||
95 | + | ||
96 | +if __name__ == '__main__': | ||
97 | + args = argparse.ArgumentParser() | ||
98 | + # DONOTCHANGE: They are reserved for nsml | ||
99 | + args.add_argument('--mode', type=str, default='train') | ||
100 | + args.add_argument('--pause', type=int, default=0) | ||
101 | + args.add_argument('--iteration', type=str, default='0') | ||
102 | + | ||
103 | + # User options | ||
104 | + args.add_argument('--output', type=int, default=1) | ||
105 | + args.add_argument('--epochs', type=int, default=10) | ||
106 | + args.add_argument('--batch', type=int, default=2000) | ||
107 | + args.add_argument('--strmaxlen', type=int, default=400) | ||
108 | + args.add_argument('--embedding', type=int, default=8) | ||
109 | + args.add_argument('--threshold', type=float, default=0.5) | ||
110 | + config = args.parse_args() | ||
111 | + | ||
112 | + if not HAS_DATASET and not IS_ON_NSML: # It is not running on nsml | ||
113 | + DATASET_PATH = '../sample_data/kin/' | ||
114 | + | ||
115 | + # 모델의 specification | ||
116 | + input_size = config.embedding*config.strmaxlen | ||
117 | + output_size = 1 | ||
118 | + hidden_layer_size = 200 | ||
119 | + learning_rate = 0.001 | ||
120 | + character_size = 251 | ||
121 | + | ||
122 | + x = tf.placeholder(tf.int32, [None, config.strmaxlen]) | ||
123 | + y_ = tf.placeholder(tf.float32, [None, output_size]) | ||
124 | + # 임베딩 | ||
125 | + char_embedding = tf.get_variable('char_embedding', [character_size, config.embedding]) | ||
126 | + embedded = tf.nn.embedding_lookup(char_embedding, x) | ||
127 | + | ||
128 | + # 첫 번째 레이어 | ||
129 | + first_layer_weight = weight_variable([input_size, hidden_layer_size]) | ||
130 | + first_layer_bias = bias_variable([hidden_layer_size]) | ||
131 | + hidden_layer = tf.matmul(tf.reshape(embedded, (-1, input_size)), | ||
132 | + first_layer_weight) + first_layer_bias | ||
133 | + | ||
134 | + # 두 번째 (아웃풋) 레이어 | ||
135 | + second_layer_weight = weight_variable([hidden_layer_size, output_size]) | ||
136 | + second_layer_bias = bias_variable([output_size]) | ||
137 | + output = tf.matmul(hidden_layer, second_layer_weight) + second_layer_bias | ||
138 | + output_sigmoid = tf.sigmoid(output) | ||
139 | + | ||
140 | + # loss와 optimizer | ||
141 | + binary_cross_entropy = tf.reduce_mean(-(y_ * tf.log(output_sigmoid)) - (1-y_) * tf.log(1-output_sigmoid)) | ||
142 | + train_step = tf.train.AdamOptimizer(learning_rate).minimize(binary_cross_entropy) | ||
143 | + | ||
144 | + sess = tf.InteractiveSession() | ||
145 | + tf.global_variables_initializer().run() | ||
146 | + | ||
147 | + # DONOTCHANGE: Reserved for nsml | ||
148 | + bind_model(sess=sess, config=config) | ||
149 | + | ||
150 | + # DONOTCHANGE: Reserved for nsml | ||
151 | + if config.pause: | ||
152 | + nsml.paused(scope=locals()) | ||
153 | + | ||
154 | + if config.mode == 'train': | ||
155 | + # 데이터를 로드합니다. | ||
156 | + dataset = KinQueryDataset(DATASET_PATH, config.strmaxlen) | ||
157 | + dataset_len = len(dataset) | ||
158 | + one_batch_size = dataset_len//config.batch | ||
159 | + if dataset_len % config.batch != 0: | ||
160 | + one_batch_size += 1 | ||
161 | + # epoch마다 학습을 수행합니다. | ||
162 | + for epoch in range(config.epochs): | ||
163 | + avg_loss = 0.0 | ||
164 | + for i, (data, labels) in enumerate(_batch_loader(dataset, config.batch)): | ||
165 | + _, loss = sess.run([train_step, binary_cross_entropy], | ||
166 | + feed_dict={x: data, y_: labels}) | ||
167 | + print('Batch : ', i + 1, '/', one_batch_size, | ||
168 | + ', BCE in this minibatch: ', float(loss)) | ||
169 | + avg_loss += float(loss) | ||
170 | + print('epoch:', epoch, ' train_loss:', float(avg_loss/one_batch_size)) | ||
171 | + nsml.report(summary=True, scope=locals(), epoch=epoch, epoch_total=config.epochs, | ||
172 | + train__loss=float(avg_loss/one_batch_size), step=epoch) | ||
173 | + # DONOTCHANGE (You can decide how often you want to save the model) | ||
174 | + nsml.save(epoch) | ||
175 | + | ||
176 | + # 로컬 테스트 모드일때 사용합니다 | ||
177 | + # 결과가 아래와 같이 나온다면, nsml submit을 통해서 제출할 수 있습니다. | ||
178 | + # [(0.3, 0), (0.7, 1), ... ] | ||
179 | + elif config.mode == 'test_local': | ||
180 | + with open(os.path.join(DATASET_PATH, 'train/train_data'), 'rt', encoding='utf-8') as f: | ||
181 | + queries = f.readlines() | ||
182 | + res = [] | ||
183 | + for batch in _batch_loader(queries, config.batch): | ||
184 | + temp_res = nsml.infer(batch) | ||
185 | + res += temp_res | ||
186 | + print(res) |
movie2/setup.py
0 → 100644
1 | +""" | ||
2 | +Copyright 2018 NAVER Corp. | ||
3 | +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and | ||
4 | +associated documentation files (the "Software"), to deal in the Software without restriction, including | ||
5 | +without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
6 | +copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to | ||
7 | +the following conditions: | ||
8 | +The above copyright notice and this permission notice shall be included in all copies or substantial | ||
9 | +portions of the Software. | ||
10 | +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, | ||
11 | +INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A | ||
12 | +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT | ||
13 | +HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF | ||
14 | +CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE | ||
15 | +OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | ||
16 | +""" | ||
17 | + | ||
18 | +from distutils.core import setup | ||
19 | +setup( | ||
20 | + name='nsml movie review', | ||
21 | + version='1.0', | ||
22 | + description='', | ||
23 | + install_requires=[ | ||
24 | + 'nltk', | ||
25 | + 'konlpy', | ||
26 | + 'twython' | ||
27 | + | ||
28 | + ] | ||
29 | +) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
movie2/test.txt
0 → 100644
This diff is collapsed. Click to expand it.
movie2/text_helpers.py
0 → 100644
1 | +# Text Helper Functions | ||
2 | +#--------------------------------------- | ||
3 | +# | ||
4 | +# We pull out text helper functions to reduce redundant code | ||
5 | + | ||
6 | +import string | ||
7 | +import os | ||
8 | +import urllib.request | ||
9 | +import io | ||
10 | +import tarfile | ||
11 | +import collections | ||
12 | +import numpy as np | ||
13 | + | ||
14 | +# Normalize text | ||
15 | +def normalize_text(texts, stops): | ||
16 | + # Lower case | ||
17 | + texts = [x.lower() for x in texts] | ||
18 | + | ||
19 | + # Remove punctuation | ||
20 | + texts = [''.join(c for c in x if c not in string.punctuation) for x in texts] | ||
21 | + | ||
22 | + # Remove numbers | ||
23 | + texts = [''.join(c for c in x if c not in '0123456789') for x in texts] | ||
24 | + | ||
25 | + # Remove stopwords | ||
26 | + texts = [' '.join([word for word in x.split() if word not in (stops)]) for x in texts] | ||
27 | + | ||
28 | + # Trim extra whitespace | ||
29 | + texts = [' '.join(x.split()) for x in texts] | ||
30 | + | ||
31 | + return(texts) | ||
32 | + | ||
33 | + | ||
34 | +# Build dictionary of words | ||
35 | +def build_dictionary(sentences, vocabulary_size): | ||
36 | + # Turn sentences (list of strings) into lists of words | ||
37 | + split_sentences = [s.split() for s in sentences] | ||
38 | + words = [x for sublist in split_sentences for x in sublist] | ||
39 | + | ||
40 | + # Initialize list of [word, word_count] for each word, starting with unknown | ||
41 | + count = [['RARE', -1]] | ||
42 | + | ||
43 | + # Now add most frequent words, limited to the N-most frequent (N=vocabulary size) | ||
44 | + count.extend(collections.Counter(words).most_common(vocabulary_size-1)) | ||
45 | + | ||
46 | + # Now create the dictionary | ||
47 | + word_dict = {} | ||
48 | + # For each word, that we want in the dictionary, add it, then make it | ||
49 | + # the value of the prior dictionary length | ||
50 | + for word, word_count in count: | ||
51 | + word_dict[word] = len(word_dict) | ||
52 | + | ||
53 | + return(word_dict) | ||
54 | + | ||
55 | + | ||
56 | +# Turn text data into lists of integers from dictionary | ||
57 | +def text_to_numbers(sentences, word_dict): | ||
58 | + # Initialize the returned data | ||
59 | + data = [] | ||
60 | + for sentence in sentences: | ||
61 | + sentence_data = [] | ||
62 | + # For each word, either use selected index or rare word index | ||
63 | + for word in sentence.split(): | ||
64 | + if word in word_dict: | ||
65 | + word_ix = word_dict[word] | ||
66 | + else: | ||
67 | + word_ix = 0 | ||
68 | + sentence_data.append(word_ix) | ||
69 | + data.append(sentence_data) | ||
70 | + return(data) | ||
71 | + | ||
72 | + | ||
73 | +# Generate data randomly (N words behind, target, N words ahead) | ||
74 | +def generate_batch_data(sentences, batch_size, window_size, method='skip_gram'): | ||
75 | + # Fill up data batch | ||
76 | + batch_data = [] | ||
77 | + label_data = [] | ||
78 | + while len(batch_data) < batch_size: | ||
79 | + # select random sentence to start | ||
80 | + rand_sentence_ix = int(np.random.choice(len(sentences), size=1)) | ||
81 | + rand_sentence = sentences[rand_sentence_ix] | ||
82 | + # Generate consecutive windows to look at | ||
83 | + window_sequences = [rand_sentence[max((ix-window_size),0):(ix+window_size+1)] for ix, x in enumerate(rand_sentence)] | ||
84 | + # Denote which element of each window is the center word of interest | ||
85 | + label_indices = [ix if ix<window_size else window_size for ix,x in enumerate(window_sequences)] | ||
86 | + | ||
87 | + # Pull out center word of interest for each window and create a tuple for each window | ||
88 | + if method=='skip_gram': | ||
89 | + batch_and_labels = [(x[y], x[:y] + x[(y+1):]) for x,y in zip(window_sequences, label_indices)] | ||
90 | + # Make it in to a big list of tuples (target word, surrounding word) | ||
91 | + tuple_data = [(x, y_) for x,y in batch_and_labels for y_ in y] | ||
92 | + batch, labels = [list(x) for x in zip(*tuple_data)] | ||
93 | + elif method=='cbow': | ||
94 | + batch_and_labels = [(x[:y] + x[(y+1):], x[y]) for x,y in zip(window_sequences, label_indices)] | ||
95 | + # Only keep windows with consistent 2*window_size | ||
96 | + batch_and_labels = [(x,y) for x,y in batch_and_labels if len(x)==2*window_size] | ||
97 | + batch, labels = [list(x) for x in zip(*batch_and_labels)] | ||
98 | + elif method=='doc2vec': | ||
99 | + # For doc2vec we keep LHS window only to predict target word | ||
100 | + batch_and_labels = [(rand_sentence[i:i+window_size], rand_sentence[i+window_size]) for i in range(0, len(rand_sentence)-window_size)] | ||
101 | + batch, labels = [list(x) for x in zip(*batch_and_labels)] | ||
102 | + # Add document index to batch!! Remember that we must extract the last index in batch for the doc-index | ||
103 | + batch = [x + [rand_sentence_ix] for x in batch] | ||
104 | + else: | ||
105 | + raise ValueError('Method {} not implmented yet.'.format(method)) | ||
106 | + | ||
107 | + # extract batch and labels | ||
108 | + batch_data.extend(batch[:batch_size]) | ||
109 | + label_data.extend(labels[:batch_size]) | ||
110 | + # Trim batch and label at the end | ||
111 | + batch_data = batch_data[:batch_size] | ||
112 | + label_data = label_data[:batch_size] | ||
113 | + | ||
114 | + # Convert to numpy array | ||
115 | + batch_data = np.array(batch_data) | ||
116 | + label_data = np.transpose(np.array([label_data])) | ||
117 | + | ||
118 | + return(batch_data, label_data) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment