data_loader.py 1.97 KB
import os
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from torch.utils.data import Dataset
from kogpt2.utils import get_tokenizer

class ArticleDataset(Dataset):
    """
    기사 학습을 위한 데이터셋
    dataset for learn articles
    """
    def __init__(self, dataset_path:str, topics:list=['경제', '문화', '미용_건강', '사회', '생활', '스포츠', '연예', '정치', 'IT_과학'], label:str='train'):
        """
        Initializer
        :param dataset_path: path of parquet dataset
        :param topic: if not None, only use specified topics; must be sublist of [경제, 문화, 미용_건강, 사회, 생활, 스포츠, 연예, 정치, IT_과학]
        :param label: specify type of dataset; must be one of [train, test, valid] (default is train)
        """
        expanded_dataset_path = os.path.expanduser(dataset_path)
        tables=[]
        for topic in topics:
            table=pq.read_table(f'{expanded_dataset_path}/topic={topic}/label={label}',columns=['paragraph'])
            tables.append(table.append_column('topic',pa.array([topic]*len(table))))
        self.data=pa.concat_tables(tables)

    def __len__(self):
        return len(self.data)

    def __getitem__(self,index):
        return self.data['paragraph'][index], self.data['topic'][index]

class ToTensor(object):
    """
    Convert Article dataset paragraph to Tensor using tokenizer
    """
    def __init__(self, tokenizer, vocab):
        self.tokenizer=tokenizer
        self.vocab=vocab
    
    def __call__(self, sample):
        tokens=[]
        for i, sentence in enumerate(sample[0]): 
            if i==0:
                tokens+=[self.vocab[self.vocab.bos_token]]+self.vocab[self.tokenizer(sample[1].as_py())+self.tokenizer(sentence.as_py())]+[self.vocab[self.vocab.eos_token]]
            else:
                tokens+=[self.vocab[self.vocab.bos_token]]+self.vocab[self.tokenizer(sentence.as_py())]+[self.vocab[self.vocab.eos_token]]
        return torch.Tensor(tokens)