data_loader.py 4.07 KB
import os
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from torch.utils.data import Dataset
from kogpt2.utils import get_tokenizer

class ArticleDataset(Dataset):
    """
    기사 학습을 위한 데이터셋
    dataset for learn articles
    """
    def __init__(self, dataset_path:str, topics:set=set(['경제', '문화', '미용_건강', '사회', '생활', '스포츠', '연예', '정치', 'IT_과학']), label:str='train',
     transform=None, use_cache=False):
        """
        Initializer
        :param dataset_path: path of parquet dataset
        :param topic: if not None, only use specified topics; must be subset of {경제, 문화, 미용_건강, 사회, 생활, 스포츠, 연예, 정치, IT_과학}
        :param label: specify type of dataset; must be one of [train, test, valid] (default is train)
        :param transform: if not None, transforms data. (paragraph:stringScalar, topic:stringScalar)=>Tensor
        :param use_cache: if True, __getitem__ uses cache. Must be used after first epoch.
        """
        expanded_dataset_path = os.path.expanduser(dataset_path)
        tables=[]
        for topic in topics:
            table=pq.read_table(f'{expanded_dataset_path}/topic={topic}/label={label}',columns=['paragraph'])
            tables.append(table.append_column('topic',pa.array([topic]*len(table))))
        self.data=pa.concat_tables(tables)
        self.transform=transform
        self.use_cache=use_cache
        self.cache=[None]*len(self.data)
        #if self.transform is not None: too slow
        #    self.data=[ self.transform((p,t)) for p, t in zip(self.data['paragraph'],self.data['topic'])]

    def __len__(self):
        return len(self.data)

    def __getitem__(self,index):
        item=(self.data['paragraph'][index], self.data['topic'][index]) if self.transform is None \
            else self.transform((self.data['paragraph'][index], self.data['topic'][index]))
        if self.use_cache and self.cache[index] is not None:
                return self.cache[index]
        else:
            self.cache[index]=item
            return item

    def load_from_file(self, cache_file_path:str):
        self.use_cache=True
        self.cache=torch.from_numpy(np.load(cache_file_path))
    
    def set_use_cache(self, use_cache:bool, cache_file_path:str=None):
        self.use_cache=use_cache
        if use_cache:
            if isinstance(self.cache,torch.Tensor):
                if cache_file_path is not None:
                    np.save(cache_file_path,self.cache.numpy())
                else:
                    print("Already fully cached.")
                return
            try:
                self.cache=torch.stack(self.cache)
                if cache_file_path is not None:
                    np.save(cache_file_path,self.cache.numpy())
            except RuntimeError:
                print("Not fully cached yet. Please run epoch with num_worker=0.")
                return
        else:
            self.cache=[]

class ToTensor(object):
    """
    Convert Article dataset paragraph to Tensor using tokenizer
    """
    def __init__(self, tokenizer, vocab, max_len=512):
        self.tokenizer=tokenizer
        self.vocab=vocab
        self.max_len=max_len
    
    def __call__(self, sample):
        tokens=[]
        paragraph=sample[0]
        topic=sample[1]
        for i, sentence in enumerate(paragraph): 
            if i==0:
                line=[self.vocab[self.vocab.bos_token]]+self.vocab[self.tokenizer(topic.as_py())+self.tokenizer(sentence.as_py())]+[self.vocab[self.vocab.eos_token]]
            else:
                line=[self.vocab[self.vocab.bos_token]]+self.vocab[self.tokenizer(sentence.as_py())]+[self.vocab[self.vocab.eos_token]]
            if len(tokens)+len(line)<=self.max_len: # prevent sentence fragment
                tokens+=line
            else:
                break
        tokens+=([self.vocab[self.vocab.padding_token]]*(self.max_len-len(tokens))) # indicate padding with -100 ref: https://huggingface.co/transformers/model_doc/gpt2.html#gpt2lmheadmodel 
        return torch.tensor(tokens,dtype=torch.long)