graykode

(refactor) black style

......@@ -68,6 +68,7 @@ def main(args):
)
print(commit_message)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Code to collect commits on github")
parser.add_argument(
......
......@@ -15,6 +15,6 @@
from .gitcommit import diff_parse, truncate
__all__ = [
'diff_parse',
'truncate',
"diff_parse",
"truncate",
]
......
......@@ -36,9 +36,11 @@ logging.basicConfig(
level=logging.INFO,
)
class PATCH(enum.Enum):
PLUS=1
MINUS=2
PLUS = 1
MINUS = 2
def truncate(tuple, max_length, value=0):
ls = []
......@@ -46,22 +48,20 @@ def truncate(tuple, max_length, value=0):
if isinstance(t, int):
t = [t]
ls.extend(t)
ls = ls[:max_length - 1]
ls = ls[: max_length - 1]
ls.insert(0, value)
if len(ls) < max_length:
ls.extend([0] * (max_length - len(ls)))
assert len(ls) == max_length
return ls
def encode_line(tokenizer, line, patch):
line = re.sub(r'[\u0100-\uFFFF\U00010000-\U0010FFFF]+', '', line).strip()
line = re.sub(r"[\u0100-\uFFFF\U00010000-\U0010FFFF]+", "", line).strip()
tokens = tokenizer.tokenize(line)
tokens = tokenizer.convert_tokens_to_ids(tokens)
return (
tokens,
[1] * len(tokens),
len(tokens) * [patch.value]
)
return (tokens, [1] * len(tokens), len(tokens) * [patch.value])
def diff_parse(diff, tokenizer):
chunks = []
......@@ -78,6 +78,7 @@ def diff_parse(diff, tokenizer):
chunks.append(encode_line(tokenizer, change.line, PATCH.MINUS))
return chunks
def sha_parse(sha, tokenizer, max_length=1024):
chunks = diff_parse(diff=repo.git.show(sha), tokenizer=tokenizer)
......@@ -91,16 +92,18 @@ def sha_parse(sha, tokenizer, max_length=1024):
return (input_ids, attention_masks, patch_ids)
def message_parse(msg, tokenizer, max_length=56):
msg = re.sub(r'(\(|)#([0-9])+(\)|)', '', msg)
msg = re.sub(r"(\(|)#([0-9])+(\)|)", "", msg)
msg = re.sub(r'[\u0100-\uFFFF\U00010000-\U0010FFFF]+', '', msg).strip()
msg = re.sub(r"[\u0100-\uFFFF\U00010000-\U0010FFFF]+", "", msg).strip()
msg = tokenizer.tokenize(msg)
msg = tokenizer.convert_tokens_to_ids(msg)
msg = truncate(msg, max_length, value=0)
return msg
def jobs(sha_msgs, args, data_config, train=True):
input_ids, attention_masks, patch_ids, targets = [], [], [], []
......@@ -110,9 +113,7 @@ def jobs(sha_msgs, args, data_config, train=True):
sha, msg = sha_msg
source = sha_parse(
sha,
tokenizer=args.tokenizer,
max_length=args.max_source_length
sha, tokenizer=args.tokenizer, max_length=args.max_source_length
)
if not source:
continue
......@@ -120,7 +121,9 @@ def jobs(sha_msgs, args, data_config, train=True):
target = message_parse(
msg,
tokenizer=args.tokenizer,
max_length=(args.max_target_length if train else args.val_max_target_length),
max_length=(
args.max_target_length if train else args.val_max_target_length
),
)
input_ids.append(input_id)
......@@ -128,14 +131,17 @@ def jobs(sha_msgs, args, data_config, train=True):
patch_ids.append(patch_id)
targets.append(target)
data_saver({
data_saver(
{
"input_ids": np.asarray(input_ids),
"attention_masks": np.asarray(attention_masks),
"patch_ids": np.asarray(patch_ids),
"targets": np.asarray(targets),
})
}
)
data_saver.disconnect()
def start(chunked_sha_msgs, train=True):
logger.info(f"Start %s pre-processing" % ("training" if train else "evaluation"))
......@@ -144,22 +150,22 @@ def start(chunked_sha_msgs, train=True):
data_config = DataConfig(
endpoint=args.endpoint,
access_key=os.environ['access_key'],
secret_key=os.environ['secret_key'],
access_key=os.environ["access_key"],
secret_key=os.environ["secret_key"],
region=args.region,
dataset_name='commit-autosuggestions',
dataset_name="commit-autosuggestions",
additional={
"mode" : ("training" if train else "evaluation"),
"mode": ("training" if train else "evaluation"),
"max_source_length": args.max_source_length,
"max_target_length": max_target_length,
"url" : args.url,
"url": args.url,
},
attributes=[
('input_ids', 'int32', (args.max_source_length,)),
('attention_masks', 'int32', (args.max_source_length,)),
('patch_ids', 'int32', (args.max_source_length,)),
('targets', 'int32', (max_target_length,))
]
("input_ids", "int32", (args.max_source_length,)),
("attention_masks", "int32", (args.max_source_length,)),
("patch_ids", "int32", (args.max_source_length,)),
("targets", "int32", (max_target_length,)),
],
)
func = partial(jobs, args=args, data_config=data_config, train=train)
......@@ -168,14 +174,15 @@ def start(chunked_sha_msgs, train=True):
for i, _ in tqdm(enumerate(pool.imap_unordered(func, chunked_sha_msgs))):
pbar.update()
def main(args):
if 'access_key' not in os.environ or 'secret_key' not in os.environ:
if "access_key" not in os.environ or "secret_key" not in os.environ:
raise OSError("access_key or secret_key are not found.")
sha_msgs = [(c.hexsha, c.summary) for c in repo.iter_commits()]
random.shuffle(sha_msgs)
chunked_sha_msgs = [
sha_msgs[x:x + args.matorage_batch]
sha_msgs[x : x + args.matorage_batch]
for x in range(0, len(sha_msgs), args.matorage_batch)
]
......@@ -185,29 +192,25 @@ def main(args):
if args.do_predict:
start(chunked_sha_msgs[barrier:], train=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Code to collect commits on github")
parser.add_argument(
"--url",
type=str,
required=True,
help="github url"
)
parser.add_argument("--url", type=str, required=True, help="github url")
parser.add_argument(
"--endpoint",
type=str,
required=True,
help='matorage endpoint, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html'
help="matorage endpoint, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html",
)
parser.add_argument(
"--region",
type=str,
default=None,
help='matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html'
help="matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html",
)
parser.add_argument(
"--tokenizer_name",
default='sshleifer/distilbart-xsum-6-6',
default="sshleifer/distilbart-xsum-6-6",
type=str,
help="Pretrained tokenizer name or path if not the same as model_name",
)
......@@ -215,13 +218,10 @@ if __name__ == "__main__":
"--matorage_batch",
default=1024,
type=int,
help='The smallest batch size stored atomically in matorage.'
help="The smallest batch size stored atomically in matorage.",
)
parser.add_argument(
"--num_workers",
default=4,
type=int,
help="number of process",
"--num_workers", default=4, type=int, help="number of process",
)
parser.add_argument(
"--max_source_length",
......@@ -244,12 +244,14 @@ if __name__ == "__main__":
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument("--p_val", type=float, default=0.25, help="percent of validation dataset")
parser.add_argument(
"--p_val", type=float, default=0.25, help="percent of validation dataset"
)
parser.add_argument("--do_train", action="store_true", default=False)
parser.add_argument("--do_predict", action="store_true", default=False)
args = parser.parse_args()
args.local_path = args.url.split('/')[-1]
args.local_path = args.url.split("/")[-1]
logger.info(f"master branch of {args.url} will be downloaded to {args.local_path}")
repo = (
Repo(args.local_path)
......
......@@ -14,6 +14,4 @@
from .modeling_bart import BartForConditionalGeneration
__all__ = [
'BartForConditionalGeneration'
]
\ No newline at end of file
__all__ = ["BartForConditionalGeneration"]
......
......@@ -20,16 +20,31 @@ logger = logging.getLogger(__name__)
class Seq2SeqLoggingCallback(pl.Callback):
def on_batch_end(self, trainer, pl_module):
lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)}
lrs = {
f"lr_group_{i}": param["lr"]
for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)
}
pl_module.logger.log_metrics(lrs)
@rank_zero_only
def _write_logs(
self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
type_path: str,
save_generations=True,
) -> None:
logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****")
logger.info(
f"***** {type_path} results at step {trainer.global_step:05d} *****"
)
metrics = trainer.callback_metrics
trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]})
trainer.logger.log_metrics(
{
k: v
for k, v in metrics.items()
if k not in ["log", "progress_bar", "preds"]
}
)
# Log results
od = Path(pl_module.hparams.output_dir)
if type_path == "test":
......@@ -39,7 +54,9 @@ class Seq2SeqLoggingCallback(pl.Callback):
# this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json
# If people want this it will be easy enough to add back.
results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt"
generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt"
generations_file = (
od / f"{type_path}_generations/{trainer.global_step:05d}.txt"
)
results_file.parent.mkdir(exist_ok=True)
generations_file.parent.mkdir(exist_ok=True)
with open(results_file, "a+") as writer:
......@@ -68,7 +85,9 @@ class Seq2SeqLoggingCallback(pl.Callback):
n_trainable_pars = count_trainable_parameters(pl_module)
# mp stands for million parameters
trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6})
trainer.logger.log_metrics(
{"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6}
)
@rank_zero_only
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
......@@ -98,8 +117,5 @@ def get_checkpoint_callback(output_dir, metric):
def get_early_stopping_callback(metric, patience):
return EarlyStopping(
monitor=f"val_{metric}",
mode="max",
patience=patience,
verbose=True,
monitor=f"val_{metric}", mode="max", patience=patience, verbose=True,
)
......
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
......@@ -69,7 +69,7 @@ class BaseTransformer(pl.LightningModule):
config=None,
tokenizer=None,
model=None,
**config_kwargs
**config_kwargs,
):
"""Initialize a model, tokenizer and config."""
super().__init__()
......@@ -83,7 +83,9 @@ class BaseTransformer(pl.LightningModule):
cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
if config is None:
self.config = AutoConfig.from_pretrained(
self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path,
self.hparams.config_name
if self.hparams.config_name
else self.hparams.model_name_or_path,
**({"num_labels": num_labels} if num_labels is not None else {}),
cache_dir=cache_dir,
**config_kwargs,
......@@ -91,15 +93,24 @@ class BaseTransformer(pl.LightningModule):
else:
self.config: PretrainedConfig = config
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
extra_model_params = (
"encoder_layerdrop",
"decoder_layerdrop",
"dropout",
"attention_dropout",
)
for p in extra_model_params:
if getattr(self.hparams, p, None):
assert hasattr(self.config, p), f"model config doesn't have a `{p}` attribute"
assert hasattr(
self.config, p
), f"model config doesn't have a `{p}` attribute"
setattr(self.config, p, getattr(self.hparams, p))
if tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained(
self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
self.hparams.tokenizer_name
if self.hparams.tokenizer_name
else self.hparams.model_name_or_path,
cache_dir=cache_dir,
)
else:
......@@ -121,7 +132,9 @@ class BaseTransformer(pl.LightningModule):
def get_lr_scheduler(self):
get_schedule_func = arg_to_scheduler[self.hparams.lr_scheduler]
scheduler = get_schedule_func(
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps
self.opt,
num_warmup_steps=self.hparams.warmup_steps,
num_training_steps=self.total_steps,
)
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
return scheduler
......@@ -132,22 +145,35 @@ class BaseTransformer(pl.LightningModule):
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"params": [
p
for n, p in model.named_parameters()
if not any(nd in n for nd in no_decay)
],
"weight_decay": self.hparams.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"params": [
p
for n, p in model.named_parameters()
if any(nd in n for nd in no_decay)
],
"weight_decay": 0.0,
},
]
if self.hparams.adafactor:
optimizer = Adafactor(
optimizer_grouped_parameters, lr=self.hparams.learning_rate, scale_parameter=False, relative_step=False
optimizer_grouped_parameters,
lr=self.hparams.learning_rate,
scale_parameter=False,
relative_step=False,
)
else:
optimizer = AdamW(
optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon
optimizer_grouped_parameters,
lr=self.hparams.learning_rate,
eps=self.hparams.adam_epsilon,
)
self.opt = optimizer
......@@ -165,13 +191,19 @@ class BaseTransformer(pl.LightningModule):
def total_steps(self) -> int:
"""The number of total training steps that will be run. Used for lr scheduler purposes."""
num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores
effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices
effective_batch_size = (
self.hparams.train_batch_size
* self.hparams.accumulate_grad_batches
* num_devices
)
dataset_size = len(self.train_loader.dataset)
return (dataset_size / effective_batch_size) * self.hparams.max_epochs
def setup(self, mode):
if mode == "fit":
self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True)
self.train_loader = self.get_dataloader(
"train", self.hparams.train_batch_size, shuffle=True
)
def get_dataloader(self, type_path, batch_size, shuffle=False):
raise NotImplementedError("You must implement this for your task")
......@@ -212,7 +244,10 @@ class BaseTransformer(pl.LightningModule):
help="Path to pretrained model or model identifier from huggingface.co/models",
)
parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
"--config_name",
default="",
type=str,
help="Pretrained config name or path if not the same as model_name",
)
parser.add_argument(
"--tokenizer_name",
......@@ -246,7 +281,12 @@ class BaseTransformer(pl.LightningModule):
type=float,
help="Attention dropout probability (Optional). Goes into model.config",
)
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument(
"--learning_rate",
default=5e-5,
type=float,
help="The initial learning rate for Adam.",
)
parser.add_argument(
"--lr_scheduler",
default="linear",
......@@ -255,11 +295,30 @@ class BaseTransformer(pl.LightningModule):
type=str,
help="Learning rate scheduler",
)
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader")
parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int)
parser.add_argument(
"--weight_decay",
default=0.0,
type=float,
help="Weight decay if we apply some.",
)
parser.add_argument(
"--adam_epsilon",
default=1e-8,
type=float,
help="Epsilon for Adam optimizer.",
)
parser.add_argument(
"--warmup_steps",
default=0,
type=int,
help="Linear warmup over warmup_steps.",
)
parser.add_argument(
"--num_workers", default=4, type=int, help="kwarg passed to DataLoader"
)
parser.add_argument(
"--num_train_epochs", dest="max_epochs", default=3, type=int
)
parser.add_argument("--train_batch_size", default=32, type=int)
parser.add_argument("--eval_batch_size", default=32, type=int)
parser.add_argument("--adafactor", action="store_true")
......@@ -283,7 +342,9 @@ class LoggingCallback(pl.Callback):
rank_zero_info("***** Test results *****")
metrics = trainer.callback_metrics
# Log and save results to file
output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt")
output_test_results_file = os.path.join(
pl_module.hparams.output_dir, "test_results.txt"
)
with open(output_test_results_file, "w") as writer:
for key in sorted(metrics):
if key not in ["log", "progress_bar"]:
......@@ -314,9 +375,21 @@ def add_generic_args(parser, root_dir) -> None:
"See details at https://nvidia.github.io/apex/amp.html",
)
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int)
parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.")
parser.add_argument(
"--max_grad_norm",
dest="gradient_clip_val",
default=1.0,
type=float,
help="Max gradient norm",
)
parser.add_argument(
"--do_train", action="store_true", help="Whether to run training."
)
parser.add_argument(
"--do_predict",
action="store_true",
help="Whether to run predictions on the test set.",
)
parser.add_argument(
"--gradient_accumulation_steps",
dest="accumulate_grad_batches",
......@@ -324,7 +397,9 @@ def add_generic_args(parser, root_dir) -> None:
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument(
"--seed", type=int, default=42, help="random seed for initialization"
)
def generic_train(
......@@ -335,7 +410,7 @@ def generic_train(
extra_callbacks=[],
checkpoint_callback=None,
logging_callback=None,
**extra_train_kwargs
**extra_train_kwargs,
):
pl.seed_everything(args.seed)
......@@ -346,7 +421,11 @@ def generic_train(
# add custom checkpoints
if checkpoint_callback is None:
checkpoint_callback = pl.callbacks.ModelCheckpoint(
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1
filepath=args.output_dir,
prefix="checkpoint",
monitor="val_loss",
mode="min",
save_top_k=1,
)
if logging_callback is None:
logging_callback = LoggingCallback()
......
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
......@@ -39,9 +39,13 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
return loss, nll_loss
def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
def encode_line(
tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"
):
"""Only used by LegacyDataset"""
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
extra_kw = (
{"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
)
return tokenizer(
[line],
max_length=max_length,
......@@ -63,9 +67,7 @@ def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict:
def trim_batch(
input_ids,
pad_token_id,
attention_mask=None,
input_ids, pad_token_id, attention_mask=None,
):
"""Remove columns that are populated exclusively by pad_token_id"""
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
......@@ -125,7 +127,9 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
def __getitem__(self, index) -> Dict[str, torch.Tensor]:
"""Call tokenizer on src and tgt_lines"""
index = index + 1 # linecache starts at 1
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip(
"\n"
)
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
assert source_line, f"empty source line for index {index}"
assert tgt_line, f"empty tgt line for index {index}"
......@@ -147,7 +151,9 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
target_ids = torch.stack([x["labels"] for x in batch])
pad_token_id = self.pad_token_id
y = trim_batch(target_ids, pad_token_id)
source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
source_ids, source_mask = trim_batch(
input_ids, pad_token_id, attention_mask=masks
)
batch = {
"input_ids": source_ids,
"attention_mask": source_mask,
......@@ -161,7 +167,9 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
def __getitem__(self, index) -> Dict[str, str]:
index = index + 1 # linecache starts at 1
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip(
"\n"
)
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
assert source_line, f"empty source line for index {index}"
assert tgt_line, f"empty tgt line for index {index}"
......@@ -201,12 +209,23 @@ class SortishSampler(Sampler):
idxs = np.random.permutation(len(self.data))
sz = self.bs * 50
ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx])
sort_idx = np.concatenate(
[sorted(s, key=self.key, reverse=True) for s in ck_idx]
)
sz = self.bs
ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx]) # find the chunk with the largest key,
ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first.
sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int)
max_ck = np.argmax(
[self.key(ck[0]) for ck in ck_idx]
) # find the chunk with the largest key,
ck_idx[0], ck_idx[max_ck] = (
ck_idx[max_ck],
ck_idx[0],
) # then make sure it goes first.
sort_idx = (
np.concatenate(np.random.permutation(ck_idx[1:]))
if len(ck_idx) > 1
else np.array([], dtype=np.int)
)
sort_idx = np.concatenate((ck_idx[0], sort_idx))
return iter(sort_idx)
......@@ -269,7 +288,9 @@ def get_git_info():
ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"]
def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer=True) -> Dict:
def calculate_rouge(
output_lns: List[str], reference_lns: List[str], use_stemmer=True
) -> Dict:
scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
aggregator = scoring.BootstrapAggregator()
......@@ -302,7 +323,9 @@ def assert_all_frozen(model):
model_grads: List[bool] = list(grad_status(model))
n_require_grad = sum(lmap(int, model_grads))
npars = len(model_grads)
assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad"
assert not any(
model_grads
), f"{n_require_grad/npars:.1%} of {npars} weights require grad"
def assert_not_all_frozen(model):
......