graykode

(add) seq2seq example in huggingface/transformers

1 +import logging
2 +import os
3 +from pathlib import Path
4 +
5 +import numpy as np
6 +import pytorch_lightning as pl
7 +import torch
8 +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
9 +from pytorch_lightning.utilities import rank_zero_only
10 +
11 +
12 +def count_trainable_parameters(model):
13 + model_parameters = filter(lambda p: p.requires_grad, model.parameters())
14 + params = sum([np.prod(p.size()) for p in model_parameters])
15 + return params
16 +
17 +
18 +logger = logging.getLogger(__name__)
19 +
20 +
21 +class Seq2SeqLoggingCallback(pl.Callback):
22 + def on_batch_end(self, trainer, pl_module):
23 + lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)}
24 + pl_module.logger.log_metrics(lrs)
25 +
26 + @rank_zero_only
27 + def _write_logs(
28 + self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True
29 + ) -> None:
30 + logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****")
31 + metrics = trainer.callback_metrics
32 + trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]})
33 + # Log results
34 + od = Path(pl_module.hparams.output_dir)
35 + if type_path == "test":
36 + results_file = od / "test_results.txt"
37 + generations_file = od / "test_generations.txt"
38 + else:
39 + # this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json
40 + # If people want this it will be easy enough to add back.
41 + results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt"
42 + generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt"
43 + results_file.parent.mkdir(exist_ok=True)
44 + generations_file.parent.mkdir(exist_ok=True)
45 + with open(results_file, "a+") as writer:
46 + for key in sorted(metrics):
47 + if key in ["log", "progress_bar", "preds"]:
48 + continue
49 + val = metrics[key]
50 + if isinstance(val, torch.Tensor):
51 + val = val.item()
52 + msg = f"{key}: {val:.6f}\n"
53 + writer.write(msg)
54 +
55 + if not save_generations:
56 + return
57 +
58 + if "preds" in metrics:
59 + content = "\n".join(metrics["preds"])
60 + generations_file.open("w+").write(content)
61 +
62 + @rank_zero_only
63 + def on_train_start(self, trainer, pl_module):
64 + try:
65 + npars = pl_module.model.model.num_parameters()
66 + except AttributeError:
67 + npars = pl_module.model.num_parameters()
68 +
69 + n_trainable_pars = count_trainable_parameters(pl_module)
70 + # mp stands for million parameters
71 + trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6})
72 +
73 + @rank_zero_only
74 + def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
75 + return self._write_logs(trainer, pl_module, "test")
76 +
77 +
78 +def get_checkpoint_callback(output_dir, metric):
79 + """Saves the best model by validation ROUGE2 score."""
80 + if metric == "rouge2":
81 + exp = "{val_avg_rouge2:.4f}-{step_count}"
82 + elif metric == "bleu":
83 + exp = "{val_avg_bleu:.4f}-{step_count}"
84 + else:
85 + raise NotImplementedError(
86 + f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function."
87 + )
88 +
89 + checkpoint_callback = ModelCheckpoint(
90 + filepath=os.path.join(output_dir, exp),
91 + monitor=f"val_{metric}",
92 + mode="max",
93 + save_top_k=1,
94 + period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
95 + )
96 + return checkpoint_callback
97 +
98 +
99 +def get_early_stopping_callback(metric, patience):
100 + return EarlyStopping(
101 + monitor=f"val_{metric}",
102 + mode="max",
103 + patience=patience,
104 + verbose=True,
105 + )
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.