신은섭(Shin Eun Seop)

fianl 1st

1 +"""
2 +kin dataset
3 +"""
4 +
5 +import os
6 +import numpy as np
7 +# from kor_char_parser import decompose_str_as_one_hot
8 +
9 +import text_helpers
10 +from konlpy.tag import Twitter
11 +pos_tagger = Twitter()
12 +
13 +class KinQueryDataset:
14 + """
15 + 지식인 데이터를 읽어서, tuple (데이터, 레이블)의 형태로 리턴하는 파이썬 오브젝트 입니다.
16 + """
17 + def __init__(self, dataset_path: str, max_length: int):
18 + """
19 + :param dataset_path: 데이터셋 root path
20 + :param max_length: 문자열의 최대 길이
21 + """
22 + # 데이터, 레이블 각각의 경로
23 + queries_path = os.path.join(dataset_path, 'train', 'train_data')
24 + labels_path = os.path.join(dataset_path, 'train', 'train_label')
25 +
26 + # 지식인 데이터를 읽고 preprocess까지 진행합니다
27 + with open(queries_path, 'rt', encoding='utf8') as f:
28 + self.queries = preprocess(f.readlines(), max_length)
29 + # 지식인 레이블을 읽고 preprocess까지 진행합니다.
30 + with open(labels_path) as f:
31 + self.labels = np.array([[np.float32(x)] for x in f.readlines()])
32 +
33 + def __len__(self):
34 + """
35 + :return: 전체 데이터의 수를 리턴합니다
36 + """
37 + return len(self.queries)
38 +
39 + def __getitem__(self, idx):
40 + """
41 + :param idx: 필요한 데이터의 인덱스
42 + :return: 인덱스에 맞는 데이터, 레이블 pair를 리턴합니다
43 + """
44 + return self.queries[idx], self.labels[idx]
45 +
46 +def tokenize(doc):
47 + # norm, stem은 optional
48 + return ['/'.join(t) for t in pos_tagger.pos(doc, norm=True, stem=True)]
49 +
50 +def preprocess(data: list, max_length: int):
51 + train_docs = [(tokenize(row[0]), tokenize(row[1])) for row in data]
52 +
No preview for this file type
1 +# -*- coding: utf-8 -*-
2 +from konlpy.corpus import kolaw
3 +def read_data(filename):
4 + with open(filename, 'r') as f:
5 + data = [line.split('\t') for line in f.read().splitlines()]
6 + data = data[1:] # header 제외
7 + return data
8 +
9 +train_data = kolaw.open('constitution.txt').read()
10 +
11 +print(len(train_data)) # nrows: 150000
12 +print(len(train_data[0]))
13 +
14 +from konlpy.tag import Twitter
15 +pos_tagger = Twitter()
16 +
17 +def tokenize(doc):
18 + # norm, stem은 optional
19 + return ['/'.join(t) for t in pos_tagger.pos(doc, norm=True, stem=True)]
20 +
21 +train_docs = []
22 +for row in train_data:
23 + train_docs.append((tokenize(row[0]), '0'))
24 + # train_docs.append((tokenize(row[1]), '0'))
25 +
26 +# 잘 들어갔는지 확인
27 +from pprint import pprint
28 +pprint(train_docs[0])
29 +
30 +from gensim.models.doc2vec import TaggedDocument
31 +tagged_train_docs = [TaggedDocument(d, [c]) for d, c in train_docs]
32 +
33 +from gensim.models import doc2vec
34 +import multiprocessing
35 +cores = multiprocessing.cpu_count()
36 +
37 +# 사전 구축
38 +doc_vectorizer = doc2vec.Doc2Vec(vector_size=1000, alpha=0.025, min_alpha=0.025, seed=1234, epochs=100, workers=cores, hs=1)
39 +doc_vectorizer.build_vocab(tagged_train_docs)
40 +doc_vectorizer.train(tagged_train_docs, epochs=doc_vectorizer.epochs, total_examples=doc_vectorizer.corpus_count)
41 +
42 +# To save
43 +doc_vectorizer.save('doc2vec.model')
44 +
45 +doc_vectorizer = doc2vec.Doc2Vec.load('doc2vec.model')
46 +pprint(doc_vectorizer.wv.most_similar('한국/Noun'))
1 +# -*- coding: utf-8 -*-
2 +
3 +"""
4 +Copyright 2018 NAVER Corp.
5 +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
6 +associated documentation files (the "Software"), to deal in the Software without restriction, including
7 +without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 +copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to
9 +the following conditions:
10 +The above copyright notice and this permission notice shall be included in all copies or substantial
11 +portions of the Software.
12 +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
13 +INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
14 +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
15 +HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
16 +CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE
17 +OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
18 +"""
19 +
20 +cho = "ㄱㄲㄴㄷㄸㄹㅁㅂㅃㅅㅆㅇㅈㅉㅊㅋㅌㅍㅎ" # len = 19
21 +jung = "ㅏㅐㅑㅒㅓㅔㅕㅖㅗㅘㅙㅚㅛㅜㅝㅞㅟㅠㅡㅢㅣ" # len = 21
22 +# len = 27
23 +jong = "ㄱ/ㄲ/ㄱㅅ/ㄴ/ㄴㅈ/ㄴㅎ/ㄷ/ㄹ/ㄹㄱ/ㄹㅁ/ㄹㅂ/ㄹㅅ/ㄹㅌ/ㄹㅍ/ㄹㅎ/ㅁ/ㅂ/ㅂㅅ/ㅅ/ㅆ/ㅇ/ㅈ/ㅊ/ㅋ/ㅌ/ㅍ/ㅎ".split(
24 + '/')
25 +test = cho + jung + ''.join(jong)
26 +
27 +hangul_length = len(cho) + len(jung) + len(jong) # 67
28 +
29 +
30 +def is_valid_decomposition_atom(x):
31 + return x in test
32 +
33 +
34 +def decompose(x):
35 + in_char = x
36 + if x < ord('가') or x > ord('힣'):
37 + return chr(x)
38 + x = x - ord('가')
39 + y = x // 28
40 + z = x % 28
41 + x = y // 21
42 + y = y % 21
43 + # if there is jong, then is z > 0. So z starts from 1 index.
44 + zz = jong[z - 1] if z > 0 else ''
45 + if x >= len(cho):
46 + print('Unknown Exception: ', in_char, chr(in_char), x, y, z, zz)
47 + return cho[x] + jung[y] + zz
48 +
49 +
50 +def decompose_as_one_hot(in_char, warning=True):
51 + one_hot = []
52 + # print(ord('ㅣ'), chr(0xac00))
53 + # [0,66]: hangul / [67,194]: ASCII / [195,245]: hangul danja,danmo / [246,249]: special characters
54 + # Total 250 dimensions.
55 + if ord('가') <= in_char <= ord('힣'): # 가:44032 , 힣: 55203
56 + x = in_char - 44032 # in_char - ord('가')
57 + y = x // 28
58 + z = x % 28
59 + x = y // 21
60 + y = y % 21
61 + # if there is jong, then is z > 0. So z starts from 1 index.
62 + zz = jong[z - 1] if z > 0 else ''
63 + if x >= len(cho):
64 + if warning:
65 + print('Unknown Exception: ', in_char,
66 + chr(in_char), x, y, z, zz)
67 +
68 + one_hot.append(x)
69 + one_hot.append(len(cho) + y)
70 + if z > 0:
71 + one_hot.append(len(cho) + len(jung) + (z - 1))
72 + return one_hot
73 + else:
74 + if in_char < 128:
75 + result = hangul_length + in_char # 67~
76 + elif ord('ㄱ') <= in_char <= ord('ㅣ'):
77 + # 194~ # [ㄱ:12593]~[ㅣ:12643] (len = 51)
78 + result = hangul_length + 128 + (in_char - 12593)
79 + elif in_char == ord('♡'):
80 + result = hangul_length + 128 + 51 # 245~ # ♡
81 + elif in_char == ord('♥'):
82 + result = hangul_length + 128 + 51 + 1 # ♥
83 + elif in_char == ord('★'):
84 + result = hangul_length + 128 + 51 + 2 # ★
85 + elif in_char == ord('☆'):
86 + result = hangul_length + 128 + 51 + 3 # ☆
87 + else:
88 + if warning:
89 + print('Unhandled character:', chr(in_char), in_char)
90 + # unknown character
91 + result = hangul_length + 128 + 51 + 4 # for unknown character
92 +
93 + return [result]
94 +
95 +
96 +def decompose_str(string):
97 + return ''.join([decompose(ord(x)) for x in string])
98 +
99 +
100 +def decompose_str_as_one_hot(string, warning=True):
101 + tmp_list = []
102 + for x in string:
103 + da = decompose_as_one_hot(ord(x), warning=warning)
104 + tmp_list.extend(da)
105 + return tmp_list
1 +# -*- coding: utf-8 -*-
2 +
3 +"""
4 +Copyright 2018 NAVER Corp.
5 +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
6 +associated documentation files (the "Software"), to deal in the Software without restriction, including
7 +without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 +copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to
9 +the following conditions:
10 +The above copyright notice and this permission notice shall be included in all copies or substantial
11 +portions of the Software.
12 +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
13 +INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
14 +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
15 +HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
16 +CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE
17 +OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
18 +"""
19 +
20 +
21 +import argparse
22 +import os
23 +
24 +import numpy as np
25 +import tensorflow as tf
26 +
27 +import nsml
28 +from nsml import DATASET_PATH, HAS_DATASET, IS_ON_NSML
29 +from dataset import KinQueryDataset, preprocess
30 +
31 +
32 +# DONOTCHANGE: They are reserved for nsml
33 +# This is for nsml leaderboard
34 +def bind_model(sess, config):
35 + # 학습한 모델을 저장하는 함수입니다.
36 + def save(dir_name, *args):
37 + # directory
38 + os.makedirs(dir_name, exist_ok=True)
39 + saver = tf.train.Saver()
40 + saver.save(sess, os.path.join(dir_name, 'model'))
41 +
42 + # 저장한 모델을 불러올 수 있는 함수입니다.
43 + def load(dir_name, *args):
44 + saver = tf.train.Saver()
45 + # find checkpoint
46 + ckpt = tf.train.get_checkpoint_state(dir_name)
47 + if ckpt and ckpt.model_checkpoint_path:
48 + checkpoint = os.path.basename(ckpt.model_checkpoint_path)
49 + saver.restore(sess, os.path.join(dir_name, checkpoint))
50 + else:
51 + raise NotImplemented('No checkpoint!')
52 + print('Model loaded')
53 +
54 + def infer(raw_data, **kwargs):
55 + """
56 + :param raw_data: raw input (여기서는 문자열)을 입력받습니다
57 + :param kwargs:
58 + :return:
59 + """
60 + # dataset.py에서 작성한 preprocess 함수를 호출하여, 문자열을 벡터로 변환합니다
61 + preprocessed_data = preprocess(raw_data, config.strmaxlen)
62 + # 저장한 모델에 입력값을 넣고 prediction 결과를 리턴받습니다
63 + pred = sess.run(output_sigmoid, feed_dict={x: preprocessed_data})
64 + clipped = np.array(pred > config.threshold, dtype=np.int)
65 + # DONOTCHANGE: They are reserved for nsml
66 + # 리턴 결과는 [(확률, 0 or 1)] 의 형태로 보내야만 리더보드에 올릴 수 있습니다. 리더보드 결과에 확률의 값은 영향을 미치지 않습니다
67 + return list(zip(pred.flatten(), clipped.flatten()))
68 +
69 + # DONOTCHANGE: They are reserved for nsml
70 + # nsml에서 지정한 함수에 접근할 수 있도록 하는 함수입니다.
71 + nsml.bind(save=save, load=load, infer=infer)
72 +
73 +
74 +def _batch_loader(iterable, n=1):
75 + """
76 + 데이터를 배치 사이즈만큼 잘라서 보내주는 함수입니다. PyTorch의 DataLoader와 같은 역할을 합니다
77 + :param iterable: 데이터 list, 혹은 다른 포맷
78 + :param n: 배치 사이즈
79 + :return:
80 + """
81 + length = len(iterable)
82 + for n_idx in range(0, length, n):
83 + yield iterable[n_idx:min(n_idx + n, length)]
84 +
85 +
86 +def weight_variable(shape):
87 + initial = tf.truncated_normal(shape, stddev=0.1)
88 + return tf.Variable(initial)
89 +
90 +
91 +def bias_variable(shape):
92 + initial = tf.constant(0.1, shape=shape)
93 + return tf.Variable(initial)
94 +
95 +
96 +if __name__ == '__main__':
97 + args = argparse.ArgumentParser()
98 + # DONOTCHANGE: They are reserved for nsml
99 + args.add_argument('--mode', type=str, default='train')
100 + args.add_argument('--pause', type=int, default=0)
101 + args.add_argument('--iteration', type=str, default='0')
102 +
103 + # User options
104 + args.add_argument('--output', type=int, default=1)
105 + args.add_argument('--epochs', type=int, default=10)
106 + args.add_argument('--batch', type=int, default=2000)
107 + args.add_argument('--strmaxlen', type=int, default=400)
108 + args.add_argument('--embedding', type=int, default=8)
109 + args.add_argument('--threshold', type=float, default=0.5)
110 + config = args.parse_args()
111 +
112 + if not HAS_DATASET and not IS_ON_NSML: # It is not running on nsml
113 + DATASET_PATH = '../sample_data/kin/'
114 +
115 + # 모델의 specification
116 + input_size = config.embedding*config.strmaxlen
117 + output_size = 1
118 + hidden_layer_size = 200
119 + learning_rate = 0.001
120 + character_size = 251
121 +
122 + x = tf.placeholder(tf.int32, [None, config.strmaxlen])
123 + y_ = tf.placeholder(tf.float32, [None, output_size])
124 + # 임베딩
125 + char_embedding = tf.get_variable('char_embedding', [character_size, config.embedding])
126 + embedded = tf.nn.embedding_lookup(char_embedding, x)
127 +
128 + # 첫 번째 레이어
129 + first_layer_weight = weight_variable([input_size, hidden_layer_size])
130 + first_layer_bias = bias_variable([hidden_layer_size])
131 + hidden_layer = tf.matmul(tf.reshape(embedded, (-1, input_size)),
132 + first_layer_weight) + first_layer_bias
133 +
134 + # 두 번째 (아웃풋) 레이어
135 + second_layer_weight = weight_variable([hidden_layer_size, output_size])
136 + second_layer_bias = bias_variable([output_size])
137 + output = tf.matmul(hidden_layer, second_layer_weight) + second_layer_bias
138 + output_sigmoid = tf.sigmoid(output)
139 +
140 + # loss와 optimizer
141 + binary_cross_entropy = tf.reduce_mean(-(y_ * tf.log(output_sigmoid)) - (1-y_) * tf.log(1-output_sigmoid))
142 + train_step = tf.train.AdamOptimizer(learning_rate).minimize(binary_cross_entropy)
143 +
144 + sess = tf.InteractiveSession()
145 + tf.global_variables_initializer().run()
146 +
147 + # DONOTCHANGE: Reserved for nsml
148 + bind_model(sess=sess, config=config)
149 +
150 + # DONOTCHANGE: Reserved for nsml
151 + if config.pause:
152 + nsml.paused(scope=locals())
153 +
154 + if config.mode == 'train':
155 + # 데이터를 로드합니다.
156 + dataset = KinQueryDataset(DATASET_PATH, config.strmaxlen)
157 + dataset_len = len(dataset)
158 + one_batch_size = dataset_len//config.batch
159 + if dataset_len % config.batch != 0:
160 + one_batch_size += 1
161 + # epoch마다 학습을 수행합니다.
162 + for epoch in range(config.epochs):
163 + avg_loss = 0.0
164 + for i, (data, labels) in enumerate(_batch_loader(dataset, config.batch)):
165 + _, loss = sess.run([train_step, binary_cross_entropy],
166 + feed_dict={x: data, y_: labels})
167 + print('Batch : ', i + 1, '/', one_batch_size,
168 + ', BCE in this minibatch: ', float(loss))
169 + avg_loss += float(loss)
170 + print('epoch:', epoch, ' train_loss:', float(avg_loss/one_batch_size))
171 + nsml.report(summary=True, scope=locals(), epoch=epoch, epoch_total=config.epochs,
172 + train__loss=float(avg_loss/one_batch_size), step=epoch)
173 + # DONOTCHANGE (You can decide how often you want to save the model)
174 + nsml.save(epoch)
175 +
176 + # 로컬 테스트 모드일때 사용합니다
177 + # 결과가 아래와 같이 나온다면, nsml submit을 통해서 제출할 수 있습니다.
178 + # [(0.3, 0), (0.7, 1), ... ]
179 + elif config.mode == 'test_local':
180 + with open(os.path.join(DATASET_PATH, 'train/train_data'), 'rt', encoding='utf-8') as f:
181 + queries = f.readlines()
182 + res = []
183 + for batch in _batch_loader(queries, config.batch):
184 + temp_res = nsml.infer(batch)
185 + res += temp_res
186 + print(res)
1 +"""
2 +Copyright 2018 NAVER Corp.
3 +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
4 +associated documentation files (the "Software"), to deal in the Software without restriction, including
5 +without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
6 +copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to
7 +the following conditions:
8 +The above copyright notice and this permission notice shall be included in all copies or substantial
9 +portions of the Software.
10 +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
11 +INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
12 +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
13 +HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
14 +CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE
15 +OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
16 +"""
17 +
18 +from distutils.core import setup
19 +setup(
20 + name='nsml movie review',
21 + version='1.0',
22 + description='',
23 + install_requires=[
24 + 'nltk',
25 + 'konlpy',
26 + 'twython'
27 +
28 + ]
29 +)
...\ No newline at end of file ...\ No newline at end of file
This diff is collapsed. Click to expand it.
1 +# Text Helper Functions
2 +#---------------------------------------
3 +#
4 +# We pull out text helper functions to reduce redundant code
5 +
6 +import string
7 +import os
8 +import urllib.request
9 +import io
10 +import tarfile
11 +import collections
12 +import numpy as np
13 +
14 +# Normalize text
15 +def normalize_text(texts, stops):
16 + # Lower case
17 + texts = [x.lower() for x in texts]
18 +
19 + # Remove punctuation
20 + texts = [''.join(c for c in x if c not in string.punctuation) for x in texts]
21 +
22 + # Remove numbers
23 + texts = [''.join(c for c in x if c not in '0123456789') for x in texts]
24 +
25 + # Remove stopwords
26 + texts = [' '.join([word for word in x.split() if word not in (stops)]) for x in texts]
27 +
28 + # Trim extra whitespace
29 + texts = [' '.join(x.split()) for x in texts]
30 +
31 + return(texts)
32 +
33 +
34 +# Build dictionary of words
35 +def build_dictionary(sentences, vocabulary_size):
36 + # Turn sentences (list of strings) into lists of words
37 + split_sentences = [s.split() for s in sentences]
38 + words = [x for sublist in split_sentences for x in sublist]
39 +
40 + # Initialize list of [word, word_count] for each word, starting with unknown
41 + count = [['RARE', -1]]
42 +
43 + # Now add most frequent words, limited to the N-most frequent (N=vocabulary size)
44 + count.extend(collections.Counter(words).most_common(vocabulary_size-1))
45 +
46 + # Now create the dictionary
47 + word_dict = {}
48 + # For each word, that we want in the dictionary, add it, then make it
49 + # the value of the prior dictionary length
50 + for word, word_count in count:
51 + word_dict[word] = len(word_dict)
52 +
53 + return(word_dict)
54 +
55 +
56 +# Turn text data into lists of integers from dictionary
57 +def text_to_numbers(sentences, word_dict):
58 + # Initialize the returned data
59 + data = []
60 + for sentence in sentences:
61 + sentence_data = []
62 + # For each word, either use selected index or rare word index
63 + for word in sentence.split():
64 + if word in word_dict:
65 + word_ix = word_dict[word]
66 + else:
67 + word_ix = 0
68 + sentence_data.append(word_ix)
69 + data.append(sentence_data)
70 + return(data)
71 +
72 +
73 +# Generate data randomly (N words behind, target, N words ahead)
74 +def generate_batch_data(sentences, batch_size, window_size, method='skip_gram'):
75 + # Fill up data batch
76 + batch_data = []
77 + label_data = []
78 + while len(batch_data) < batch_size:
79 + # select random sentence to start
80 + rand_sentence_ix = int(np.random.choice(len(sentences), size=1))
81 + rand_sentence = sentences[rand_sentence_ix]
82 + # Generate consecutive windows to look at
83 + window_sequences = [rand_sentence[max((ix-window_size),0):(ix+window_size+1)] for ix, x in enumerate(rand_sentence)]
84 + # Denote which element of each window is the center word of interest
85 + label_indices = [ix if ix<window_size else window_size for ix,x in enumerate(window_sequences)]
86 +
87 + # Pull out center word of interest for each window and create a tuple for each window
88 + if method=='skip_gram':
89 + batch_and_labels = [(x[y], x[:y] + x[(y+1):]) for x,y in zip(window_sequences, label_indices)]
90 + # Make it in to a big list of tuples (target word, surrounding word)
91 + tuple_data = [(x, y_) for x,y in batch_and_labels for y_ in y]
92 + batch, labels = [list(x) for x in zip(*tuple_data)]
93 + elif method=='cbow':
94 + batch_and_labels = [(x[:y] + x[(y+1):], x[y]) for x,y in zip(window_sequences, label_indices)]
95 + # Only keep windows with consistent 2*window_size
96 + batch_and_labels = [(x,y) for x,y in batch_and_labels if len(x)==2*window_size]
97 + batch, labels = [list(x) for x in zip(*batch_and_labels)]
98 + elif method=='doc2vec':
99 + # For doc2vec we keep LHS window only to predict target word
100 + batch_and_labels = [(rand_sentence[i:i+window_size], rand_sentence[i+window_size]) for i in range(0, len(rand_sentence)-window_size)]
101 + batch, labels = [list(x) for x in zip(*batch_and_labels)]
102 + # Add document index to batch!! Remember that we must extract the last index in batch for the doc-index
103 + batch = [x + [rand_sentence_ix] for x in batch]
104 + else:
105 + raise ValueError('Method {} not implmented yet.'.format(method))
106 +
107 + # extract batch and labels
108 + batch_data.extend(batch[:batch_size])
109 + label_data.extend(labels[:batch_size])
110 + # Trim batch and label at the end
111 + batch_data = batch_data[:batch_size]
112 + label_data = label_data[:batch_size]
113 +
114 + # Convert to numpy array
115 + batch_data = np.array(batch_data)
116 + label_data = np.transpose(np.array([label_data]))
117 +
118 + return(batch_data, label_data)
...\ No newline at end of file ...\ No newline at end of file