Toggle navigation
Toggle navigation
This project
Loading...
Sign in
graykode
/
commit-autosuggestions
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Snippets
Network
Create a new issue
Builds
Commits
Issue Boards
Authored by
graykode
2020-09-09 19:58:20 +0900
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
9b9ed4f689ae444f6de6f5748588aa20a94cf721
9b9ed4f6
1 parent
fef8c9aa
(refactor) black style
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1028 additions
and
381 deletions
commit_suggester.py
preprocess/__init__.py
preprocess/gitcommit.py
train/__init__.py
train/callbacks.py
train/finetune.py
train/generation_utils.py
train/lightning_base.py
train/modeling_bart.py
train/modeling_utils.py
train/utils.py
commit_suggester.py
View file @
9b9ed4f
...
...
@@ -68,6 +68,7 @@ def main(args):
)
print
(
commit_message
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Code to collect commits on github"
)
parser
.
add_argument
(
...
...
preprocess/__init__.py
View file @
9b9ed4f
# Copyright 2020-present Tae Hwan Jung
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -15,6 +15,6 @@
from
.gitcommit
import
diff_parse
,
truncate
__all__
=
[
'diff_parse'
,
'truncate'
,
]
\ No newline at end of file
"diff_parse"
,
"truncate"
,
]
...
...
preprocess/gitcommit.py
View file @
9b9ed4f
...
...
@@ -36,9 +36,11 @@ logging.basicConfig(
level
=
logging
.
INFO
,
)
class
PATCH
(
enum
.
Enum
):
PLUS
=
1
MINUS
=
2
PLUS
=
1
MINUS
=
2
def
truncate
(
tuple
,
max_length
,
value
=
0
):
ls
=
[]
...
...
@@ -46,22 +48,20 @@ def truncate(tuple, max_length, value=0):
if
isinstance
(
t
,
int
):
t
=
[
t
]
ls
.
extend
(
t
)
ls
=
ls
[:
max_length
-
1
]
ls
=
ls
[:
max_length
-
1
]
ls
.
insert
(
0
,
value
)
if
len
(
ls
)
<
max_length
:
ls
.
extend
([
0
]
*
(
max_length
-
len
(
ls
)))
assert
len
(
ls
)
==
max_length
return
ls
def
encode_line
(
tokenizer
,
line
,
patch
):
line
=
re
.
sub
(
r
'[\u0100-\uFFFF\U00010000-\U0010FFFF]+'
,
''
,
line
)
.
strip
()
line
=
re
.
sub
(
r
"[\u0100-\uFFFF\U00010000-\U0010FFFF]+"
,
""
,
line
)
.
strip
()
tokens
=
tokenizer
.
tokenize
(
line
)
tokens
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
return
(
tokens
,
[
1
]
*
len
(
tokens
),
len
(
tokens
)
*
[
patch
.
value
]
)
return
(
tokens
,
[
1
]
*
len
(
tokens
),
len
(
tokens
)
*
[
patch
.
value
])
def
diff_parse
(
diff
,
tokenizer
):
chunks
=
[]
...
...
@@ -78,6 +78,7 @@ def diff_parse(diff, tokenizer):
chunks
.
append
(
encode_line
(
tokenizer
,
change
.
line
,
PATCH
.
MINUS
))
return
chunks
def
sha_parse
(
sha
,
tokenizer
,
max_length
=
1024
):
chunks
=
diff_parse
(
diff
=
repo
.
git
.
show
(
sha
),
tokenizer
=
tokenizer
)
...
...
@@ -91,16 +92,18 @@ def sha_parse(sha, tokenizer, max_length=1024):
return
(
input_ids
,
attention_masks
,
patch_ids
)
def
message_parse
(
msg
,
tokenizer
,
max_length
=
56
):
msg
=
re
.
sub
(
r
'(\(|)#([0-9])+(\)|)'
,
''
,
msg
)
msg
=
re
.
sub
(
r
"(\(|)#([0-9])+(\)|)"
,
""
,
msg
)
msg
=
re
.
sub
(
r
'[\u0100-\uFFFF\U00010000-\U0010FFFF]+'
,
''
,
msg
)
.
strip
()
msg
=
re
.
sub
(
r
"[\u0100-\uFFFF\U00010000-\U0010FFFF]+"
,
""
,
msg
)
.
strip
()
msg
=
tokenizer
.
tokenize
(
msg
)
msg
=
tokenizer
.
convert_tokens_to_ids
(
msg
)
msg
=
truncate
(
msg
,
max_length
,
value
=
0
)
return
msg
def
jobs
(
sha_msgs
,
args
,
data_config
,
train
=
True
):
input_ids
,
attention_masks
,
patch_ids
,
targets
=
[],
[],
[],
[]
...
...
@@ -110,9 +113,7 @@ def jobs(sha_msgs, args, data_config, train=True):
sha
,
msg
=
sha_msg
source
=
sha_parse
(
sha
,
tokenizer
=
args
.
tokenizer
,
max_length
=
args
.
max_source_length
sha
,
tokenizer
=
args
.
tokenizer
,
max_length
=
args
.
max_source_length
)
if
not
source
:
continue
...
...
@@ -120,7 +121,9 @@ def jobs(sha_msgs, args, data_config, train=True):
target
=
message_parse
(
msg
,
tokenizer
=
args
.
tokenizer
,
max_length
=
(
args
.
max_target_length
if
train
else
args
.
val_max_target_length
),
max_length
=
(
args
.
max_target_length
if
train
else
args
.
val_max_target_length
),
)
input_ids
.
append
(
input_id
)
...
...
@@ -128,14 +131,17 @@ def jobs(sha_msgs, args, data_config, train=True):
patch_ids
.
append
(
patch_id
)
targets
.
append
(
target
)
data_saver
({
"input_ids"
:
np
.
asarray
(
input_ids
),
"attention_masks"
:
np
.
asarray
(
attention_masks
),
"patch_ids"
:
np
.
asarray
(
patch_ids
),
"targets"
:
np
.
asarray
(
targets
),
})
data_saver
(
{
"input_ids"
:
np
.
asarray
(
input_ids
),
"attention_masks"
:
np
.
asarray
(
attention_masks
),
"patch_ids"
:
np
.
asarray
(
patch_ids
),
"targets"
:
np
.
asarray
(
targets
),
}
)
data_saver
.
disconnect
()
def
start
(
chunked_sha_msgs
,
train
=
True
):
logger
.
info
(
f
"Start
%
s pre-processing"
%
(
"training"
if
train
else
"evaluation"
))
...
...
@@ -144,22 +150,22 @@ def start(chunked_sha_msgs, train=True):
data_config
=
DataConfig
(
endpoint
=
args
.
endpoint
,
access_key
=
os
.
environ
[
'access_key'
],
secret_key
=
os
.
environ
[
'secret_key'
],
access_key
=
os
.
environ
[
"access_key"
],
secret_key
=
os
.
environ
[
"secret_key"
],
region
=
args
.
region
,
dataset_name
=
'commit-autosuggestions'
,
dataset_name
=
"commit-autosuggestions"
,
additional
=
{
"mode"
:
(
"training"
if
train
else
"evaluation"
),
"mode"
:
(
"training"
if
train
else
"evaluation"
),
"max_source_length"
:
args
.
max_source_length
,
"max_target_length"
:
max_target_length
,
"url"
:
args
.
url
,
"url"
:
args
.
url
,
},
attributes
=
[
(
'input_ids'
,
'int32'
,
(
args
.
max_source_length
,)),
(
'attention_masks'
,
'int32'
,
(
args
.
max_source_length
,)),
(
'patch_ids'
,
'int32'
,
(
args
.
max_source_length
,)),
(
'targets'
,
'int32'
,
(
max_target_length
,))
]
(
"input_ids"
,
"int32"
,
(
args
.
max_source_length
,)),
(
"attention_masks"
,
"int32"
,
(
args
.
max_source_length
,)),
(
"patch_ids"
,
"int32"
,
(
args
.
max_source_length
,)),
(
"targets"
,
"int32"
,
(
max_target_length
,)),
]
,
)
func
=
partial
(
jobs
,
args
=
args
,
data_config
=
data_config
,
train
=
train
)
...
...
@@ -168,14 +174,15 @@ def start(chunked_sha_msgs, train=True):
for
i
,
_
in
tqdm
(
enumerate
(
pool
.
imap_unordered
(
func
,
chunked_sha_msgs
))):
pbar
.
update
()
def
main
(
args
):
if
'access_key'
not
in
os
.
environ
or
'secret_key'
not
in
os
.
environ
:
if
"access_key"
not
in
os
.
environ
or
"secret_key"
not
in
os
.
environ
:
raise
OSError
(
"access_key or secret_key are not found."
)
sha_msgs
=
[(
c
.
hexsha
,
c
.
summary
)
for
c
in
repo
.
iter_commits
()]
random
.
shuffle
(
sha_msgs
)
chunked_sha_msgs
=
[
sha_msgs
[
x
:
x
+
args
.
matorage_batch
]
sha_msgs
[
x
:
x
+
args
.
matorage_batch
]
for
x
in
range
(
0
,
len
(
sha_msgs
),
args
.
matorage_batch
)
]
...
...
@@ -185,29 +192,25 @@ def main(args):
if
args
.
do_predict
:
start
(
chunked_sha_msgs
[
barrier
:],
train
=
False
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Code to collect commits on github"
)
parser
.
add_argument
(
"--url"
,
type
=
str
,
required
=
True
,
help
=
"github url"
)
parser
.
add_argument
(
"--url"
,
type
=
str
,
required
=
True
,
help
=
"github url"
)
parser
.
add_argument
(
"--endpoint"
,
type
=
str
,
required
=
True
,
help
=
'matorage endpoint, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html'
help
=
"matorage endpoint, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html"
,
)
parser
.
add_argument
(
"--region"
,
type
=
str
,
default
=
None
,
help
=
'matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html'
help
=
"matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html"
,
)
parser
.
add_argument
(
"--tokenizer_name"
,
default
=
'sshleifer/distilbart-xsum-6-6'
,
default
=
"sshleifer/distilbart-xsum-6-6"
,
type
=
str
,
help
=
"Pretrained tokenizer name or path if not the same as model_name"
,
)
...
...
@@ -215,41 +218,40 @@ if __name__ == "__main__":
"--matorage_batch"
,
default
=
1024
,
type
=
int
,
help
=
'The smallest batch size stored atomically in matorage.'
help
=
"The smallest batch size stored atomically in matorage."
,
)
parser
.
add_argument
(
"--num_workers"
,
default
=
4
,
type
=
int
,
help
=
"number of process"
,
"--num_workers"
,
default
=
4
,
type
=
int
,
help
=
"number of process"
,
)
parser
.
add_argument
(
"--max_source_length"
,
default
=
1024
,
type
=
int
,
help
=
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
,
"than this will be truncated, sequences shorter will be padded."
,
)
parser
.
add_argument
(
"--max_target_length"
,
default
=
56
,
type
=
int
,
help
=
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
,
"than this will be truncated, sequences shorter will be padded."
,
)
parser
.
add_argument
(
"--val_max_target_length"
,
default
=
142
,
# these defaults are optimized for CNNDM. For xsum, see README.md.
type
=
int
,
help
=
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
,
"than this will be truncated, sequences shorter will be padded."
,
)
parser
.
add_argument
(
"--p_val"
,
type
=
float
,
default
=
0.25
,
help
=
"percent of validation dataset"
)
parser
.
add_argument
(
"--p_val"
,
type
=
float
,
default
=
0.25
,
help
=
"percent of validation dataset"
)
parser
.
add_argument
(
"--do_train"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--do_predict"
,
action
=
"store_true"
,
default
=
False
)
args
=
parser
.
parse_args
()
args
.
local_path
=
args
.
url
.
split
(
'/'
)[
-
1
]
args
.
local_path
=
args
.
url
.
split
(
"/"
)[
-
1
]
logger
.
info
(
f
"master branch of {args.url} will be downloaded to {args.local_path}"
)
repo
=
(
Repo
(
args
.
local_path
)
...
...
train/__init__.py
View file @
9b9ed4f
# Copyright 2020-present Tae Hwan Jung
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -14,6 +14,4 @@
from
.modeling_bart
import
BartForConditionalGeneration
__all__
=
[
'BartForConditionalGeneration'
]
\ No newline at end of file
__all__
=
[
"BartForConditionalGeneration"
]
...
...
train/callbacks.py
View file @
9b9ed4f
...
...
@@ -20,16 +20,31 @@ logger = logging.getLogger(__name__)
class
Seq2SeqLoggingCallback
(
pl
.
Callback
):
def
on_batch_end
(
self
,
trainer
,
pl_module
):
lrs
=
{
f
"lr_group_{i}"
:
param
[
"lr"
]
for
i
,
param
in
enumerate
(
pl_module
.
trainer
.
optimizers
[
0
]
.
param_groups
)}
lrs
=
{
f
"lr_group_{i}"
:
param
[
"lr"
]
for
i
,
param
in
enumerate
(
pl_module
.
trainer
.
optimizers
[
0
]
.
param_groups
)
}
pl_module
.
logger
.
log_metrics
(
lrs
)
@rank_zero_only
def
_write_logs
(
self
,
trainer
:
pl
.
Trainer
,
pl_module
:
pl
.
LightningModule
,
type_path
:
str
,
save_generations
=
True
self
,
trainer
:
pl
.
Trainer
,
pl_module
:
pl
.
LightningModule
,
type_path
:
str
,
save_generations
=
True
,
)
->
None
:
logger
.
info
(
f
"***** {type_path} results at step {trainer.global_step:05d} *****"
)
logger
.
info
(
f
"***** {type_path} results at step {trainer.global_step:05d} *****"
)
metrics
=
trainer
.
callback_metrics
trainer
.
logger
.
log_metrics
({
k
:
v
for
k
,
v
in
metrics
.
items
()
if
k
not
in
[
"log"
,
"progress_bar"
,
"preds"
]})
trainer
.
logger
.
log_metrics
(
{
k
:
v
for
k
,
v
in
metrics
.
items
()
if
k
not
in
[
"log"
,
"progress_bar"
,
"preds"
]
}
)
# Log results
od
=
Path
(
pl_module
.
hparams
.
output_dir
)
if
type_path
==
"test"
:
...
...
@@ -39,7 +54,9 @@ class Seq2SeqLoggingCallback(pl.Callback):
# this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json
# If people want this it will be easy enough to add back.
results_file
=
od
/
f
"{type_path}_results/{trainer.global_step:05d}.txt"
generations_file
=
od
/
f
"{type_path}_generations/{trainer.global_step:05d}.txt"
generations_file
=
(
od
/
f
"{type_path}_generations/{trainer.global_step:05d}.txt"
)
results_file
.
parent
.
mkdir
(
exist_ok
=
True
)
generations_file
.
parent
.
mkdir
(
exist_ok
=
True
)
with
open
(
results_file
,
"a+"
)
as
writer
:
...
...
@@ -68,7 +85,9 @@ class Seq2SeqLoggingCallback(pl.Callback):
n_trainable_pars
=
count_trainable_parameters
(
pl_module
)
# mp stands for million parameters
trainer
.
logger
.
log_metrics
({
"n_params"
:
npars
,
"mp"
:
npars
/
1e6
,
"grad_mp"
:
n_trainable_pars
/
1e6
})
trainer
.
logger
.
log_metrics
(
{
"n_params"
:
npars
,
"mp"
:
npars
/
1e6
,
"grad_mp"
:
n_trainable_pars
/
1e6
}
)
@rank_zero_only
def
on_test_end
(
self
,
trainer
:
pl
.
Trainer
,
pl_module
:
pl
.
LightningModule
):
...
...
@@ -98,8 +117,5 @@ def get_checkpoint_callback(output_dir, metric):
def
get_early_stopping_callback
(
metric
,
patience
):
return
EarlyStopping
(
monitor
=
f
"val_{metric}"
,
mode
=
"max"
,
patience
=
patience
,
verbose
=
True
,
monitor
=
f
"val_{metric}"
,
mode
=
"max"
,
patience
=
patience
,
verbose
=
True
,
)
...
...
train/finetune.py
View file @
9b9ed4f
...
...
@@ -21,7 +21,11 @@ from matorage.torch import Dataset
try
:
from
.callbacks
import
Seq2SeqLoggingCallback
,
get_checkpoint_callback
,
get_early_stopping_callback
from
.callbacks
import
(
Seq2SeqLoggingCallback
,
get_checkpoint_callback
,
get_early_stopping_callback
,
)
from
.utils
import
(
ROUGE_KEYS
,
LegacySeq2SeqDataset
,
...
...
@@ -40,7 +44,11 @@ try:
use_task_specific_params
,
)
except
ImportError
:
from
callbacks
import
Seq2SeqLoggingCallback
,
get_checkpoint_callback
,
get_early_stopping_callback
from
callbacks
import
(
Seq2SeqLoggingCallback
,
get_checkpoint_callback
,
get_early_stopping_callback
,
)
from
utils
import
(
ROUGE_KEYS
,
LegacySeq2SeqDataset
,
...
...
@@ -83,8 +91,12 @@ class SummarizationModule(BaseTransformer):
"val"
:
self
.
hparams
.
val_max_target_length
,
"test"
:
self
.
hparams
.
test_max_target_length
,
}
assert
self
.
target_lens
[
"train"
]
<=
self
.
target_lens
[
"val"
],
f
"target_lens: {self.target_lens}"
assert
self
.
target_lens
[
"train"
]
<=
self
.
target_lens
[
"test"
],
f
"target_lens: {self.target_lens}"
assert
(
self
.
target_lens
[
"train"
]
<=
self
.
target_lens
[
"val"
]
),
f
"target_lens: {self.target_lens}"
assert
(
self
.
target_lens
[
"train"
]
<=
self
.
target_lens
[
"test"
]
),
f
"target_lens: {self.target_lens}"
if
self
.
hparams
.
freeze_embeds
:
self
.
freeze_embeds
()
...
...
@@ -95,13 +107,27 @@ class SummarizationModule(BaseTransformer):
self
.
hparams
.
git_sha
=
get_git_info
()[
"repo_sha"
]
self
.
num_workers
=
hparams
.
num_workers
self
.
decoder_start_token_id
=
None
# default to config
if
self
.
model
.
config
.
decoder_start_token_id
is
None
and
isinstance
(
self
.
tokenizer
,
MBartTokenizer
):
self
.
decoder_start_token_id
=
self
.
tokenizer
.
lang_code_to_id
[
hparams
.
tgt_lang
]
if
self
.
model
.
config
.
decoder_start_token_id
is
None
and
isinstance
(
self
.
tokenizer
,
MBartTokenizer
):
self
.
decoder_start_token_id
=
self
.
tokenizer
.
lang_code_to_id
[
hparams
.
tgt_lang
]
self
.
model
.
config
.
decoder_start_token_id
=
self
.
decoder_start_token_id
self
.
eval_beams
=
self
.
model
.
config
.
num_beams
if
self
.
hparams
.
eval_beams
is
None
else
self
.
hparams
.
eval_beams
assert
self
.
eval_beams
>=
1
,
f
"got self.eval_beams={self.eval_beams}. Need an integer > 1"
self
.
val_metric
=
self
.
default_val_metric
if
self
.
hparams
.
val_metric
is
None
else
self
.
hparams
.
val_metric
self
.
eval_beams
=
(
self
.
model
.
config
.
num_beams
if
self
.
hparams
.
eval_beams
is
None
else
self
.
hparams
.
eval_beams
)
assert
(
self
.
eval_beams
>=
1
),
f
"got self.eval_beams={self.eval_beams}. Need an integer > 1"
self
.
val_metric
=
(
self
.
default_val_metric
if
self
.
hparams
.
val_metric
is
None
else
self
.
hparams
.
val_metric
)
def
freeze_embeds
(
self
):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
...
...
@@ -133,7 +159,13 @@ class SummarizationModule(BaseTransformer):
else
:
decoder_input_ids
=
shift_tokens_right
(
tgt_ids
,
pad_token_id
)
outputs
=
self
(
src_ids
,
src_patch
,
attention_mask
=
src_mask
,
decoder_input_ids
=
decoder_input_ids
,
use_cache
=
False
)
outputs
=
self
(
src_ids
,
src_patch
,
attention_mask
=
src_mask
,
decoder_input_ids
=
decoder_input_ids
,
use_cache
=
False
,
)
lm_logits
=
outputs
[
0
]
if
self
.
hparams
.
label_smoothing
==
0
:
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
...
...
@@ -157,7 +189,9 @@ class SummarizationModule(BaseTransformer):
logs
=
{
name
:
loss
for
name
,
loss
in
zip
(
self
.
loss_names
,
loss_tensors
)}
# tokens per batch
logs
[
"tpb"
]
=
batch
[
0
]
.
long
()
.
ne
(
self
.
pad
)
.
sum
()
+
batch
[
3
]
.
long
()
.
ne
(
self
.
pad
)
.
sum
()
logs
[
"tpb"
]
=
(
batch
[
0
]
.
long
()
.
ne
(
self
.
pad
)
.
sum
()
+
batch
[
3
]
.
long
()
.
ne
(
self
.
pad
)
.
sum
()
)
return
{
"loss"
:
loss_tensors
[
0
],
"log"
:
logs
}
def
validation_step
(
self
,
batch
,
batch_idx
)
->
Dict
:
...
...
@@ -165,17 +199,29 @@ class SummarizationModule(BaseTransformer):
def
validation_epoch_end
(
self
,
outputs
,
prefix
=
"val"
)
->
Dict
:
self
.
step_count
+=
1
losses
=
{
k
:
torch
.
stack
([
x
[
k
]
for
x
in
outputs
])
.
mean
()
for
k
in
self
.
loss_names
}
losses
=
{
k
:
torch
.
stack
([
x
[
k
]
for
x
in
outputs
])
.
mean
()
for
k
in
self
.
loss_names
}
loss
=
losses
[
"loss"
]
rouges
=
{
k
:
np
.
array
([
x
[
k
]
for
x
in
outputs
])
.
mean
()
for
k
in
self
.
metric_names
+
[
"gen_time"
,
"gen_len"
]}
rouge_tensor
:
torch
.
FloatTensor
=
torch
.
tensor
(
rouges
[
self
.
val_metric
])
.
type_as
(
loss
)
rouges
=
{
k
:
np
.
array
([
x
[
k
]
for
x
in
outputs
])
.
mean
()
for
k
in
self
.
metric_names
+
[
"gen_time"
,
"gen_len"
]
}
rouge_tensor
:
torch
.
FloatTensor
=
torch
.
tensor
(
rouges
[
self
.
val_metric
])
.
type_as
(
loss
)
rouges
.
update
({
k
:
v
.
item
()
for
k
,
v
in
losses
.
items
()})
losses
.
update
(
rouges
)
metrics
=
{
f
"{prefix}_avg_{k}"
:
x
for
k
,
x
in
losses
.
items
()}
metrics
[
"step_count"
]
=
self
.
step_count
self
.
save_metrics
(
metrics
,
prefix
)
# writes to self.metrics_save_path
preds
=
flatten_list
([
x
[
"preds"
]
for
x
in
outputs
])
return
{
"log"
:
metrics
,
"preds"
:
preds
,
f
"{prefix}_loss"
:
loss
,
f
"{prefix}_{self.val_metric}"
:
rouge_tensor
}
return
{
"log"
:
metrics
,
"preds"
:
preds
,
f
"{prefix}_loss"
:
loss
,
f
"{prefix}_{self.val_metric}"
:
rouge_tensor
,
}
def
save_metrics
(
self
,
latest_metrics
,
type_path
)
->
None
:
self
.
metrics
[
type_path
]
.
append
(
latest_metrics
)
...
...
@@ -200,7 +246,9 @@ class SummarizationModule(BaseTransformer):
base_metrics
=
{
name
:
loss
for
name
,
loss
in
zip
(
self
.
loss_names
,
loss_tensors
)}
rouge
:
Dict
=
self
.
calc_generative_metrics
(
preds
,
target
)
summ_len
=
np
.
mean
(
lmap
(
len
,
generated_ids
))
base_metrics
.
update
(
gen_time
=
gen_time
,
gen_len
=
summ_len
,
preds
=
preds
,
target
=
target
,
**
rouge
)
base_metrics
.
update
(
gen_time
=
gen_time
,
gen_len
=
summ_len
,
preds
=
preds
,
target
=
target
,
**
rouge
)
return
base_metrics
def
test_step
(
self
,
batch
,
batch_idx
):
...
...
@@ -213,10 +261,10 @@ class SummarizationModule(BaseTransformer):
max_target_length
=
self
.
target_lens
[
type_path
]
data_config
=
DataConfig
(
endpoint
=
args
.
endpoint
,
access_key
=
os
.
environ
[
'access_key'
],
secret_key
=
os
.
environ
[
'secret_key'
],
access_key
=
os
.
environ
[
"access_key"
],
secret_key
=
os
.
environ
[
"secret_key"
],
region
=
args
.
region
,
dataset_name
=
'commit-autosuggestions'
,
dataset_name
=
"commit-autosuggestions"
,
additional
=
{
"mode"
:
(
"training"
if
type_path
==
"train"
else
"evaluation"
),
"max_source_length"
:
self
.
hparams
.
max_source_length
,
...
...
@@ -224,15 +272,17 @@ class SummarizationModule(BaseTransformer):
"url"
:
args
.
url
,
},
attributes
=
[
(
'input_ids'
,
'int32'
,
(
self
.
hparams
.
max_source_length
,)),
(
'attention_masks'
,
'int32'
,
(
self
.
hparams
.
max_source_length
,)),
(
'patch_ids'
,
'int32'
,
(
self
.
hparams
.
max_source_length
,)),
(
'targets'
,
'int32'
,
(
max_target_length
,))
]
(
"input_ids"
,
"int32"
,
(
self
.
hparams
.
max_source_length
,)),
(
"attention_masks"
,
"int32"
,
(
self
.
hparams
.
max_source_length
,)),
(
"patch_ids"
,
"int32"
,
(
self
.
hparams
.
max_source_length
,)),
(
"targets"
,
"int32"
,
(
max_target_length
,)),
]
,
)
return
Dataset
(
config
=
data_config
,
clear
=
True
)
def
get_dataloader
(
self
,
type_path
:
str
,
batch_size
:
int
,
shuffle
:
bool
=
False
)
->
DataLoader
:
def
get_dataloader
(
self
,
type_path
:
str
,
batch_size
:
int
,
shuffle
:
bool
=
False
)
->
DataLoader
:
dataset
=
self
.
get_dataset
(
type_path
)
sampler
=
None
...
...
@@ -246,7 +296,9 @@ class SummarizationModule(BaseTransformer):
return
dataloader
def
train_dataloader
(
self
)
->
DataLoader
:
dataloader
=
self
.
get_dataloader
(
"train"
,
batch_size
=
self
.
hparams
.
train_batch_size
,
shuffle
=
True
)
dataloader
=
self
.
get_dataloader
(
"train"
,
batch_size
=
self
.
hparams
.
train_batch_size
,
shuffle
=
True
)
return
dataloader
def
val_dataloader
(
self
)
->
DataLoader
:
...
...
@@ -259,23 +311,18 @@ class SummarizationModule(BaseTransformer):
def
add_model_specific_args
(
parser
,
root_dir
):
BaseTransformer
.
add_model_specific_args
(
parser
,
root_dir
)
add_generic_args
(
parser
,
root_dir
)
parser
.
add_argument
(
"--url"
,
type
=
str
,
required
=
True
,
help
=
"github url"
)
parser
.
add_argument
(
"--url"
,
type
=
str
,
required
=
True
,
help
=
"github url"
)
parser
.
add_argument
(
"--endpoint"
,
type
=
str
,
required
=
True
,
help
=
'matorage endpoint, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html'
help
=
"matorage endpoint, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html"
,
)
parser
.
add_argument
(
"--region"
,
type
=
str
,
default
=
None
,
help
=
'matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html'
help
=
"matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html"
,
)
parser
.
add_argument
(
"--max_source_length"
,
...
...
@@ -308,14 +355,43 @@ class SummarizationModule(BaseTransformer):
parser
.
add_argument
(
"--freeze_encoder"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--freeze_embeds"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--sortish_sampler"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--logger_name"
,
type
=
str
,
choices
=
[
"default"
,
"wandb"
,
"wandb_shared"
],
default
=
"default"
)
parser
.
add_argument
(
"--n_train"
,
type
=
int
,
default
=-
1
,
required
=
False
,
help
=
"# examples. -1 means use all."
)
parser
.
add_argument
(
"--n_val"
,
type
=
int
,
default
=
500
,
required
=
False
,
help
=
"# examples. -1 means use all."
)
parser
.
add_argument
(
"--n_test"
,
type
=
int
,
default
=-
1
,
required
=
False
,
help
=
"# examples. -1 means use all."
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
default
=
"summarization"
,
required
=
False
,
help
=
"# examples. -1 means use all."
"--logger_name"
,
type
=
str
,
choices
=
[
"default"
,
"wandb"
,
"wandb_shared"
],
default
=
"default"
,
)
parser
.
add_argument
(
"--n_train"
,
type
=
int
,
default
=-
1
,
required
=
False
,
help
=
"# examples. -1 means use all."
,
)
parser
.
add_argument
(
"--n_val"
,
type
=
int
,
default
=
500
,
required
=
False
,
help
=
"# examples. -1 means use all."
,
)
parser
.
add_argument
(
"--n_test"
,
type
=
int
,
default
=-
1
,
required
=
False
,
help
=
"# examples. -1 means use all."
,
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
default
=
"summarization"
,
required
=
False
,
help
=
"# examples. -1 means use all."
,
)
parser
.
add_argument
(
"--label_smoothing"
,
type
=
float
,
default
=
0.0
,
required
=
False
)
parser
.
add_argument
(
"--label_smoothing"
,
type
=
float
,
default
=
0.0
,
required
=
False
)
parser
.
add_argument
(
"--src_lang"
,
type
=
str
,
default
=
""
,
required
=
False
)
parser
.
add_argument
(
"--tgt_lang"
,
type
=
str
,
default
=
""
,
required
=
False
)
parser
.
add_argument
(
"--eval_beams"
,
type
=
int
,
default
=
None
,
required
=
False
)
...
...
@@ -348,7 +424,11 @@ class TranslationModule(SummarizationModule):
def
main
(
args
,
model
=
None
)
->
SummarizationModule
:
Path
(
args
.
output_dir
)
.
mkdir
(
exist_ok
=
True
)
if
len
(
os
.
listdir
(
args
.
output_dir
))
>
3
and
args
.
do_train
:
raise
ValueError
(
"Output directory ({}) already exists and is not empty."
.
format
(
args
.
output_dir
))
raise
ValueError
(
"Output directory ({}) already exists and is not empty."
.
format
(
args
.
output_dir
)
)
if
model
is
None
:
if
args
.
task
==
"summarization"
:
model
:
SummarizationModule
=
SummarizationModule
(
args
)
...
...
@@ -371,7 +451,9 @@ def main(args, model=None) -> SummarizationModule:
return
model
model
.
hparams
.
test_checkpoint
=
""
checkpoints
=
list
(
sorted
(
glob
.
glob
(
os
.
path
.
join
(
args
.
output_dir
,
"*.ckpt"
),
recursive
=
True
)))
checkpoints
=
list
(
sorted
(
glob
.
glob
(
os
.
path
.
join
(
args
.
output_dir
,
"*.ckpt"
),
recursive
=
True
))
)
if
checkpoints
:
model
.
hparams
.
test_checkpoint
=
checkpoints
[
-
1
]
trainer
.
resume_from_checkpoint
=
checkpoints
[
-
1
]
...
...
train/generation_utils.py
View file @
9b9ed4f
...
...
@@ -30,6 +30,7 @@ logging.basicConfig(
level
=
logging
.
INFO
,
)
class
GenerationMixin
:
"""
A class contraining all of the functions supporting generation, to be used as a mixin in
...
...
@@ -50,7 +51,9 @@ class GenerationMixin:
"""
return
logits
def
enforce_repetition_penalty_
(
self
,
lprobs
,
batch_size
,
num_beams
,
prev_output_tokens
,
repetition_penalty
):
def
enforce_repetition_penalty_
(
self
,
lprobs
,
batch_size
,
num_beams
,
prev_output_tokens
,
repetition_penalty
):
"""
Enforce the repetition penalty (from the `CTRL paper <https://arxiv.org/abs/1909.05858>`__).
"""
...
...
@@ -79,11 +82,7 @@ class GenerationMixin:
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if
repetition_penalty
!=
1.0
:
self
.
enforce_repetition_penalty_
(
scores
,
batch_size
,
num_beams
,
input_ids
,
repetition_penalty
,
scores
,
batch_size
,
num_beams
,
input_ids
,
repetition_penalty
,
)
# set eos token prob to zero if min_length is not reached
...
...
@@ -102,7 +101,11 @@ class GenerationMixin:
if
bad_words_ids
is
not
None
:
# Exclude EOS token (already processed)
bad_words_ids
=
list
(
filter
(
lambda
bad_token_seq
:
bad_token_seq
!=
[
eos_token_id
],
bad_words_ids
))
bad_words_ids
=
list
(
filter
(
lambda
bad_token_seq
:
bad_token_seq
!=
[
eos_token_id
],
bad_words_ids
)
)
# calculate a list of banned tokens according to bad words
banned_tokens
=
calc_banned_bad_words_ids
(
input_ids
.
tolist
(),
bad_words_ids
)
# Modify the scores in place by setting the banned tokens logits to `-inf`
...
...
@@ -134,7 +137,7 @@ class GenerationMixin:
attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
decoder_start_token_id
:
Optional
[
int
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
**
model_kwargs
**
model_kwargs
,
)
->
torch
.
LongTensor
:
r"""
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
...
...
@@ -262,26 +265,50 @@ class GenerationMixin:
max_length
=
max_length
if
max_length
is
not
None
else
self
.
config
.
max_length
min_length
=
min_length
if
min_length
is
not
None
else
self
.
config
.
min_length
do_sample
=
do_sample
if
do_sample
is
not
None
else
self
.
config
.
do_sample
early_stopping
=
early_stopping
if
early_stopping
is
not
None
else
self
.
config
.
early_stopping
early_stopping
=
(
early_stopping
if
early_stopping
is
not
None
else
self
.
config
.
early_stopping
)
use_cache
=
use_cache
if
use_cache
is
not
None
else
self
.
config
.
use_cache
num_beams
=
num_beams
if
num_beams
is
not
None
else
self
.
config
.
num_beams
temperature
=
temperature
if
temperature
is
not
None
else
self
.
config
.
temperature
temperature
=
(
temperature
if
temperature
is
not
None
else
self
.
config
.
temperature
)
top_k
=
top_k
if
top_k
is
not
None
else
self
.
config
.
top_k
top_p
=
top_p
if
top_p
is
not
None
else
self
.
config
.
top_p
repetition_penalty
=
repetition_penalty
if
repetition_penalty
is
not
None
else
self
.
config
.
repetition_penalty
bos_token_id
=
bos_token_id
if
bos_token_id
is
not
None
else
self
.
config
.
bos_token_id
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
config
.
pad_token_id
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
config
.
eos_token_id
length_penalty
=
length_penalty
if
length_penalty
is
not
None
else
self
.
config
.
length_penalty
repetition_penalty
=
(
repetition_penalty
if
repetition_penalty
is
not
None
else
self
.
config
.
repetition_penalty
)
bos_token_id
=
(
bos_token_id
if
bos_token_id
is
not
None
else
self
.
config
.
bos_token_id
)
pad_token_id
=
(
pad_token_id
if
pad_token_id
is
not
None
else
self
.
config
.
pad_token_id
)
eos_token_id
=
(
eos_token_id
if
eos_token_id
is
not
None
else
self
.
config
.
eos_token_id
)
length_penalty
=
(
length_penalty
if
length_penalty
is
not
None
else
self
.
config
.
length_penalty
)
no_repeat_ngram_size
=
(
no_repeat_ngram_size
if
no_repeat_ngram_size
is
not
None
else
self
.
config
.
no_repeat_ngram_size
no_repeat_ngram_size
if
no_repeat_ngram_size
is
not
None
else
self
.
config
.
no_repeat_ngram_size
)
bad_words_ids
=
(
bad_words_ids
if
bad_words_ids
is
not
None
else
self
.
config
.
bad_words_ids
)
bad_words_ids
=
bad_words_ids
if
bad_words_ids
is
not
None
else
self
.
config
.
bad_words_ids
num_return_sequences
=
(
num_return_sequences
if
num_return_sequences
is
not
None
else
self
.
config
.
num_return_sequences
num_return_sequences
if
num_return_sequences
is
not
None
else
self
.
config
.
num_return_sequences
)
decoder_start_token_id
=
(
decoder_start_token_id
if
decoder_start_token_id
is
not
None
else
self
.
config
.
decoder_start_token_id
decoder_start_token_id
if
decoder_start_token_id
is
not
None
else
self
.
config
.
decoder_start_token_id
)
if
input_ids
is
not
None
:
...
...
@@ -289,14 +316,22 @@ class GenerationMixin:
else
:
batch_size
=
1
assert
isinstance
(
max_length
,
int
)
and
max_length
>
0
,
"`max_length` should be a strictly positive integer."
assert
isinstance
(
min_length
,
int
)
and
min_length
>=
0
,
"`min_length` should be a positive integer."
assert
(
isinstance
(
max_length
,
int
)
and
max_length
>
0
),
"`max_length` should be a strictly positive integer."
assert
(
isinstance
(
min_length
,
int
)
and
min_length
>=
0
),
"`min_length` should be a positive integer."
assert
isinstance
(
do_sample
,
bool
),
"`do_sample` should be a boolean."
assert
isinstance
(
early_stopping
,
bool
),
"`early_stopping` should be a boolean."
assert
isinstance
(
use_cache
,
bool
),
"`use_cache` should be a boolean."
assert
isinstance
(
num_beams
,
int
)
and
num_beams
>
0
,
"`num_beams` should be a strictly positive integer."
assert
(
isinstance
(
num_beams
,
int
)
and
num_beams
>
0
),
"`num_beams` should be a strictly positive integer."
assert
temperature
>
0
,
"`temperature` should be strictly positive."
assert
isinstance
(
top_k
,
int
)
and
top_k
>=
0
,
"`top_k` should be a positive integer."
assert
(
isinstance
(
top_k
,
int
)
and
top_k
>=
0
),
"`top_k` should be a positive integer."
assert
0
<=
top_p
<=
1
,
"`top_p` should be between 0 and 1."
assert
repetition_penalty
>=
1.0
,
"`repetition_penalty` should be >= 1."
assert
input_ids
is
not
None
or
(
...
...
@@ -316,7 +351,9 @@ class GenerationMixin:
isinstance
(
num_return_sequences
,
int
)
and
num_return_sequences
>
0
),
"`num_return_sequences` should be a strictly positive integer."
assert
(
bad_words_ids
is
None
or
isinstance
(
bad_words_ids
,
list
)
and
isinstance
(
bad_words_ids
[
0
],
list
)
bad_words_ids
is
None
or
isinstance
(
bad_words_ids
,
list
)
and
isinstance
(
bad_words_ids
[
0
],
list
)
),
"`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
if
input_ids
is
None
:
...
...
@@ -331,7 +368,9 @@ class GenerationMixin:
device
=
next
(
self
.
parameters
())
.
device
,
)
else
:
assert
input_ids
.
dim
()
==
2
,
"Input prompt should be of shape (batch_size, sequence length)."
assert
(
input_ids
.
dim
()
==
2
),
"Input prompt should be of shape (batch_size, sequence length)."
# not allow to duplicate outputs when greedy decoding
if
do_sample
is
False
:
...
...
@@ -349,7 +388,11 @@ class GenerationMixin:
# create attention mask if necessary
# TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
if
(
attention_mask
is
None
)
and
(
pad_token_id
is
not
None
)
and
(
pad_token_id
in
input_ids
):
if
(
(
attention_mask
is
None
)
and
(
pad_token_id
is
not
None
)
and
(
pad_token_id
in
input_ids
)
):
attention_mask
=
input_ids
.
ne
(
pad_token_id
)
.
long
()
elif
attention_mask
is
None
:
attention_mask
=
input_ids
.
new_ones
(
input_ids
.
shape
)
...
...
@@ -358,7 +401,9 @@ class GenerationMixin:
# attention_mask is created
if
pad_token_id
is
None
and
eos_token_id
is
not
None
:
logger
.
warning
(
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence"
.
format
(
eos_token_id
)
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence"
.
format
(
eos_token_id
)
)
pad_token_id
=
eos_token_id
...
...
@@ -385,25 +430,37 @@ class GenerationMixin:
# see if BOS token can be used for decoder_start_token_id
if
bos_token_id
is
not
None
:
decoder_start_token_id
=
bos_token_id
elif
hasattr
(
self
.
config
,
"decoder"
)
and
hasattr
(
self
.
config
.
decoder
,
"bos_token_id"
):
elif
hasattr
(
self
.
config
,
"decoder"
)
and
hasattr
(
self
.
config
.
decoder
,
"bos_token_id"
):
decoder_start_token_id
=
self
.
config
.
decoder
.
bos_token_id
else
:
raise
ValueError
(
"decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
)
assert
hasattr
(
self
,
"get_encoder"
),
"{} should have a 'get_encoder' function defined"
.
format
(
self
)
assert
callable
(
self
.
get_encoder
),
"{} should be a method"
.
format
(
self
.
get_encoder
)
assert
hasattr
(
self
,
"get_encoder"
),
"{} should have a 'get_encoder' function defined"
.
format
(
self
)
assert
callable
(
self
.
get_encoder
),
"{} should be a method"
.
format
(
self
.
get_encoder
)
# get encoder and store encoder outputs
encoder
=
self
.
get_encoder
()
encoder_outputs
:
ModelOutput
=
encoder
(
input_ids
,
patch_ids
,
attention_mask
=
attention_mask
,
return_dict
=
True
)
encoder_outputs
:
ModelOutput
=
encoder
(
input_ids
,
patch_ids
,
attention_mask
=
attention_mask
,
return_dict
=
True
)
# Expand input ids if num_beams > 1 or num_return_sequences > 1
if
num_return_sequences
>
1
or
num_beams
>
1
:
input_ids_len
=
input_ids
.
shape
[
-
1
]
input_ids
=
input_ids
.
unsqueeze
(
1
)
.
expand
(
batch_size
,
effective_batch_mult
*
num_beams
,
input_ids_len
)
patch_ids
=
patch_ids
.
unsqueeze
(
1
)
.
expand
(
batch_size
,
effective_batch_mult
*
num_beams
,
input_ids_len
)
input_ids
=
input_ids
.
unsqueeze
(
1
)
.
expand
(
batch_size
,
effective_batch_mult
*
num_beams
,
input_ids_len
)
patch_ids
=
patch_ids
.
unsqueeze
(
1
)
.
expand
(
batch_size
,
effective_batch_mult
*
num_beams
,
input_ids_len
)
attention_mask
=
attention_mask
.
unsqueeze
(
1
)
.
expand
(
batch_size
,
effective_batch_mult
*
num_beams
,
input_ids_len
)
...
...
@@ -442,9 +499,9 @@ class GenerationMixin:
)
# expand encoder_outputs
encoder_outputs
[
"last_hidden_state"
]
=
encoder_outputs
.
last_hidden_state
.
index_select
(
0
,
expanded_batch_idxs
)
encoder_outputs
[
"last_hidden_state"
]
=
encoder_outputs
.
last_hidden_state
.
index_select
(
0
,
expanded_batch_idxs
)
# save encoder_outputs in `model_kwargs`
model_kwargs
[
"encoder_outputs"
]
=
encoder_outputs
...
...
@@ -534,7 +591,11 @@ class GenerationMixin:
past
=
None
while
cur_len
<
max_length
:
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
past
=
past
,
attention_mask
=
attention_mask
,
use_cache
=
use_cache
,
**
model_kwargs
input_ids
,
past
=
past
,
attention_mask
=
attention_mask
,
use_cache
=
use_cache
,
**
model_kwargs
,
)
outputs
=
self
(
**
model_inputs
,
return_dict
=
True
)
...
...
@@ -565,7 +626,9 @@ class GenerationMixin:
if
temperature
!=
1.0
:
scores
=
scores
/
temperature
# Top-p/top-k filtering
next_token_logscores
=
top_k_top_p_filtering
(
scores
,
top_k
=
top_k
,
top_p
=
top_p
)
next_token_logscores
=
top_k_top_p_filtering
(
scores
,
top_k
=
top_k
,
top_p
=
top_p
)
# Sample
probs
=
F
.
softmax
(
next_token_logscores
,
dim
=-
1
)
next_token
=
torch
.
multinomial
(
probs
,
num_samples
=
1
)
.
squeeze
(
1
)
...
...
@@ -576,7 +639,9 @@ class GenerationMixin:
# update generations and finished sentences
if
eos_token_id
is
not
None
:
# pad finished sentences if eos_token_id exist
tokens_to_add
=
next_token
*
unfinished_sents
+
(
pad_token_id
)
*
(
1
-
unfinished_sents
)
tokens_to_add
=
next_token
*
unfinished_sents
+
(
pad_token_id
)
*
(
1
-
unfinished_sents
)
else
:
tokens_to_add
=
next_token
...
...
@@ -587,8 +652,12 @@ class GenerationMixin:
if
eos_token_id
is
not
None
:
eos_in_sents
=
tokens_to_add
==
eos_token_id
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
is_sents_unfinished_and_token_to_add_is_eos
=
unfinished_sents
.
mul
(
eos_in_sents
.
long
())
.
bool
()
sent_lengths
.
masked_fill_
(
is_sents_unfinished_and_token_to_add_is_eos
,
cur_len
)
is_sents_unfinished_and_token_to_add_is_eos
=
unfinished_sents
.
mul
(
eos_in_sents
.
long
()
)
.
bool
()
sent_lengths
.
masked_fill_
(
is_sents_unfinished_and_token_to_add_is_eos
,
cur_len
)
# unfinished_sents is set to zero if eos in sentence
unfinished_sents
.
mul_
((
~
eos_in_sents
)
.
long
())
...
...
@@ -599,7 +668,11 @@ class GenerationMixin:
# extend attention_mask for new generated input if only decoder
if
self
.
config
.
is_encoder_decoder
is
False
:
attention_mask
=
torch
.
cat
(
[
attention_mask
,
attention_mask
.
new_ones
((
attention_mask
.
shape
[
0
],
1
))],
dim
=-
1
[
attention_mask
,
attention_mask
.
new_ones
((
attention_mask
.
shape
[
0
],
1
)),
],
dim
=-
1
,
)
return
input_ids
...
...
@@ -633,12 +706,16 @@ class GenerationMixin:
# generated hypotheses
generated_hyps
=
[
BeamHypotheses
(
num_beams
,
max_length
,
length_penalty
,
early_stopping
=
early_stopping
)
BeamHypotheses
(
num_beams
,
max_length
,
length_penalty
,
early_stopping
=
early_stopping
)
for
_
in
range
(
batch_size
)
]
# scores for each sentence in the beam
beam_scores
=
torch
.
zeros
((
batch_size
,
num_beams
),
dtype
=
torch
.
float
,
device
=
input_ids
.
device
)
beam_scores
=
torch
.
zeros
(
(
batch_size
,
num_beams
),
dtype
=
torch
.
float
,
device
=
input_ids
.
device
)
# for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
if
do_sample
is
False
:
...
...
@@ -653,10 +730,18 @@ class GenerationMixin:
while
cur_len
<
max_length
:
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
past
=
past
,
attention_mask
=
attention_mask
,
use_cache
=
use_cache
,
**
model_kwargs
input_ids
,
past
=
past
,
attention_mask
=
attention_mask
,
use_cache
=
use_cache
,
**
model_kwargs
,
)
outputs
=
self
(
**
model_inputs
,
return_dict
=
True
)
# (batch_size * num_beams, cur_len, vocab_size)
next_token_logits
=
outputs
.
logits
[:,
-
1
,
:]
# (batch_size * num_beams, vocab_size)
outputs
=
self
(
**
model_inputs
,
return_dict
=
True
)
# (batch_size * num_beams, cur_len, vocab_size)
next_token_logits
=
outputs
.
logits
[
:,
-
1
,
:
]
# (batch_size * num_beams, vocab_size)
# if model has past, then set the past variable to speed up decoding
if
"past_key_values"
in
outputs
:
...
...
@@ -670,7 +755,9 @@ class GenerationMixin:
next_token_logits
,
cur_len
=
cur_len
,
max_length
=
max_length
)
scores
=
F
.
log_softmax
(
next_token_logits
,
dim
=-
1
)
# (batch_size * num_beams, vocab_size)
scores
=
F
.
log_softmax
(
next_token_logits
,
dim
=-
1
)
# (batch_size * num_beams, vocab_size)
scores
=
self
.
postprocess_next_token_scores
(
scores
=
scores
,
...
...
@@ -686,12 +773,17 @@ class GenerationMixin:
num_beams
=
num_beams
,
)
assert
scores
.
shape
==
(
batch_size
*
num_beams
,
vocab_size
),
"Shapes of scores: {} != {}"
.
format
(
assert
scores
.
shape
==
(
batch_size
*
num_beams
,
vocab_size
,
),
"Shapes of scores: {} != {}"
.
format
(
scores
.
shape
,
(
batch_size
*
num_beams
,
vocab_size
)
)
if
do_sample
:
_scores
=
scores
+
beam_scores
[:,
None
]
.
expand_as
(
scores
)
# (batch_size * num_beams, vocab_size)
_scores
=
scores
+
beam_scores
[:,
None
]
.
expand_as
(
scores
)
# (batch_size * num_beams, vocab_size)
# Temperature
if
temperature
!=
1.0
:
_scores
=
_scores
/
temperature
...
...
@@ -706,24 +798,38 @@ class GenerationMixin:
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
probs
=
F
.
softmax
(
_scores
,
dim
=-
1
)
next_tokens
=
torch
.
multinomial
(
probs
,
num_samples
=
2
*
num_beams
)
# (batch_size, num_beams * 2)
next_tokens
=
torch
.
multinomial
(
probs
,
num_samples
=
2
*
num_beams
)
# (batch_size, num_beams * 2)
# Compute next scores
next_scores
=
torch
.
gather
(
_scores
,
-
1
,
next_tokens
)
# (batch_size, num_beams * 2)
next_scores
=
torch
.
gather
(
_scores
,
-
1
,
next_tokens
)
# (batch_size, num_beams * 2)
# sort the sampled vector to make sure that the first num_beams samples are the best
next_scores
,
next_scores_indices
=
torch
.
sort
(
next_scores
,
descending
=
True
,
dim
=
1
)
next_tokens
=
torch
.
gather
(
next_tokens
,
-
1
,
next_scores_indices
)
# (batch_size, num_beams * 2)
next_scores
,
next_scores_indices
=
torch
.
sort
(
next_scores
,
descending
=
True
,
dim
=
1
)
next_tokens
=
torch
.
gather
(
next_tokens
,
-
1
,
next_scores_indices
)
# (batch_size, num_beams * 2)
else
:
next_scores
=
scores
+
beam_scores
[:,
None
]
.
expand_as
(
scores
)
# (batch_size * num_beams, vocab_size)
next_scores
=
scores
+
beam_scores
[:,
None
]
.
expand_as
(
scores
)
# (batch_size * num_beams, vocab_size)
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
next_scores
=
next_scores
.
view
(
batch_size
,
num_beams
*
vocab_size
)
# (batch_size, num_beams * vocab_size)
next_scores
,
next_tokens
=
torch
.
topk
(
next_scores
,
2
*
num_beams
,
dim
=
1
,
largest
=
True
,
sorted
=
True
)
next_scores
,
next_tokens
=
torch
.
topk
(
next_scores
,
2
*
num_beams
,
dim
=
1
,
largest
=
True
,
sorted
=
True
)
assert
next_scores
.
size
()
==
next_tokens
.
size
()
==
(
batch_size
,
2
*
num_beams
)
assert
(
next_scores
.
size
()
==
next_tokens
.
size
()
==
(
batch_size
,
2
*
num_beams
)
)
# next batch beam content
next_batch_beam
=
[]
...
...
@@ -735,11 +841,15 @@ class GenerationMixin:
if
done
[
batch_idx
]:
assert
(
len
(
generated_hyps
[
batch_idx
])
>=
num_beams
),
"Batch can only be done if at least {} beams have been generated"
.
format
(
num_beams
)
),
"Batch can only be done if at least {} beams have been generated"
.
format
(
num_beams
)
assert
(
eos_token_id
is
not
None
and
pad_token_id
is
not
None
),
"generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
next_batch_beam
.
extend
([(
0
,
pad_token_id
,
0
)]
*
num_beams
)
# pad the batch
next_batch_beam
.
extend
(
[(
0
,
pad_token_id
,
0
)]
*
num_beams
)
# pad the batch
continue
# next sentence beam content, this will get added to next_batch_beam
...
...
@@ -757,7 +867,9 @@ class GenerationMixin:
# add to generated hypotheses if end of sentence
if
(
eos_token_id
is
not
None
)
and
(
token_id
.
item
()
==
eos_token_id
):
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams
=
beam_token_rank
>=
num_beams
is_beam_token_worse_than_top_num_beams
=
(
beam_token_rank
>=
num_beams
)
if
is_beam_token_worse_than_top_num_beams
:
continue
generated_hyps
[
batch_idx
]
.
add
(
...
...
@@ -766,7 +878,9 @@ class GenerationMixin:
)
else
:
# add next predicted token since it is not eos_token
next_sent_beam
.
append
((
beam_token_score
,
token_id
,
effective_beam_id
))
next_sent_beam
.
append
(
(
beam_token_score
,
token_id
,
effective_beam_id
)
)
# once the beam for next step is full, don't add more tokens to it.
if
len
(
next_sent_beam
)
==
num_beams
:
...
...
@@ -780,7 +894,9 @@ class GenerationMixin:
# update next beam content
assert
len
(
next_sent_beam
)
==
num_beams
,
"Beam should always be full"
next_batch_beam
.
extend
(
next_sent_beam
)
assert
len
(
next_batch_beam
)
==
num_beams
*
(
batch_idx
+
1
),
"We should have added num_beams each step"
assert
len
(
next_batch_beam
)
==
num_beams
*
(
batch_idx
+
1
),
"We should have added num_beams each step"
# stop when we are done with each sentence
if
all
(
done
):
...
...
@@ -804,7 +920,11 @@ class GenerationMixin:
# extend attention_mask for new generated input if only decoder
if
self
.
config
.
is_encoder_decoder
is
False
:
attention_mask
=
torch
.
cat
(
[
attention_mask
,
attention_mask
.
new_ones
((
attention_mask
.
shape
[
0
],
1
))],
dim
=-
1
[
attention_mask
,
attention_mask
.
new_ones
((
attention_mask
.
shape
[
0
],
1
)),
],
dim
=-
1
,
)
# finalize all open beam hypotheses and add to generated hypotheses
...
...
@@ -814,10 +934,12 @@ class GenerationMixin:
# test that beam scores match previously calculated scores if not eos and batch_idx not done
if
eos_token_id
is
not
None
and
all
(
(
token_id
%
vocab_size
)
.
item
()
!=
eos_token_id
for
token_id
in
next_tokens
[
batch_idx
]
(
token_id
%
vocab_size
)
.
item
()
!=
eos_token_id
for
token_id
in
next_tokens
[
batch_idx
]
):
assert
torch
.
all
(
next_scores
[
batch_idx
,
:
num_beams
]
==
beam_scores
.
view
(
batch_size
,
num_beams
)[
batch_idx
]
next_scores
[
batch_idx
,
:
num_beams
]
==
beam_scores
.
view
(
batch_size
,
num_beams
)[
batch_idx
]
),
"If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}"
.
format
(
next_scores
[:,
:
num_beams
][
batch_idx
],
beam_scores
.
view
(
batch_size
,
num_beams
)[
batch_idx
],
...
...
@@ -831,7 +953,9 @@ class GenerationMixin:
generated_hyps
[
batch_idx
]
.
add
(
final_tokens
,
final_score
)
# depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
output_batch_size
=
batch_size
if
do_sample
else
batch_size
*
num_return_sequences
output_batch_size
=
(
batch_size
if
do_sample
else
batch_size
*
num_return_sequences
)
output_num_return_sequences_per_batch
=
1
if
do_sample
else
num_return_sequences
# select the best hypotheses
...
...
@@ -861,7 +985,9 @@ class GenerationMixin:
else
:
# none of the hypotheses have an eos_token
assert
(
len
(
hypo
)
==
max_length
for
hypo
in
best
)
decoded
=
torch
.
stack
(
best
)
.
type
(
torch
.
long
)
.
to
(
next
(
self
.
parameters
())
.
device
)
decoded
=
(
torch
.
stack
(
best
)
.
type
(
torch
.
long
)
.
to
(
next
(
self
.
parameters
())
.
device
)
)
return
decoded
...
...
@@ -870,7 +996,9 @@ class GenerationMixin:
return
tuple
(
layer_past
.
index_select
(
1
,
beam_idx
)
for
layer_past
in
past
)
def
calc_banned_ngram_tokens
(
prev_input_ids
:
Tensor
,
num_hypos
:
int
,
no_repeat_ngram_size
:
int
,
cur_len
:
int
)
->
None
:
def
calc_banned_ngram_tokens
(
prev_input_ids
:
Tensor
,
num_hypos
:
int
,
no_repeat_ngram_size
:
int
,
cur_len
:
int
)
->
None
:
"""Copied from fairseq for no_repeat_ngram in beam_search"""
if
cur_len
+
1
<
no_repeat_ngram_size
:
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
...
...
@@ -881,7 +1009,9 @@ def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_n
generated_ngram
=
generated_ngrams
[
idx
]
for
ngram
in
zip
(
*
[
gen_tokens
[
i
:]
for
i
in
range
(
no_repeat_ngram_size
)]):
prev_ngram_tuple
=
tuple
(
ngram
[:
-
1
])
generated_ngram
[
prev_ngram_tuple
]
=
generated_ngram
.
get
(
prev_ngram_tuple
,
[])
+
[
ngram
[
-
1
]]
generated_ngram
[
prev_ngram_tuple
]
=
generated_ngram
.
get
(
prev_ngram_tuple
,
[]
)
+
[
ngram
[
-
1
]]
def
_get_generated_ngrams
(
hypo_idx
):
# Before decoding the next token, prevent decoding of ngrams that have already appeared
...
...
@@ -893,7 +1023,9 @@ def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_n
return
banned_tokens
def
calc_banned_bad_words_ids
(
prev_input_ids
:
Iterable
[
int
],
bad_words_ids
:
Iterable
[
int
])
->
Iterable
[
int
]:
def
calc_banned_bad_words_ids
(
prev_input_ids
:
Iterable
[
int
],
bad_words_ids
:
Iterable
[
int
]
)
->
Iterable
[
int
]:
banned_tokens
=
[]
def
_tokens_match
(
prev_tokens
,
tokens
):
...
...
@@ -914,7 +1046,9 @@ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iter
banned_tokens_slice
=
[]
for
banned_token_seq
in
bad_words_ids
:
assert
len
(
banned_token_seq
)
>
0
,
"Banned words token sequences {} cannot have an empty list"
.
format
(
assert
(
len
(
banned_token_seq
)
>
0
),
"Banned words token sequences {} cannot have an empty list"
.
format
(
bad_words_ids
)
...
...
@@ -929,7 +1063,9 @@ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iter
return
banned_tokens
def
set_scores_to_inf_for_banned_tokens
(
scores
:
torch
.
Tensor
,
banned_tokens
:
List
[
List
[
int
]])
->
None
:
def
set_scores_to_inf_for_banned_tokens
(
scores
:
torch
.
Tensor
,
banned_tokens
:
List
[
List
[
int
]]
)
->
None
:
"""Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be
a list of list of banned tokens to ban in the format [[batch index, vocabulary position],...]
Args:
...
...
@@ -949,7 +1085,12 @@ def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: Lis
# [ 0 0 0 ]
# [ 1 0 0 ]
banned_mask
=
torch
.
sparse
.
LongTensor
(
banned_mask
.
t
(),
indices
,
scores
.
size
())
.
to
(
scores
.
device
)
.
to_dense
()
.
bool
()
banned_mask
=
(
torch
.
sparse
.
LongTensor
(
banned_mask
.
t
(),
indices
,
scores
.
size
())
.
to
(
scores
.
device
)
.
to_dense
()
.
bool
()
)
scores
.
masked_fill_
(
banned_mask
,
-
float
(
"inf"
))
...
...
@@ -989,7 +1130,9 @@ def top_k_top_p_filtering(
sorted_indices_to_remove
[
...
,
0
]
=
0
# scatter sorted tensors to original indexing
indices_to_remove
=
sorted_indices_to_remove
.
scatter
(
1
,
sorted_indices
,
sorted_indices_to_remove
)
indices_to_remove
=
sorted_indices_to_remove
.
scatter
(
1
,
sorted_indices
,
sorted_indices_to_remove
)
logits
[
indices_to_remove
]
=
filter_value
return
logits
...
...
@@ -1020,7 +1163,9 @@ class BeamHypotheses(object):
if
len
(
self
)
<
self
.
num_beams
or
score
>
self
.
worst_score
:
self
.
beams
.
append
((
score
,
hyp
))
if
len
(
self
)
>
self
.
num_beams
:
sorted_scores
=
sorted
([(
s
,
idx
)
for
idx
,
(
s
,
_
)
in
enumerate
(
self
.
beams
)])
sorted_scores
=
sorted
(
[(
s
,
idx
)
for
idx
,
(
s
,
_
)
in
enumerate
(
self
.
beams
)]
)
del
self
.
beams
[
sorted_scores
[
0
][
1
]]
self
.
worst_score
=
sorted_scores
[
1
][
0
]
else
:
...
...
train/lightning_base.py
View file @
9b9ed4f
...
...
@@ -69,7 +69,7 @@ class BaseTransformer(pl.LightningModule):
config
=
None
,
tokenizer
=
None
,
model
=
None
,
**
config_kwargs
**
config_kwargs
,
):
"""Initialize a model, tokenizer and config."""
super
()
.
__init__
()
...
...
@@ -83,7 +83,9 @@ class BaseTransformer(pl.LightningModule):
cache_dir
=
self
.
hparams
.
cache_dir
if
self
.
hparams
.
cache_dir
else
None
if
config
is
None
:
self
.
config
=
AutoConfig
.
from_pretrained
(
self
.
hparams
.
config_name
if
self
.
hparams
.
config_name
else
self
.
hparams
.
model_name_or_path
,
self
.
hparams
.
config_name
if
self
.
hparams
.
config_name
else
self
.
hparams
.
model_name_or_path
,
**
({
"num_labels"
:
num_labels
}
if
num_labels
is
not
None
else
{}),
cache_dir
=
cache_dir
,
**
config_kwargs
,
...
...
@@ -91,15 +93,24 @@ class BaseTransformer(pl.LightningModule):
else
:
self
.
config
:
PretrainedConfig
=
config
extra_model_params
=
(
"encoder_layerdrop"
,
"decoder_layerdrop"
,
"dropout"
,
"attention_dropout"
)
extra_model_params
=
(
"encoder_layerdrop"
,
"decoder_layerdrop"
,
"dropout"
,
"attention_dropout"
,
)
for
p
in
extra_model_params
:
if
getattr
(
self
.
hparams
,
p
,
None
):
assert
hasattr
(
self
.
config
,
p
),
f
"model config doesn't have a `{p}` attribute"
assert
hasattr
(
self
.
config
,
p
),
f
"model config doesn't have a `{p}` attribute"
setattr
(
self
.
config
,
p
,
getattr
(
self
.
hparams
,
p
))
if
tokenizer
is
None
:
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
self
.
hparams
.
tokenizer_name
if
self
.
hparams
.
tokenizer_name
else
self
.
hparams
.
model_name_or_path
,
self
.
hparams
.
tokenizer_name
if
self
.
hparams
.
tokenizer_name
else
self
.
hparams
.
model_name_or_path
,
cache_dir
=
cache_dir
,
)
else
:
...
...
@@ -121,7 +132,9 @@ class BaseTransformer(pl.LightningModule):
def
get_lr_scheduler
(
self
):
get_schedule_func
=
arg_to_scheduler
[
self
.
hparams
.
lr_scheduler
]
scheduler
=
get_schedule_func
(
self
.
opt
,
num_warmup_steps
=
self
.
hparams
.
warmup_steps
,
num_training_steps
=
self
.
total_steps
self
.
opt
,
num_warmup_steps
=
self
.
hparams
.
warmup_steps
,
num_training_steps
=
self
.
total_steps
,
)
scheduler
=
{
"scheduler"
:
scheduler
,
"interval"
:
"step"
,
"frequency"
:
1
}
return
scheduler
...
...
@@ -132,22 +145,35 @@ class BaseTransformer(pl.LightningModule):
no_decay
=
[
"bias"
,
"LayerNorm.weight"
]
optimizer_grouped_parameters
=
[
{
"params"
:
[
p
for
n
,
p
in
model
.
named_parameters
()
if
not
any
(
nd
in
n
for
nd
in
no_decay
)],
"params"
:
[
p
for
n
,
p
in
model
.
named_parameters
()
if
not
any
(
nd
in
n
for
nd
in
no_decay
)
],
"weight_decay"
:
self
.
hparams
.
weight_decay
,
},
{
"params"
:
[
p
for
n
,
p
in
model
.
named_parameters
()
if
any
(
nd
in
n
for
nd
in
no_decay
)],
"params"
:
[
p
for
n
,
p
in
model
.
named_parameters
()
if
any
(
nd
in
n
for
nd
in
no_decay
)
],
"weight_decay"
:
0.0
,
},
]
if
self
.
hparams
.
adafactor
:
optimizer
=
Adafactor
(
optimizer_grouped_parameters
,
lr
=
self
.
hparams
.
learning_rate
,
scale_parameter
=
False
,
relative_step
=
False
optimizer_grouped_parameters
,
lr
=
self
.
hparams
.
learning_rate
,
scale_parameter
=
False
,
relative_step
=
False
,
)
else
:
optimizer
=
AdamW
(
optimizer_grouped_parameters
,
lr
=
self
.
hparams
.
learning_rate
,
eps
=
self
.
hparams
.
adam_epsilon
optimizer_grouped_parameters
,
lr
=
self
.
hparams
.
learning_rate
,
eps
=
self
.
hparams
.
adam_epsilon
,
)
self
.
opt
=
optimizer
...
...
@@ -165,13 +191,19 @@ class BaseTransformer(pl.LightningModule):
def
total_steps
(
self
)
->
int
:
"""The number of total training steps that will be run. Used for lr scheduler purposes."""
num_devices
=
max
(
1
,
self
.
hparams
.
gpus
)
# TODO: consider num_tpu_cores
effective_batch_size
=
self
.
hparams
.
train_batch_size
*
self
.
hparams
.
accumulate_grad_batches
*
num_devices
effective_batch_size
=
(
self
.
hparams
.
train_batch_size
*
self
.
hparams
.
accumulate_grad_batches
*
num_devices
)
dataset_size
=
len
(
self
.
train_loader
.
dataset
)
return
(
dataset_size
/
effective_batch_size
)
*
self
.
hparams
.
max_epochs
def
setup
(
self
,
mode
):
if
mode
==
"fit"
:
self
.
train_loader
=
self
.
get_dataloader
(
"train"
,
self
.
hparams
.
train_batch_size
,
shuffle
=
True
)
self
.
train_loader
=
self
.
get_dataloader
(
"train"
,
self
.
hparams
.
train_batch_size
,
shuffle
=
True
)
def
get_dataloader
(
self
,
type_path
,
batch_size
,
shuffle
=
False
):
raise
NotImplementedError
(
"You must implement this for your task"
)
...
...
@@ -212,7 +244,10 @@ class BaseTransformer(pl.LightningModule):
help
=
"Path to pretrained model or model identifier from huggingface.co/models"
,
)
parser
.
add_argument
(
"--config_name"
,
default
=
""
,
type
=
str
,
help
=
"Pretrained config name or path if not the same as model_name"
"--config_name"
,
default
=
""
,
type
=
str
,
help
=
"Pretrained config name or path if not the same as model_name"
,
)
parser
.
add_argument
(
"--tokenizer_name"
,
...
...
@@ -246,7 +281,12 @@ class BaseTransformer(pl.LightningModule):
type
=
float
,
help
=
"Attention dropout probability (Optional). Goes into model.config"
,
)
parser
.
add_argument
(
"--learning_rate"
,
default
=
5e-5
,
type
=
float
,
help
=
"The initial learning rate for Adam."
)
parser
.
add_argument
(
"--learning_rate"
,
default
=
5e-5
,
type
=
float
,
help
=
"The initial learning rate for Adam."
,
)
parser
.
add_argument
(
"--lr_scheduler"
,
default
=
"linear"
,
...
...
@@ -255,11 +295,30 @@ class BaseTransformer(pl.LightningModule):
type
=
str
,
help
=
"Learning rate scheduler"
,
)
parser
.
add_argument
(
"--weight_decay"
,
default
=
0.0
,
type
=
float
,
help
=
"Weight decay if we apply some."
)
parser
.
add_argument
(
"--adam_epsilon"
,
default
=
1e-8
,
type
=
float
,
help
=
"Epsilon for Adam optimizer."
)
parser
.
add_argument
(
"--warmup_steps"
,
default
=
0
,
type
=
int
,
help
=
"Linear warmup over warmup_steps."
)
parser
.
add_argument
(
"--num_workers"
,
default
=
4
,
type
=
int
,
help
=
"kwarg passed to DataLoader"
)
parser
.
add_argument
(
"--num_train_epochs"
,
dest
=
"max_epochs"
,
default
=
3
,
type
=
int
)
parser
.
add_argument
(
"--weight_decay"
,
default
=
0.0
,
type
=
float
,
help
=
"Weight decay if we apply some."
,
)
parser
.
add_argument
(
"--adam_epsilon"
,
default
=
1e-8
,
type
=
float
,
help
=
"Epsilon for Adam optimizer."
,
)
parser
.
add_argument
(
"--warmup_steps"
,
default
=
0
,
type
=
int
,
help
=
"Linear warmup over warmup_steps."
,
)
parser
.
add_argument
(
"--num_workers"
,
default
=
4
,
type
=
int
,
help
=
"kwarg passed to DataLoader"
)
parser
.
add_argument
(
"--num_train_epochs"
,
dest
=
"max_epochs"
,
default
=
3
,
type
=
int
)
parser
.
add_argument
(
"--train_batch_size"
,
default
=
32
,
type
=
int
)
parser
.
add_argument
(
"--eval_batch_size"
,
default
=
32
,
type
=
int
)
parser
.
add_argument
(
"--adafactor"
,
action
=
"store_true"
)
...
...
@@ -283,7 +342,9 @@ class LoggingCallback(pl.Callback):
rank_zero_info
(
"***** Test results *****"
)
metrics
=
trainer
.
callback_metrics
# Log and save results to file
output_test_results_file
=
os
.
path
.
join
(
pl_module
.
hparams
.
output_dir
,
"test_results.txt"
)
output_test_results_file
=
os
.
path
.
join
(
pl_module
.
hparams
.
output_dir
,
"test_results.txt"
)
with
open
(
output_test_results_file
,
"w"
)
as
writer
:
for
key
in
sorted
(
metrics
):
if
key
not
in
[
"log"
,
"progress_bar"
]:
...
...
@@ -314,9 +375,21 @@ def add_generic_args(parser, root_dir) -> None:
"See details at https://nvidia.github.io/apex/amp.html"
,
)
parser
.
add_argument
(
"--n_tpu_cores"
,
dest
=
"tpu_cores"
,
type
=
int
)
parser
.
add_argument
(
"--max_grad_norm"
,
dest
=
"gradient_clip_val"
,
default
=
1.0
,
type
=
float
,
help
=
"Max gradient norm"
)
parser
.
add_argument
(
"--do_train"
,
action
=
"store_true"
,
help
=
"Whether to run training."
)
parser
.
add_argument
(
"--do_predict"
,
action
=
"store_true"
,
help
=
"Whether to run predictions on the test set."
)
parser
.
add_argument
(
"--max_grad_norm"
,
dest
=
"gradient_clip_val"
,
default
=
1.0
,
type
=
float
,
help
=
"Max gradient norm"
,
)
parser
.
add_argument
(
"--do_train"
,
action
=
"store_true"
,
help
=
"Whether to run training."
)
parser
.
add_argument
(
"--do_predict"
,
action
=
"store_true"
,
help
=
"Whether to run predictions on the test set."
,
)
parser
.
add_argument
(
"--gradient_accumulation_steps"
,
dest
=
"accumulate_grad_batches"
,
...
...
@@ -324,7 +397,9 @@ def add_generic_args(parser, root_dir) -> None:
default
=
1
,
help
=
"Number of updates steps to accumulate before performing a backward/update pass."
,
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
,
help
=
"random seed for initialization"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
,
help
=
"random seed for initialization"
)
def
generic_train
(
...
...
@@ -335,7 +410,7 @@ def generic_train(
extra_callbacks
=
[],
checkpoint_callback
=
None
,
logging_callback
=
None
,
**
extra_train_kwargs
**
extra_train_kwargs
,
):
pl
.
seed_everything
(
args
.
seed
)
...
...
@@ -346,7 +421,11 @@ def generic_train(
# add custom checkpoints
if
checkpoint_callback
is
None
:
checkpoint_callback
=
pl
.
callbacks
.
ModelCheckpoint
(
filepath
=
args
.
output_dir
,
prefix
=
"checkpoint"
,
monitor
=
"val_loss"
,
mode
=
"min"
,
save_top_k
=
1
filepath
=
args
.
output_dir
,
prefix
=
"checkpoint"
,
monitor
=
"val_loss"
,
mode
=
"min"
,
save_top_k
=
1
,
)
if
logging_callback
is
None
:
logging_callback
=
LoggingCallback
()
...
...
train/modeling_bart.py
View file @
9b9ed4f
...
...
@@ -141,7 +141,11 @@ def invert_mask(attention_mask):
def
_prepare_bart_decoder_inputs
(
config
,
input_ids
,
decoder_input_ids
=
None
,
decoder_padding_mask
=
None
,
causal_mask_dtype
=
torch
.
float32
config
,
input_ids
,
decoder_input_ids
=
None
,
decoder_padding_mask
=
None
,
causal_mask_dtype
=
torch
.
float32
,
):
"""Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if
none are provided. This mimics the default behavior in fairseq. To override it pass in masks.
...
...
@@ -184,7 +188,9 @@ class PretrainedBartModel(PreTrainedModel):
@property
def
dummy_inputs
(
self
):
pad_token
=
self
.
config
.
pad_token_id
input_ids
=
torch
.
tensor
([[
0
,
6
,
10
,
4
,
2
],
[
0
,
8
,
12
,
2
,
pad_token
]],
device
=
self
.
device
)
input_ids
=
torch
.
tensor
(
[[
0
,
6
,
10
,
4
,
2
],
[
0
,
8
,
12
,
2
,
pad_token
]],
device
=
self
.
device
)
dummy_inputs
=
{
"attention_mask"
:
input_ids
.
ne
(
pad_token
),
"input_ids"
:
input_ids
,
...
...
@@ -229,7 +235,11 @@ class EncoderLayer(nn.Module):
def
__init__
(
self
,
config
:
BartConfig
):
super
()
.
__init__
()
self
.
embed_dim
=
config
.
d_model
self
.
self_attn
=
Attention
(
self
.
embed_dim
,
config
.
encoder_attention_heads
,
dropout
=
config
.
attention_dropout
)
self
.
self_attn
=
Attention
(
self
.
embed_dim
,
config
.
encoder_attention_heads
,
dropout
=
config
.
attention_dropout
,
)
self
.
normalize_before
=
config
.
normalize_before
self
.
self_attn_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
dropout
=
config
.
dropout
...
...
@@ -255,7 +265,10 @@ class EncoderLayer(nn.Module):
if
self
.
normalize_before
:
x
=
self
.
self_attn_layer_norm
(
x
)
x
,
attn_weights
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
key_padding_mask
=
encoder_padding_mask
,
output_attentions
=
output_attentions
query
=
x
,
key
=
x
,
key_padding_mask
=
encoder_padding_mask
,
output_attentions
=
output_attentions
,
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
...
...
@@ -308,13 +321,23 @@ class BartEncoder(nn.Module):
config
.
extra_pos_embeddings
,
)
self
.
embed_patches
=
nn
.
Embedding
(
3
,
config
.
d_model
,
padding_idx
=
0
)
self
.
layers
=
nn
.
ModuleList
([
EncoderLayer
(
config
)
for
_
in
range
(
config
.
encoder_layers
)])
self
.
layernorm_embedding
=
LayerNorm
(
embed_dim
)
if
config
.
normalize_embedding
else
nn
.
Identity
()
self
.
layers
=
nn
.
ModuleList
(
[
EncoderLayer
(
config
)
for
_
in
range
(
config
.
encoder_layers
)]
)
self
.
layernorm_embedding
=
(
LayerNorm
(
embed_dim
)
if
config
.
normalize_embedding
else
nn
.
Identity
()
)
# mbart has one extra layer_norm
self
.
layer_norm
=
LayerNorm
(
config
.
d_model
)
if
config
.
normalize_before
else
None
def
forward
(
self
,
input_ids
,
patch_ids
,
attention_mask
=
None
,
output_attentions
=
False
,
output_hidden_states
=
False
,
return_dict
=
False
self
,
input_ids
,
patch_ids
,
attention_mask
=
None
,
output_attentions
=
False
,
output_hidden_states
=
False
,
return_dict
=
False
,
):
"""
Args:
...
...
@@ -352,10 +375,14 @@ class BartEncoder(nn.Module):
encoder_states
.
append
(
x
)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability
=
random
.
uniform
(
0
,
1
)
if
self
.
training
and
(
dropout_probability
<
self
.
layerdrop
):
# skip the layer
if
self
.
training
and
(
dropout_probability
<
self
.
layerdrop
):
# skip the layer
attn
=
None
else
:
x
,
attn
=
encoder_layer
(
x
,
attention_mask
,
output_attentions
=
output_attentions
)
x
,
attn
=
encoder_layer
(
x
,
attention_mask
,
output_attentions
=
output_attentions
)
if
output_attentions
:
all_attentions
=
all_attentions
+
(
attn
,)
...
...
@@ -365,14 +392,20 @@ class BartEncoder(nn.Module):
if
output_hidden_states
:
encoder_states
.
append
(
x
)
# T x B x C -> B x T x C
encoder_states
=
tuple
(
hidden_state
.
transpose
(
0
,
1
)
for
hidden_state
in
encoder_states
)
encoder_states
=
tuple
(
hidden_state
.
transpose
(
0
,
1
)
for
hidden_state
in
encoder_states
)
# T x B x C -> B x T x C
x
=
x
.
transpose
(
0
,
1
)
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
x
,
encoder_states
,
all_attentions
]
if
v
is
not
None
)
return
BaseModelOutput
(
last_hidden_state
=
x
,
hidden_states
=
encoder_states
,
attentions
=
all_attentions
)
return
tuple
(
v
for
v
in
[
x
,
encoder_states
,
all_attentions
]
if
v
is
not
None
)
return
BaseModelOutput
(
last_hidden_state
=
x
,
hidden_states
=
encoder_states
,
attentions
=
all_attentions
)
class
DecoderLayer
(
nn
.
Module
):
...
...
@@ -498,8 +531,12 @@ class BartDecoder(nn.Module):
self
.
layers
=
nn
.
ModuleList
(
[
DecoderLayer
(
config
)
for
_
in
range
(
config
.
decoder_layers
)]
)
# type: List[DecoderLayer]
self
.
layernorm_embedding
=
LayerNorm
(
config
.
d_model
)
if
config
.
normalize_embedding
else
nn
.
Identity
()
self
.
layer_norm
=
LayerNorm
(
config
.
d_model
)
if
config
.
add_final_layer_norm
else
None
self
.
layernorm_embedding
=
(
LayerNorm
(
config
.
d_model
)
if
config
.
normalize_embedding
else
nn
.
Identity
()
)
self
.
layer_norm
=
(
LayerNorm
(
config
.
d_model
)
if
config
.
add_final_layer_norm
else
None
)
def
forward
(
self
,
...
...
@@ -595,23 +632,34 @@ class BartDecoder(nn.Module):
if
use_cache
:
next_decoder_cache
.
append
(
layer_past
.
copy
())
if
self
.
layer_norm
and
(
idx
==
len
(
self
.
layers
)
-
1
):
# if config.add_final_layer_norm (mBART)
if
self
.
layer_norm
and
(
idx
==
len
(
self
.
layers
)
-
1
):
# if config.add_final_layer_norm (mBART)
x
=
self
.
layer_norm
(
x
)
if
output_attentions
:
all_self_attns
+=
(
layer_self_attn
,)
# Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
if
output_hidden_states
:
all_hidden_states
=
tuple
(
hidden_state
.
transpose
(
0
,
1
)
for
hidden_state
in
all_hidden_states
)
all_hidden_states
=
tuple
(
hidden_state
.
transpose
(
0
,
1
)
for
hidden_state
in
all_hidden_states
)
x
=
x
.
transpose
(
0
,
1
)
encoder_hidden_states
=
encoder_hidden_states
.
transpose
(
0
,
1
)
next_cache
=
next_decoder_cache
if
use_cache
else
None
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
x
,
next_cache
,
all_hidden_states
,
all_self_attns
]
if
v
is
not
None
)
return
tuple
(
v
for
v
in
[
x
,
next_cache
,
all_hidden_states
,
all_self_attns
]
if
v
is
not
None
)
return
BaseModelOutputWithPast
(
last_hidden_state
=
x
,
past_key_values
=
next_cache
,
hidden_states
=
all_hidden_states
,
attentions
=
all_self_attns
last_hidden_state
=
x
,
past_key_values
=
next_cache
,
hidden_states
=
all_hidden_states
,
attentions
=
all_self_attns
,
)
...
...
@@ -638,7 +686,9 @@ class Attention(nn.Module):
self
.
num_heads
=
num_heads
self
.
dropout
=
dropout
self
.
head_dim
=
embed_dim
//
num_heads
assert
self
.
head_dim
*
num_heads
==
self
.
embed_dim
,
"embed_dim must be divisible by num_heads"
assert
(
self
.
head_dim
*
num_heads
==
self
.
embed_dim
),
"embed_dim must be divisible by num_heads"
self
.
scaling
=
self
.
head_dim
**
-
0.5
self
.
encoder_decoder_attention
=
encoder_decoder_attention
...
...
@@ -649,7 +699,11 @@ class Attention(nn.Module):
self
.
cache_key
=
"encoder_decoder"
if
self
.
encoder_decoder_attention
else
"self"
def
_shape
(
self
,
tensor
,
seq_len
,
bsz
):
return
tensor
.
contiguous
()
.
view
(
seq_len
,
bsz
*
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
0
,
1
)
return
(
tensor
.
contiguous
()
.
view
(
seq_len
,
bsz
*
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
0
,
1
)
)
def
forward
(
self
,
...
...
@@ -693,7 +747,9 @@ class Attention(nn.Module):
v
=
self
.
_shape
(
v
,
-
1
,
bsz
)
if
saved_state
is
not
None
:
k
,
v
,
key_padding_mask
=
self
.
_use_saved_state
(
k
,
v
,
saved_state
,
key_padding_mask
,
static_kv
,
bsz
)
k
,
v
,
key_padding_mask
=
self
.
_use_saved_state
(
k
,
v
,
saved_state
,
key_padding_mask
,
static_kv
,
bsz
)
# Update cache
layer_state
[
self
.
cache_key
]
=
{
...
...
@@ -708,7 +764,9 @@ class Attention(nn.Module):
assert
attn_weights
.
size
()
==
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
if
attn_mask
is
not
None
:
attn_weights
=
attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
+
attn_mask
attn_weights
=
(
attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
+
attn_mask
)
attn_weights
=
attn_weights
.
view
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
# This is part of a workaround to get around fork/join parallelism not supporting Optional types.
...
...
@@ -725,16 +783,14 @@ class Attention(nn.Module):
attn_weights
=
attn_weights
.
masked_fill
(
reshaped
,
float
(
"-inf"
))
attn_weights
=
attn_weights
.
view
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
=
F
.
softmax
(
attn_weights
,
dim
=-
1
)
attn_probs
=
F
.
dropout
(
attn_weights
,
p
=
self
.
dropout
,
training
=
self
.
training
,
)
attn_probs
=
F
.
dropout
(
attn_weights
,
p
=
self
.
dropout
,
training
=
self
.
training
,)
assert
v
is
not
None
attn_output
=
torch
.
bmm
(
attn_probs
,
v
)
assert
attn_output
.
size
()
==
(
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
attn_output
=
attn_output
.
transpose
(
0
,
1
)
.
contiguous
()
.
view
(
tgt_len
,
bsz
,
embed_dim
)
attn_output
=
(
attn_output
.
transpose
(
0
,
1
)
.
contiguous
()
.
view
(
tgt_len
,
bsz
,
embed_dim
)
)
attn_output
=
self
.
out_proj
(
attn_output
)
if
output_attentions
:
attn_weights
=
attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
...
...
@@ -763,12 +819,16 @@ class Attention(nn.Module):
assert
v
is
not
None
v
=
torch
.
cat
([
prev_value
,
v
],
dim
=
1
)
assert
k
is
not
None
and
v
is
not
None
prev_key_padding_mask
:
Optional
[
Tensor
]
=
saved_state
.
get
(
"prev_key_padding_mask"
,
None
)
prev_key_padding_mask
:
Optional
[
Tensor
]
=
saved_state
.
get
(
"prev_key_padding_mask"
,
None
)
if
prev_key_padding_mask
is
not
None
:
if
static_kv
:
new_key_padding_mask
=
prev_key_padding_mask
else
:
new_key_padding_mask
=
torch
.
cat
([
prev_key_padding_mask
,
key_padding_mask
],
dim
=
1
)
new_key_padding_mask
=
torch
.
cat
(
[
prev_key_padding_mask
,
key_padding_mask
],
dim
=
1
)
else
:
new_key_padding_mask
=
key_padding_mask
return
k
,
v
,
new_key_padding_mask
...
...
@@ -780,11 +840,7 @@ class BartClassificationHead(nn.Module):
# This can trivially be shared with RobertaClassificationHead
def
__init__
(
self
,
input_dim
,
inner_dim
,
num_classes
,
pooler_dropout
,
self
,
input_dim
,
inner_dim
,
num_classes
,
pooler_dropout
,
):
super
()
.
__init__
()
self
.
dense
=
nn
.
Linear
(
input_dim
,
inner_dim
)
...
...
@@ -808,7 +864,9 @@ class LearnedPositionalEmbedding(nn.Embedding):
position ids are passed to the forward function.
"""
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
padding_idx
:
int
,
offset
):
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
padding_idx
:
int
,
offset
):
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models dont have this hack
self
.
offset
=
offset
...
...
@@ -820,10 +878,14 @@ class LearnedPositionalEmbedding(nn.Embedding):
"""Input is expected to be of size [bsz x seqlen]."""
bsz
,
seq_len
=
input_ids
.
shape
[:
2
]
if
use_cache
:
positions
=
input_ids
.
data
.
new
(
1
,
1
)
.
fill_
(
seq_len
-
1
)
# called before slicing
positions
=
input_ids
.
data
.
new
(
1
,
1
)
.
fill_
(
seq_len
-
1
)
# called before slicing
else
:
# starts at 0, ends at 1-seq_len
positions
=
torch
.
arange
(
seq_len
,
dtype
=
torch
.
long
,
device
=
self
.
weight
.
device
)
positions
=
torch
.
arange
(
seq_len
,
dtype
=
torch
.
long
,
device
=
self
.
weight
.
device
)
return
super
()
.
forward
(
positions
+
self
.
offset
)
...
...
@@ -896,16 +958,28 @@ class BartModel(PretrainedBartModel):
if
decoder_input_ids
is
None
:
use_cache
=
False
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_attentions
=
(
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
)
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
use_cache
=
use_cache
if
use_cache
is
not
None
else
self
.
config
.
use_cache
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
return_dict
=
(
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
)
# make masks if user doesn't supply
if
not
use_cache
:
decoder_input_ids
,
decoder_padding_mask
,
causal_mask
=
_prepare_bart_decoder_inputs
(
(
decoder_input_ids
,
decoder_padding_mask
,
causal_mask
,
)
=
_prepare_bart_decoder_inputs
(
self
.
config
,
input_ids
,
decoder_input_ids
=
decoder_input_ids
,
...
...
@@ -974,17 +1048,24 @@ class BartModel(PretrainedBartModel):
@add_start_docstrings
(
"The BART Model with a language modeling head. Can be used for summarization."
,
BART_START_DOCSTRING
"The BART Model with a language modeling head. Can be used for summarization."
,
BART_START_DOCSTRING
,
)
class
BartForConditionalGeneration
(
PretrainedBartModel
):
base_model_prefix
=
"model"
authorized_missing_keys
=
[
r"final_logits_bias"
,
r"encoder\.version"
,
r"decoder\.version"
]
authorized_missing_keys
=
[
r"final_logits_bias"
,
r"encoder\.version"
,
r"decoder\.version"
,
]
def
__init__
(
self
,
config
:
BartConfig
):
super
()
.
__init__
(
config
)
base_model
=
BartModel
(
config
)
self
.
model
=
base_model
self
.
register_buffer
(
"final_logits_bias"
,
torch
.
zeros
((
1
,
self
.
model
.
shared
.
num_embeddings
)))
self
.
register_buffer
(
"final_logits_bias"
,
torch
.
zeros
((
1
,
self
.
model
.
shared
.
num_embeddings
))
)
def
resize_token_embeddings
(
self
,
new_num_tokens
:
int
)
->
nn
.
Embedding
:
old_num_tokens
=
self
.
model
.
shared
.
num_embeddings
...
...
@@ -993,16 +1074,23 @@ class BartForConditionalGeneration(PretrainedBartModel):
self
.
_resize_final_logits_bias
(
new_num_tokens
,
old_num_tokens
)
return
new_embeddings
def
_resize_final_logits_bias
(
self
,
new_num_tokens
:
int
,
old_num_tokens
:
int
)
->
None
:
def
_resize_final_logits_bias
(
self
,
new_num_tokens
:
int
,
old_num_tokens
:
int
)
->
None
:
if
new_num_tokens
<=
old_num_tokens
:
new_bias
=
self
.
final_logits_bias
[:,
:
new_num_tokens
]
else
:
extra_bias
=
torch
.
zeros
((
1
,
new_num_tokens
-
old_num_tokens
),
device
=
self
.
final_logits_bias
.
device
)
extra_bias
=
torch
.
zeros
(
(
1
,
new_num_tokens
-
old_num_tokens
),
device
=
self
.
final_logits_bias
.
device
,
)
new_bias
=
torch
.
cat
([
self
.
final_logits_bias
,
extra_bias
],
dim
=
1
)
self
.
register_buffer
(
"final_logits_bias"
,
new_bias
)
@add_start_docstrings_to_callable
(
BART_INPUTS_DOCSTRING
)
@replace_return_docstrings
(
output_type
=
Seq2SeqLMOutput
,
config_class
=
_CONFIG_FOR_DOC
)
@replace_return_docstrings
(
output_type
=
Seq2SeqLMOutput
,
config_class
=
_CONFIG_FOR_DOC
)
@add_end_docstrings
(
BART_GENERATION_EXAMPLE
)
def
forward
(
self
,
...
...
@@ -1065,7 +1153,9 @@ class BartForConditionalGeneration(PretrainedBartModel):
FutureWarning
,
)
past_key_values
=
unused
.
pop
(
"decoder_past_key_values"
)
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
return_dict
=
(
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
)
if
labels
is
not
None
:
use_cache
=
False
...
...
@@ -1085,17 +1175,23 @@ class BartForConditionalGeneration(PretrainedBartModel):
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
)
lm_logits
=
F
.
linear
(
outputs
[
0
],
self
.
model
.
shared
.
weight
,
bias
=
self
.
final_logits_bias
)
lm_logits
=
F
.
linear
(
outputs
[
0
],
self
.
model
.
shared
.
weight
,
bias
=
self
.
final_logits_bias
)
masked_lm_loss
=
None
if
labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
# TODO(SS): do we need to ignore pad tokens in labels?
masked_lm_loss
=
loss_fct
(
lm_logits
.
view
(
-
1
,
self
.
config
.
vocab_size
),
labels
.
view
(
-
1
))
masked_lm_loss
=
loss_fct
(
lm_logits
.
view
(
-
1
,
self
.
config
.
vocab_size
),
labels
.
view
(
-
1
)
)
if
not
return_dict
:
output
=
(
lm_logits
,)
+
outputs
[
1
:]
return
((
masked_lm_loss
,)
+
output
)
if
masked_lm_loss
is
not
None
else
output
return
(
((
masked_lm_loss
,)
+
output
)
if
masked_lm_loss
is
not
None
else
output
)
return
Seq2SeqLMOutput
(
loss
=
masked_lm_loss
,
...
...
@@ -1109,7 +1205,13 @@ class BartForConditionalGeneration(PretrainedBartModel):
)
def
prepare_inputs_for_generation
(
self
,
decoder_input_ids
,
past
,
attention_mask
,
use_cache
,
encoder_outputs
,
**
kwargs
self
,
decoder_input_ids
,
past
,
attention_mask
,
use_cache
,
encoder_outputs
,
**
kwargs
,
):
return
{
"input_ids"
:
None
,
# encoder_outputs is defined. input_ids not needed
...
...
@@ -1130,7 +1232,9 @@ class BartForConditionalGeneration(PretrainedBartModel):
def
_force_token_ids_generation
(
self
,
scores
,
token_id
)
->
None
:
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
scores
[:,
[
x
for
x
in
range
(
self
.
config
.
vocab_size
)
if
x
!=
token_id
]]
=
-
float
(
"inf"
)
scores
[:,
[
x
for
x
in
range
(
self
.
config
.
vocab_size
)
if
x
!=
token_id
]]
=
-
float
(
"inf"
)
@staticmethod
def
_reorder_cache
(
past
,
beam_idx
):
...
...
@@ -1138,7 +1242,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
for
layer_past
in
past
:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
layer_past_new
=
{
attn_key
:
_reorder_buffer
(
attn_cache
,
beam_idx
)
for
attn_key
,
attn_cache
in
layer_past
.
items
()
attn_key
:
_reorder_buffer
(
attn_cache
,
beam_idx
)
for
attn_key
,
attn_cache
in
layer_past
.
items
()
}
reordered_past
.
append
(
layer_past_new
)
return
reordered_past
...
...
@@ -1159,10 +1264,7 @@ class BartForSequenceClassification(PretrainedBartModel):
super
()
.
__init__
(
config
,
**
kwargs
)
self
.
model
=
BartModel
(
config
)
self
.
classification_head
=
BartClassificationHead
(
config
.
d_model
,
config
.
d_model
,
config
.
num_labels
,
config
.
classif_dropout
,
config
.
d_model
,
config
.
d_model
,
config
.
num_labels
,
config
.
classif_dropout
,
)
self
.
model
.
_init_weights
(
self
.
classification_head
.
dense
)
self
.
model
.
_init_weights
(
self
.
classification_head
.
out_proj
)
...
...
@@ -1193,7 +1295,9 @@ class BartForSequenceClassification(PretrainedBartModel):
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
return_dict
=
(
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
)
if
labels
is
not
None
:
use_cache
=
False
...
...
@@ -1212,7 +1316,9 @@ class BartForSequenceClassification(PretrainedBartModel):
eos_mask
=
input_ids
.
eq
(
self
.
config
.
eos_token_id
)
if
len
(
torch
.
unique
(
eos_mask
.
sum
(
1
)))
>
1
:
raise
ValueError
(
"All examples must have the same number of <eos> tokens."
)
sentence_representation
=
x
[
eos_mask
,
:]
.
view
(
x
.
size
(
0
),
-
1
,
x
.
size
(
-
1
))[:,
-
1
,
:]
sentence_representation
=
x
[
eos_mask
,
:]
.
view
(
x
.
size
(
0
),
-
1
,
x
.
size
(
-
1
))[
:,
-
1
,
:
]
logits
=
self
.
classification_head
(
sentence_representation
)
loss
=
None
...
...
@@ -1284,7 +1390,9 @@ class BartForQuestionAnswering(PretrainedBartModel):
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
"""
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
return_dict
=
(
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
)
if
start_positions
is
not
None
and
end_positions
is
not
None
:
use_cache
=
False
...
...
@@ -1325,10 +1433,7 @@ class BartForQuestionAnswering(PretrainedBartModel):
total_loss
=
(
start_loss
+
end_loss
)
/
2
if
not
return_dict
:
output
=
(
start_logits
,
end_logits
,
)
+
outputs
[
1
:]
output
=
(
start_logits
,
end_logits
,)
+
outputs
[
1
:]
return
((
total_loss
,)
+
output
)
if
total_loss
is
not
None
else
output
return
Seq2SeqQuestionAnsweringModelOutput
(
...
...
@@ -1350,7 +1455,9 @@ class SinusoidalPositionalEmbedding(nn.Embedding):
def
__init__
(
self
,
num_positions
,
embedding_dim
,
padding_idx
=
None
):
super
()
.
__init__
(
num_positions
,
embedding_dim
)
if
embedding_dim
%
2
!=
0
:
raise
NotImplementedError
(
f
"odd embedding_dim {embedding_dim} not supported"
)
raise
NotImplementedError
(
f
"odd embedding_dim {embedding_dim} not supported"
)
self
.
weight
=
self
.
_init_weight
(
self
.
weight
)
@staticmethod
...
...
@@ -1360,9 +1467,14 @@ class SinusoidalPositionalEmbedding(nn.Embedding):
"""
n_pos
,
dim
=
out
.
shape
position_enc
=
np
.
array
(
[[
pos
/
np
.
power
(
10000
,
2
*
(
j
//
2
)
/
dim
)
for
j
in
range
(
dim
)]
for
pos
in
range
(
n_pos
)]
[
[
pos
/
np
.
power
(
10000
,
2
*
(
j
//
2
)
/
dim
)
for
j
in
range
(
dim
)]
for
pos
in
range
(
n_pos
)
]
)
out
[:,
0
:
dim
//
2
]
=
torch
.
FloatTensor
(
np
.
sin
(
position_enc
[:,
0
::
2
]))
# This line breaks for odd n_pos
out
[:,
0
:
dim
//
2
]
=
torch
.
FloatTensor
(
np
.
sin
(
position_enc
[:,
0
::
2
])
)
# This line breaks for odd n_pos
out
[:,
dim
//
2
:]
=
torch
.
FloatTensor
(
np
.
cos
(
position_enc
[:,
1
::
2
]))
out
.
detach_
()
out
.
requires_grad
=
False
...
...
@@ -1373,8 +1485,12 @@ class SinusoidalPositionalEmbedding(nn.Embedding):
"""Input is expected to be of size [bsz x seqlen]."""
bsz
,
seq_len
=
input_ids
.
shape
[:
2
]
if
use_cache
:
positions
=
input_ids
.
data
.
new
(
1
,
1
)
.
fill_
(
seq_len
-
1
)
# called before slicing
positions
=
input_ids
.
data
.
new
(
1
,
1
)
.
fill_
(
seq_len
-
1
)
# called before slicing
else
:
# starts at 0, ends at 1-seq_len
positions
=
torch
.
arange
(
seq_len
,
dtype
=
torch
.
long
,
device
=
self
.
weight
.
device
)
positions
=
torch
.
arange
(
seq_len
,
dtype
=
torch
.
long
,
device
=
self
.
weight
.
device
)
return
super
()
.
forward
(
positions
)
...
...
train/modeling_utils.py
View file @
9b9ed4f
...
...
@@ -80,7 +80,9 @@ def find_pruneable_heads_and_indices(
:obj:`Tuple[Set[int], torch.LongTensor]`: A tuple with the remaining heads and their corresponding indices.
"""
mask
=
torch
.
ones
(
n_heads
,
head_size
)
heads
=
set
(
heads
)
-
already_pruned_heads
# Convert to set and remove already pruned heads
heads
=
(
set
(
heads
)
-
already_pruned_heads
)
# Convert to set and remove already pruned heads
for
head
in
heads
:
# Compute how many pruned heads are before the head and move the index accordingly
head
=
head
-
sum
(
1
if
h
<
head
else
0
for
h
in
already_pruned_heads
)
...
...
@@ -106,7 +108,11 @@ class ModuleUtilsMixin:
Returns:
:obj:`int`: The number of parameters.
"""
params
=
filter
(
lambda
x
:
x
.
requires_grad
,
self
.
parameters
())
if
only_trainable
else
self
.
parameters
()
params
=
(
filter
(
lambda
x
:
x
.
requires_grad
,
self
.
parameters
())
if
only_trainable
else
self
.
parameters
()
)
return
sum
(
p
.
numel
()
for
p
in
params
)
@staticmethod
...
...
@@ -114,7 +120,9 @@ class ModuleUtilsMixin:
try
:
import
psutil
except
(
ImportError
):
raise
ImportError
(
"You need to install psutil (pip install psutil) to use memory tracing."
)
raise
ImportError
(
"You need to install psutil (pip install psutil) to use memory tracing."
)
process
=
psutil
.
Process
(
os
.
getpid
())
mem
=
process
.
memory_info
()
...
...
@@ -126,13 +134,17 @@ class ModuleUtilsMixin:
try
:
import
psutil
except
(
ImportError
):
raise
ImportError
(
"You need to install psutil (pip install psutil) to use memory tracing."
)
raise
ImportError
(
"You need to install psutil (pip install psutil) to use memory tracing."
)
process
=
psutil
.
Process
(
os
.
getpid
())
mem
=
process
.
memory_info
()
module
.
mem_rss_post_forward
=
mem
.
rss
mem_rss_diff
=
module
.
mem_rss_post_forward
-
module
.
mem_rss_pre_forward
module
.
mem_rss_diff
=
mem_rss_diff
+
(
module
.
mem_rss_diff
if
hasattr
(
module
,
"mem_rss_diff"
)
else
0
)
module
.
mem_rss_diff
=
mem_rss_diff
+
(
module
.
mem_rss_diff
if
hasattr
(
module
,
"mem_rss_diff"
)
else
0
)
return
None
def
add_memory_hooks
(
self
):
...
...
@@ -169,7 +181,9 @@ class ModuleUtilsMixin:
# For nn.DataParallel compatibility in PyTorch 1.5
def
find_tensor_attributes
(
module
:
nn
.
Module
)
->
List
[
Tuple
[
str
,
Tensor
]]:
tuples
=
[(
k
,
v
)
for
k
,
v
in
module
.
__dict__
.
items
()
if
torch
.
is_tensor
(
v
)]
tuples
=
[
(
k
,
v
)
for
k
,
v
in
module
.
__dict__
.
items
()
if
torch
.
is_tensor
(
v
)
]
return
tuples
gen
=
self
.
_named_members
(
get_members_fn
=
find_tensor_attributes
)
...
...
@@ -187,7 +201,9 @@ class ModuleUtilsMixin:
# For nn.DataParallel compatibility in PyTorch 1.5
def
find_tensor_attributes
(
module
:
nn
.
Module
)
->
List
[
Tuple
[
str
,
Tensor
]]:
tuples
=
[(
k
,
v
)
for
k
,
v
in
module
.
__dict__
.
items
()
if
torch
.
is_tensor
(
v
)]
tuples
=
[
(
k
,
v
)
for
k
,
v
in
module
.
__dict__
.
items
()
if
torch
.
is_tensor
(
v
)
]
return
tuples
gen
=
self
.
_named_members
(
get_members_fn
=
find_tensor_attributes
)
...
...
@@ -213,12 +229,18 @@ class ModuleUtilsMixin:
# /transformer/transformer_layers.py#L270
# encoder_extended_attention_mask = (encoder_extended_attention_mask ==
# encoder_extended_attention_mask.transpose(-1, -2))
encoder_extended_attention_mask
=
encoder_extended_attention_mask
.
to
(
dtype
=
self
.
dtype
)
# fp16 compatibility
encoder_extended_attention_mask
=
encoder_extended_attention_mask
.
to
(
dtype
=
self
.
dtype
)
# fp16 compatibility
if
self
.
dtype
==
torch
.
float16
:
encoder_extended_attention_mask
=
(
1.0
-
encoder_extended_attention_mask
)
*
-
1e4
encoder_extended_attention_mask
=
(
1.0
-
encoder_extended_attention_mask
)
*
-
1e4
elif
self
.
dtype
==
torch
.
float32
:
encoder_extended_attention_mask
=
(
1.0
-
encoder_extended_attention_mask
)
*
-
1e9
encoder_extended_attention_mask
=
(
1.0
-
encoder_extended_attention_mask
)
*
-
1e9
else
:
raise
ValueError
(
"{} not recognized. `dtype` should be set to either `torch.float32` or `torch.float16`"
.
format
(
...
...
@@ -228,7 +250,9 @@ class ModuleUtilsMixin:
return
encoder_extended_attention_mask
def
get_extended_attention_mask
(
self
,
attention_mask
:
Tensor
,
input_shape
:
Tuple
[
int
],
device
:
device
)
->
Tensor
:
def
get_extended_attention_mask
(
self
,
attention_mask
:
Tensor
,
input_shape
:
Tuple
[
int
],
device
:
device
)
->
Tensor
:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
...
...
@@ -254,10 +278,15 @@ class ModuleUtilsMixin:
if
self
.
config
.
is_decoder
:
batch_size
,
seq_length
=
input_shape
seq_ids
=
torch
.
arange
(
seq_length
,
device
=
device
)
causal_mask
=
seq_ids
[
None
,
None
,
:]
.
repeat
(
batch_size
,
seq_length
,
1
)
<=
seq_ids
[
None
,
:,
None
]
causal_mask
=
(
seq_ids
[
None
,
None
,
:]
.
repeat
(
batch_size
,
seq_length
,
1
)
<=
seq_ids
[
None
,
:,
None
]
)
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask
=
causal_mask
.
to
(
attention_mask
.
dtype
)
extended_attention_mask
=
causal_mask
[:,
None
,
:,
:]
*
attention_mask
[:,
None
,
None
,
:]
extended_attention_mask
=
(
causal_mask
[:,
None
,
:,
:]
*
attention_mask
[:,
None
,
None
,
:]
)
else
:
extended_attention_mask
=
attention_mask
[:,
None
,
None
,
:]
else
:
...
...
@@ -272,12 +301,17 @@ class ModuleUtilsMixin:
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
self
.
dtype
)
# fp16 compatibility
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
self
.
dtype
)
# fp16 compatibility
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
10000.0
return
extended_attention_mask
def
get_head_mask
(
self
,
head_mask
:
Optional
[
Tensor
],
num_hidden_layers
:
int
,
is_attention_chunked
:
bool
=
False
self
,
head_mask
:
Optional
[
Tensor
],
num_hidden_layers
:
int
,
is_attention_chunked
:
bool
=
False
,
)
->
Tensor
:
"""
Prepare the head mask if needed.
...
...
@@ -309,9 +343,13 @@ class ModuleUtilsMixin:
head_mask
=
head_mask
.
unsqueeze
(
0
)
.
unsqueeze
(
0
)
.
unsqueeze
(
-
1
)
.
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
(
num_hidden_layers
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
)
.
unsqueeze
(
-
1
)
.
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
head_mask
=
(
head_mask
.
unsqueeze
(
1
)
.
unsqueeze
(
-
1
)
.
unsqueeze
(
-
1
)
)
# We can specify head_mask for each layer
assert
head_mask
.
dim
()
==
5
,
f
"head_mask.dim != 5, instead {head_mask.dim()}"
head_mask
=
head_mask
.
to
(
dtype
=
self
.
dtype
)
# switch to fload if need + fp16 compatibility
head_mask
=
head_mask
.
to
(
dtype
=
self
.
dtype
)
# switch to fload if need + fp16 compatibility
return
head_mask
...
...
@@ -420,12 +458,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
self
.
_tie_or_clone_weights
(
output_embeddings
,
self
.
get_input_embeddings
())
if
self
.
config
.
is_encoder_decoder
and
self
.
config
.
tie_encoder_decoder
:
self
.
_tie_encoder_decoder_weights
(
self
.
encoder
,
self
.
decoder
,
self
.
base_model_prefix
)
self
.
_tie_encoder_decoder_weights
(
self
.
encoder
,
self
.
decoder
,
self
.
base_model_prefix
)
@staticmethod
def
_tie_encoder_decoder_weights
(
encoder
:
nn
.
Module
,
decoder
:
nn
.
Module
,
base_model_prefix
:
str
):
def
_tie_encoder_decoder_weights
(
encoder
:
nn
.
Module
,
decoder
:
nn
.
Module
,
base_model_prefix
:
str
):
uninitialized_encoder_weights
:
List
[
str
]
=
[]
assert
decoder
.
__class__
==
encoder
.
__class__
,
f
"{decoder.__class__} and {encoder.__class__} have to be equal."
assert
(
decoder
.
__class__
==
encoder
.
__class__
),
f
"{decoder.__class__} and {encoder.__class__} have to be equal."
def
tie_encoder_to_decoder_recursively
(
decoder_pointer
:
nn
.
Module
,
...
...
@@ -452,13 +496,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
len
(
encoder_modules
)
>
0
),
f
"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
all_encoder_weights
=
set
([
module_name
+
"/"
+
sub_name
for
sub_name
in
encoder_modules
.
keys
()])
all_encoder_weights
=
set
(
[
module_name
+
"/"
+
sub_name
for
sub_name
in
encoder_modules
.
keys
()
]
)
encoder_layer_pos
=
0
for
name
,
module
in
decoder_modules
.
items
():
if
name
.
isdigit
():
encoder_name
=
str
(
int
(
name
)
+
encoder_layer_pos
)
decoder_name
=
name
if
not
isinstance
(
decoder_modules
[
decoder_name
],
type
(
encoder_modules
[
encoder_name
])):
if
not
isinstance
(
decoder_modules
[
decoder_name
],
type
(
encoder_modules
[
encoder_name
]),
):
# this can happen if the name corresponds to the position in a list module list of layers
# in this case the decoder has added a cross-attention that the encoder does not have
# thus skip this step and substract one layer pos from encoder
...
...
@@ -484,7 +536,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
uninitialized_encoder_weights
+=
list
(
all_encoder_weights
)
# tie weights recursively
tie_encoder_to_decoder_recursively
(
decoder
,
encoder
,
base_model_prefix
,
uninitialized_encoder_weights
)
tie_encoder_to_decoder_recursively
(
decoder
,
encoder
,
base_model_prefix
,
uninitialized_encoder_weights
)
if
len
(
uninitialized_encoder_weights
)
>
0
:
logger
.
warning
(
f
"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}"
...
...
@@ -507,10 +561,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
"constant"
,
0
,
)
if
hasattr
(
output_embeddings
,
"out_features"
)
and
hasattr
(
input_embeddings
,
"num_embeddings"
):
if
hasattr
(
output_embeddings
,
"out_features"
)
and
hasattr
(
input_embeddings
,
"num_embeddings"
):
output_embeddings
.
out_features
=
input_embeddings
.
num_embeddings
def
resize_token_embeddings
(
self
,
new_num_tokens
:
Optional
[
int
]
=
None
)
->
torch
.
nn
.
Embedding
:
def
resize_token_embeddings
(
self
,
new_num_tokens
:
Optional
[
int
]
=
None
)
->
torch
.
nn
.
Embedding
:
"""
Resizes input token embeddings matrix of the model if :obj:`new_num_tokens != config.vocab_size`.
...
...
@@ -526,7 +584,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
Return:
:obj:`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
"""
base_model
=
getattr
(
self
,
self
.
base_model_prefix
,
self
)
# get the base model if needed
base_model
=
getattr
(
self
,
self
.
base_model_prefix
,
self
)
# get the base model if needed
model_embeds
=
base_model
.
_resize_token_embeddings
(
new_num_tokens
)
if
new_num_tokens
is
None
:
return
model_embeds
...
...
@@ -583,7 +643,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
# Copy token embeddings from the previous weights
num_tokens_to_copy
=
min
(
old_num_tokens
,
new_num_tokens
)
new_embeddings
.
weight
.
data
[:
num_tokens_to_copy
,
:]
=
old_embeddings
.
weight
.
data
[:
num_tokens_to_copy
,
:]
new_embeddings
.
weight
.
data
[:
num_tokens_to_copy
,
:]
=
old_embeddings
.
weight
.
data
[
:
num_tokens_to_copy
,
:
]
return
new_embeddings
...
...
@@ -614,7 +676,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
# save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
for
layer
,
heads
in
heads_to_prune
.
items
():
union_heads
=
set
(
self
.
config
.
pruned_heads
.
get
(
layer
,
[]))
|
set
(
heads
)
self
.
config
.
pruned_heads
[
layer
]
=
list
(
union_heads
)
# Unfortunately we have to store it as list for JSON
self
.
config
.
pruned_heads
[
layer
]
=
list
(
union_heads
)
# Unfortunately we have to store it as list for JSON
self
.
base_model
.
_prune_heads
(
heads_to_prune
)
...
...
@@ -628,7 +692,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
Directory to which to save. Will be created if it doesn't exist.
"""
if
os
.
path
.
isfile
(
save_directory
):
logger
.
error
(
"Provided path ({}) should be a directory, not a file"
.
format
(
save_directory
))
logger
.
error
(
"Provided path ({}) should be a directory, not a file"
.
format
(
save_directory
)
)
return
os
.
makedirs
(
save_directory
,
exist_ok
=
True
)
...
...
@@ -775,7 +843,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
# Load config if we don't provide a configuration
if
not
isinstance
(
config
,
PretrainedConfig
):
config_path
=
config
if
config
is
not
None
else
pretrained_model_name_or_path
config_path
=
(
config
if
config
is
not
None
else
pretrained_model_name_or_path
)
config
,
model_kwargs
=
cls
.
config_class
.
from_pretrained
(
config_path
,
*
model_args
,
...
...
@@ -793,23 +863,43 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
# Load model
if
pretrained_model_name_or_path
is
not
None
:
if
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
if
from_tf
and
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF_WEIGHTS_NAME
+
".index"
)):
if
from_tf
and
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF_WEIGHTS_NAME
+
".index"
)
):
# Load from a TF 1.0 checkpoint
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF_WEIGHTS_NAME
+
".index"
)
elif
from_tf
and
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF2_WEIGHTS_NAME
)):
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF_WEIGHTS_NAME
+
".index"
)
elif
from_tf
and
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF2_WEIGHTS_NAME
)
):
# Load from a TF 2.0 checkpoint
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF2_WEIGHTS_NAME
)
elif
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)):
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF2_WEIGHTS_NAME
)
elif
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
):
# Load from a PyTorch checkpoint
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
else
:
raise
EnvironmentError
(
"Error no file named {} found in directory {} or `from_tf` set to False"
.
format
(
[
WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
,
TF_WEIGHTS_NAME
+
".index"
],
[
WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
,
TF_WEIGHTS_NAME
+
".index"
,
],
pretrained_model_name_or_path
,
)
)
elif
os
.
path
.
isfile
(
pretrained_model_name_or_path
)
or
is_remote_url
(
pretrained_model_name_or_path
):
elif
os
.
path
.
isfile
(
pretrained_model_name_or_path
)
or
is_remote_url
(
pretrained_model_name_or_path
):
archive_file
=
pretrained_model_name_or_path
elif
os
.
path
.
isfile
(
pretrained_model_name_or_path
+
".index"
):
assert
(
...
...
@@ -848,7 +938,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
if
resolved_archive_file
==
archive_file
:
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
else
:
logger
.
info
(
"loading weights file {} from cache at {}"
.
format
(
archive_file
,
resolved_archive_file
))
logger
.
info
(
"loading weights file {} from cache at {}"
.
format
(
archive_file
,
resolved_archive_file
)
)
else
:
resolved_archive_file
=
None
...
...
@@ -871,13 +965,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
if
from_tf
:
if
resolved_archive_file
.
endswith
(
".index"
):
# Load from a TensorFlow 1.X checkpoint - provided by original authors
model
=
cls
.
load_tf_weights
(
model
,
config
,
resolved_archive_file
[:
-
6
])
# Remove the '.index'
model
=
cls
.
load_tf_weights
(
model
,
config
,
resolved_archive_file
[:
-
6
]
)
# Remove the '.index'
else
:
# Load from our TensorFlow 2.0 checkpoints
try
:
from
transformers
import
load_tf2_checkpoint_in_pytorch_model
model
=
load_tf2_checkpoint_in_pytorch_model
(
model
,
resolved_archive_file
,
allow_missing_keys
=
True
)
model
=
load_tf2_checkpoint_in_pytorch_model
(
model
,
resolved_archive_file
,
allow_missing_keys
=
True
)
except
ImportError
:
logger
.
error
(
"Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
...
...
@@ -909,7 +1007,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def
load
(
module
:
nn
.
Module
,
prefix
=
""
):
local_metadata
=
{}
if
metadata
is
None
else
metadata
.
get
(
prefix
[:
-
1
],
{})
local_metadata
=
(
{}
if
metadata
is
None
else
metadata
.
get
(
prefix
[:
-
1
],
{})
)
module
.
_load_from_state_dict
(
state_dict
,
prefix
,
...
...
@@ -926,7 +1026,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
# Make sure we are able to load base models as well as derived models (with heads)
start_prefix
=
""
model_to_load
=
model
has_prefix_module
=
any
(
s
.
startswith
(
cls
.
base_model_prefix
)
for
s
in
state_dict
.
keys
())
has_prefix_module
=
any
(
s
.
startswith
(
cls
.
base_model_prefix
)
for
s
in
state_dict
.
keys
()
)
if
not
hasattr
(
model
,
cls
.
base_model_prefix
)
and
has_prefix_module
:
start_prefix
=
cls
.
base_model_prefix
+
"."
if
hasattr
(
model
,
cls
.
base_model_prefix
)
and
not
has_prefix_module
:
...
...
@@ -937,15 +1039,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
if
model
.
__class__
.
__name__
!=
model_to_load
.
__class__
.
__name__
:
base_model_state_dict
=
model_to_load
.
state_dict
()
.
keys
()
head_model_state_dict_without_base_prefix
=
[
key
.
split
(
cls
.
base_model_prefix
+
"."
)[
-
1
]
for
key
in
model
.
state_dict
()
.
keys
()
key
.
split
(
cls
.
base_model_prefix
+
"."
)[
-
1
]
for
key
in
model
.
state_dict
()
.
keys
()
]
missing_keys
.
extend
(
head_model_state_dict_without_base_prefix
-
base_model_state_dict
)
missing_keys
.
extend
(
head_model_state_dict_without_base_prefix
-
base_model_state_dict
)
# Some models may have keys that are not in the state by design, removing them before needlessly warning
# the user.
if
cls
.
authorized_missing_keys
is
not
None
:
for
pat
in
cls
.
authorized_missing_keys
:
missing_keys
=
[
k
for
k
in
missing_keys
if
re
.
search
(
pat
,
k
)
is
None
]
missing_keys
=
[
k
for
k
in
missing_keys
if
re
.
search
(
pat
,
k
)
is
None
]
if
len
(
unexpected_keys
)
>
0
:
logger
.
warning
(
...
...
@@ -957,7 +1064,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
f
"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
)
else
:
logger
.
info
(
f
"All model checkpoint weights were used when initializing {model.__class__.__name__}.
\n
"
)
logger
.
info
(
f
"All model checkpoint weights were used when initializing {model.__class__.__name__}.
\n
"
)
if
len
(
missing_keys
)
>
0
:
logger
.
warning
(
f
"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
...
...
@@ -990,7 +1099,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
}
return
model
,
loading_info
if
hasattr
(
config
,
"xla_device"
)
and
config
.
xla_device
and
is_torch_tpu_available
():
if
(
hasattr
(
config
,
"xla_device"
)
and
config
.
xla_device
and
is_torch_tpu_available
()
):
import
torch_xla.core.xla_model
as
xm
model
=
xm
.
send_cpu_data_to_device
(
model
,
xm
.
xla_device
())
...
...
@@ -1039,7 +1152,9 @@ class PoolerStartLogits(nn.Module):
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
1
)
def
forward
(
self
,
hidden_states
:
torch
.
FloatTensor
,
p_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
self
,
hidden_states
:
torch
.
FloatTensor
,
p_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
)
->
torch
.
FloatTensor
:
"""
Args:
...
...
@@ -1112,8 +1227,12 @@ class PoolerEndLogits(nn.Module):
),
"One of start_states, start_positions should be not None"
if
start_positions
is
not
None
:
slen
,
hsz
=
hidden_states
.
shape
[
-
2
:]
start_positions
=
start_positions
[:,
None
,
None
]
.
expand
(
-
1
,
-
1
,
hsz
)
# shape (bsz, 1, hsz)
start_states
=
hidden_states
.
gather
(
-
2
,
start_positions
)
# shape (bsz, 1, hsz)
start_positions
=
start_positions
[:,
None
,
None
]
.
expand
(
-
1
,
-
1
,
hsz
)
# shape (bsz, 1, hsz)
start_states
=
hidden_states
.
gather
(
-
2
,
start_positions
)
# shape (bsz, 1, hsz)
start_states
=
start_states
.
expand
(
-
1
,
slen
,
-
1
)
# shape (bsz, slen, hsz)
x
=
self
.
dense_0
(
torch
.
cat
([
hidden_states
,
start_states
],
dim
=-
1
))
...
...
@@ -1177,12 +1296,20 @@ class PoolerAnswerClass(nn.Module):
start_states
is
not
None
or
start_positions
is
not
None
),
"One of start_states, start_positions should be not None"
if
start_positions
is
not
None
:
start_positions
=
start_positions
[:,
None
,
None
]
.
expand
(
-
1
,
-
1
,
hsz
)
# shape (bsz, 1, hsz)
start_states
=
hidden_states
.
gather
(
-
2
,
start_positions
)
.
squeeze
(
-
2
)
# shape (bsz, hsz)
start_positions
=
start_positions
[:,
None
,
None
]
.
expand
(
-
1
,
-
1
,
hsz
)
# shape (bsz, 1, hsz)
start_states
=
hidden_states
.
gather
(
-
2
,
start_positions
)
.
squeeze
(
-
2
)
# shape (bsz, hsz)
if
cls_index
is
not
None
:
cls_index
=
cls_index
[:,
None
,
None
]
.
expand
(
-
1
,
-
1
,
hsz
)
# shape (bsz, 1, hsz)
cls_token_state
=
hidden_states
.
gather
(
-
2
,
cls_index
)
.
squeeze
(
-
2
)
# shape (bsz, hsz)
cls_index
=
cls_index
[:,
None
,
None
]
.
expand
(
-
1
,
-
1
,
hsz
)
# shape (bsz, 1, hsz)
cls_token_state
=
hidden_states
.
gather
(
-
2
,
cls_index
)
.
squeeze
(
-
2
)
# shape (bsz, hsz)
else
:
cls_token_state
=
hidden_states
[:,
-
1
,
:]
# shape (bsz, hsz)
...
...
@@ -1241,7 +1368,9 @@ class SQuADHead(nn.Module):
self
.
end_logits
=
PoolerEndLogits
(
config
)
self
.
answer_class
=
PoolerAnswerClass
(
config
)
@replace_return_docstrings
(
output_type
=
SquadHeadOutput
,
config_class
=
PretrainedConfig
)
@replace_return_docstrings
(
output_type
=
SquadHeadOutput
,
config_class
=
PretrainedConfig
)
def
forward
(
self
,
hidden_states
:
torch
.
FloatTensor
,
...
...
@@ -1281,7 +1410,9 @@ class SQuADHead(nn.Module):
x
.
squeeze_
(
-
1
)
# during training, compute the end logits based on the ground truth of the start position
end_logits
=
self
.
end_logits
(
hidden_states
,
start_positions
=
start_positions
,
p_mask
=
p_mask
)
end_logits
=
self
.
end_logits
(
hidden_states
,
start_positions
=
start_positions
,
p_mask
=
p_mask
)
loss_fct
=
CrossEntropyLoss
()
start_loss
=
loss_fct
(
start_logits
,
start_positions
)
...
...
@@ -1290,7 +1421,9 @@ class SQuADHead(nn.Module):
if
cls_index
is
not
None
and
is_impossible
is
not
None
:
# Predict answerability from the representation of CLS and START
cls_logits
=
self
.
answer_class
(
hidden_states
,
start_positions
=
start_positions
,
cls_index
=
cls_index
)
cls_logits
=
self
.
answer_class
(
hidden_states
,
start_positions
=
start_positions
,
cls_index
=
cls_index
)
loss_fct_cls
=
nn
.
BCEWithLogitsLoss
()
cls_loss
=
loss_fct_cls
(
cls_logits
,
is_impossible
)
...
...
@@ -1307,28 +1440,48 @@ class SQuADHead(nn.Module):
start_top_log_probs
,
start_top_index
=
torch
.
topk
(
start_log_probs
,
self
.
start_n_top
,
dim
=-
1
)
# shape (bsz, start_n_top)
start_top_index_exp
=
start_top_index
.
unsqueeze
(
-
1
)
.
expand
(
-
1
,
-
1
,
hsz
)
# shape (bsz, start_n_top, hsz)
start_states
=
torch
.
gather
(
hidden_states
,
-
2
,
start_top_index_exp
)
# shape (bsz, start_n_top, hsz)
start_states
=
start_states
.
unsqueeze
(
1
)
.
expand
(
-
1
,
slen
,
-
1
,
-
1
)
# shape (bsz, slen, start_n_top, hsz)
start_top_index_exp
=
start_top_index
.
unsqueeze
(
-
1
)
.
expand
(
-
1
,
-
1
,
hsz
)
# shape (bsz, start_n_top, hsz)
start_states
=
torch
.
gather
(
hidden_states
,
-
2
,
start_top_index_exp
)
# shape (bsz, start_n_top, hsz)
start_states
=
start_states
.
unsqueeze
(
1
)
.
expand
(
-
1
,
slen
,
-
1
,
-
1
)
# shape (bsz, slen, start_n_top, hsz)
hidden_states_expanded
=
hidden_states
.
unsqueeze
(
2
)
.
expand_as
(
start_states
)
# shape (bsz, slen, start_n_top, hsz)
p_mask
=
p_mask
.
unsqueeze
(
-
1
)
if
p_mask
is
not
None
else
None
end_logits
=
self
.
end_logits
(
hidden_states_expanded
,
start_states
=
start_states
,
p_mask
=
p_mask
)
end_log_probs
=
F
.
softmax
(
end_logits
,
dim
=
1
)
# shape (bsz, slen, start_n_top)
end_logits
=
self
.
end_logits
(
hidden_states_expanded
,
start_states
=
start_states
,
p_mask
=
p_mask
)
end_log_probs
=
F
.
softmax
(
end_logits
,
dim
=
1
)
# shape (bsz, slen, start_n_top)
end_top_log_probs
,
end_top_index
=
torch
.
topk
(
end_log_probs
,
self
.
end_n_top
,
dim
=
1
)
# shape (bsz, end_n_top, start_n_top)
end_top_log_probs
=
end_top_log_probs
.
view
(
-
1
,
self
.
start_n_top
*
self
.
end_n_top
)
end_top_log_probs
=
end_top_log_probs
.
view
(
-
1
,
self
.
start_n_top
*
self
.
end_n_top
)
end_top_index
=
end_top_index
.
view
(
-
1
,
self
.
start_n_top
*
self
.
end_n_top
)
start_states
=
torch
.
einsum
(
"blh,bl->bh"
,
hidden_states
,
start_log_probs
)
cls_logits
=
self
.
answer_class
(
hidden_states
,
start_states
=
start_states
,
cls_index
=
cls_index
)
cls_logits
=
self
.
answer_class
(
hidden_states
,
start_states
=
start_states
,
cls_index
=
cls_index
)
if
not
return_dict
:
return
(
start_top_log_probs
,
start_top_index
,
end_top_log_probs
,
end_top_index
,
cls_logits
)
return
(
start_top_log_probs
,
start_top_index
,
end_top_log_probs
,
end_top_index
,
cls_logits
,
)
else
:
return
SquadHeadOutput
(
start_top_log_probs
=
start_top_log_probs
,
...
...
@@ -1379,17 +1532,26 @@ class SequenceSummary(nn.Module):
self
.
summary
=
Identity
()
if
hasattr
(
config
,
"summary_use_proj"
)
and
config
.
summary_use_proj
:
if
hasattr
(
config
,
"summary_proj_to_labels"
)
and
config
.
summary_proj_to_labels
and
config
.
num_labels
>
0
:
if
(
hasattr
(
config
,
"summary_proj_to_labels"
)
and
config
.
summary_proj_to_labels
and
config
.
num_labels
>
0
):
num_classes
=
config
.
num_labels
else
:
num_classes
=
config
.
hidden_size
self
.
summary
=
nn
.
Linear
(
config
.
hidden_size
,
num_classes
)
activation_string
=
getattr
(
config
,
"summary_activation"
,
None
)
self
.
activation
:
Callable
=
get_activation
(
activation_string
)
if
activation_string
else
Identity
()
self
.
activation
:
Callable
=
get_activation
(
activation_string
)
if
activation_string
else
Identity
()
self
.
first_dropout
=
Identity
()
if
hasattr
(
config
,
"summary_first_dropout"
)
and
config
.
summary_first_dropout
>
0
:
if
(
hasattr
(
config
,
"summary_first_dropout"
)
and
config
.
summary_first_dropout
>
0
):
self
.
first_dropout
=
nn
.
Dropout
(
config
.
summary_first_dropout
)
self
.
last_dropout
=
Identity
()
...
...
@@ -1397,7 +1559,9 @@ class SequenceSummary(nn.Module):
self
.
last_dropout
=
nn
.
Dropout
(
config
.
summary_last_dropout
)
def
forward
(
self
,
hidden_states
:
torch
.
FloatTensor
,
cls_index
:
Optional
[
torch
.
LongTensor
]
=
None
self
,
hidden_states
:
torch
.
FloatTensor
,
cls_index
:
Optional
[
torch
.
LongTensor
]
=
None
,
)
->
torch
.
FloatTensor
:
"""
Compute a single vector summary of a sequence hidden states.
...
...
@@ -1427,9 +1591,13 @@ class SequenceSummary(nn.Module):
)
else
:
cls_index
=
cls_index
.
unsqueeze
(
-
1
)
.
unsqueeze
(
-
1
)
cls_index
=
cls_index
.
expand
((
-
1
,)
*
(
cls_index
.
dim
()
-
1
)
+
(
hidden_states
.
size
(
-
1
),))
cls_index
=
cls_index
.
expand
(
(
-
1
,)
*
(
cls_index
.
dim
()
-
1
)
+
(
hidden_states
.
size
(
-
1
),)
)
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
output
=
hidden_states
.
gather
(
-
2
,
cls_index
)
.
squeeze
(
-
2
)
# shape (bsz, XX, hidden_size)
output
=
hidden_states
.
gather
(
-
2
,
cls_index
)
.
squeeze
(
-
2
)
# shape (bsz, XX, hidden_size)
elif
self
.
summary_type
==
"attn"
:
raise
NotImplementedError
...
...
@@ -1441,7 +1609,9 @@ class SequenceSummary(nn.Module):
return
output
def
prune_linear_layer
(
layer
:
torch
.
nn
.
Linear
,
index
:
torch
.
LongTensor
,
dim
:
int
=
0
)
->
torch
.
nn
.
Linear
:
def
prune_linear_layer
(
layer
:
torch
.
nn
.
Linear
,
index
:
torch
.
LongTensor
,
dim
:
int
=
0
)
->
torch
.
nn
.
Linear
:
"""
Prune a linear layer to keep only entries in index.
...
...
@@ -1464,7 +1634,9 @@ def prune_linear_layer(layer: torch.nn.Linear, index: torch.LongTensor, dim: int
b
=
layer
.
bias
[
index
]
.
clone
()
.
detach
()
new_size
=
list
(
layer
.
weight
.
size
())
new_size
[
dim
]
=
len
(
index
)
new_layer
=
nn
.
Linear
(
new_size
[
1
],
new_size
[
0
],
bias
=
layer
.
bias
is
not
None
)
.
to
(
layer
.
weight
.
device
)
new_layer
=
nn
.
Linear
(
new_size
[
1
],
new_size
[
0
],
bias
=
layer
.
bias
is
not
None
)
.
to
(
layer
.
weight
.
device
)
new_layer
.
weight
.
requires_grad
=
False
new_layer
.
weight
.
copy_
(
W
.
contiguous
())
new_layer
.
weight
.
requires_grad
=
True
...
...
@@ -1509,7 +1681,9 @@ def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) ->
def
prune_layer
(
layer
:
Union
[
torch
.
nn
.
Linear
,
Conv1D
],
index
:
torch
.
LongTensor
,
dim
:
Optional
[
int
]
=
None
layer
:
Union
[
torch
.
nn
.
Linear
,
Conv1D
],
index
:
torch
.
LongTensor
,
dim
:
Optional
[
int
]
=
None
,
)
->
Union
[
torch
.
nn
.
Linear
,
Conv1D
]:
"""
Prune a Conv1D or linear layer to keep only entries in index.
...
...
@@ -1534,7 +1708,10 @@ def prune_layer(
def
apply_chunking_to_forward
(
forward_fn
:
Callable
[
...
,
torch
.
Tensor
],
chunk_size
:
int
,
chunk_dim
:
int
,
*
input_tensors
forward_fn
:
Callable
[
...
,
torch
.
Tensor
],
chunk_size
:
int
,
chunk_dim
:
int
,
*
input_tensors
,
)
->
torch
.
Tensor
:
"""
This function chunks the :obj:`input_tensors` into smaller input tensor parts of size :obj:`chunk_size` over the
...
...
@@ -1568,7 +1745,9 @@ def apply_chunking_to_forward(
return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
"""
assert
len
(
input_tensors
)
>
0
,
"{} has to be a tuple/list of tensors"
.
format
(
input_tensors
)
assert
len
(
input_tensors
)
>
0
,
"{} has to be a tuple/list of tensors"
.
format
(
input_tensors
)
tensor_shape
=
input_tensors
[
0
]
.
shape
assert
all
(
input_tensor
.
shape
==
tensor_shape
for
input_tensor
in
input_tensors
...
...
@@ -1592,9 +1771,15 @@ def apply_chunking_to_forward(
num_chunks
=
input_tensors
[
0
]
.
shape
[
chunk_dim
]
//
chunk_size
# chunk input tensor into tuples
input_tensors_chunks
=
tuple
(
input_tensor
.
chunk
(
num_chunks
,
dim
=
chunk_dim
)
for
input_tensor
in
input_tensors
)
input_tensors_chunks
=
tuple
(
input_tensor
.
chunk
(
num_chunks
,
dim
=
chunk_dim
)
for
input_tensor
in
input_tensors
)
# apply forward fn to every tuple
output_chunks
=
tuple
(
forward_fn
(
*
input_tensors_chunk
)
for
input_tensors_chunk
in
zip
(
*
input_tensors_chunks
))
output_chunks
=
tuple
(
forward_fn
(
*
input_tensors_chunk
)
for
input_tensors_chunk
in
zip
(
*
input_tensors_chunks
)
)
# concatenate output at same dimension
return
torch
.
cat
(
output_chunks
,
dim
=
chunk_dim
)
...
...
train/utils.py
View file @
9b9ed4f
...
...
@@ -39,9 +39,13 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
return
loss
,
nll_loss
def
encode_line
(
tokenizer
,
line
,
max_length
,
pad_to_max_length
=
True
,
return_tensors
=
"pt"
):
def
encode_line
(
tokenizer
,
line
,
max_length
,
pad_to_max_length
=
True
,
return_tensors
=
"pt"
):
"""Only used by LegacyDataset"""
extra_kw
=
{
"add_prefix_space"
:
True
}
if
isinstance
(
tokenizer
,
BartTokenizer
)
else
{}
extra_kw
=
(
{
"add_prefix_space"
:
True
}
if
isinstance
(
tokenizer
,
BartTokenizer
)
else
{}
)
return
tokenizer
(
[
line
],
max_length
=
max_length
,
...
...
@@ -63,9 +67,7 @@ def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict:
def
trim_batch
(
input_ids
,
pad_token_id
,
attention_mask
=
None
,
input_ids
,
pad_token_id
,
attention_mask
=
None
,
):
"""Remove columns that are populated exclusively by pad_token_id"""
keep_column_mask
=
input_ids
.
ne
(
pad_token_id
)
.
any
(
dim
=
0
)
...
...
@@ -125,7 +127,9 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
def
__getitem__
(
self
,
index
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""Call tokenizer on src and tgt_lines"""
index
=
index
+
1
# linecache starts at 1
source_line
=
self
.
prefix
+
linecache
.
getline
(
str
(
self
.
src_file
),
index
)
.
rstrip
(
"
\n
"
)
source_line
=
self
.
prefix
+
linecache
.
getline
(
str
(
self
.
src_file
),
index
)
.
rstrip
(
"
\n
"
)
tgt_line
=
linecache
.
getline
(
str
(
self
.
tgt_file
),
index
)
.
rstrip
(
"
\n
"
)
assert
source_line
,
f
"empty source line for index {index}"
assert
tgt_line
,
f
"empty tgt line for index {index}"
...
...
@@ -147,7 +151,9 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
target_ids
=
torch
.
stack
([
x
[
"labels"
]
for
x
in
batch
])
pad_token_id
=
self
.
pad_token_id
y
=
trim_batch
(
target_ids
,
pad_token_id
)
source_ids
,
source_mask
=
trim_batch
(
input_ids
,
pad_token_id
,
attention_mask
=
masks
)
source_ids
,
source_mask
=
trim_batch
(
input_ids
,
pad_token_id
,
attention_mask
=
masks
)
batch
=
{
"input_ids"
:
source_ids
,
"attention_mask"
:
source_mask
,
...
...
@@ -161,7 +167,9 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
def
__getitem__
(
self
,
index
)
->
Dict
[
str
,
str
]:
index
=
index
+
1
# linecache starts at 1
source_line
=
self
.
prefix
+
linecache
.
getline
(
str
(
self
.
src_file
),
index
)
.
rstrip
(
"
\n
"
)
source_line
=
self
.
prefix
+
linecache
.
getline
(
str
(
self
.
src_file
),
index
)
.
rstrip
(
"
\n
"
)
tgt_line
=
linecache
.
getline
(
str
(
self
.
tgt_file
),
index
)
.
rstrip
(
"
\n
"
)
assert
source_line
,
f
"empty source line for index {index}"
assert
tgt_line
,
f
"empty tgt line for index {index}"
...
...
@@ -201,12 +209,23 @@ class SortishSampler(Sampler):
idxs
=
np
.
random
.
permutation
(
len
(
self
.
data
))
sz
=
self
.
bs
*
50
ck_idx
=
[
idxs
[
i
:
i
+
sz
]
for
i
in
range
(
0
,
len
(
idxs
),
sz
)]
sort_idx
=
np
.
concatenate
([
sorted
(
s
,
key
=
self
.
key
,
reverse
=
True
)
for
s
in
ck_idx
])
sort_idx
=
np
.
concatenate
(
[
sorted
(
s
,
key
=
self
.
key
,
reverse
=
True
)
for
s
in
ck_idx
]
)
sz
=
self
.
bs
ck_idx
=
[
sort_idx
[
i
:
i
+
sz
]
for
i
in
range
(
0
,
len
(
sort_idx
),
sz
)]
max_ck
=
np
.
argmax
([
self
.
key
(
ck
[
0
])
for
ck
in
ck_idx
])
# find the chunk with the largest key,
ck_idx
[
0
],
ck_idx
[
max_ck
]
=
ck_idx
[
max_ck
],
ck_idx
[
0
]
# then make sure it goes first.
sort_idx
=
np
.
concatenate
(
np
.
random
.
permutation
(
ck_idx
[
1
:]))
if
len
(
ck_idx
)
>
1
else
np
.
array
([],
dtype
=
np
.
int
)
max_ck
=
np
.
argmax
(
[
self
.
key
(
ck
[
0
])
for
ck
in
ck_idx
]
)
# find the chunk with the largest key,
ck_idx
[
0
],
ck_idx
[
max_ck
]
=
(
ck_idx
[
max_ck
],
ck_idx
[
0
],
)
# then make sure it goes first.
sort_idx
=
(
np
.
concatenate
(
np
.
random
.
permutation
(
ck_idx
[
1
:]))
if
len
(
ck_idx
)
>
1
else
np
.
array
([],
dtype
=
np
.
int
)
)
sort_idx
=
np
.
concatenate
((
ck_idx
[
0
],
sort_idx
))
return
iter
(
sort_idx
)
...
...
@@ -269,7 +288,9 @@ def get_git_info():
ROUGE_KEYS
=
[
"rouge1"
,
"rouge2"
,
"rougeL"
]
def
calculate_rouge
(
output_lns
:
List
[
str
],
reference_lns
:
List
[
str
],
use_stemmer
=
True
)
->
Dict
:
def
calculate_rouge
(
output_lns
:
List
[
str
],
reference_lns
:
List
[
str
],
use_stemmer
=
True
)
->
Dict
:
scorer
=
rouge_scorer
.
RougeScorer
(
ROUGE_KEYS
,
use_stemmer
=
use_stemmer
)
aggregator
=
scoring
.
BootstrapAggregator
()
...
...
@@ -302,7 +323,9 @@ def assert_all_frozen(model):
model_grads
:
List
[
bool
]
=
list
(
grad_status
(
model
))
n_require_grad
=
sum
(
lmap
(
int
,
model_grads
))
npars
=
len
(
model_grads
)
assert
not
any
(
model_grads
),
f
"{n_require_grad/npars:.1
%
} of {npars} weights require grad"
assert
not
any
(
model_grads
),
f
"{n_require_grad/npars:.1
%
} of {npars} weights require grad"
def
assert_not_all_frozen
(
model
):
...
...
Please
register
or
login
to post a comment