graykode

(refactor) black style

......@@ -68,6 +68,7 @@ def main(args):
)
print(commit_message)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Code to collect commits on github")
parser.add_argument(
......
......@@ -15,6 +15,6 @@
from .gitcommit import diff_parse, truncate
__all__ = [
'diff_parse',
'truncate',
"diff_parse",
"truncate",
]
......
......@@ -36,9 +36,11 @@ logging.basicConfig(
level=logging.INFO,
)
class PATCH(enum.Enum):
PLUS=1
MINUS=2
PLUS = 1
MINUS = 2
def truncate(tuple, max_length, value=0):
ls = []
......@@ -46,22 +48,20 @@ def truncate(tuple, max_length, value=0):
if isinstance(t, int):
t = [t]
ls.extend(t)
ls = ls[:max_length - 1]
ls = ls[: max_length - 1]
ls.insert(0, value)
if len(ls) < max_length:
ls.extend([0] * (max_length - len(ls)))
assert len(ls) == max_length
return ls
def encode_line(tokenizer, line, patch):
line = re.sub(r'[\u0100-\uFFFF\U00010000-\U0010FFFF]+', '', line).strip()
line = re.sub(r"[\u0100-\uFFFF\U00010000-\U0010FFFF]+", "", line).strip()
tokens = tokenizer.tokenize(line)
tokens = tokenizer.convert_tokens_to_ids(tokens)
return (
tokens,
[1] * len(tokens),
len(tokens) * [patch.value]
)
return (tokens, [1] * len(tokens), len(tokens) * [patch.value])
def diff_parse(diff, tokenizer):
chunks = []
......@@ -78,6 +78,7 @@ def diff_parse(diff, tokenizer):
chunks.append(encode_line(tokenizer, change.line, PATCH.MINUS))
return chunks
def sha_parse(sha, tokenizer, max_length=1024):
chunks = diff_parse(diff=repo.git.show(sha), tokenizer=tokenizer)
......@@ -91,16 +92,18 @@ def sha_parse(sha, tokenizer, max_length=1024):
return (input_ids, attention_masks, patch_ids)
def message_parse(msg, tokenizer, max_length=56):
msg = re.sub(r'(\(|)#([0-9])+(\)|)', '', msg)
msg = re.sub(r"(\(|)#([0-9])+(\)|)", "", msg)
msg = re.sub(r'[\u0100-\uFFFF\U00010000-\U0010FFFF]+', '', msg).strip()
msg = re.sub(r"[\u0100-\uFFFF\U00010000-\U0010FFFF]+", "", msg).strip()
msg = tokenizer.tokenize(msg)
msg = tokenizer.convert_tokens_to_ids(msg)
msg = truncate(msg, max_length, value=0)
return msg
def jobs(sha_msgs, args, data_config, train=True):
input_ids, attention_masks, patch_ids, targets = [], [], [], []
......@@ -110,9 +113,7 @@ def jobs(sha_msgs, args, data_config, train=True):
sha, msg = sha_msg
source = sha_parse(
sha,
tokenizer=args.tokenizer,
max_length=args.max_source_length
sha, tokenizer=args.tokenizer, max_length=args.max_source_length
)
if not source:
continue
......@@ -120,7 +121,9 @@ def jobs(sha_msgs, args, data_config, train=True):
target = message_parse(
msg,
tokenizer=args.tokenizer,
max_length=(args.max_target_length if train else args.val_max_target_length),
max_length=(
args.max_target_length if train else args.val_max_target_length
),
)
input_ids.append(input_id)
......@@ -128,14 +131,17 @@ def jobs(sha_msgs, args, data_config, train=True):
patch_ids.append(patch_id)
targets.append(target)
data_saver({
data_saver(
{
"input_ids": np.asarray(input_ids),
"attention_masks": np.asarray(attention_masks),
"patch_ids": np.asarray(patch_ids),
"targets": np.asarray(targets),
})
}
)
data_saver.disconnect()
def start(chunked_sha_msgs, train=True):
logger.info(f"Start %s pre-processing" % ("training" if train else "evaluation"))
......@@ -144,22 +150,22 @@ def start(chunked_sha_msgs, train=True):
data_config = DataConfig(
endpoint=args.endpoint,
access_key=os.environ['access_key'],
secret_key=os.environ['secret_key'],
access_key=os.environ["access_key"],
secret_key=os.environ["secret_key"],
region=args.region,
dataset_name='commit-autosuggestions',
dataset_name="commit-autosuggestions",
additional={
"mode" : ("training" if train else "evaluation"),
"mode": ("training" if train else "evaluation"),
"max_source_length": args.max_source_length,
"max_target_length": max_target_length,
"url" : args.url,
"url": args.url,
},
attributes=[
('input_ids', 'int32', (args.max_source_length,)),
('attention_masks', 'int32', (args.max_source_length,)),
('patch_ids', 'int32', (args.max_source_length,)),
('targets', 'int32', (max_target_length,))
]
("input_ids", "int32", (args.max_source_length,)),
("attention_masks", "int32", (args.max_source_length,)),
("patch_ids", "int32", (args.max_source_length,)),
("targets", "int32", (max_target_length,)),
],
)
func = partial(jobs, args=args, data_config=data_config, train=train)
......@@ -168,14 +174,15 @@ def start(chunked_sha_msgs, train=True):
for i, _ in tqdm(enumerate(pool.imap_unordered(func, chunked_sha_msgs))):
pbar.update()
def main(args):
if 'access_key' not in os.environ or 'secret_key' not in os.environ:
if "access_key" not in os.environ or "secret_key" not in os.environ:
raise OSError("access_key or secret_key are not found.")
sha_msgs = [(c.hexsha, c.summary) for c in repo.iter_commits()]
random.shuffle(sha_msgs)
chunked_sha_msgs = [
sha_msgs[x:x + args.matorage_batch]
sha_msgs[x : x + args.matorage_batch]
for x in range(0, len(sha_msgs), args.matorage_batch)
]
......@@ -185,29 +192,25 @@ def main(args):
if args.do_predict:
start(chunked_sha_msgs[barrier:], train=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Code to collect commits on github")
parser.add_argument(
"--url",
type=str,
required=True,
help="github url"
)
parser.add_argument("--url", type=str, required=True, help="github url")
parser.add_argument(
"--endpoint",
type=str,
required=True,
help='matorage endpoint, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html'
help="matorage endpoint, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html",
)
parser.add_argument(
"--region",
type=str,
default=None,
help='matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html'
help="matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html",
)
parser.add_argument(
"--tokenizer_name",
default='sshleifer/distilbart-xsum-6-6',
default="sshleifer/distilbart-xsum-6-6",
type=str,
help="Pretrained tokenizer name or path if not the same as model_name",
)
......@@ -215,13 +218,10 @@ if __name__ == "__main__":
"--matorage_batch",
default=1024,
type=int,
help='The smallest batch size stored atomically in matorage.'
help="The smallest batch size stored atomically in matorage.",
)
parser.add_argument(
"--num_workers",
default=4,
type=int,
help="number of process",
"--num_workers", default=4, type=int, help="number of process",
)
parser.add_argument(
"--max_source_length",
......@@ -244,12 +244,14 @@ if __name__ == "__main__":
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument("--p_val", type=float, default=0.25, help="percent of validation dataset")
parser.add_argument(
"--p_val", type=float, default=0.25, help="percent of validation dataset"
)
parser.add_argument("--do_train", action="store_true", default=False)
parser.add_argument("--do_predict", action="store_true", default=False)
args = parser.parse_args()
args.local_path = args.url.split('/')[-1]
args.local_path = args.url.split("/")[-1]
logger.info(f"master branch of {args.url} will be downloaded to {args.local_path}")
repo = (
Repo(args.local_path)
......
......@@ -14,6 +14,4 @@
from .modeling_bart import BartForConditionalGeneration
__all__ = [
'BartForConditionalGeneration'
]
\ No newline at end of file
__all__ = ["BartForConditionalGeneration"]
......
......@@ -20,16 +20,31 @@ logger = logging.getLogger(__name__)
class Seq2SeqLoggingCallback(pl.Callback):
def on_batch_end(self, trainer, pl_module):
lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)}
lrs = {
f"lr_group_{i}": param["lr"]
for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)
}
pl_module.logger.log_metrics(lrs)
@rank_zero_only
def _write_logs(
self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
type_path: str,
save_generations=True,
) -> None:
logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****")
logger.info(
f"***** {type_path} results at step {trainer.global_step:05d} *****"
)
metrics = trainer.callback_metrics
trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]})
trainer.logger.log_metrics(
{
k: v
for k, v in metrics.items()
if k not in ["log", "progress_bar", "preds"]
}
)
# Log results
od = Path(pl_module.hparams.output_dir)
if type_path == "test":
......@@ -39,7 +54,9 @@ class Seq2SeqLoggingCallback(pl.Callback):
# this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json
# If people want this it will be easy enough to add back.
results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt"
generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt"
generations_file = (
od / f"{type_path}_generations/{trainer.global_step:05d}.txt"
)
results_file.parent.mkdir(exist_ok=True)
generations_file.parent.mkdir(exist_ok=True)
with open(results_file, "a+") as writer:
......@@ -68,7 +85,9 @@ class Seq2SeqLoggingCallback(pl.Callback):
n_trainable_pars = count_trainable_parameters(pl_module)
# mp stands for million parameters
trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6})
trainer.logger.log_metrics(
{"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6}
)
@rank_zero_only
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
......@@ -98,8 +117,5 @@ def get_checkpoint_callback(output_dir, metric):
def get_early_stopping_callback(metric, patience):
return EarlyStopping(
monitor=f"val_{metric}",
mode="max",
patience=patience,
verbose=True,
monitor=f"val_{metric}", mode="max", patience=patience, verbose=True,
)
......
......@@ -21,7 +21,11 @@ from matorage.torch import Dataset
try:
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
from .callbacks import (
Seq2SeqLoggingCallback,
get_checkpoint_callback,
get_early_stopping_callback,
)
from .utils import (
ROUGE_KEYS,
LegacySeq2SeqDataset,
......@@ -40,7 +44,11 @@ try:
use_task_specific_params,
)
except ImportError:
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
from callbacks import (
Seq2SeqLoggingCallback,
get_checkpoint_callback,
get_early_stopping_callback,
)
from utils import (
ROUGE_KEYS,
LegacySeq2SeqDataset,
......@@ -83,8 +91,12 @@ class SummarizationModule(BaseTransformer):
"val": self.hparams.val_max_target_length,
"test": self.hparams.test_max_target_length,
}
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
assert (
self.target_lens["train"] <= self.target_lens["val"]
), f"target_lens: {self.target_lens}"
assert (
self.target_lens["train"] <= self.target_lens["test"]
), f"target_lens: {self.target_lens}"
if self.hparams.freeze_embeds:
self.freeze_embeds()
......@@ -95,13 +107,27 @@ class SummarizationModule(BaseTransformer):
self.hparams.git_sha = get_git_info()["repo_sha"]
self.num_workers = hparams.num_workers
self.decoder_start_token_id = None # default to config
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
if self.model.config.decoder_start_token_id is None and isinstance(
self.tokenizer, MBartTokenizer
):
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[
hparams.tgt_lang
]
self.model.config.decoder_start_token_id = self.decoder_start_token_id
self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1"
self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
self.eval_beams = (
self.model.config.num_beams
if self.hparams.eval_beams is None
else self.hparams.eval_beams
)
assert (
self.eval_beams >= 1
), f"got self.eval_beams={self.eval_beams}. Need an integer > 1"
self.val_metric = (
self.default_val_metric
if self.hparams.val_metric is None
else self.hparams.val_metric
)
def freeze_embeds(self):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
......@@ -133,7 +159,13 @@ class SummarizationModule(BaseTransformer):
else:
decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
outputs = self(src_ids, src_patch, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
outputs = self(
src_ids,
src_patch,
attention_mask=src_mask,
decoder_input_ids=decoder_input_ids,
use_cache=False,
)
lm_logits = outputs[0]
if self.hparams.label_smoothing == 0:
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
......@@ -157,7 +189,9 @@ class SummarizationModule(BaseTransformer):
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
# tokens per batch
logs["tpb"] = batch[0].long().ne(self.pad).sum() + batch[3].long().ne(self.pad).sum()
logs["tpb"] = (
batch[0].long().ne(self.pad).sum() + batch[3].long().ne(self.pad).sum()
)
return {"loss": loss_tensors[0], "log": logs}
def validation_step(self, batch, batch_idx) -> Dict:
......@@ -165,17 +199,29 @@ class SummarizationModule(BaseTransformer):
def validation_epoch_end(self, outputs, prefix="val") -> Dict:
self.step_count += 1
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
losses = {
k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names
}
loss = losses["loss"]
rouges = {k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]}
rouge_tensor: torch.FloatTensor = torch.tensor(rouges[self.val_metric]).type_as(loss)
rouges = {
k: np.array([x[k] for x in outputs]).mean()
for k in self.metric_names + ["gen_time", "gen_len"]
}
rouge_tensor: torch.FloatTensor = torch.tensor(rouges[self.val_metric]).type_as(
loss
)
rouges.update({k: v.item() for k, v in losses.items()})
losses.update(rouges)
metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
metrics["step_count"] = self.step_count
self.save_metrics(metrics, prefix) # writes to self.metrics_save_path
preds = flatten_list([x["preds"] for x in outputs])
return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_{self.val_metric}": rouge_tensor}
return {
"log": metrics,
"preds": preds,
f"{prefix}_loss": loss,
f"{prefix}_{self.val_metric}": rouge_tensor,
}
def save_metrics(self, latest_metrics, type_path) -> None:
self.metrics[type_path].append(latest_metrics)
......@@ -200,7 +246,9 @@ class SummarizationModule(BaseTransformer):
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
rouge: Dict = self.calc_generative_metrics(preds, target)
summ_len = np.mean(lmap(len, generated_ids))
base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge)
base_metrics.update(
gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge
)
return base_metrics
def test_step(self, batch, batch_idx):
......@@ -213,10 +261,10 @@ class SummarizationModule(BaseTransformer):
max_target_length = self.target_lens[type_path]
data_config = DataConfig(
endpoint=args.endpoint,
access_key=os.environ['access_key'],
secret_key=os.environ['secret_key'],
access_key=os.environ["access_key"],
secret_key=os.environ["secret_key"],
region=args.region,
dataset_name='commit-autosuggestions',
dataset_name="commit-autosuggestions",
additional={
"mode": ("training" if type_path == "train" else "evaluation"),
"max_source_length": self.hparams.max_source_length,
......@@ -224,15 +272,17 @@ class SummarizationModule(BaseTransformer):
"url": args.url,
},
attributes=[
('input_ids', 'int32', (self.hparams.max_source_length,)),
('attention_masks', 'int32', (self.hparams.max_source_length,)),
('patch_ids', 'int32', (self.hparams.max_source_length,)),
('targets', 'int32', (max_target_length,))
]
("input_ids", "int32", (self.hparams.max_source_length,)),
("attention_masks", "int32", (self.hparams.max_source_length,)),
("patch_ids", "int32", (self.hparams.max_source_length,)),
("targets", "int32", (max_target_length,)),
],
)
return Dataset(config=data_config, clear=True)
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
def get_dataloader(
self, type_path: str, batch_size: int, shuffle: bool = False
) -> DataLoader:
dataset = self.get_dataset(type_path)
sampler = None
......@@ -246,7 +296,9 @@ class SummarizationModule(BaseTransformer):
return dataloader
def train_dataloader(self) -> DataLoader:
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
dataloader = self.get_dataloader(
"train", batch_size=self.hparams.train_batch_size, shuffle=True
)
return dataloader
def val_dataloader(self) -> DataLoader:
......@@ -259,23 +311,18 @@ class SummarizationModule(BaseTransformer):
def add_model_specific_args(parser, root_dir):
BaseTransformer.add_model_specific_args(parser, root_dir)
add_generic_args(parser, root_dir)
parser.add_argument(
"--url",
type=str,
required=True,
help="github url"
)
parser.add_argument("--url", type=str, required=True, help="github url")
parser.add_argument(
"--endpoint",
type=str,
required=True,
help='matorage endpoint, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html'
help="matorage endpoint, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html",
)
parser.add_argument(
"--region",
type=str,
default=None,
help='matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html'
help="matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html",
)
parser.add_argument(
"--max_source_length",
......@@ -308,14 +355,43 @@ class SummarizationModule(BaseTransformer):
parser.add_argument("--freeze_encoder", action="store_true")
parser.add_argument("--freeze_embeds", action="store_true")
parser.add_argument("--sortish_sampler", action="store_true", default=False)
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
parser.add_argument(
"--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
"--logger_name",
type=str,
choices=["default", "wandb", "wandb_shared"],
default="default",
)
parser.add_argument(
"--n_train",
type=int,
default=-1,
required=False,
help="# examples. -1 means use all.",
)
parser.add_argument(
"--n_val",
type=int,
default=500,
required=False,
help="# examples. -1 means use all.",
)
parser.add_argument(
"--n_test",
type=int,
default=-1,
required=False,
help="# examples. -1 means use all.",
)
parser.add_argument(
"--task",
type=str,
default="summarization",
required=False,
help="# examples. -1 means use all.",
)
parser.add_argument(
"--label_smoothing", type=float, default=0.0, required=False
)
parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
parser.add_argument("--src_lang", type=str, default="", required=False)
parser.add_argument("--tgt_lang", type=str, default="", required=False)
parser.add_argument("--eval_beams", type=int, default=None, required=False)
......@@ -348,7 +424,11 @@ class TranslationModule(SummarizationModule):
def main(args, model=None) -> SummarizationModule:
Path(args.output_dir).mkdir(exist_ok=True)
if len(os.listdir(args.output_dir)) > 3 and args.do_train:
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
raise ValueError(
"Output directory ({}) already exists and is not empty.".format(
args.output_dir
)
)
if model is None:
if args.task == "summarization":
model: SummarizationModule = SummarizationModule(args)
......@@ -371,7 +451,9 @@ def main(args, model=None) -> SummarizationModule:
return model
model.hparams.test_checkpoint = ""
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
checkpoints = list(
sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True))
)
if checkpoints:
model.hparams.test_checkpoint = checkpoints[-1]
trainer.resume_from_checkpoint = checkpoints[-1]
......
......@@ -30,6 +30,7 @@ logging.basicConfig(
level=logging.INFO,
)
class GenerationMixin:
"""
A class contraining all of the functions supporting generation, to be used as a mixin in
......@@ -50,7 +51,9 @@ class GenerationMixin:
"""
return logits
def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
def enforce_repetition_penalty_(
self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty
):
"""
Enforce the repetition penalty (from the `CTRL paper <https://arxiv.org/abs/1909.05858>`__).
"""
......@@ -79,11 +82,7 @@ class GenerationMixin:
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
self.enforce_repetition_penalty_(
scores,
batch_size,
num_beams,
input_ids,
repetition_penalty,
scores, batch_size, num_beams, input_ids, repetition_penalty,
)
# set eos token prob to zero if min_length is not reached
......@@ -102,7 +101,11 @@ class GenerationMixin:
if bad_words_ids is not None:
# Exclude EOS token (already processed)
bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids))
bad_words_ids = list(
filter(
lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids
)
)
# calculate a list of banned tokens according to bad words
banned_tokens = calc_banned_bad_words_ids(input_ids.tolist(), bad_words_ids)
# Modify the scores in place by setting the banned tokens logits to `-inf`
......@@ -134,7 +137,7 @@ class GenerationMixin:
attention_mask: Optional[torch.LongTensor] = None,
decoder_start_token_id: Optional[int] = None,
use_cache: Optional[bool] = None,
**model_kwargs
**model_kwargs,
) -> torch.LongTensor:
r"""
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
......@@ -262,26 +265,50 @@ class GenerationMixin:
max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length
do_sample = do_sample if do_sample is not None else self.config.do_sample
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
early_stopping = (
early_stopping if early_stopping is not None else self.config.early_stopping
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
num_beams = num_beams if num_beams is not None else self.config.num_beams
temperature = temperature if temperature is not None else self.config.temperature
temperature = (
temperature if temperature is not None else self.config.temperature
)
top_k = top_k if top_k is not None else self.config.top_k
top_p = top_p if top_p is not None else self.config.top_p
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
repetition_penalty = (
repetition_penalty
if repetition_penalty is not None
else self.config.repetition_penalty
)
bos_token_id = (
bos_token_id if bos_token_id is not None else self.config.bos_token_id
)
pad_token_id = (
pad_token_id if pad_token_id is not None else self.config.pad_token_id
)
eos_token_id = (
eos_token_id if eos_token_id is not None else self.config.eos_token_id
)
length_penalty = (
length_penalty if length_penalty is not None else self.config.length_penalty
)
no_repeat_ngram_size = (
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
no_repeat_ngram_size
if no_repeat_ngram_size is not None
else self.config.no_repeat_ngram_size
)
bad_words_ids = (
bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
)
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
num_return_sequences
if num_return_sequences is not None
else self.config.num_return_sequences
)
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
decoder_start_token_id
if decoder_start_token_id is not None
else self.config.decoder_start_token_id
)
if input_ids is not None:
......@@ -289,14 +316,22 @@ class GenerationMixin:
else:
batch_size = 1
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
assert (
isinstance(max_length, int) and max_length > 0
), "`max_length` should be a strictly positive integer."
assert (
isinstance(min_length, int) and min_length >= 0
), "`min_length` should be a positive integer."
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
assert isinstance(use_cache, bool), "`use_cache` should be a boolean."
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
assert (
isinstance(num_beams, int) and num_beams > 0
), "`num_beams` should be a strictly positive integer."
assert temperature > 0, "`temperature` should be strictly positive."
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
assert (
isinstance(top_k, int) and top_k >= 0
), "`top_k` should be a positive integer."
assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
assert input_ids is not None or (
......@@ -316,7 +351,9 @@ class GenerationMixin:
isinstance(num_return_sequences, int) and num_return_sequences > 0
), "`num_return_sequences` should be a strictly positive integer."
assert (
bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
bad_words_ids is None
or isinstance(bad_words_ids, list)
and isinstance(bad_words_ids[0], list)
), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
if input_ids is None:
......@@ -331,7 +368,9 @@ class GenerationMixin:
device=next(self.parameters()).device,
)
else:
assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
assert (
input_ids.dim() == 2
), "Input prompt should be of shape (batch_size, sequence length)."
# not allow to duplicate outputs when greedy decoding
if do_sample is False:
......@@ -349,7 +388,11 @@ class GenerationMixin:
# create attention mask if necessary
# TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
if (
(attention_mask is None)
and (pad_token_id is not None)
and (pad_token_id in input_ids)
):
attention_mask = input_ids.ne(pad_token_id).long()
elif attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
......@@ -358,7 +401,9 @@ class GenerationMixin:
# attention_mask is created
if pad_token_id is None and eos_token_id is not None:
logger.warning(
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(
eos_token_id
)
)
pad_token_id = eos_token_id
......@@ -385,25 +430,37 @@ class GenerationMixin:
# see if BOS token can be used for decoder_start_token_id
if bos_token_id is not None:
decoder_start_token_id = bos_token_id
elif hasattr(self.config, "decoder") and hasattr(self.config.decoder, "bos_token_id"):
elif hasattr(self.config, "decoder") and hasattr(
self.config.decoder, "bos_token_id"
):
decoder_start_token_id = self.config.decoder.bos_token_id
else:
raise ValueError(
"decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
)
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
assert hasattr(
self, "get_encoder"
), "{} should have a 'get_encoder' function defined".format(self)
assert callable(self.get_encoder), "{} should be a method".format(
self.get_encoder
)
# get encoder and store encoder outputs
encoder = self.get_encoder()
encoder_outputs: ModelOutput = encoder(input_ids, patch_ids, attention_mask=attention_mask, return_dict=True)
encoder_outputs: ModelOutput = encoder(
input_ids, patch_ids, attention_mask=attention_mask, return_dict=True
)
# Expand input ids if num_beams > 1 or num_return_sequences > 1
if num_return_sequences > 1 or num_beams > 1:
input_ids_len = input_ids.shape[-1]
input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
patch_ids = patch_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
input_ids = input_ids.unsqueeze(1).expand(
batch_size, effective_batch_mult * num_beams, input_ids_len
)
patch_ids = patch_ids.unsqueeze(1).expand(
batch_size, effective_batch_mult * num_beams, input_ids_len
)
attention_mask = attention_mask.unsqueeze(1).expand(
batch_size, effective_batch_mult * num_beams, input_ids_len
)
......@@ -442,9 +499,9 @@ class GenerationMixin:
)
# expand encoder_outputs
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
0, expanded_batch_idxs
)
encoder_outputs[
"last_hidden_state"
] = encoder_outputs.last_hidden_state.index_select(0, expanded_batch_idxs)
# save encoder_outputs in `model_kwargs`
model_kwargs["encoder_outputs"] = encoder_outputs
......@@ -534,7 +591,11 @@ class GenerationMixin:
past = None
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
input_ids,
past=past,
attention_mask=attention_mask,
use_cache=use_cache,
**model_kwargs,
)
outputs = self(**model_inputs, return_dict=True)
......@@ -565,7 +626,9 @@ class GenerationMixin:
if temperature != 1.0:
scores = scores / temperature
# Top-p/top-k filtering
next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p)
next_token_logscores = top_k_top_p_filtering(
scores, top_k=top_k, top_p=top_p
)
# Sample
probs = F.softmax(next_token_logscores, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
......@@ -576,7 +639,9 @@ class GenerationMixin:
# update generations and finished sentences
if eos_token_id is not None:
# pad finished sentences if eos_token_id exist
tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (
1 - unfinished_sents
)
else:
tokens_to_add = next_token
......@@ -587,8 +652,12 @@ class GenerationMixin:
if eos_token_id is not None:
eos_in_sents = tokens_to_add == eos_token_id
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len)
is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(
eos_in_sents.long()
).bool()
sent_lengths.masked_fill_(
is_sents_unfinished_and_token_to_add_is_eos, cur_len
)
# unfinished_sents is set to zero if eos in sentence
unfinished_sents.mul_((~eos_in_sents).long())
......@@ -599,7 +668,11 @@ class GenerationMixin:
# extend attention_mask for new generated input if only decoder
if self.config.is_encoder_decoder is False:
attention_mask = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
[
attention_mask,
attention_mask.new_ones((attention_mask.shape[0], 1)),
],
dim=-1,
)
return input_ids
......@@ -633,12 +706,16 @@ class GenerationMixin:
# generated hypotheses
generated_hyps = [
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
BeamHypotheses(
num_beams, max_length, length_penalty, early_stopping=early_stopping
)
for _ in range(batch_size)
]
# scores for each sentence in the beam
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores = torch.zeros(
(batch_size, num_beams), dtype=torch.float, device=input_ids.device
)
# for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
if do_sample is False:
......@@ -653,10 +730,18 @@ class GenerationMixin:
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
input_ids,
past=past,
attention_mask=attention_mask,
use_cache=use_cache,
**model_kwargs,
)
outputs = self(**model_inputs, return_dict=True) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs.logits[:, -1, :] # (batch_size * num_beams, vocab_size)
outputs = self(
**model_inputs, return_dict=True
) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs.logits[
:, -1, :
] # (batch_size * num_beams, vocab_size)
# if model has past, then set the past variable to speed up decoding
if "past_key_values" in outputs:
......@@ -670,7 +755,9 @@ class GenerationMixin:
next_token_logits, cur_len=cur_len, max_length=max_length
)
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
scores = F.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
scores = self.postprocess_next_token_scores(
scores=scores,
......@@ -686,12 +773,17 @@ class GenerationMixin:
num_beams=num_beams,
)
assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
assert scores.shape == (
batch_size * num_beams,
vocab_size,
), "Shapes of scores: {} != {}".format(
scores.shape, (batch_size * num_beams, vocab_size)
)
if do_sample:
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
_scores = scores + beam_scores[:, None].expand_as(
scores
) # (batch_size * num_beams, vocab_size)
# Temperature
if temperature != 1.0:
_scores = _scores / temperature
......@@ -706,24 +798,38 @@ class GenerationMixin:
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
probs = F.softmax(_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) # (batch_size, num_beams * 2)
next_tokens = torch.multinomial(
probs, num_samples=2 * num_beams
) # (batch_size, num_beams * 2)
# Compute next scores
next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2)
next_scores = torch.gather(
_scores, -1, next_tokens
) # (batch_size, num_beams * 2)
# sort the sampled vector to make sure that the first num_beams samples are the best
next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1)
next_tokens = torch.gather(next_tokens, -1, next_scores_indices) # (batch_size, num_beams * 2)
next_scores, next_scores_indices = torch.sort(
next_scores, descending=True, dim=1
)
next_tokens = torch.gather(
next_tokens, -1, next_scores_indices
) # (batch_size, num_beams * 2)
else:
next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
next_scores = scores + beam_scores[:, None].expand_as(
scores
) # (batch_size * num_beams, vocab_size)
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
next_scores = next_scores.view(
batch_size, num_beams * vocab_size
) # (batch_size, num_beams * vocab_size)
next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
next_scores, next_tokens = torch.topk(
next_scores, 2 * num_beams, dim=1, largest=True, sorted=True
)
assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)
assert (
next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)
)
# next batch beam content
next_batch_beam = []
......@@ -735,11 +841,15 @@ class GenerationMixin:
if done[batch_idx]:
assert (
len(generated_hyps[batch_idx]) >= num_beams
), "Batch can only be done if at least {} beams have been generated".format(num_beams)
), "Batch can only be done if at least {} beams have been generated".format(
num_beams
)
assert (
eos_token_id is not None and pad_token_id is not None
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
next_batch_beam.extend(
[(0, pad_token_id, 0)] * num_beams
) # pad the batch
continue
# next sentence beam content, this will get added to next_batch_beam
......@@ -757,7 +867,9 @@ class GenerationMixin:
# add to generated hypotheses if end of sentence
if (eos_token_id is not None) and (token_id.item() == eos_token_id):
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
is_beam_token_worse_than_top_num_beams = (
beam_token_rank >= num_beams
)
if is_beam_token_worse_than_top_num_beams:
continue
generated_hyps[batch_idx].add(
......@@ -766,7 +878,9 @@ class GenerationMixin:
)
else:
# add next predicted token since it is not eos_token
next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
next_sent_beam.append(
(beam_token_score, token_id, effective_beam_id)
)
# once the beam for next step is full, don't add more tokens to it.
if len(next_sent_beam) == num_beams:
......@@ -780,7 +894,9 @@ class GenerationMixin:
# update next beam content
assert len(next_sent_beam) == num_beams, "Beam should always be full"
next_batch_beam.extend(next_sent_beam)
assert len(next_batch_beam) == num_beams * (batch_idx + 1), "We should have added num_beams each step"
assert len(next_batch_beam) == num_beams * (
batch_idx + 1
), "We should have added num_beams each step"
# stop when we are done with each sentence
if all(done):
......@@ -804,7 +920,11 @@ class GenerationMixin:
# extend attention_mask for new generated input if only decoder
if self.config.is_encoder_decoder is False:
attention_mask = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
[
attention_mask,
attention_mask.new_ones((attention_mask.shape[0], 1)),
],
dim=-1,
)
# finalize all open beam hypotheses and add to generated hypotheses
......@@ -814,10 +934,12 @@ class GenerationMixin:
# test that beam scores match previously calculated scores if not eos and batch_idx not done
if eos_token_id is not None and all(
(token_id % vocab_size).item() != eos_token_id for token_id in next_tokens[batch_idx]
(token_id % vocab_size).item() != eos_token_id
for token_id in next_tokens[batch_idx]
):
assert torch.all(
next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
next_scores[batch_idx, :num_beams]
== beam_scores.view(batch_size, num_beams)[batch_idx]
), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
next_scores[:, :num_beams][batch_idx],
beam_scores.view(batch_size, num_beams)[batch_idx],
......@@ -831,7 +953,9 @@ class GenerationMixin:
generated_hyps[batch_idx].add(final_tokens, final_score)
# depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
output_batch_size = (
batch_size if do_sample else batch_size * num_return_sequences
)
output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
# select the best hypotheses
......@@ -861,7 +985,9 @@ class GenerationMixin:
else:
# none of the hypotheses have an eos_token
assert (len(hypo) == max_length for hypo in best)
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
decoded = (
torch.stack(best).type(torch.long).to(next(self.parameters()).device)
)
return decoded
......@@ -870,7 +996,9 @@ class GenerationMixin:
return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int) -> None:
def calc_banned_ngram_tokens(
prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int
) -> None:
"""Copied from fairseq for no_repeat_ngram in beam_search"""
if cur_len + 1 < no_repeat_ngram_size:
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
......@@ -881,7 +1009,9 @@ def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_n
generated_ngram = generated_ngrams[idx]
for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
prev_ngram_tuple = tuple(ngram[:-1])
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
generated_ngram[prev_ngram_tuple] = generated_ngram.get(
prev_ngram_tuple, []
) + [ngram[-1]]
def _get_generated_ngrams(hypo_idx):
# Before decoding the next token, prevent decoding of ngrams that have already appeared
......@@ -893,7 +1023,9 @@ def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_n
return banned_tokens
def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]) -> Iterable[int]:
def calc_banned_bad_words_ids(
prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]
) -> Iterable[int]:
banned_tokens = []
def _tokens_match(prev_tokens, tokens):
......@@ -914,7 +1046,9 @@ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iter
banned_tokens_slice = []
for banned_token_seq in bad_words_ids:
assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format(
assert (
len(banned_token_seq) > 0
), "Banned words token sequences {} cannot have an empty list".format(
bad_words_ids
)
......@@ -929,7 +1063,9 @@ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iter
return banned_tokens
def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: List[List[int]]) -> None:
def set_scores_to_inf_for_banned_tokens(
scores: torch.Tensor, banned_tokens: List[List[int]]
) -> None:
"""Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be
a list of list of banned tokens to ban in the format [[batch index, vocabulary position],...]
Args:
......@@ -949,7 +1085,12 @@ def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: Lis
# [ 0 0 0 ]
# [ 1 0 0 ]
banned_mask = torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool()
banned_mask = (
torch.sparse.LongTensor(banned_mask.t(), indices, scores.size())
.to(scores.device)
.to_dense()
.bool()
)
scores.masked_fill_(banned_mask, -float("inf"))
......@@ -989,7 +1130,9 @@ def top_k_top_p_filtering(
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
......@@ -1020,7 +1163,9 @@ class BeamHypotheses(object):
if len(self) < self.num_beams or score > self.worst_score:
self.beams.append((score, hyp))
if len(self) > self.num_beams:
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
sorted_scores = sorted(
[(s, idx) for idx, (s, _) in enumerate(self.beams)]
)
del self.beams[sorted_scores[0][1]]
self.worst_score = sorted_scores[1][0]
else:
......
......@@ -69,7 +69,7 @@ class BaseTransformer(pl.LightningModule):
config=None,
tokenizer=None,
model=None,
**config_kwargs
**config_kwargs,
):
"""Initialize a model, tokenizer and config."""
super().__init__()
......@@ -83,7 +83,9 @@ class BaseTransformer(pl.LightningModule):
cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
if config is None:
self.config = AutoConfig.from_pretrained(
self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path,
self.hparams.config_name
if self.hparams.config_name
else self.hparams.model_name_or_path,
**({"num_labels": num_labels} if num_labels is not None else {}),
cache_dir=cache_dir,
**config_kwargs,
......@@ -91,15 +93,24 @@ class BaseTransformer(pl.LightningModule):
else:
self.config: PretrainedConfig = config
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
extra_model_params = (
"encoder_layerdrop",
"decoder_layerdrop",
"dropout",
"attention_dropout",
)
for p in extra_model_params:
if getattr(self.hparams, p, None):
assert hasattr(self.config, p), f"model config doesn't have a `{p}` attribute"
assert hasattr(
self.config, p
), f"model config doesn't have a `{p}` attribute"
setattr(self.config, p, getattr(self.hparams, p))
if tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained(
self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
self.hparams.tokenizer_name
if self.hparams.tokenizer_name
else self.hparams.model_name_or_path,
cache_dir=cache_dir,
)
else:
......@@ -121,7 +132,9 @@ class BaseTransformer(pl.LightningModule):
def get_lr_scheduler(self):
get_schedule_func = arg_to_scheduler[self.hparams.lr_scheduler]
scheduler = get_schedule_func(
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps
self.opt,
num_warmup_steps=self.hparams.warmup_steps,
num_training_steps=self.total_steps,
)
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
return scheduler
......@@ -132,22 +145,35 @@ class BaseTransformer(pl.LightningModule):
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"params": [
p
for n, p in model.named_parameters()
if not any(nd in n for nd in no_decay)
],
"weight_decay": self.hparams.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"params": [
p
for n, p in model.named_parameters()
if any(nd in n for nd in no_decay)
],
"weight_decay": 0.0,
},
]
if self.hparams.adafactor:
optimizer = Adafactor(
optimizer_grouped_parameters, lr=self.hparams.learning_rate, scale_parameter=False, relative_step=False
optimizer_grouped_parameters,
lr=self.hparams.learning_rate,
scale_parameter=False,
relative_step=False,
)
else:
optimizer = AdamW(
optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon
optimizer_grouped_parameters,
lr=self.hparams.learning_rate,
eps=self.hparams.adam_epsilon,
)
self.opt = optimizer
......@@ -165,13 +191,19 @@ class BaseTransformer(pl.LightningModule):
def total_steps(self) -> int:
"""The number of total training steps that will be run. Used for lr scheduler purposes."""
num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores
effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices
effective_batch_size = (
self.hparams.train_batch_size
* self.hparams.accumulate_grad_batches
* num_devices
)
dataset_size = len(self.train_loader.dataset)
return (dataset_size / effective_batch_size) * self.hparams.max_epochs
def setup(self, mode):
if mode == "fit":
self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True)
self.train_loader = self.get_dataloader(
"train", self.hparams.train_batch_size, shuffle=True
)
def get_dataloader(self, type_path, batch_size, shuffle=False):
raise NotImplementedError("You must implement this for your task")
......@@ -212,7 +244,10 @@ class BaseTransformer(pl.LightningModule):
help="Path to pretrained model or model identifier from huggingface.co/models",
)
parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
"--config_name",
default="",
type=str,
help="Pretrained config name or path if not the same as model_name",
)
parser.add_argument(
"--tokenizer_name",
......@@ -246,7 +281,12 @@ class BaseTransformer(pl.LightningModule):
type=float,
help="Attention dropout probability (Optional). Goes into model.config",
)
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument(
"--learning_rate",
default=5e-5,
type=float,
help="The initial learning rate for Adam.",
)
parser.add_argument(
"--lr_scheduler",
default="linear",
......@@ -255,11 +295,30 @@ class BaseTransformer(pl.LightningModule):
type=str,
help="Learning rate scheduler",
)
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader")
parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int)
parser.add_argument(
"--weight_decay",
default=0.0,
type=float,
help="Weight decay if we apply some.",
)
parser.add_argument(
"--adam_epsilon",
default=1e-8,
type=float,
help="Epsilon for Adam optimizer.",
)
parser.add_argument(
"--warmup_steps",
default=0,
type=int,
help="Linear warmup over warmup_steps.",
)
parser.add_argument(
"--num_workers", default=4, type=int, help="kwarg passed to DataLoader"
)
parser.add_argument(
"--num_train_epochs", dest="max_epochs", default=3, type=int
)
parser.add_argument("--train_batch_size", default=32, type=int)
parser.add_argument("--eval_batch_size", default=32, type=int)
parser.add_argument("--adafactor", action="store_true")
......@@ -283,7 +342,9 @@ class LoggingCallback(pl.Callback):
rank_zero_info("***** Test results *****")
metrics = trainer.callback_metrics
# Log and save results to file
output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt")
output_test_results_file = os.path.join(
pl_module.hparams.output_dir, "test_results.txt"
)
with open(output_test_results_file, "w") as writer:
for key in sorted(metrics):
if key not in ["log", "progress_bar"]:
......@@ -314,9 +375,21 @@ def add_generic_args(parser, root_dir) -> None:
"See details at https://nvidia.github.io/apex/amp.html",
)
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int)
parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.")
parser.add_argument(
"--max_grad_norm",
dest="gradient_clip_val",
default=1.0,
type=float,
help="Max gradient norm",
)
parser.add_argument(
"--do_train", action="store_true", help="Whether to run training."
)
parser.add_argument(
"--do_predict",
action="store_true",
help="Whether to run predictions on the test set.",
)
parser.add_argument(
"--gradient_accumulation_steps",
dest="accumulate_grad_batches",
......@@ -324,7 +397,9 @@ def add_generic_args(parser, root_dir) -> None:
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument(
"--seed", type=int, default=42, help="random seed for initialization"
)
def generic_train(
......@@ -335,7 +410,7 @@ def generic_train(
extra_callbacks=[],
checkpoint_callback=None,
logging_callback=None,
**extra_train_kwargs
**extra_train_kwargs,
):
pl.seed_everything(args.seed)
......@@ -346,7 +421,11 @@ def generic_train(
# add custom checkpoints
if checkpoint_callback is None:
checkpoint_callback = pl.callbacks.ModelCheckpoint(
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1
filepath=args.output_dir,
prefix="checkpoint",
monitor="val_loss",
mode="min",
save_top_k=1,
)
if logging_callback is None:
logging_callback = LoggingCallback()
......
......@@ -141,7 +141,11 @@ def invert_mask(attention_mask):
def _prepare_bart_decoder_inputs(
config, input_ids, decoder_input_ids=None, decoder_padding_mask=None, causal_mask_dtype=torch.float32
config,
input_ids,
decoder_input_ids=None,
decoder_padding_mask=None,
causal_mask_dtype=torch.float32,
):
"""Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if
none are provided. This mimics the default behavior in fairseq. To override it pass in masks.
......@@ -184,7 +188,9 @@ class PretrainedBartModel(PreTrainedModel):
@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
input_ids = torch.tensor(
[[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device
)
dummy_inputs = {
"attention_mask": input_ids.ne(pad_token),
"input_ids": input_ids,
......@@ -229,7 +235,11 @@ class EncoderLayer(nn.Module):
def __init__(self, config: BartConfig):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = Attention(self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout)
self.self_attn = Attention(
self.embed_dim,
config.encoder_attention_heads,
dropout=config.attention_dropout,
)
self.normalize_before = config.normalize_before
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout = config.dropout
......@@ -255,7 +265,10 @@ class EncoderLayer(nn.Module):
if self.normalize_before:
x = self.self_attn_layer_norm(x)
x, attn_weights = self.self_attn(
query=x, key=x, key_padding_mask=encoder_padding_mask, output_attentions=output_attentions
query=x,
key=x,
key_padding_mask=encoder_padding_mask,
output_attentions=output_attentions,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
......@@ -308,13 +321,23 @@ class BartEncoder(nn.Module):
config.extra_pos_embeddings,
)
self.embed_patches = nn.Embedding(3, config.d_model, padding_idx=0)
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
self.layers = nn.ModuleList(
[EncoderLayer(config) for _ in range(config.encoder_layers)]
)
self.layernorm_embedding = (
LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
)
# mbart has one extra layer_norm
self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None
def forward(
self, input_ids, patch_ids, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=False
self,
input_ids,
patch_ids,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=False,
):
"""
Args:
......@@ -352,10 +375,14 @@ class BartEncoder(nn.Module):
encoder_states.append(x)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop): # skip the layer
if self.training and (
dropout_probability < self.layerdrop
): # skip the layer
attn = None
else:
x, attn = encoder_layer(x, attention_mask, output_attentions=output_attentions)
x, attn = encoder_layer(
x, attention_mask, output_attentions=output_attentions
)
if output_attentions:
all_attentions = all_attentions + (attn,)
......@@ -365,14 +392,20 @@ class BartEncoder(nn.Module):
if output_hidden_states:
encoder_states.append(x)
# T x B x C -> B x T x C
encoder_states = tuple(hidden_state.transpose(0, 1) for hidden_state in encoder_states)
encoder_states = tuple(
hidden_state.transpose(0, 1) for hidden_state in encoder_states
)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
if not return_dict:
return tuple(v for v in [x, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(last_hidden_state=x, hidden_states=encoder_states, attentions=all_attentions)
return tuple(
v for v in [x, encoder_states, all_attentions] if v is not None
)
return BaseModelOutput(
last_hidden_state=x, hidden_states=encoder_states, attentions=all_attentions
)
class DecoderLayer(nn.Module):
......@@ -498,8 +531,12 @@ class BartDecoder(nn.Module):
self.layers = nn.ModuleList(
[DecoderLayer(config) for _ in range(config.decoder_layers)]
) # type: List[DecoderLayer]
self.layernorm_embedding = LayerNorm(config.d_model) if config.normalize_embedding else nn.Identity()
self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None
self.layernorm_embedding = (
LayerNorm(config.d_model) if config.normalize_embedding else nn.Identity()
)
self.layer_norm = (
LayerNorm(config.d_model) if config.add_final_layer_norm else None
)
def forward(
self,
......@@ -595,23 +632,34 @@ class BartDecoder(nn.Module):
if use_cache:
next_decoder_cache.append(layer_past.copy())
if self.layer_norm and (idx == len(self.layers) - 1): # if config.add_final_layer_norm (mBART)
if self.layer_norm and (
idx == len(self.layers) - 1
): # if config.add_final_layer_norm (mBART)
x = self.layer_norm(x)
if output_attentions:
all_self_attns += (layer_self_attn,)
# Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
if output_hidden_states:
all_hidden_states = tuple(hidden_state.transpose(0, 1) for hidden_state in all_hidden_states)
all_hidden_states = tuple(
hidden_state.transpose(0, 1) for hidden_state in all_hidden_states
)
x = x.transpose(0, 1)
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [x, next_cache, all_hidden_states, all_self_attns] if v is not None)
return tuple(
v
for v in [x, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=x, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns
last_hidden_state=x,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
......@@ -638,7 +686,9 @@ class Attention(nn.Module):
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.encoder_decoder_attention = encoder_decoder_attention
......@@ -649,7 +699,11 @@ class Attention(nn.Module):
self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self"
def _shape(self, tensor, seq_len, bsz):
return tensor.contiguous().view(seq_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
return (
tensor.contiguous()
.view(seq_len, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
def forward(
self,
......@@ -693,7 +747,9 @@ class Attention(nn.Module):
v = self._shape(v, -1, bsz)
if saved_state is not None:
k, v, key_padding_mask = self._use_saved_state(k, v, saved_state, key_padding_mask, static_kv, bsz)
k, v, key_padding_mask = self._use_saved_state(
k, v, saved_state, key_padding_mask, static_kv, bsz
)
# Update cache
layer_state[self.cache_key] = {
......@@ -708,7 +764,9 @@ class Attention(nn.Module):
assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len)
if attn_mask is not None:
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask
attn_weights = (
attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
# This is part of a workaround to get around fork/join parallelism not supporting Optional types.
......@@ -725,16 +783,14 @@ class Attention(nn.Module):
attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1)
attn_probs = F.dropout(
attn_weights,
p=self.dropout,
training=self.training,
)
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training,)
assert v is not None
attn_output = torch.bmm(attn_probs, v)
assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn_output = (
attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
)
attn_output = self.out_proj(attn_output)
if output_attentions:
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
......@@ -763,12 +819,16 @@ class Attention(nn.Module):
assert v is not None
v = torch.cat([prev_value, v], dim=1)
assert k is not None and v is not None
prev_key_padding_mask: Optional[Tensor] = saved_state.get("prev_key_padding_mask", None)
prev_key_padding_mask: Optional[Tensor] = saved_state.get(
"prev_key_padding_mask", None
)
if prev_key_padding_mask is not None:
if static_kv:
new_key_padding_mask = prev_key_padding_mask
else:
new_key_padding_mask = torch.cat([prev_key_padding_mask, key_padding_mask], dim=1)
new_key_padding_mask = torch.cat(
[prev_key_padding_mask, key_padding_mask], dim=1
)
else:
new_key_padding_mask = key_padding_mask
return k, v, new_key_padding_mask
......@@ -780,11 +840,7 @@ class BartClassificationHead(nn.Module):
# This can trivially be shared with RobertaClassificationHead
def __init__(
self,
input_dim,
inner_dim,
num_classes,
pooler_dropout,
self, input_dim, inner_dim, num_classes, pooler_dropout,
):
super().__init__()
self.dense = nn.Linear(input_dim, inner_dim)
......@@ -808,7 +864,9 @@ class LearnedPositionalEmbedding(nn.Embedding):
position ids are passed to the forward function.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, offset):
def __init__(
self, num_embeddings: int, embedding_dim: int, padding_idx: int, offset
):
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models dont have this hack
self.offset = offset
......@@ -820,10 +878,14 @@ class LearnedPositionalEmbedding(nn.Embedding):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input_ids.shape[:2]
if use_cache:
positions = input_ids.data.new(1, 1).fill_(seq_len - 1) # called before slicing
positions = input_ids.data.new(1, 1).fill_(
seq_len - 1
) # called before slicing
else:
# starts at 0, ends at 1-seq_len
positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device)
positions = torch.arange(
seq_len, dtype=torch.long, device=self.weight.device
)
return super().forward(positions + self.offset)
......@@ -896,16 +958,28 @@ class BartModel(PretrainedBartModel):
if decoder_input_ids is None:
use_cache = False
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# make masks if user doesn't supply
if not use_cache:
decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs(
(
decoder_input_ids,
decoder_padding_mask,
causal_mask,
) = _prepare_bart_decoder_inputs(
self.config,
input_ids,
decoder_input_ids=decoder_input_ids,
......@@ -974,17 +1048,24 @@ class BartModel(PretrainedBartModel):
@add_start_docstrings(
"The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING
"The BART Model with a language modeling head. Can be used for summarization.",
BART_START_DOCSTRING,
)
class BartForConditionalGeneration(PretrainedBartModel):
base_model_prefix = "model"
authorized_missing_keys = [r"final_logits_bias", r"encoder\.version", r"decoder\.version"]
authorized_missing_keys = [
r"final_logits_bias",
r"encoder\.version",
r"decoder\.version",
]
def __init__(self, config: BartConfig):
super().__init__(config)
base_model = BartModel(config)
self.model = base_model
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
self.register_buffer(
"final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))
)
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
old_num_tokens = self.model.shared.num_embeddings
......@@ -993,16 +1074,23 @@ class BartForConditionalGeneration(PretrainedBartModel):
self._resize_final_logits_bias(new_num_tokens, old_num_tokens)
return new_embeddings
def _resize_final_logits_bias(self, new_num_tokens: int, old_num_tokens: int) -> None:
def _resize_final_logits_bias(
self, new_num_tokens: int, old_num_tokens: int
) -> None:
if new_num_tokens <= old_num_tokens:
new_bias = self.final_logits_bias[:, :new_num_tokens]
else:
extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
extra_bias = torch.zeros(
(1, new_num_tokens - old_num_tokens),
device=self.final_logits_bias.device,
)
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
self.register_buffer("final_logits_bias", new_bias)
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(
output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
)
@add_end_docstrings(BART_GENERATION_EXAMPLE)
def forward(
self,
......@@ -1065,7 +1153,9 @@ class BartForConditionalGeneration(PretrainedBartModel):
FutureWarning,
)
past_key_values = unused.pop("decoder_past_key_values")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if labels is not None:
use_cache = False
......@@ -1085,17 +1175,23 @@ class BartForConditionalGeneration(PretrainedBartModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
lm_logits = F.linear(outputs[0], self.model.shared.weight, bias=self.final_logits_bias)
lm_logits = F.linear(
outputs[0], self.model.shared.weight, bias=self.final_logits_bias
)
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
# TODO(SS): do we need to ignore pad tokens in labels?
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
masked_lm_loss = loss_fct(
lm_logits.view(-1, self.config.vocab_size), labels.view(-1)
)
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return (
((masked_lm_loss,) + output) if masked_lm_loss is not None else output
)
return Seq2SeqLMOutput(
loss=masked_lm_loss,
......@@ -1109,7 +1205,13 @@ class BartForConditionalGeneration(PretrainedBartModel):
)
def prepare_inputs_for_generation(
self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs
self,
decoder_input_ids,
past,
attention_mask,
use_cache,
encoder_outputs,
**kwargs,
):
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
......@@ -1130,7 +1232,9 @@ class BartForConditionalGeneration(PretrainedBartModel):
def _force_token_ids_generation(self, scores, token_id) -> None:
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
scores[:, [x for x in range(self.config.vocab_size) if x != token_id]] = -float("inf")
scores[:, [x for x in range(self.config.vocab_size) if x != token_id]] = -float(
"inf"
)
@staticmethod
def _reorder_cache(past, beam_idx):
......@@ -1138,7 +1242,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
for layer_past in past:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
layer_past_new = {
attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
attn_key: _reorder_buffer(attn_cache, beam_idx)
for attn_key, attn_cache in layer_past.items()
}
reordered_past.append(layer_past_new)
return reordered_past
......@@ -1159,10 +1264,7 @@ class BartForSequenceClassification(PretrainedBartModel):
super().__init__(config, **kwargs)
self.model = BartModel(config)
self.classification_head = BartClassificationHead(
config.d_model,
config.d_model,
config.num_labels,
config.classif_dropout,
config.d_model, config.d_model, config.num_labels, config.classif_dropout,
)
self.model._init_weights(self.classification_head.dense)
self.model._init_weights(self.classification_head.out_proj)
......@@ -1193,7 +1295,9 @@ class BartForSequenceClassification(PretrainedBartModel):
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if labels is not None:
use_cache = False
......@@ -1212,7 +1316,9 @@ class BartForSequenceClassification(PretrainedBartModel):
eos_mask = input_ids.eq(self.config.eos_token_id)
if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[
:, -1, :
]
logits = self.classification_head(sentence_representation)
loss = None
......@@ -1284,7 +1390,9 @@ class BartForQuestionAnswering(PretrainedBartModel):
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if start_positions is not None and end_positions is not None:
use_cache = False
......@@ -1325,10 +1433,7 @@ class BartForQuestionAnswering(PretrainedBartModel):
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (
start_logits,
end_logits,
) + outputs[1:]
output = (start_logits, end_logits,) + outputs[1:]
return ((total_loss,) + output) if total_loss is not None else output
return Seq2SeqQuestionAnsweringModelOutput(
......@@ -1350,7 +1455,9 @@ class SinusoidalPositionalEmbedding(nn.Embedding):
def __init__(self, num_positions, embedding_dim, padding_idx=None):
super().__init__(num_positions, embedding_dim)
if embedding_dim % 2 != 0:
raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported")
raise NotImplementedError(
f"odd embedding_dim {embedding_dim} not supported"
)
self.weight = self._init_weight(self.weight)
@staticmethod
......@@ -1360,9 +1467,14 @@ class SinusoidalPositionalEmbedding(nn.Embedding):
"""
n_pos, dim = out.shape
position_enc = np.array(
[[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
[
[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)]
for pos in range(n_pos)
]
)
out[:, 0 : dim // 2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) # This line breaks for odd n_pos
out[:, 0 : dim // 2] = torch.FloatTensor(
np.sin(position_enc[:, 0::2])
) # This line breaks for odd n_pos
out[:, dim // 2 :] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
out.detach_()
out.requires_grad = False
......@@ -1373,8 +1485,12 @@ class SinusoidalPositionalEmbedding(nn.Embedding):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input_ids.shape[:2]
if use_cache:
positions = input_ids.data.new(1, 1).fill_(seq_len - 1) # called before slicing
positions = input_ids.data.new(1, 1).fill_(
seq_len - 1
) # called before slicing
else:
# starts at 0, ends at 1-seq_len
positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device)
positions = torch.arange(
seq_len, dtype=torch.long, device=self.weight.device
)
return super().forward(positions)
......
......@@ -80,7 +80,9 @@ def find_pruneable_heads_and_indices(
:obj:`Tuple[Set[int], torch.LongTensor]`: A tuple with the remaining heads and their corresponding indices.
"""
mask = torch.ones(n_heads, head_size)
heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads
heads = (
set(heads) - already_pruned_heads
) # Convert to set and remove already pruned heads
for head in heads:
# Compute how many pruned heads are before the head and move the index accordingly
head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
......@@ -106,7 +108,11 @@ class ModuleUtilsMixin:
Returns:
:obj:`int`: The number of parameters.
"""
params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters()
params = (
filter(lambda x: x.requires_grad, self.parameters())
if only_trainable
else self.parameters()
)
return sum(p.numel() for p in params)
@staticmethod
......@@ -114,7 +120,9 @@ class ModuleUtilsMixin:
try:
import psutil
except (ImportError):
raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
raise ImportError(
"You need to install psutil (pip install psutil) to use memory tracing."
)
process = psutil.Process(os.getpid())
mem = process.memory_info()
......@@ -126,13 +134,17 @@ class ModuleUtilsMixin:
try:
import psutil
except (ImportError):
raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
raise ImportError(
"You need to install psutil (pip install psutil) to use memory tracing."
)
process = psutil.Process(os.getpid())
mem = process.memory_info()
module.mem_rss_post_forward = mem.rss
mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward
module.mem_rss_diff = mem_rss_diff + (module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0)
module.mem_rss_diff = mem_rss_diff + (
module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0
)
return None
def add_memory_hooks(self):
......@@ -169,7 +181,9 @@ class ModuleUtilsMixin:
# For nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
tuples = [
(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)
]
return tuples
gen = self._named_members(get_members_fn=find_tensor_attributes)
......@@ -187,7 +201,9 @@ class ModuleUtilsMixin:
# For nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
tuples = [
(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)
]
return tuples
gen = self._named_members(get_members_fn=find_tensor_attributes)
......@@ -213,12 +229,18 @@ class ModuleUtilsMixin:
# /transformer/transformer_layers.py#L270
# encoder_extended_attention_mask = (encoder_extended_attention_mask ==
# encoder_extended_attention_mask.transpose(-1, -2))
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
encoder_extended_attention_mask = encoder_extended_attention_mask.to(
dtype=self.dtype
) # fp16 compatibility
if self.dtype == torch.float16:
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e4
encoder_extended_attention_mask = (
1.0 - encoder_extended_attention_mask
) * -1e4
elif self.dtype == torch.float32:
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
encoder_extended_attention_mask = (
1.0 - encoder_extended_attention_mask
) * -1e9
else:
raise ValueError(
"{} not recognized. `dtype` should be set to either `torch.float32` or `torch.float16`".format(
......@@ -228,7 +250,9 @@ class ModuleUtilsMixin:
return encoder_extended_attention_mask
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device) -> Tensor:
def get_extended_attention_mask(
self, attention_mask: Tensor, input_shape: Tuple[int], device: device
) -> Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
......@@ -254,10 +278,15 @@ class ModuleUtilsMixin:
if self.config.is_decoder:
batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
causal_mask = (
seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
<= seq_ids[None, :, None]
)
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask = causal_mask.to(attention_mask.dtype)
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
extended_attention_mask = (
causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
)
else:
extended_attention_mask = attention_mask[:, None, None, :]
else:
......@@ -272,12 +301,17 @@ class ModuleUtilsMixin:
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = extended_attention_mask.to(
dtype=self.dtype
) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def get_head_mask(
self, head_mask: Optional[Tensor], num_hidden_layers: int, is_attention_chunked: bool = False
self,
head_mask: Optional[Tensor],
num_hidden_layers: int,
is_attention_chunked: bool = False,
) -> Tensor:
"""
Prepare the head mask if needed.
......@@ -309,9 +343,13 @@ class ModuleUtilsMixin:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = (
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
) # We can specify head_mask for each layer
assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
head_mask = head_mask.to(dtype=self.dtype) # switch to fload if need + fp16 compatibility
head_mask = head_mask.to(
dtype=self.dtype
) # switch to fload if need + fp16 compatibility
return head_mask
......@@ -420,12 +458,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
if self.config.is_encoder_decoder and self.config.tie_encoder_decoder:
self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix)
self._tie_encoder_decoder_weights(
self.encoder, self.decoder, self.base_model_prefix
)
@staticmethod
def _tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str):
def _tie_encoder_decoder_weights(
encoder: nn.Module, decoder: nn.Module, base_model_prefix: str
):
uninitialized_encoder_weights: List[str] = []
assert decoder.__class__ == encoder.__class__, f"{decoder.__class__} and {encoder.__class__} have to be equal."
assert (
decoder.__class__ == encoder.__class__
), f"{decoder.__class__} and {encoder.__class__} have to be equal."
def tie_encoder_to_decoder_recursively(
decoder_pointer: nn.Module,
......@@ -452,13 +496,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
len(encoder_modules) > 0
), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()])
all_encoder_weights = set(
[
module_name + "/" + sub_name
for sub_name in encoder_modules.keys()
]
)
encoder_layer_pos = 0
for name, module in decoder_modules.items():
if name.isdigit():
encoder_name = str(int(name) + encoder_layer_pos)
decoder_name = name
if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])):
if not isinstance(
decoder_modules[decoder_name],
type(encoder_modules[encoder_name]),
):
# this can happen if the name corresponds to the position in a list module list of layers
# in this case the decoder has added a cross-attention that the encoder does not have
# thus skip this step and substract one layer pos from encoder
......@@ -484,7 +536,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
uninitialized_encoder_weights += list(all_encoder_weights)
# tie weights recursively
tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights)
tie_encoder_to_decoder_recursively(
decoder, encoder, base_model_prefix, uninitialized_encoder_weights
)
if len(uninitialized_encoder_weights) > 0:
logger.warning(
f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}"
......@@ -507,10 +561,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
"constant",
0,
)
if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
if hasattr(output_embeddings, "out_features") and hasattr(
input_embeddings, "num_embeddings"
):
output_embeddings.out_features = input_embeddings.num_embeddings
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> torch.nn.Embedding:
def resize_token_embeddings(
self, new_num_tokens: Optional[int] = None
) -> torch.nn.Embedding:
"""
Resizes input token embeddings matrix of the model if :obj:`new_num_tokens != config.vocab_size`.
......@@ -526,7 +584,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
Return:
:obj:`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
"""
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
base_model = getattr(
self, self.base_model_prefix, self
) # get the base model if needed
model_embeds = base_model._resize_token_embeddings(new_num_tokens)
if new_num_tokens is None:
return model_embeds
......@@ -583,7 +643,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
# Copy token embeddings from the previous weights
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[
:num_tokens_to_copy, :
]
return new_embeddings
......@@ -614,7 +676,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
# save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
for layer, heads in heads_to_prune.items():
union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON
self.config.pruned_heads[layer] = list(
union_heads
) # Unfortunately we have to store it as list for JSON
self.base_model._prune_heads(heads_to_prune)
......@@ -628,7 +692,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
Directory to which to save. Will be created if it doesn't exist.
"""
if os.path.isfile(save_directory):
logger.error("Provided path ({}) should be a directory, not a file".format(save_directory))
logger.error(
"Provided path ({}) should be a directory, not a file".format(
save_directory
)
)
return
os.makedirs(save_directory, exist_ok=True)
......@@ -775,7 +843,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
# Load config if we don't provide a configuration
if not isinstance(config, PretrainedConfig):
config_path = config if config is not None else pretrained_model_name_or_path
config_path = (
config if config is not None else pretrained_model_name_or_path
)
config, model_kwargs = cls.config_class.from_pretrained(
config_path,
*model_args,
......@@ -793,23 +863,43 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
# Load model
if pretrained_model_name_or_path is not None:
if os.path.isdir(pretrained_model_name_or_path):
if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
if from_tf and os.path.isfile(
os.path.join(
pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index"
)
):
# Load from a TF 1.0 checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
archive_file = os.path.join(
pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index"
)
elif from_tf and os.path.isfile(
os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
):
# Load from a TF 2.0 checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
archive_file = os.path.join(
pretrained_model_name_or_path, TF2_WEIGHTS_NAME
)
elif os.path.isfile(
os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
):
# Load from a PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
archive_file = os.path.join(
pretrained_model_name_or_path, WEIGHTS_NAME
)
else:
raise EnvironmentError(
"Error no file named {} found in directory {} or `from_tf` set to False".format(
[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"],
[
WEIGHTS_NAME,
TF2_WEIGHTS_NAME,
TF_WEIGHTS_NAME + ".index",
],
pretrained_model_name_or_path,
)
)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(
pretrained_model_name_or_path
):
archive_file = pretrained_model_name_or_path
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
assert (
......@@ -848,7 +938,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
if resolved_archive_file == archive_file:
logger.info("loading weights file {}".format(archive_file))
else:
logger.info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file))
logger.info(
"loading weights file {} from cache at {}".format(
archive_file, resolved_archive_file
)
)
else:
resolved_archive_file = None
......@@ -871,13 +965,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
if from_tf:
if resolved_archive_file.endswith(".index"):
# Load from a TensorFlow 1.X checkpoint - provided by original authors
model = cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index'
model = cls.load_tf_weights(
model, config, resolved_archive_file[:-6]
) # Remove the '.index'
else:
# Load from our TensorFlow 2.0 checkpoints
try:
from transformers import load_tf2_checkpoint_in_pytorch_model
model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True)
model = load_tf2_checkpoint_in_pytorch_model(
model, resolved_archive_file, allow_missing_keys=True
)
except ImportError:
logger.error(
"Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
......@@ -909,7 +1007,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module: nn.Module, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
local_metadata = (
{} if metadata is None else metadata.get(prefix[:-1], {})
)
module._load_from_state_dict(
state_dict,
prefix,
......@@ -926,7 +1026,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
# Make sure we are able to load base models as well as derived models (with heads)
start_prefix = ""
model_to_load = model
has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys())
has_prefix_module = any(
s.startswith(cls.base_model_prefix) for s in state_dict.keys()
)
if not hasattr(model, cls.base_model_prefix) and has_prefix_module:
start_prefix = cls.base_model_prefix + "."
if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
......@@ -937,15 +1039,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
if model.__class__.__name__ != model_to_load.__class__.__name__:
base_model_state_dict = model_to_load.state_dict().keys()
head_model_state_dict_without_base_prefix = [
key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
key.split(cls.base_model_prefix + ".")[-1]
for key in model.state_dict().keys()
]
missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
missing_keys.extend(
head_model_state_dict_without_base_prefix - base_model_state_dict
)
# Some models may have keys that are not in the state by design, removing them before needlessly warning
# the user.
if cls.authorized_missing_keys is not None:
for pat in cls.authorized_missing_keys:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
missing_keys = [
k for k in missing_keys if re.search(pat, k) is None
]
if len(unexpected_keys) > 0:
logger.warning(
......@@ -957,7 +1064,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
)
else:
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
logger.info(
f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n"
)
if len(missing_keys) > 0:
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
......@@ -990,7 +1099,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
}
return model, loading_info
if hasattr(config, "xla_device") and config.xla_device and is_torch_tpu_available():
if (
hasattr(config, "xla_device")
and config.xla_device
and is_torch_tpu_available()
):
import torch_xla.core.xla_model as xm
model = xm.send_cpu_data_to_device(model, xm.xla_device())
......@@ -1039,7 +1152,9 @@ class PoolerStartLogits(nn.Module):
self.dense = nn.Linear(config.hidden_size, 1)
def forward(
self, hidden_states: torch.FloatTensor, p_mask: Optional[torch.FloatTensor] = None
self,
hidden_states: torch.FloatTensor,
p_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
"""
Args:
......@@ -1112,8 +1227,12 @@ class PoolerEndLogits(nn.Module):
), "One of start_states, start_positions should be not None"
if start_positions is not None:
slen, hsz = hidden_states.shape[-2:]
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
start_positions = start_positions[:, None, None].expand(
-1, -1, hsz
) # shape (bsz, 1, hsz)
start_states = hidden_states.gather(
-2, start_positions
) # shape (bsz, 1, hsz)
start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
......@@ -1177,12 +1296,20 @@ class PoolerAnswerClass(nn.Module):
start_states is not None or start_positions is not None
), "One of start_states, start_positions should be not None"
if start_positions is not None:
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
start_positions = start_positions[:, None, None].expand(
-1, -1, hsz
) # shape (bsz, 1, hsz)
start_states = hidden_states.gather(-2, start_positions).squeeze(
-2
) # shape (bsz, hsz)
if cls_index is not None:
cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
cls_index = cls_index[:, None, None].expand(
-1, -1, hsz
) # shape (bsz, 1, hsz)
cls_token_state = hidden_states.gather(-2, cls_index).squeeze(
-2
) # shape (bsz, hsz)
else:
cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
......@@ -1241,7 +1368,9 @@ class SQuADHead(nn.Module):
self.end_logits = PoolerEndLogits(config)
self.answer_class = PoolerAnswerClass(config)
@replace_return_docstrings(output_type=SquadHeadOutput, config_class=PretrainedConfig)
@replace_return_docstrings(
output_type=SquadHeadOutput, config_class=PretrainedConfig
)
def forward(
self,
hidden_states: torch.FloatTensor,
......@@ -1281,7 +1410,9 @@ class SQuADHead(nn.Module):
x.squeeze_(-1)
# during training, compute the end logits based on the ground truth of the start position
end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
end_logits = self.end_logits(
hidden_states, start_positions=start_positions, p_mask=p_mask
)
loss_fct = CrossEntropyLoss()
start_loss = loss_fct(start_logits, start_positions)
......@@ -1290,7 +1421,9 @@ class SQuADHead(nn.Module):
if cls_index is not None and is_impossible is not None:
# Predict answerability from the representation of CLS and START
cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
cls_logits = self.answer_class(
hidden_states, start_positions=start_positions, cls_index=cls_index
)
loss_fct_cls = nn.BCEWithLogitsLoss()
cls_loss = loss_fct_cls(cls_logits, is_impossible)
......@@ -1307,28 +1440,48 @@ class SQuADHead(nn.Module):
start_top_log_probs, start_top_index = torch.topk(
start_log_probs, self.start_n_top, dim=-1
) # shape (bsz, start_n_top)
start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
start_top_index_exp = start_top_index.unsqueeze(-1).expand(
-1, -1, hsz
) # shape (bsz, start_n_top, hsz)
start_states = torch.gather(
hidden_states, -2, start_top_index_exp
) # shape (bsz, start_n_top, hsz)
start_states = start_states.unsqueeze(1).expand(
-1, slen, -1, -1
) # shape (bsz, slen, start_n_top, hsz)
hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
start_states
) # shape (bsz, slen, start_n_top, hsz)
p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
end_logits = self.end_logits(
hidden_states_expanded, start_states=start_states, p_mask=p_mask
)
end_log_probs = F.softmax(
end_logits, dim=1
) # shape (bsz, slen, start_n_top)
end_top_log_probs, end_top_index = torch.topk(
end_log_probs, self.end_n_top, dim=1
) # shape (bsz, end_n_top, start_n_top)
end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
end_top_log_probs = end_top_log_probs.view(
-1, self.start_n_top * self.end_n_top
)
end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
cls_logits = self.answer_class(
hidden_states, start_states=start_states, cls_index=cls_index
)
if not return_dict:
return (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits)
return (
start_top_log_probs,
start_top_index,
end_top_log_probs,
end_top_index,
cls_logits,
)
else:
return SquadHeadOutput(
start_top_log_probs=start_top_log_probs,
......@@ -1379,17 +1532,26 @@ class SequenceSummary(nn.Module):
self.summary = Identity()
if hasattr(config, "summary_use_proj") and config.summary_use_proj:
if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
if (
hasattr(config, "summary_proj_to_labels")
and config.summary_proj_to_labels
and config.num_labels > 0
):
num_classes = config.num_labels
else:
num_classes = config.hidden_size
self.summary = nn.Linear(config.hidden_size, num_classes)
activation_string = getattr(config, "summary_activation", None)
self.activation: Callable = get_activation(activation_string) if activation_string else Identity()
self.activation: Callable = get_activation(
activation_string
) if activation_string else Identity()
self.first_dropout = Identity()
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
if (
hasattr(config, "summary_first_dropout")
and config.summary_first_dropout > 0
):
self.first_dropout = nn.Dropout(config.summary_first_dropout)
self.last_dropout = Identity()
......@@ -1397,7 +1559,9 @@ class SequenceSummary(nn.Module):
self.last_dropout = nn.Dropout(config.summary_last_dropout)
def forward(
self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
self,
hidden_states: torch.FloatTensor,
cls_index: Optional[torch.LongTensor] = None,
) -> torch.FloatTensor:
"""
Compute a single vector summary of a sequence hidden states.
......@@ -1427,9 +1591,13 @@ class SequenceSummary(nn.Module):
)
else:
cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
cls_index = cls_index.expand(
(-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)
)
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
output = hidden_states.gather(-2, cls_index).squeeze(
-2
) # shape (bsz, XX, hidden_size)
elif self.summary_type == "attn":
raise NotImplementedError
......@@ -1441,7 +1609,9 @@ class SequenceSummary(nn.Module):
return output
def prune_linear_layer(layer: torch.nn.Linear, index: torch.LongTensor, dim: int = 0) -> torch.nn.Linear:
def prune_linear_layer(
layer: torch.nn.Linear, index: torch.LongTensor, dim: int = 0
) -> torch.nn.Linear:
"""
Prune a linear layer to keep only entries in index.
......@@ -1464,7 +1634,9 @@ def prune_linear_layer(layer: torch.nn.Linear, index: torch.LongTensor, dim: int
b = layer.bias[index].clone().detach()
new_size = list(layer.weight.size())
new_size[dim] = len(index)
new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(
layer.weight.device
)
new_layer.weight.requires_grad = False
new_layer.weight.copy_(W.contiguous())
new_layer.weight.requires_grad = True
......@@ -1509,7 +1681,9 @@ def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) ->
def prune_layer(
layer: Union[torch.nn.Linear, Conv1D], index: torch.LongTensor, dim: Optional[int] = None
layer: Union[torch.nn.Linear, Conv1D],
index: torch.LongTensor,
dim: Optional[int] = None,
) -> Union[torch.nn.Linear, Conv1D]:
"""
Prune a Conv1D or linear layer to keep only entries in index.
......@@ -1534,7 +1708,10 @@ def prune_layer(
def apply_chunking_to_forward(
forward_fn: Callable[..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors
forward_fn: Callable[..., torch.Tensor],
chunk_size: int,
chunk_dim: int,
*input_tensors,
) -> torch.Tensor:
"""
This function chunks the :obj:`input_tensors` into smaller input tensor parts of size :obj:`chunk_size` over the
......@@ -1568,7 +1745,9 @@ def apply_chunking_to_forward(
return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
"""
assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors)
assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(
input_tensors
)
tensor_shape = input_tensors[0].shape
assert all(
input_tensor.shape == tensor_shape for input_tensor in input_tensors
......@@ -1592,9 +1771,15 @@ def apply_chunking_to_forward(
num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
# chunk input tensor into tuples
input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
input_tensors_chunks = tuple(
input_tensor.chunk(num_chunks, dim=chunk_dim)
for input_tensor in input_tensors
)
# apply forward fn to every tuple
output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
output_chunks = tuple(
forward_fn(*input_tensors_chunk)
for input_tensors_chunk in zip(*input_tensors_chunks)
)
# concatenate output at same dimension
return torch.cat(output_chunks, dim=chunk_dim)
......
......@@ -39,9 +39,13 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
return loss, nll_loss
def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
def encode_line(
tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"
):
"""Only used by LegacyDataset"""
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
extra_kw = (
{"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
)
return tokenizer(
[line],
max_length=max_length,
......@@ -63,9 +67,7 @@ def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict:
def trim_batch(
input_ids,
pad_token_id,
attention_mask=None,
input_ids, pad_token_id, attention_mask=None,
):
"""Remove columns that are populated exclusively by pad_token_id"""
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
......@@ -125,7 +127,9 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
def __getitem__(self, index) -> Dict[str, torch.Tensor]:
"""Call tokenizer on src and tgt_lines"""
index = index + 1 # linecache starts at 1
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip(
"\n"
)
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
assert source_line, f"empty source line for index {index}"
assert tgt_line, f"empty tgt line for index {index}"
......@@ -147,7 +151,9 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
target_ids = torch.stack([x["labels"] for x in batch])
pad_token_id = self.pad_token_id
y = trim_batch(target_ids, pad_token_id)
source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
source_ids, source_mask = trim_batch(
input_ids, pad_token_id, attention_mask=masks
)
batch = {
"input_ids": source_ids,
"attention_mask": source_mask,
......@@ -161,7 +167,9 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
def __getitem__(self, index) -> Dict[str, str]:
index = index + 1 # linecache starts at 1
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip(
"\n"
)
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
assert source_line, f"empty source line for index {index}"
assert tgt_line, f"empty tgt line for index {index}"
......@@ -201,12 +209,23 @@ class SortishSampler(Sampler):
idxs = np.random.permutation(len(self.data))
sz = self.bs * 50
ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx])
sort_idx = np.concatenate(
[sorted(s, key=self.key, reverse=True) for s in ck_idx]
)
sz = self.bs
ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx]) # find the chunk with the largest key,
ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first.
sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int)
max_ck = np.argmax(
[self.key(ck[0]) for ck in ck_idx]
) # find the chunk with the largest key,
ck_idx[0], ck_idx[max_ck] = (
ck_idx[max_ck],
ck_idx[0],
) # then make sure it goes first.
sort_idx = (
np.concatenate(np.random.permutation(ck_idx[1:]))
if len(ck_idx) > 1
else np.array([], dtype=np.int)
)
sort_idx = np.concatenate((ck_idx[0], sort_idx))
return iter(sort_idx)
......@@ -269,7 +288,9 @@ def get_git_info():
ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"]
def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer=True) -> Dict:
def calculate_rouge(
output_lns: List[str], reference_lns: List[str], use_stemmer=True
) -> Dict:
scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
aggregator = scoring.BootstrapAggregator()
......@@ -302,7 +323,9 @@ def assert_all_frozen(model):
model_grads: List[bool] = list(grad_status(model))
n_require_grad = sum(lmap(int, model_grads))
npars = len(model_grads)
assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad"
assert not any(
model_grads
), f"{n_require_grad/npars:.1%} of {npars} weights require grad"
def assert_not_all_frozen(model):
......