Toggle navigation
Toggle navigation
This project
Loading...
Sign in
2020-1-capstone-design1
/
PSB_Project1
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Snippets
Network
Create a new issue
Builds
Commits
Issue Boards
Authored by
bongminkim
2020-04-02 14:17:45 +0900
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
877c7eff222222f909e4c7b34597d2b0136ba8fc
877c7eff
1 parent
c951d5e9
bert_SA_datasetpy
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
61 additions
and
0 deletions
KoBERT/dataset_.py
KoBERT/dataset_.py
0 → 100644
View file @
877c7ef
import
torch
from
torch.utils.data
import
Dataset
import
gluonnlp
as
nlp
import
numpy
as
np
from
kobert.utils
import
get_tokenizer
from
KoBERT.Sentiment_Analysis_BERT_main
import
bertmodel
,
vocab
tokenizer
=
get_tokenizer
()
tok
=
nlp
.
data
.
BERTSPTokenizer
(
tokenizer
,
vocab
,
lower
=
False
)
class
BERTDataset
(
Dataset
):
def
__init__
(
self
,
dataset
,
sent_idx
,
label_idx
,
bert_tokenizer
,
max_len
,
pad
,
pair
):
transform
=
nlp
.
data
.
BERTSentenceTransform
(
bert_tokenizer
,
max_seq_length
=
max_len
,
pad
=
pad
,
pair
=
pair
)
self
.
sentences
=
[
transform
([
i
[
sent_idx
]])
for
i
in
dataset
]
self
.
labels
=
[
np
.
int32
(
i
[
label_idx
])
for
i
in
dataset
]
def
__getitem__
(
self
,
i
):
return
(
self
.
sentences
[
i
]
+
(
self
.
labels
[
i
],
))
def
__len__
(
self
):
return
(
len
(
self
.
labels
))
class
infer_BERTDataset
(
Dataset
):
def
__init__
(
self
,
dataset
,
sent_idx
,
bert_tokenizer
,
max_len
,
pad
,
pair
):
transform
=
nlp
.
data
.
BERTSentenceTransform
(
bert_tokenizer
,
max_seq_length
=
max_len
,
pad
=
pad
,
pair
=
pair
)
self
.
sentences
=
[
transform
([
i
[
sent_idx
]])
for
i
in
dataset
]
def
__getitem__
(
self
,
i
):
return
(
self
.
sentences
[
i
])
def
get_loader
(
args
):
dataset_train
=
nlp
.
data
.
TSVDataset
(
"ratings_train.txt"
,
field_indices
=
[
1
,
2
],
num_discard_samples
=
1
)
dataset_test
=
nlp
.
data
.
TSVDataset
(
"ratings_test.txt"
,
field_indices
=
[
1
,
2
],
num_discard_samples
=
1
)
#chatbot_0325_label_0.txt
data_train
=
BERTDataset
(
dataset_train
,
0
,
1
,
tok
,
args
.
max_len
,
True
,
False
)
data_test
=
BERTDataset
(
dataset_test
,
0
,
1
,
tok
,
args
.
max_len
,
True
,
False
)
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
data_train
,
batch_size
=
args
.
batch_size
,
drop_last
=
True
,
shuffle
=
True
)
test_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
data_test
,
batch_size
=
args
.
batch_size
,
drop_last
=
False
,
shuffle
=
False
)
return
train_dataloader
,
test_dataloader
def
infer
(
args
,
src
):
SRC_data
=
infer_BERTDataset
(
src
,
0
,
tok
,
args
.
max_len
,
True
,
False
)
return
SRC_data
# import csv
# num=0
# f = open('chatbot_0325_label_0.txt', 'r', encoding='utf-8')
# rdr = csv.reader(f, delimiter='\t')
# for idx, lin in enumerate(rdr):
# num+=1
# print(num)
Please
register
or
login
to post a comment