data_loader.py
4.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from torch.utils.data import Dataset
from kogpt2.utils import get_tokenizer
class ArticleDataset(Dataset):
"""
기사 학습을 위한 데이터셋
dataset for learn articles
"""
def __init__(self, dataset_path:str, topics:set=set(['경제', '문화', '미용_건강', '사회', '생활', '스포츠', '연예', '정치', 'IT_과학']), label:str='train',
transform=None, use_cache=False):
"""
Initializer
:param dataset_path: path of parquet dataset
:param topic: if not None, only use specified topics; must be subset of {경제, 문화, 미용_건강, 사회, 생활, 스포츠, 연예, 정치, IT_과학}
:param label: specify type of dataset; must be one of [train, test, valid] (default is train)
:param transform: if not None, transforms data. (paragraph:stringScalar, topic:stringScalar)=>Tensor
:param use_cache: if True, __getitem__ uses cache. Must be used after first epoch.
"""
expanded_dataset_path = os.path.expanduser(dataset_path)
tables=[]
for topic in topics:
table=pq.read_table(f'{expanded_dataset_path}/topic={topic}/label={label}',columns=['paragraph'])
tables.append(table.append_column('topic',pa.array([topic]*len(table))))
self.data=pa.concat_tables(tables)
self.transform=transform
self.use_cache=use_cache
self.cache=[None]*len(self.data)
#if self.transform is not None: too slow
# self.data=[ self.transform((p,t)) for p, t in zip(self.data['paragraph'],self.data['topic'])]
def __len__(self):
return len(self.data)
def __getitem__(self,index):
item=(self.data['paragraph'][index], self.data['topic'][index]) if self.transform is None \
else self.transform((self.data['paragraph'][index], self.data['topic'][index]))
if self.use_cache and self.cache[index] is not None:
return self.cache[index]
else:
self.cache[index]=item
return item
def load_from_file(self, cache_file_path:str):
self.use_cache=True
self.cache=torch.from_numpy(np.load(cache_file_path))
def set_use_cache(self, use_cache:bool, cache_file_path:str=None):
self.use_cache=use_cache
if use_cache:
if isinstance(self.cache,torch.Tensor):
if cache_file_path is not None:
np.save(cache_file_path,self.cache.numpy())
else:
print("Already fully cached.")
return
try:
self.cache=torch.stack(self.cache)
if cache_file_path is not None:
np.save(cache_file_path,self.cache.numpy())
except RuntimeError:
print("Not fully cached yet. Please run epoch with num_worker=0.")
return
else:
self.cache=[]
class ToTensor(object):
"""
Convert Article dataset paragraph to Tensor using tokenizer
"""
def __init__(self, tokenizer, vocab, max_len=512):
self.tokenizer=tokenizer
self.vocab=vocab
self.max_len=max_len
def __call__(self, sample):
tokens=[]
paragraph=sample[0]
topic=sample[1]
for i, sentence in enumerate(paragraph):
if i==0:
line=[self.vocab[self.vocab.bos_token]]+self.vocab[self.tokenizer(topic.as_py())+self.tokenizer(sentence.as_py())]+[self.vocab[self.vocab.eos_token]]
else:
line=[self.vocab[self.vocab.bos_token]]+self.vocab[self.tokenizer(sentence.as_py())]+[self.vocab[self.vocab.eos_token]]
if len(tokens)+len(line)<=self.max_len: # prevent sentence fragment
tokens+=line
else:
break
tokens+=([self.vocab[self.vocab.padding_token]]*(self.max_len-len(tokens))) # indicate padding with -100 ref: https://huggingface.co/transformers/model_doc/gpt2.html#gpt2lmheadmodel
return torch.tensor(tokens,dtype=torch.long)