dataset_.py 2.19 KB
import torch
from torch.utils.data import Dataset
import gluonnlp as nlp
import numpy as np
from kobert.utils import get_tokenizer
from KoBERT.Sentiment_Analysis_BERT_main import bertmodel, vocab

tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)

class BERTDataset(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len,
                 pad, pair):
        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)

        self.sentences = [transform([i[sent_idx]]) for i in dataset]
        self.labels = [np.int32(i[label_idx]) for i in dataset]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))

    def __len__(self):
        return (len(self.labels))

class infer_BERTDataset(Dataset):
    def __init__(self, dataset, sent_idx, bert_tokenizer, max_len,
                 pad, pair):
        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)

        self.sentences = [transform([i[sent_idx]]) for i in dataset]

    def __getitem__(self, i):
        return (self.sentences[i])

def get_loader(args):
    dataset_train = nlp.data.TSVDataset("ratings_train.txt", field_indices=[1, 2], num_discard_samples=1)
    dataset_test = nlp.data.TSVDataset("ratings_test.txt", field_indices=[1, 2], num_discard_samples=1)
    #chatbot_0325_label_0.txt
    data_train = BERTDataset(dataset_train, 0, 1, tok, args.max_len, True, False)
    data_test = BERTDataset(dataset_test, 0, 1, tok, args.max_len, True, False)

    train_dataloader = torch.utils.data.DataLoader(
        data_train, batch_size=args.batch_size, drop_last=True, shuffle=True)
    test_dataloader = torch.utils.data.DataLoader(
        data_test, batch_size=args.batch_size, drop_last=False, shuffle=False)

    return train_dataloader, test_dataloader

def infer(args, src):
   SRC_data = infer_BERTDataset(src, 0, tok, args.max_len, True, False)
   return SRC_data

# import csv
# num=0
# f = open('chatbot_0325_label_0.txt', 'r', encoding='utf-8')
# rdr = csv.reader(f, delimiter='\t')
# for idx, lin in enumerate(rdr):
#     num+=1
# print(num)