Showing
4 changed files
with
105 additions
and
0 deletions
callbacks.py
0 → 100644
1 | +import logging | ||
2 | +import os | ||
3 | +from pathlib import Path | ||
4 | + | ||
5 | +import numpy as np | ||
6 | +import pytorch_lightning as pl | ||
7 | +import torch | ||
8 | +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint | ||
9 | +from pytorch_lightning.utilities import rank_zero_only | ||
10 | + | ||
11 | + | ||
12 | +def count_trainable_parameters(model): | ||
13 | + model_parameters = filter(lambda p: p.requires_grad, model.parameters()) | ||
14 | + params = sum([np.prod(p.size()) for p in model_parameters]) | ||
15 | + return params | ||
16 | + | ||
17 | + | ||
18 | +logger = logging.getLogger(__name__) | ||
19 | + | ||
20 | + | ||
21 | +class Seq2SeqLoggingCallback(pl.Callback): | ||
22 | + def on_batch_end(self, trainer, pl_module): | ||
23 | + lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)} | ||
24 | + pl_module.logger.log_metrics(lrs) | ||
25 | + | ||
26 | + @rank_zero_only | ||
27 | + def _write_logs( | ||
28 | + self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True | ||
29 | + ) -> None: | ||
30 | + logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****") | ||
31 | + metrics = trainer.callback_metrics | ||
32 | + trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]}) | ||
33 | + # Log results | ||
34 | + od = Path(pl_module.hparams.output_dir) | ||
35 | + if type_path == "test": | ||
36 | + results_file = od / "test_results.txt" | ||
37 | + generations_file = od / "test_generations.txt" | ||
38 | + else: | ||
39 | + # this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json | ||
40 | + # If people want this it will be easy enough to add back. | ||
41 | + results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt" | ||
42 | + generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt" | ||
43 | + results_file.parent.mkdir(exist_ok=True) | ||
44 | + generations_file.parent.mkdir(exist_ok=True) | ||
45 | + with open(results_file, "a+") as writer: | ||
46 | + for key in sorted(metrics): | ||
47 | + if key in ["log", "progress_bar", "preds"]: | ||
48 | + continue | ||
49 | + val = metrics[key] | ||
50 | + if isinstance(val, torch.Tensor): | ||
51 | + val = val.item() | ||
52 | + msg = f"{key}: {val:.6f}\n" | ||
53 | + writer.write(msg) | ||
54 | + | ||
55 | + if not save_generations: | ||
56 | + return | ||
57 | + | ||
58 | + if "preds" in metrics: | ||
59 | + content = "\n".join(metrics["preds"]) | ||
60 | + generations_file.open("w+").write(content) | ||
61 | + | ||
62 | + @rank_zero_only | ||
63 | + def on_train_start(self, trainer, pl_module): | ||
64 | + try: | ||
65 | + npars = pl_module.model.model.num_parameters() | ||
66 | + except AttributeError: | ||
67 | + npars = pl_module.model.num_parameters() | ||
68 | + | ||
69 | + n_trainable_pars = count_trainable_parameters(pl_module) | ||
70 | + # mp stands for million parameters | ||
71 | + trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6}) | ||
72 | + | ||
73 | + @rank_zero_only | ||
74 | + def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): | ||
75 | + return self._write_logs(trainer, pl_module, "test") | ||
76 | + | ||
77 | + | ||
78 | +def get_checkpoint_callback(output_dir, metric): | ||
79 | + """Saves the best model by validation ROUGE2 score.""" | ||
80 | + if metric == "rouge2": | ||
81 | + exp = "{val_avg_rouge2:.4f}-{step_count}" | ||
82 | + elif metric == "bleu": | ||
83 | + exp = "{val_avg_bleu:.4f}-{step_count}" | ||
84 | + else: | ||
85 | + raise NotImplementedError( | ||
86 | + f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function." | ||
87 | + ) | ||
88 | + | ||
89 | + checkpoint_callback = ModelCheckpoint( | ||
90 | + filepath=os.path.join(output_dir, exp), | ||
91 | + monitor=f"val_{metric}", | ||
92 | + mode="max", | ||
93 | + save_top_k=1, | ||
94 | + period=0, # maybe save a checkpoint every time val is run, not just end of epoch. | ||
95 | + ) | ||
96 | + return checkpoint_callback | ||
97 | + | ||
98 | + | ||
99 | +def get_early_stopping_callback(metric, patience): | ||
100 | + return EarlyStopping( | ||
101 | + monitor=f"val_{metric}", | ||
102 | + mode="max", | ||
103 | + patience=patience, | ||
104 | + verbose=True, | ||
105 | + ) |
finetune.py
0 → 100644
This diff is collapsed. Click to expand it.
lightning_base.py
0 → 100644
This diff is collapsed. Click to expand it.
utils.py
0 → 100644
This diff is collapsed. Click to expand it.
-
Please register or login to post a comment