graykode

(add) seq2seq example in huggingface/transformers

1 +import logging
2 +import os
3 +from pathlib import Path
4 +
5 +import numpy as np
6 +import pytorch_lightning as pl
7 +import torch
8 +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
9 +from pytorch_lightning.utilities import rank_zero_only
10 +
11 +
12 +def count_trainable_parameters(model):
13 + model_parameters = filter(lambda p: p.requires_grad, model.parameters())
14 + params = sum([np.prod(p.size()) for p in model_parameters])
15 + return params
16 +
17 +
18 +logger = logging.getLogger(__name__)
19 +
20 +
21 +class Seq2SeqLoggingCallback(pl.Callback):
22 + def on_batch_end(self, trainer, pl_module):
23 + lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)}
24 + pl_module.logger.log_metrics(lrs)
25 +
26 + @rank_zero_only
27 + def _write_logs(
28 + self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True
29 + ) -> None:
30 + logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****")
31 + metrics = trainer.callback_metrics
32 + trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]})
33 + # Log results
34 + od = Path(pl_module.hparams.output_dir)
35 + if type_path == "test":
36 + results_file = od / "test_results.txt"
37 + generations_file = od / "test_generations.txt"
38 + else:
39 + # this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json
40 + # If people want this it will be easy enough to add back.
41 + results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt"
42 + generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt"
43 + results_file.parent.mkdir(exist_ok=True)
44 + generations_file.parent.mkdir(exist_ok=True)
45 + with open(results_file, "a+") as writer:
46 + for key in sorted(metrics):
47 + if key in ["log", "progress_bar", "preds"]:
48 + continue
49 + val = metrics[key]
50 + if isinstance(val, torch.Tensor):
51 + val = val.item()
52 + msg = f"{key}: {val:.6f}\n"
53 + writer.write(msg)
54 +
55 + if not save_generations:
56 + return
57 +
58 + if "preds" in metrics:
59 + content = "\n".join(metrics["preds"])
60 + generations_file.open("w+").write(content)
61 +
62 + @rank_zero_only
63 + def on_train_start(self, trainer, pl_module):
64 + try:
65 + npars = pl_module.model.model.num_parameters()
66 + except AttributeError:
67 + npars = pl_module.model.num_parameters()
68 +
69 + n_trainable_pars = count_trainable_parameters(pl_module)
70 + # mp stands for million parameters
71 + trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6})
72 +
73 + @rank_zero_only
74 + def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
75 + return self._write_logs(trainer, pl_module, "test")
76 +
77 +
78 +def get_checkpoint_callback(output_dir, metric):
79 + """Saves the best model by validation ROUGE2 score."""
80 + if metric == "rouge2":
81 + exp = "{val_avg_rouge2:.4f}-{step_count}"
82 + elif metric == "bleu":
83 + exp = "{val_avg_bleu:.4f}-{step_count}"
84 + else:
85 + raise NotImplementedError(
86 + f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function."
87 + )
88 +
89 + checkpoint_callback = ModelCheckpoint(
90 + filepath=os.path.join(output_dir, exp),
91 + monitor=f"val_{metric}",
92 + mode="max",
93 + save_top_k=1,
94 + period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
95 + )
96 + return checkpoint_callback
97 +
98 +
99 +def get_early_stopping_callback(metric, patience):
100 + return EarlyStopping(
101 + monitor=f"val_{metric}",
102 + mode="max",
103 + patience=patience,
104 + verbose=True,
105 + )
1 +import argparse
2 +import glob
3 +import logging
4 +import os
5 +import time
6 +from collections import defaultdict
7 +from pathlib import Path
8 +from typing import Dict, List, Tuple
9 +
10 +import numpy as np
11 +import pytorch_lightning as pl
12 +import torch
13 +from torch.utils.data import DataLoader
14 +
15 +from lightning_base import BaseTransformer, add_generic_args, generic_train
16 +from transformers import MBartTokenizer, T5ForConditionalGeneration
17 +from transformers.modeling_bart import shift_tokens_right
18 +
19 +
20 +try:
21 + from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
22 + from .utils import (
23 + ROUGE_KEYS,
24 + LegacySeq2SeqDataset,
25 + Seq2SeqDataset,
26 + assert_all_frozen,
27 + calculate_bleu,
28 + calculate_rouge,
29 + flatten_list,
30 + freeze_params,
31 + get_git_info,
32 + label_smoothed_nll_loss,
33 + lmap,
34 + pickle_save,
35 + save_git_info,
36 + save_json,
37 + use_task_specific_params,
38 + )
39 +except ImportError:
40 + from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
41 + from utils import (
42 + ROUGE_KEYS,
43 + LegacySeq2SeqDataset,
44 + Seq2SeqDataset,
45 + assert_all_frozen,
46 + calculate_bleu,
47 + calculate_rouge,
48 + flatten_list,
49 + freeze_params,
50 + get_git_info,
51 + label_smoothed_nll_loss,
52 + lmap,
53 + pickle_save,
54 + save_git_info,
55 + save_json,
56 + use_task_specific_params,
57 + )
58 +
59 +logger = logging.getLogger(__name__)
60 +
61 +
62 +class SummarizationModule(BaseTransformer):
63 + mode = "summarization"
64 + loss_names = ["loss"]
65 + metric_names = ROUGE_KEYS
66 + default_val_metric = "rouge2"
67 +
68 + def __init__(self, hparams, **kwargs):
69 + super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
70 + use_task_specific_params(self.model, "summarization")
71 + save_git_info(self.hparams.output_dir)
72 + self.metrics_save_path = Path(self.output_dir) / "metrics.json"
73 + self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
74 + pickle_save(self.hparams, self.hparams_save_path)
75 + self.step_count = 0
76 + self.metrics = defaultdict(list)
77 +
78 + self.dataset_kwargs: dict = dict(
79 + data_dir=self.hparams.data_dir,
80 + max_source_length=self.hparams.max_source_length,
81 + prefix=self.model.config.prefix or "",
82 + )
83 + n_observations_per_split = {
84 + "train": self.hparams.n_train,
85 + "val": self.hparams.n_val,
86 + "test": self.hparams.n_test,
87 + }
88 + self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}
89 +
90 + self.target_lens = {
91 + "train": self.hparams.max_target_length,
92 + "val": self.hparams.val_max_target_length,
93 + "test": self.hparams.test_max_target_length,
94 + }
95 + assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
96 + assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
97 +
98 + if self.hparams.freeze_embeds:
99 + self.freeze_embeds()
100 + if self.hparams.freeze_encoder:
101 + freeze_params(self.model.get_encoder())
102 + assert_all_frozen(self.model.get_encoder())
103 +
104 + self.hparams.git_sha = get_git_info()["repo_sha"]
105 + self.num_workers = hparams.num_workers
106 + self.decoder_start_token_id = None # default to config
107 + if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
108 + self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
109 + self.model.config.decoder_start_token_id = self.decoder_start_token_id
110 + self.dataset_class = (
111 + Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
112 + )
113 + self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
114 + assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1"
115 + self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
116 +
117 + def freeze_embeds(self):
118 + """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
119 + try:
120 + freeze_params(self.model.model.shared)
121 + for d in [self.model.model.encoder, self.model.model.decoder]:
122 + freeze_params(d.embed_positions)
123 + freeze_params(d.embed_tokens)
124 + except AttributeError:
125 + freeze_params(self.model.shared)
126 + for d in [self.model.encoder, self.model.decoder]:
127 + freeze_params(d.embed_tokens)
128 +
129 + def forward(self, input_ids, **kwargs):
130 + return self.model(input_ids, **kwargs)
131 +
132 + def ids_to_clean_text(self, generated_ids: List[int]):
133 + gen_text = self.tokenizer.batch_decode(
134 + generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
135 + )
136 + return lmap(str.strip, gen_text)
137 +
138 + def _step(self, batch: dict) -> Tuple:
139 + pad_token_id = self.tokenizer.pad_token_id
140 + src_ids, src_mask = batch["input_ids"], batch["attention_mask"]
141 + tgt_ids = batch["labels"]
142 + if isinstance(self.model, T5ForConditionalGeneration):
143 + decoder_input_ids = self.model._shift_right(tgt_ids)
144 + else:
145 + decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
146 +
147 + outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
148 + lm_logits = outputs[0]
149 + if self.hparams.label_smoothing == 0:
150 + # Same behavior as modeling_bart.py, besides ignoring pad_token_id
151 + loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
152 +
153 + assert lm_logits.shape[-1] == self.model.config.vocab_size
154 + loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
155 + else:
156 + lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
157 + loss, nll_loss = label_smoothed_nll_loss(
158 + lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id
159 + )
160 + return (loss,)
161 +
162 + @property
163 + def pad(self) -> int:
164 + return self.tokenizer.pad_token_id
165 +
166 + def training_step(self, batch, batch_idx) -> Dict:
167 + loss_tensors = self._step(batch)
168 +
169 + logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
170 + # tokens per batch
171 + logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["labels"].ne(self.pad).sum()
172 + return {"loss": loss_tensors[0], "log": logs}
173 +
174 + def validation_step(self, batch, batch_idx) -> Dict:
175 + return self._generative_step(batch)
176 +
177 + def validation_epoch_end(self, outputs, prefix="val") -> Dict:
178 + self.step_count += 1
179 + losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
180 + loss = losses["loss"]
181 + rouges = {k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]}
182 + rouge_tensor: torch.FloatTensor = torch.tensor(rouges[self.val_metric]).type_as(loss)
183 + rouges.update({k: v.item() for k, v in losses.items()})
184 + losses.update(rouges)
185 + metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
186 + metrics["step_count"] = self.step_count
187 + self.save_metrics(metrics, prefix) # writes to self.metrics_save_path
188 + preds = flatten_list([x["preds"] for x in outputs])
189 + return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_{self.val_metric}": rouge_tensor}
190 +
191 + def save_metrics(self, latest_metrics, type_path) -> None:
192 + self.metrics[type_path].append(latest_metrics)
193 + save_json(self.metrics, self.metrics_save_path)
194 +
195 + def calc_generative_metrics(self, preds, target) -> Dict:
196 + return calculate_rouge(preds, target)
197 +
198 + def _generative_step(self, batch: dict) -> dict:
199 + t0 = time.time()
200 + generated_ids = self.model.generate(
201 + batch["input_ids"],
202 + attention_mask=batch["attention_mask"],
203 + use_cache=True,
204 + decoder_start_token_id=self.decoder_start_token_id,
205 + )
206 + gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
207 + preds: List[str] = self.ids_to_clean_text(generated_ids)
208 + target: List[str] = self.ids_to_clean_text(batch["labels"])
209 + loss_tensors = self._step(batch)
210 + base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
211 + rouge: Dict = self.calc_generative_metrics(preds, target)
212 + summ_len = np.mean(lmap(len, generated_ids))
213 + base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge)
214 + return base_metrics
215 +
216 + def test_step(self, batch, batch_idx):
217 + return self._generative_step(batch)
218 +
219 + def test_epoch_end(self, outputs):
220 + return self.validation_epoch_end(outputs, prefix="test")
221 +
222 + def get_dataset(self, type_path) -> Seq2SeqDataset:
223 + n_obs = self.n_obs[type_path]
224 + max_target_length = self.target_lens[type_path]
225 + dataset = self.dataset_class(
226 + self.tokenizer,
227 + type_path=type_path,
228 + n_obs=n_obs,
229 + max_target_length=max_target_length,
230 + **self.dataset_kwargs,
231 + )
232 + return dataset
233 +
234 + def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
235 + dataset = self.get_dataset(type_path)
236 + sampler = None
237 + if self.hparams.sortish_sampler and type_path == "train":
238 + assert self.hparams.gpus <= 1 # TODO: assert earlier
239 + sampler = dataset.make_sortish_sampler(batch_size)
240 + shuffle = False
241 +
242 + dataloader = DataLoader(
243 + dataset,
244 + batch_size=batch_size,
245 + collate_fn=dataset.collate_fn,
246 + shuffle=shuffle,
247 + num_workers=self.num_workers,
248 + sampler=sampler,
249 + )
250 + return dataloader
251 +
252 + def train_dataloader(self) -> DataLoader:
253 + dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
254 + return dataloader
255 +
256 + def val_dataloader(self) -> DataLoader:
257 + return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
258 +
259 + def test_dataloader(self) -> DataLoader:
260 + return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)
261 +
262 + @staticmethod
263 + def add_model_specific_args(parser, root_dir):
264 + BaseTransformer.add_model_specific_args(parser, root_dir)
265 + add_generic_args(parser, root_dir)
266 + parser.add_argument(
267 + "--max_source_length",
268 + default=1024,
269 + type=int,
270 + help="The maximum total input sequence length after tokenization. Sequences longer "
271 + "than this will be truncated, sequences shorter will be padded.",
272 + )
273 + parser.add_argument(
274 + "--max_target_length",
275 + default=56,
276 + type=int,
277 + help="The maximum total input sequence length after tokenization. Sequences longer "
278 + "than this will be truncated, sequences shorter will be padded.",
279 + )
280 + parser.add_argument(
281 + "--val_max_target_length",
282 + default=142, # these defaults are optimized for CNNDM. For xsum, see README.md.
283 + type=int,
284 + help="The maximum total input sequence length after tokenization. Sequences longer "
285 + "than this will be truncated, sequences shorter will be padded.",
286 + )
287 + parser.add_argument(
288 + "--test_max_target_length",
289 + default=142,
290 + type=int,
291 + help="The maximum total input sequence length after tokenization. Sequences longer "
292 + "than this will be truncated, sequences shorter will be padded.",
293 + )
294 + parser.add_argument("--freeze_encoder", action="store_true")
295 + parser.add_argument("--freeze_embeds", action="store_true")
296 + parser.add_argument("--sortish_sampler", action="store_true", default=False)
297 + parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
298 + parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
299 + parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
300 + parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
301 + parser.add_argument(
302 + "--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
303 + )
304 + parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
305 + parser.add_argument("--src_lang", type=str, default="", required=False)
306 + parser.add_argument("--tgt_lang", type=str, default="", required=False)
307 + parser.add_argument("--eval_beams", type=int, default=None, required=False)
308 + parser.add_argument("--val_metric", type=str, default=None, required=False)
309 + parser.add_argument(
310 + "--early_stopping_patience",
311 + type=int,
312 + default=-1,
313 + required=False,
314 + help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
315 + )
316 + return parser
317 +
318 +
319 +class TranslationModule(SummarizationModule):
320 + mode = "translation"
321 + loss_names = ["loss"]
322 + metric_names = ["bleu"]
323 + default_val_metric = "bleu"
324 +
325 + def __init__(self, hparams, **kwargs):
326 + super().__init__(hparams, **kwargs)
327 + self.dataset_kwargs["src_lang"] = hparams.src_lang
328 + self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
329 +
330 + def calc_generative_metrics(self, preds, target) -> dict:
331 + return calculate_bleu(preds, target)
332 +
333 +
334 +def main(args, model=None) -> SummarizationModule:
335 + Path(args.output_dir).mkdir(exist_ok=True)
336 + if len(os.listdir(args.output_dir)) > 3 and args.do_train:
337 + raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
338 + if model is None:
339 + if args.task == "summarization":
340 + model: SummarizationModule = SummarizationModule(args)
341 + else:
342 + model: SummarizationModule = TranslationModule(args)
343 +
344 + dataset = Path(args.data_dir).name
345 + if (
346 + args.logger_name == "default"
347 + or args.fast_dev_run
348 + or str(args.output_dir).startswith("/tmp")
349 + or str(args.output_dir).startswith("/var")
350 + ):
351 + logger = True # don't pollute wandb logs unnecessarily
352 + elif args.logger_name == "wandb":
353 + from pytorch_lightning.loggers import WandbLogger
354 +
355 + project = os.environ.get("WANDB_PROJECT", dataset)
356 + logger = WandbLogger(name=model.output_dir.name, project=project)
357 +
358 + elif args.logger_name == "wandb_shared":
359 + from pytorch_lightning.loggers import WandbLogger
360 +
361 + logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
362 +
363 + if args.early_stopping_patience >= 0:
364 + es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
365 + else:
366 + es_callback = False
367 + trainer: pl.Trainer = generic_train(
368 + model,
369 + args,
370 + logging_callback=Seq2SeqLoggingCallback(),
371 + checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
372 + early_stopping_callback=es_callback,
373 + logger=logger,
374 + # TODO: early stopping callback seems messed up
375 + )
376 + pickle_save(model.hparams, model.output_dir / "hparams.pkl")
377 + if not args.do_predict:
378 + return model
379 +
380 + model.hparams.test_checkpoint = ""
381 + checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
382 + if checkpoints:
383 + model.hparams.test_checkpoint = checkpoints[-1]
384 + trainer.resume_from_checkpoint = checkpoints[-1]
385 + trainer.logger.log_hyperparams(model.hparams)
386 +
387 + # test() without a model tests using the best checkpoint automatically
388 + trainer.test()
389 + return model
390 +
391 +
392 +if __name__ == "__main__":
393 + parser = argparse.ArgumentParser()
394 + parser = pl.Trainer.add_argparse_args(parser)
395 + parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
396 +
397 + args = parser.parse_args()
398 +
399 + main(args)
1 +import argparse
2 +import logging
3 +import os
4 +from pathlib import Path
5 +from typing import Any, Dict
6 +
7 +import pytorch_lightning as pl
8 +from pytorch_lightning.utilities import rank_zero_info
9 +
10 +from transformers import (
11 + AdamW,
12 + AutoConfig,
13 + AutoModel,
14 + AutoModelForPreTraining,
15 + AutoModelForQuestionAnswering,
16 + AutoModelForSeq2SeqLM,
17 + AutoModelForSequenceClassification,
18 + AutoModelForTokenClassification,
19 + AutoModelWithLMHead,
20 + AutoTokenizer,
21 + PretrainedConfig,
22 + PreTrainedTokenizer,
23 +)
24 +from transformers.optimization import (
25 + Adafactor,
26 + get_cosine_schedule_with_warmup,
27 + get_cosine_with_hard_restarts_schedule_with_warmup,
28 + get_linear_schedule_with_warmup,
29 + get_polynomial_decay_schedule_with_warmup,
30 +)
31 +
32 +
33 +logger = logging.getLogger(__name__)
34 +
35 +
36 +MODEL_MODES = {
37 + "base": AutoModel,
38 + "sequence-classification": AutoModelForSequenceClassification,
39 + "question-answering": AutoModelForQuestionAnswering,
40 + "pretraining": AutoModelForPreTraining,
41 + "token-classification": AutoModelForTokenClassification,
42 + "language-modeling": AutoModelWithLMHead,
43 + "summarization": AutoModelForSeq2SeqLM,
44 + "translation": AutoModelForSeq2SeqLM,
45 +}
46 +
47 +
48 +# update this and the import above to support new schedulers from transformers.optimization
49 +arg_to_scheduler = {
50 + "linear": get_linear_schedule_with_warmup,
51 + "cosine": get_cosine_schedule_with_warmup,
52 + "cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup,
53 + "polynomial": get_polynomial_decay_schedule_with_warmup,
54 + # '': get_constant_schedule, # not supported for now
55 + # '': get_constant_schedule_with_warmup, # not supported for now
56 +}
57 +arg_to_scheduler_choices = sorted(arg_to_scheduler.keys())
58 +arg_to_scheduler_metavar = "{" + ", ".join(arg_to_scheduler_choices) + "}"
59 +
60 +
61 +class BaseTransformer(pl.LightningModule):
62 + def __init__(
63 + self,
64 + hparams: argparse.Namespace,
65 + num_labels=None,
66 + mode="base",
67 + config=None,
68 + tokenizer=None,
69 + model=None,
70 + **config_kwargs
71 + ):
72 + """Initialize a model, tokenizer and config."""
73 + super().__init__()
74 + # TODO: move to self.save_hyperparameters()
75 + # self.save_hyperparameters()
76 + # can also expand arguments into trainer signature for easier reading
77 +
78 + self.save_hyperparameters(hparams)
79 + self.step_count = 0
80 + self.output_dir = Path(self.hparams.output_dir)
81 + cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
82 + if config is None:
83 + self.config = AutoConfig.from_pretrained(
84 + self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path,
85 + **({"num_labels": num_labels} if num_labels is not None else {}),
86 + cache_dir=cache_dir,
87 + **config_kwargs,
88 + )
89 + else:
90 + self.config: PretrainedConfig = config
91 +
92 + extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
93 + for p in extra_model_params:
94 + if getattr(self.hparams, p, None):
95 + assert hasattr(self.config, p), f"model config doesn't have a `{p}` attribute"
96 + setattr(self.config, p, getattr(self.hparams, p))
97 +
98 + if tokenizer is None:
99 + self.tokenizer = AutoTokenizer.from_pretrained(
100 + self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
101 + cache_dir=cache_dir,
102 + )
103 + else:
104 + self.tokenizer: PreTrainedTokenizer = tokenizer
105 + self.model_type = MODEL_MODES[mode]
106 + if model is None:
107 + self.model = self.model_type.from_pretrained(
108 + self.hparams.model_name_or_path,
109 + from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
110 + config=self.config,
111 + cache_dir=cache_dir,
112 + )
113 + else:
114 + self.model = model
115 +
116 + def load_hf_checkpoint(self, *args, **kwargs):
117 + self.model = self.model_type.from_pretrained(*args, **kwargs)
118 +
119 + def get_lr_scheduler(self):
120 + get_schedule_func = arg_to_scheduler[self.hparams.lr_scheduler]
121 + scheduler = get_schedule_func(
122 + self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps
123 + )
124 + scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
125 + return scheduler
126 +
127 + def configure_optimizers(self):
128 + """Prepare optimizer and schedule (linear warmup and decay)"""
129 + model = self.model
130 + no_decay = ["bias", "LayerNorm.weight"]
131 + optimizer_grouped_parameters = [
132 + {
133 + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
134 + "weight_decay": self.hparams.weight_decay,
135 + },
136 + {
137 + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
138 + "weight_decay": 0.0,
139 + },
140 + ]
141 + if self.hparams.adafactor:
142 + optimizer = Adafactor(
143 + optimizer_grouped_parameters, lr=self.hparams.learning_rate, scale_parameter=False, relative_step=False
144 + )
145 +
146 + else:
147 + optimizer = AdamW(
148 + optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon
149 + )
150 + self.opt = optimizer
151 +
152 + scheduler = self.get_lr_scheduler()
153 +
154 + return [optimizer], [scheduler]
155 +
156 + def test_step(self, batch, batch_nb):
157 + return self.validation_step(batch, batch_nb)
158 +
159 + def test_epoch_end(self, outputs):
160 + return self.validation_end(outputs)
161 +
162 + @property
163 + def total_steps(self) -> int:
164 + """The number of total training steps that will be run. Used for lr scheduler purposes."""
165 + num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores
166 + effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices
167 + dataset_size = len(self.train_loader.dataset)
168 + return (dataset_size / effective_batch_size) * self.hparams.max_epochs
169 +
170 + def setup(self, mode):
171 + if mode == "fit":
172 + self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True)
173 +
174 + def get_dataloader(self, type_path, batch_size, shuffle=False):
175 + raise NotImplementedError("You must implement this for your task")
176 +
177 + def train_dataloader(self):
178 + return self.train_loader
179 +
180 + def val_dataloader(self):
181 + return self.get_dataloader("dev", self.hparams.eval_batch_size, shuffle=False)
182 +
183 + def test_dataloader(self):
184 + return self.get_dataloader("test", self.hparams.eval_batch_size, shuffle=False)
185 +
186 + def _feature_file(self, mode):
187 + return os.path.join(
188 + self.hparams.data_dir,
189 + "cached_{}_{}_{}".format(
190 + mode,
191 + list(filter(None, self.hparams.model_name_or_path.split("/"))).pop(),
192 + str(self.hparams.max_seq_length),
193 + ),
194 + )
195 +
196 + @pl.utilities.rank_zero_only
197 + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
198 + save_path = self.output_dir.joinpath("best_tfmr")
199 + self.model.config.save_step = self.step_count
200 + self.model.save_pretrained(save_path)
201 + self.tokenizer.save_pretrained(save_path)
202 +
203 + @staticmethod
204 + def add_model_specific_args(parser, root_dir):
205 + parser.add_argument(
206 + "--model_name_or_path",
207 + default=None,
208 + type=str,
209 + required=True,
210 + help="Path to pretrained model or model identifier from huggingface.co/models",
211 + )
212 + parser.add_argument(
213 + "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
214 + )
215 + parser.add_argument(
216 + "--tokenizer_name",
217 + default=None,
218 + type=str,
219 + help="Pretrained tokenizer name or path if not the same as model_name",
220 + )
221 + parser.add_argument(
222 + "--cache_dir",
223 + default="",
224 + type=str,
225 + help="Where do you want to store the pre-trained models downloaded from s3",
226 + )
227 + parser.add_argument(
228 + "--encoder_layerdrop",
229 + type=float,
230 + help="Encoder layer dropout probability (Optional). Goes into model.config",
231 + )
232 + parser.add_argument(
233 + "--decoder_layerdrop",
234 + type=float,
235 + help="Decoder layer dropout probability (Optional). Goes into model.config",
236 + )
237 + parser.add_argument(
238 + "--dropout",
239 + type=float,
240 + help="Dropout probability (Optional). Goes into model.config",
241 + )
242 + parser.add_argument(
243 + "--attention_dropout",
244 + type=float,
245 + help="Attention dropout probability (Optional). Goes into model.config",
246 + )
247 + parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
248 + parser.add_argument(
249 + "--lr_scheduler",
250 + default="linear",
251 + choices=arg_to_scheduler_choices,
252 + metavar=arg_to_scheduler_metavar,
253 + type=str,
254 + help="Learning rate scheduler",
255 + )
256 + parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
257 + parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
258 + parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
259 + parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader")
260 + parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int)
261 + parser.add_argument("--train_batch_size", default=32, type=int)
262 + parser.add_argument("--eval_batch_size", default=32, type=int)
263 + parser.add_argument("--adafactor", action="store_true")
264 +
265 +
266 +class LoggingCallback(pl.Callback):
267 + def on_batch_end(self, trainer, pl_module):
268 + lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
269 + lrs = {f"lr_group_{i}": lr for i, lr in enumerate(lr_scheduler.get_lr())}
270 + pl_module.logger.log_metrics(lrs)
271 +
272 + def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
273 + rank_zero_info("***** Validation results *****")
274 + metrics = trainer.callback_metrics
275 + # Log results
276 + for key in sorted(metrics):
277 + if key not in ["log", "progress_bar"]:
278 + rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
279 +
280 + def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
281 + rank_zero_info("***** Test results *****")
282 + metrics = trainer.callback_metrics
283 + # Log and save results to file
284 + output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt")
285 + with open(output_test_results_file, "w") as writer:
286 + for key in sorted(metrics):
287 + if key not in ["log", "progress_bar"]:
288 + rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
289 + writer.write("{} = {}\n".format(key, str(metrics[key])))
290 +
291 +
292 +def add_generic_args(parser, root_dir) -> None:
293 + # TODO(SS): allow all pl args? parser = pl.Trainer.add_argparse_args(parser)
294 + parser.add_argument(
295 + "--output_dir",
296 + default=None,
297 + type=str,
298 + required=True,
299 + help="The output directory where the model predictions and checkpoints will be written.",
300 + )
301 + parser.add_argument(
302 + "--fp16",
303 + action="store_true",
304 + help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
305 + )
306 +
307 + parser.add_argument(
308 + "--fp16_opt_level",
309 + type=str,
310 + default="O2",
311 + help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
312 + "See details at https://nvidia.github.io/apex/amp.html",
313 + )
314 + parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int)
315 + parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")
316 + parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
317 + parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.")
318 + parser.add_argument(
319 + "--gradient_accumulation_steps",
320 + dest="accumulate_grad_batches",
321 + type=int,
322 + default=1,
323 + help="Number of updates steps to accumulate before performing a backward/update pass.",
324 + )
325 + parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
326 + parser.add_argument(
327 + "--data_dir",
328 + default=None,
329 + type=str,
330 + required=True,
331 + help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
332 + )
333 +
334 +
335 +def generic_train(
336 + model: BaseTransformer,
337 + args: argparse.Namespace,
338 + early_stopping_callback=False,
339 + logger=True, # can pass WandbLogger() here
340 + extra_callbacks=[],
341 + checkpoint_callback=None,
342 + logging_callback=None,
343 + **extra_train_kwargs
344 +):
345 + pl.seed_everything(args.seed)
346 +
347 + # init model
348 + odir = Path(model.hparams.output_dir)
349 + odir.mkdir(exist_ok=True)
350 +
351 + # add custom checkpoints
352 + if checkpoint_callback is None:
353 + checkpoint_callback = pl.callbacks.ModelCheckpoint(
354 + filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1
355 + )
356 + if logging_callback is None:
357 + logging_callback = LoggingCallback()
358 +
359 + train_params = {}
360 +
361 + # TODO: remove with PyTorch 1.6 since pl uses native amp
362 + if args.fp16:
363 + train_params["precision"] = 16
364 + train_params["amp_level"] = args.fp16_opt_level
365 +
366 + if args.gpus > 1:
367 + train_params["distributed_backend"] = "ddp"
368 +
369 + trainer = pl.Trainer.from_argparse_args(
370 + args,
371 + weights_summary=None,
372 + callbacks=[logging_callback] + extra_callbacks,
373 + logger=logger,
374 + checkpoint_callback=checkpoint_callback,
375 + early_stop_callback=early_stopping_callback,
376 + **train_params,
377 + )
378 +
379 + if args.do_train:
380 + trainer.fit(model)
381 +
382 + return trainer
1 +import itertools
2 +import json
3 +import linecache
4 +import os
5 +import pickle
6 +from logging import getLogger
7 +from pathlib import Path
8 +from typing import Callable, Dict, Iterable, List
9 +
10 +import git
11 +import numpy as np
12 +import torch
13 +from rouge_score import rouge_scorer, scoring
14 +from sacrebleu import corpus_bleu
15 +from torch import nn
16 +from torch.utils.data import Dataset, Sampler
17 +
18 +from transformers import BartTokenizer
19 +
20 +
21 +def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
22 + """From fairseq"""
23 + if target.dim() == lprobs.dim() - 1:
24 + target = target.unsqueeze(-1)
25 + nll_loss = -lprobs.gather(dim=-1, index=target)
26 + smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
27 + if ignore_index is not None:
28 + pad_mask = target.eq(ignore_index)
29 + nll_loss.masked_fill_(pad_mask, 0.0)
30 + smooth_loss.masked_fill_(pad_mask, 0.0)
31 + else:
32 + nll_loss = nll_loss.squeeze(-1)
33 + smooth_loss = smooth_loss.squeeze(-1)
34 +
35 + nll_loss = nll_loss.sum() # mean()? Scared to break other math.
36 + smooth_loss = smooth_loss.sum()
37 + eps_i = epsilon / lprobs.size(-1)
38 + loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
39 + return loss, nll_loss
40 +
41 +
42 +def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
43 + """Only used by LegacyDataset"""
44 + extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
45 + return tokenizer(
46 + [line],
47 + max_length=max_length,
48 + padding="max_length" if pad_to_max_length else None,
49 + truncation=True,
50 + return_tensors=return_tensors,
51 + **extra_kw,
52 + )
53 +
54 +
55 +def lmap(f: Callable, x: Iterable) -> List:
56 + """list(map(f, x))"""
57 + return list(map(f, x))
58 +
59 +
60 +def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict:
61 + """Uses sacrebleu's corpus_bleu implementation."""
62 + return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)}
63 +
64 +
65 +def trim_batch(
66 + input_ids,
67 + pad_token_id,
68 + attention_mask=None,
69 +):
70 + """Remove columns that are populated exclusively by pad_token_id"""
71 + keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
72 + if attention_mask is None:
73 + return input_ids[:, keep_column_mask]
74 + else:
75 + return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
76 +
77 +
78 +class AbstractSeq2SeqDataset(Dataset):
79 + def __init__(
80 + self,
81 + tokenizer,
82 + data_dir,
83 + max_source_length,
84 + max_target_length,
85 + type_path="train",
86 + n_obs=None,
87 + src_lang=None,
88 + tgt_lang=None,
89 + prefix="",
90 + ):
91 + super().__init__()
92 + self.src_file = Path(data_dir).joinpath(type_path + ".source")
93 + self.tgt_file = Path(data_dir).joinpath(type_path + ".target")
94 + self.src_lens = self.get_char_lens(self.src_file)
95 + self.max_source_length = max_source_length
96 + self.max_target_length = max_target_length
97 + assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
98 + self.tokenizer = tokenizer
99 + self.prefix = prefix
100 + if n_obs is not None:
101 + self.src_lens = self.src_lens[:n_obs]
102 + self.pad_token_id = self.tokenizer.pad_token_id
103 + self.src_lang = src_lang
104 + self.tgt_lang = tgt_lang
105 + self.add_prefix_space = isinstance(self.tokenizer, BartTokenizer)
106 +
107 + def __len__(self):
108 + return len(self.src_lens)
109 +
110 + @staticmethod
111 + def get_char_lens(data_file):
112 + return [len(x) for x in Path(data_file).open().readlines()]
113 +
114 + def make_sortish_sampler(self, batch_size):
115 + return SortishSampler(self.src_lens, batch_size)
116 +
117 + def __getitem__(self, item):
118 + raise NotImplementedError("You must implement this")
119 +
120 + def collate_fn(self, batch):
121 + raise NotImplementedError("You must implement this")
122 +
123 +
124 +class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
125 + def __getitem__(self, index) -> Dict[str, torch.Tensor]:
126 + """Call tokenizer on src and tgt_lines"""
127 + index = index + 1 # linecache starts at 1
128 + source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
129 + tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
130 + assert source_line, f"empty source line for index {index}"
131 + assert tgt_line, f"empty tgt line for index {index}"
132 + source_inputs = encode_line(self.tokenizer, source_line, self.max_source_length)
133 + target_inputs = encode_line(self.tokenizer, tgt_line, self.max_target_length)
134 +
135 + source_ids = source_inputs["input_ids"].squeeze()
136 + target_ids = target_inputs["input_ids"].squeeze()
137 + src_mask = source_inputs["attention_mask"].squeeze()
138 + return {
139 + "input_ids": source_ids,
140 + "attention_mask": src_mask,
141 + "labels": target_ids,
142 + }
143 +
144 + def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
145 + input_ids = torch.stack([x["input_ids"] for x in batch])
146 + masks = torch.stack([x["attention_mask"] for x in batch])
147 + target_ids = torch.stack([x["labels"] for x in batch])
148 + pad_token_id = self.pad_token_id
149 + y = trim_batch(target_ids, pad_token_id)
150 + source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
151 + batch = {
152 + "input_ids": source_ids,
153 + "attention_mask": source_mask,
154 + "labels": y,
155 + }
156 + return batch
157 +
158 +
159 +class Seq2SeqDataset(AbstractSeq2SeqDataset):
160 + """A dataset that calls prepare_seq2seq_batch."""
161 +
162 + def __getitem__(self, index) -> Dict[str, str]:
163 + index = index + 1 # linecache starts at 1
164 + source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
165 + tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
166 + assert source_line, f"empty source line for index {index}"
167 + assert tgt_line, f"empty tgt line for index {index}"
168 + return {
169 + "tgt_texts": tgt_line,
170 + "src_texts": source_line,
171 + }
172 +
173 + def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
174 + """Call prepare_seq2seq_batch."""
175 + batch_encoding = self.tokenizer.prepare_seq2seq_batch(
176 + [x["src_texts"] for x in batch],
177 + src_lang=self.src_lang,
178 + tgt_texts=[x["tgt_texts"] for x in batch],
179 + tgt_lang=self.tgt_lang,
180 + max_length=self.max_source_length,
181 + max_target_length=self.max_target_length,
182 + return_tensors="pt",
183 + add_prefix_space=self.add_prefix_space,
184 + )
185 + return batch_encoding.data
186 +
187 +
188 +class SortishSampler(Sampler):
189 + "Go through the text data by order of src length with a bit of randomness. From fastai repo."
190 +
191 + def __init__(self, data, batch_size):
192 + self.data, self.bs = data, batch_size
193 +
194 + def key(self, i):
195 + return self.data[i]
196 +
197 + def __len__(self) -> int:
198 + return len(self.data)
199 +
200 + def __iter__(self):
201 + idxs = np.random.permutation(len(self.data))
202 + sz = self.bs * 50
203 + ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
204 + sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx])
205 + sz = self.bs
206 + ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
207 + max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx]) # find the chunk with the largest key,
208 + ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first.
209 + sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int)
210 + sort_idx = np.concatenate((ck_idx[0], sort_idx))
211 + return iter(sort_idx)
212 +
213 +
214 +logger = getLogger(__name__)
215 +
216 +
217 +def use_task_specific_params(model, task):
218 + """Update config with summarization specific params."""
219 + task_specific_params = model.config.task_specific_params
220 +
221 + if task_specific_params is not None:
222 + pars = task_specific_params.get(task, {})
223 + logger.info(f"using task specific params for {task}: {pars}")
224 + model.config.update(pars)
225 +
226 +
227 +def pickle_load(path):
228 + """pickle.load(path)"""
229 + with open(path, "rb") as f:
230 + return pickle.load(f)
231 +
232 +
233 +def pickle_save(obj, path):
234 + """pickle.dump(obj, path)"""
235 + with open(path, "wb") as f:
236 + return pickle.dump(obj, f)
237 +
238 +
239 +def flatten_list(summary_ids: List[List]):
240 + return [x for x in itertools.chain.from_iterable(summary_ids)]
241 +
242 +
243 +def save_git_info(folder_path: str) -> None:
244 + """Save git information to output_dir/git_log.json"""
245 + repo_infos = get_git_info()
246 + save_json(repo_infos, os.path.join(folder_path, "git_log.json"))
247 +
248 +
249 +def save_json(content, path):
250 + with open(path, "w") as f:
251 + json.dump(content, f, indent=4)
252 +
253 +
254 +def load_json(path):
255 + with open(path) as f:
256 + return json.load(f)
257 +
258 +
259 +def get_git_info():
260 + repo = git.Repo(search_parent_directories=True)
261 + repo_infos = {
262 + "repo_id": str(repo),
263 + "repo_sha": str(repo.head.object.hexsha),
264 + "repo_branch": str(repo.active_branch),
265 + }
266 + return repo_infos
267 +
268 +
269 +ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"]
270 +
271 +
272 +def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer=True) -> Dict:
273 + scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
274 + aggregator = scoring.BootstrapAggregator()
275 +
276 + for reference_ln, output_ln in zip(reference_lns, output_lns):
277 + scores = scorer.score(reference_ln, output_ln)
278 + aggregator.add_scores(scores)
279 +
280 + result = aggregator.aggregate()
281 + return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}
282 +
283 +
284 +# Utilities for freezing parameters and checking whether they are frozen
285 +
286 +
287 +def freeze_params(model: nn.Module):
288 + """Set requires_grad=False for each of model.parameters()"""
289 + for par in model.parameters():
290 + par.requires_grad = False
291 +
292 +
293 +def grad_status(model: nn.Module) -> Iterable:
294 + return (par.requires_grad for par in model.parameters())
295 +
296 +
297 +def any_requires_grad(model: nn.Module) -> bool:
298 + return any(grad_status(model))
299 +
300 +
301 +def assert_all_frozen(model):
302 + model_grads: List[bool] = list(grad_status(model))
303 + n_require_grad = sum(lmap(int, model_grads))
304 + npars = len(model_grads)
305 + assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad"
306 +
307 +
308 +def assert_not_all_frozen(model):
309 + model_grads: List[bool] = list(grad_status(model))
310 + npars = len(model_grads)
311 + assert any(model_grads), f"none of {npars} weights require grad"