Showing
4 changed files
with
1197 additions
and
0 deletions
callbacks.py
0 → 100644
1 | +import logging | ||
2 | +import os | ||
3 | +from pathlib import Path | ||
4 | + | ||
5 | +import numpy as np | ||
6 | +import pytorch_lightning as pl | ||
7 | +import torch | ||
8 | +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint | ||
9 | +from pytorch_lightning.utilities import rank_zero_only | ||
10 | + | ||
11 | + | ||
12 | +def count_trainable_parameters(model): | ||
13 | + model_parameters = filter(lambda p: p.requires_grad, model.parameters()) | ||
14 | + params = sum([np.prod(p.size()) for p in model_parameters]) | ||
15 | + return params | ||
16 | + | ||
17 | + | ||
18 | +logger = logging.getLogger(__name__) | ||
19 | + | ||
20 | + | ||
21 | +class Seq2SeqLoggingCallback(pl.Callback): | ||
22 | + def on_batch_end(self, trainer, pl_module): | ||
23 | + lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)} | ||
24 | + pl_module.logger.log_metrics(lrs) | ||
25 | + | ||
26 | + @rank_zero_only | ||
27 | + def _write_logs( | ||
28 | + self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True | ||
29 | + ) -> None: | ||
30 | + logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****") | ||
31 | + metrics = trainer.callback_metrics | ||
32 | + trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]}) | ||
33 | + # Log results | ||
34 | + od = Path(pl_module.hparams.output_dir) | ||
35 | + if type_path == "test": | ||
36 | + results_file = od / "test_results.txt" | ||
37 | + generations_file = od / "test_generations.txt" | ||
38 | + else: | ||
39 | + # this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json | ||
40 | + # If people want this it will be easy enough to add back. | ||
41 | + results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt" | ||
42 | + generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt" | ||
43 | + results_file.parent.mkdir(exist_ok=True) | ||
44 | + generations_file.parent.mkdir(exist_ok=True) | ||
45 | + with open(results_file, "a+") as writer: | ||
46 | + for key in sorted(metrics): | ||
47 | + if key in ["log", "progress_bar", "preds"]: | ||
48 | + continue | ||
49 | + val = metrics[key] | ||
50 | + if isinstance(val, torch.Tensor): | ||
51 | + val = val.item() | ||
52 | + msg = f"{key}: {val:.6f}\n" | ||
53 | + writer.write(msg) | ||
54 | + | ||
55 | + if not save_generations: | ||
56 | + return | ||
57 | + | ||
58 | + if "preds" in metrics: | ||
59 | + content = "\n".join(metrics["preds"]) | ||
60 | + generations_file.open("w+").write(content) | ||
61 | + | ||
62 | + @rank_zero_only | ||
63 | + def on_train_start(self, trainer, pl_module): | ||
64 | + try: | ||
65 | + npars = pl_module.model.model.num_parameters() | ||
66 | + except AttributeError: | ||
67 | + npars = pl_module.model.num_parameters() | ||
68 | + | ||
69 | + n_trainable_pars = count_trainable_parameters(pl_module) | ||
70 | + # mp stands for million parameters | ||
71 | + trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6}) | ||
72 | + | ||
73 | + @rank_zero_only | ||
74 | + def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): | ||
75 | + return self._write_logs(trainer, pl_module, "test") | ||
76 | + | ||
77 | + | ||
78 | +def get_checkpoint_callback(output_dir, metric): | ||
79 | + """Saves the best model by validation ROUGE2 score.""" | ||
80 | + if metric == "rouge2": | ||
81 | + exp = "{val_avg_rouge2:.4f}-{step_count}" | ||
82 | + elif metric == "bleu": | ||
83 | + exp = "{val_avg_bleu:.4f}-{step_count}" | ||
84 | + else: | ||
85 | + raise NotImplementedError( | ||
86 | + f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function." | ||
87 | + ) | ||
88 | + | ||
89 | + checkpoint_callback = ModelCheckpoint( | ||
90 | + filepath=os.path.join(output_dir, exp), | ||
91 | + monitor=f"val_{metric}", | ||
92 | + mode="max", | ||
93 | + save_top_k=1, | ||
94 | + period=0, # maybe save a checkpoint every time val is run, not just end of epoch. | ||
95 | + ) | ||
96 | + return checkpoint_callback | ||
97 | + | ||
98 | + | ||
99 | +def get_early_stopping_callback(metric, patience): | ||
100 | + return EarlyStopping( | ||
101 | + monitor=f"val_{metric}", | ||
102 | + mode="max", | ||
103 | + patience=patience, | ||
104 | + verbose=True, | ||
105 | + ) |
finetune.py
0 → 100644
1 | +import argparse | ||
2 | +import glob | ||
3 | +import logging | ||
4 | +import os | ||
5 | +import time | ||
6 | +from collections import defaultdict | ||
7 | +from pathlib import Path | ||
8 | +from typing import Dict, List, Tuple | ||
9 | + | ||
10 | +import numpy as np | ||
11 | +import pytorch_lightning as pl | ||
12 | +import torch | ||
13 | +from torch.utils.data import DataLoader | ||
14 | + | ||
15 | +from lightning_base import BaseTransformer, add_generic_args, generic_train | ||
16 | +from transformers import MBartTokenizer, T5ForConditionalGeneration | ||
17 | +from transformers.modeling_bart import shift_tokens_right | ||
18 | + | ||
19 | + | ||
20 | +try: | ||
21 | + from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback | ||
22 | + from .utils import ( | ||
23 | + ROUGE_KEYS, | ||
24 | + LegacySeq2SeqDataset, | ||
25 | + Seq2SeqDataset, | ||
26 | + assert_all_frozen, | ||
27 | + calculate_bleu, | ||
28 | + calculate_rouge, | ||
29 | + flatten_list, | ||
30 | + freeze_params, | ||
31 | + get_git_info, | ||
32 | + label_smoothed_nll_loss, | ||
33 | + lmap, | ||
34 | + pickle_save, | ||
35 | + save_git_info, | ||
36 | + save_json, | ||
37 | + use_task_specific_params, | ||
38 | + ) | ||
39 | +except ImportError: | ||
40 | + from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback | ||
41 | + from utils import ( | ||
42 | + ROUGE_KEYS, | ||
43 | + LegacySeq2SeqDataset, | ||
44 | + Seq2SeqDataset, | ||
45 | + assert_all_frozen, | ||
46 | + calculate_bleu, | ||
47 | + calculate_rouge, | ||
48 | + flatten_list, | ||
49 | + freeze_params, | ||
50 | + get_git_info, | ||
51 | + label_smoothed_nll_loss, | ||
52 | + lmap, | ||
53 | + pickle_save, | ||
54 | + save_git_info, | ||
55 | + save_json, | ||
56 | + use_task_specific_params, | ||
57 | + ) | ||
58 | + | ||
59 | +logger = logging.getLogger(__name__) | ||
60 | + | ||
61 | + | ||
62 | +class SummarizationModule(BaseTransformer): | ||
63 | + mode = "summarization" | ||
64 | + loss_names = ["loss"] | ||
65 | + metric_names = ROUGE_KEYS | ||
66 | + default_val_metric = "rouge2" | ||
67 | + | ||
68 | + def __init__(self, hparams, **kwargs): | ||
69 | + super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs) | ||
70 | + use_task_specific_params(self.model, "summarization") | ||
71 | + save_git_info(self.hparams.output_dir) | ||
72 | + self.metrics_save_path = Path(self.output_dir) / "metrics.json" | ||
73 | + self.hparams_save_path = Path(self.output_dir) / "hparams.pkl" | ||
74 | + pickle_save(self.hparams, self.hparams_save_path) | ||
75 | + self.step_count = 0 | ||
76 | + self.metrics = defaultdict(list) | ||
77 | + | ||
78 | + self.dataset_kwargs: dict = dict( | ||
79 | + data_dir=self.hparams.data_dir, | ||
80 | + max_source_length=self.hparams.max_source_length, | ||
81 | + prefix=self.model.config.prefix or "", | ||
82 | + ) | ||
83 | + n_observations_per_split = { | ||
84 | + "train": self.hparams.n_train, | ||
85 | + "val": self.hparams.n_val, | ||
86 | + "test": self.hparams.n_test, | ||
87 | + } | ||
88 | + self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()} | ||
89 | + | ||
90 | + self.target_lens = { | ||
91 | + "train": self.hparams.max_target_length, | ||
92 | + "val": self.hparams.val_max_target_length, | ||
93 | + "test": self.hparams.test_max_target_length, | ||
94 | + } | ||
95 | + assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}" | ||
96 | + assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}" | ||
97 | + | ||
98 | + if self.hparams.freeze_embeds: | ||
99 | + self.freeze_embeds() | ||
100 | + if self.hparams.freeze_encoder: | ||
101 | + freeze_params(self.model.get_encoder()) | ||
102 | + assert_all_frozen(self.model.get_encoder()) | ||
103 | + | ||
104 | + self.hparams.git_sha = get_git_info()["repo_sha"] | ||
105 | + self.num_workers = hparams.num_workers | ||
106 | + self.decoder_start_token_id = None # default to config | ||
107 | + if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer): | ||
108 | + self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang] | ||
109 | + self.model.config.decoder_start_token_id = self.decoder_start_token_id | ||
110 | + self.dataset_class = ( | ||
111 | + Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset | ||
112 | + ) | ||
113 | + self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams | ||
114 | + assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1" | ||
115 | + self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric | ||
116 | + | ||
117 | + def freeze_embeds(self): | ||
118 | + """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" | ||
119 | + try: | ||
120 | + freeze_params(self.model.model.shared) | ||
121 | + for d in [self.model.model.encoder, self.model.model.decoder]: | ||
122 | + freeze_params(d.embed_positions) | ||
123 | + freeze_params(d.embed_tokens) | ||
124 | + except AttributeError: | ||
125 | + freeze_params(self.model.shared) | ||
126 | + for d in [self.model.encoder, self.model.decoder]: | ||
127 | + freeze_params(d.embed_tokens) | ||
128 | + | ||
129 | + def forward(self, input_ids, **kwargs): | ||
130 | + return self.model(input_ids, **kwargs) | ||
131 | + | ||
132 | + def ids_to_clean_text(self, generated_ids: List[int]): | ||
133 | + gen_text = self.tokenizer.batch_decode( | ||
134 | + generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True | ||
135 | + ) | ||
136 | + return lmap(str.strip, gen_text) | ||
137 | + | ||
138 | + def _step(self, batch: dict) -> Tuple: | ||
139 | + pad_token_id = self.tokenizer.pad_token_id | ||
140 | + src_ids, src_mask = batch["input_ids"], batch["attention_mask"] | ||
141 | + tgt_ids = batch["labels"] | ||
142 | + if isinstance(self.model, T5ForConditionalGeneration): | ||
143 | + decoder_input_ids = self.model._shift_right(tgt_ids) | ||
144 | + else: | ||
145 | + decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id) | ||
146 | + | ||
147 | + outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False) | ||
148 | + lm_logits = outputs[0] | ||
149 | + if self.hparams.label_smoothing == 0: | ||
150 | + # Same behavior as modeling_bart.py, besides ignoring pad_token_id | ||
151 | + loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) | ||
152 | + | ||
153 | + assert lm_logits.shape[-1] == self.model.config.vocab_size | ||
154 | + loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1)) | ||
155 | + else: | ||
156 | + lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1) | ||
157 | + loss, nll_loss = label_smoothed_nll_loss( | ||
158 | + lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id | ||
159 | + ) | ||
160 | + return (loss,) | ||
161 | + | ||
162 | + @property | ||
163 | + def pad(self) -> int: | ||
164 | + return self.tokenizer.pad_token_id | ||
165 | + | ||
166 | + def training_step(self, batch, batch_idx) -> Dict: | ||
167 | + loss_tensors = self._step(batch) | ||
168 | + | ||
169 | + logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} | ||
170 | + # tokens per batch | ||
171 | + logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["labels"].ne(self.pad).sum() | ||
172 | + return {"loss": loss_tensors[0], "log": logs} | ||
173 | + | ||
174 | + def validation_step(self, batch, batch_idx) -> Dict: | ||
175 | + return self._generative_step(batch) | ||
176 | + | ||
177 | + def validation_epoch_end(self, outputs, prefix="val") -> Dict: | ||
178 | + self.step_count += 1 | ||
179 | + losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names} | ||
180 | + loss = losses["loss"] | ||
181 | + rouges = {k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]} | ||
182 | + rouge_tensor: torch.FloatTensor = torch.tensor(rouges[self.val_metric]).type_as(loss) | ||
183 | + rouges.update({k: v.item() for k, v in losses.items()}) | ||
184 | + losses.update(rouges) | ||
185 | + metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()} | ||
186 | + metrics["step_count"] = self.step_count | ||
187 | + self.save_metrics(metrics, prefix) # writes to self.metrics_save_path | ||
188 | + preds = flatten_list([x["preds"] for x in outputs]) | ||
189 | + return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_{self.val_metric}": rouge_tensor} | ||
190 | + | ||
191 | + def save_metrics(self, latest_metrics, type_path) -> None: | ||
192 | + self.metrics[type_path].append(latest_metrics) | ||
193 | + save_json(self.metrics, self.metrics_save_path) | ||
194 | + | ||
195 | + def calc_generative_metrics(self, preds, target) -> Dict: | ||
196 | + return calculate_rouge(preds, target) | ||
197 | + | ||
198 | + def _generative_step(self, batch: dict) -> dict: | ||
199 | + t0 = time.time() | ||
200 | + generated_ids = self.model.generate( | ||
201 | + batch["input_ids"], | ||
202 | + attention_mask=batch["attention_mask"], | ||
203 | + use_cache=True, | ||
204 | + decoder_start_token_id=self.decoder_start_token_id, | ||
205 | + ) | ||
206 | + gen_time = (time.time() - t0) / batch["input_ids"].shape[0] | ||
207 | + preds: List[str] = self.ids_to_clean_text(generated_ids) | ||
208 | + target: List[str] = self.ids_to_clean_text(batch["labels"]) | ||
209 | + loss_tensors = self._step(batch) | ||
210 | + base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} | ||
211 | + rouge: Dict = self.calc_generative_metrics(preds, target) | ||
212 | + summ_len = np.mean(lmap(len, generated_ids)) | ||
213 | + base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge) | ||
214 | + return base_metrics | ||
215 | + | ||
216 | + def test_step(self, batch, batch_idx): | ||
217 | + return self._generative_step(batch) | ||
218 | + | ||
219 | + def test_epoch_end(self, outputs): | ||
220 | + return self.validation_epoch_end(outputs, prefix="test") | ||
221 | + | ||
222 | + def get_dataset(self, type_path) -> Seq2SeqDataset: | ||
223 | + n_obs = self.n_obs[type_path] | ||
224 | + max_target_length = self.target_lens[type_path] | ||
225 | + dataset = self.dataset_class( | ||
226 | + self.tokenizer, | ||
227 | + type_path=type_path, | ||
228 | + n_obs=n_obs, | ||
229 | + max_target_length=max_target_length, | ||
230 | + **self.dataset_kwargs, | ||
231 | + ) | ||
232 | + return dataset | ||
233 | + | ||
234 | + def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader: | ||
235 | + dataset = self.get_dataset(type_path) | ||
236 | + sampler = None | ||
237 | + if self.hparams.sortish_sampler and type_path == "train": | ||
238 | + assert self.hparams.gpus <= 1 # TODO: assert earlier | ||
239 | + sampler = dataset.make_sortish_sampler(batch_size) | ||
240 | + shuffle = False | ||
241 | + | ||
242 | + dataloader = DataLoader( | ||
243 | + dataset, | ||
244 | + batch_size=batch_size, | ||
245 | + collate_fn=dataset.collate_fn, | ||
246 | + shuffle=shuffle, | ||
247 | + num_workers=self.num_workers, | ||
248 | + sampler=sampler, | ||
249 | + ) | ||
250 | + return dataloader | ||
251 | + | ||
252 | + def train_dataloader(self) -> DataLoader: | ||
253 | + dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True) | ||
254 | + return dataloader | ||
255 | + | ||
256 | + def val_dataloader(self) -> DataLoader: | ||
257 | + return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size) | ||
258 | + | ||
259 | + def test_dataloader(self) -> DataLoader: | ||
260 | + return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size) | ||
261 | + | ||
262 | + @staticmethod | ||
263 | + def add_model_specific_args(parser, root_dir): | ||
264 | + BaseTransformer.add_model_specific_args(parser, root_dir) | ||
265 | + add_generic_args(parser, root_dir) | ||
266 | + parser.add_argument( | ||
267 | + "--max_source_length", | ||
268 | + default=1024, | ||
269 | + type=int, | ||
270 | + help="The maximum total input sequence length after tokenization. Sequences longer " | ||
271 | + "than this will be truncated, sequences shorter will be padded.", | ||
272 | + ) | ||
273 | + parser.add_argument( | ||
274 | + "--max_target_length", | ||
275 | + default=56, | ||
276 | + type=int, | ||
277 | + help="The maximum total input sequence length after tokenization. Sequences longer " | ||
278 | + "than this will be truncated, sequences shorter will be padded.", | ||
279 | + ) | ||
280 | + parser.add_argument( | ||
281 | + "--val_max_target_length", | ||
282 | + default=142, # these defaults are optimized for CNNDM. For xsum, see README.md. | ||
283 | + type=int, | ||
284 | + help="The maximum total input sequence length after tokenization. Sequences longer " | ||
285 | + "than this will be truncated, sequences shorter will be padded.", | ||
286 | + ) | ||
287 | + parser.add_argument( | ||
288 | + "--test_max_target_length", | ||
289 | + default=142, | ||
290 | + type=int, | ||
291 | + help="The maximum total input sequence length after tokenization. Sequences longer " | ||
292 | + "than this will be truncated, sequences shorter will be padded.", | ||
293 | + ) | ||
294 | + parser.add_argument("--freeze_encoder", action="store_true") | ||
295 | + parser.add_argument("--freeze_embeds", action="store_true") | ||
296 | + parser.add_argument("--sortish_sampler", action="store_true", default=False) | ||
297 | + parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default") | ||
298 | + parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.") | ||
299 | + parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.") | ||
300 | + parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.") | ||
301 | + parser.add_argument( | ||
302 | + "--task", type=str, default="summarization", required=False, help="# examples. -1 means use all." | ||
303 | + ) | ||
304 | + parser.add_argument("--label_smoothing", type=float, default=0.0, required=False) | ||
305 | + parser.add_argument("--src_lang", type=str, default="", required=False) | ||
306 | + parser.add_argument("--tgt_lang", type=str, default="", required=False) | ||
307 | + parser.add_argument("--eval_beams", type=int, default=None, required=False) | ||
308 | + parser.add_argument("--val_metric", type=str, default=None, required=False) | ||
309 | + parser.add_argument( | ||
310 | + "--early_stopping_patience", | ||
311 | + type=int, | ||
312 | + default=-1, | ||
313 | + required=False, | ||
314 | + help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.", | ||
315 | + ) | ||
316 | + return parser | ||
317 | + | ||
318 | + | ||
319 | +class TranslationModule(SummarizationModule): | ||
320 | + mode = "translation" | ||
321 | + loss_names = ["loss"] | ||
322 | + metric_names = ["bleu"] | ||
323 | + default_val_metric = "bleu" | ||
324 | + | ||
325 | + def __init__(self, hparams, **kwargs): | ||
326 | + super().__init__(hparams, **kwargs) | ||
327 | + self.dataset_kwargs["src_lang"] = hparams.src_lang | ||
328 | + self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang | ||
329 | + | ||
330 | + def calc_generative_metrics(self, preds, target) -> dict: | ||
331 | + return calculate_bleu(preds, target) | ||
332 | + | ||
333 | + | ||
334 | +def main(args, model=None) -> SummarizationModule: | ||
335 | + Path(args.output_dir).mkdir(exist_ok=True) | ||
336 | + if len(os.listdir(args.output_dir)) > 3 and args.do_train: | ||
337 | + raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) | ||
338 | + if model is None: | ||
339 | + if args.task == "summarization": | ||
340 | + model: SummarizationModule = SummarizationModule(args) | ||
341 | + else: | ||
342 | + model: SummarizationModule = TranslationModule(args) | ||
343 | + | ||
344 | + dataset = Path(args.data_dir).name | ||
345 | + if ( | ||
346 | + args.logger_name == "default" | ||
347 | + or args.fast_dev_run | ||
348 | + or str(args.output_dir).startswith("/tmp") | ||
349 | + or str(args.output_dir).startswith("/var") | ||
350 | + ): | ||
351 | + logger = True # don't pollute wandb logs unnecessarily | ||
352 | + elif args.logger_name == "wandb": | ||
353 | + from pytorch_lightning.loggers import WandbLogger | ||
354 | + | ||
355 | + project = os.environ.get("WANDB_PROJECT", dataset) | ||
356 | + logger = WandbLogger(name=model.output_dir.name, project=project) | ||
357 | + | ||
358 | + elif args.logger_name == "wandb_shared": | ||
359 | + from pytorch_lightning.loggers import WandbLogger | ||
360 | + | ||
361 | + logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}") | ||
362 | + | ||
363 | + if args.early_stopping_patience >= 0: | ||
364 | + es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience) | ||
365 | + else: | ||
366 | + es_callback = False | ||
367 | + trainer: pl.Trainer = generic_train( | ||
368 | + model, | ||
369 | + args, | ||
370 | + logging_callback=Seq2SeqLoggingCallback(), | ||
371 | + checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric), | ||
372 | + early_stopping_callback=es_callback, | ||
373 | + logger=logger, | ||
374 | + # TODO: early stopping callback seems messed up | ||
375 | + ) | ||
376 | + pickle_save(model.hparams, model.output_dir / "hparams.pkl") | ||
377 | + if not args.do_predict: | ||
378 | + return model | ||
379 | + | ||
380 | + model.hparams.test_checkpoint = "" | ||
381 | + checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True))) | ||
382 | + if checkpoints: | ||
383 | + model.hparams.test_checkpoint = checkpoints[-1] | ||
384 | + trainer.resume_from_checkpoint = checkpoints[-1] | ||
385 | + trainer.logger.log_hyperparams(model.hparams) | ||
386 | + | ||
387 | + # test() without a model tests using the best checkpoint automatically | ||
388 | + trainer.test() | ||
389 | + return model | ||
390 | + | ||
391 | + | ||
392 | +if __name__ == "__main__": | ||
393 | + parser = argparse.ArgumentParser() | ||
394 | + parser = pl.Trainer.add_argparse_args(parser) | ||
395 | + parser = SummarizationModule.add_model_specific_args(parser, os.getcwd()) | ||
396 | + | ||
397 | + args = parser.parse_args() | ||
398 | + | ||
399 | + main(args) |
lightning_base.py
0 → 100644
1 | +import argparse | ||
2 | +import logging | ||
3 | +import os | ||
4 | +from pathlib import Path | ||
5 | +from typing import Any, Dict | ||
6 | + | ||
7 | +import pytorch_lightning as pl | ||
8 | +from pytorch_lightning.utilities import rank_zero_info | ||
9 | + | ||
10 | +from transformers import ( | ||
11 | + AdamW, | ||
12 | + AutoConfig, | ||
13 | + AutoModel, | ||
14 | + AutoModelForPreTraining, | ||
15 | + AutoModelForQuestionAnswering, | ||
16 | + AutoModelForSeq2SeqLM, | ||
17 | + AutoModelForSequenceClassification, | ||
18 | + AutoModelForTokenClassification, | ||
19 | + AutoModelWithLMHead, | ||
20 | + AutoTokenizer, | ||
21 | + PretrainedConfig, | ||
22 | + PreTrainedTokenizer, | ||
23 | +) | ||
24 | +from transformers.optimization import ( | ||
25 | + Adafactor, | ||
26 | + get_cosine_schedule_with_warmup, | ||
27 | + get_cosine_with_hard_restarts_schedule_with_warmup, | ||
28 | + get_linear_schedule_with_warmup, | ||
29 | + get_polynomial_decay_schedule_with_warmup, | ||
30 | +) | ||
31 | + | ||
32 | + | ||
33 | +logger = logging.getLogger(__name__) | ||
34 | + | ||
35 | + | ||
36 | +MODEL_MODES = { | ||
37 | + "base": AutoModel, | ||
38 | + "sequence-classification": AutoModelForSequenceClassification, | ||
39 | + "question-answering": AutoModelForQuestionAnswering, | ||
40 | + "pretraining": AutoModelForPreTraining, | ||
41 | + "token-classification": AutoModelForTokenClassification, | ||
42 | + "language-modeling": AutoModelWithLMHead, | ||
43 | + "summarization": AutoModelForSeq2SeqLM, | ||
44 | + "translation": AutoModelForSeq2SeqLM, | ||
45 | +} | ||
46 | + | ||
47 | + | ||
48 | +# update this and the import above to support new schedulers from transformers.optimization | ||
49 | +arg_to_scheduler = { | ||
50 | + "linear": get_linear_schedule_with_warmup, | ||
51 | + "cosine": get_cosine_schedule_with_warmup, | ||
52 | + "cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup, | ||
53 | + "polynomial": get_polynomial_decay_schedule_with_warmup, | ||
54 | + # '': get_constant_schedule, # not supported for now | ||
55 | + # '': get_constant_schedule_with_warmup, # not supported for now | ||
56 | +} | ||
57 | +arg_to_scheduler_choices = sorted(arg_to_scheduler.keys()) | ||
58 | +arg_to_scheduler_metavar = "{" + ", ".join(arg_to_scheduler_choices) + "}" | ||
59 | + | ||
60 | + | ||
61 | +class BaseTransformer(pl.LightningModule): | ||
62 | + def __init__( | ||
63 | + self, | ||
64 | + hparams: argparse.Namespace, | ||
65 | + num_labels=None, | ||
66 | + mode="base", | ||
67 | + config=None, | ||
68 | + tokenizer=None, | ||
69 | + model=None, | ||
70 | + **config_kwargs | ||
71 | + ): | ||
72 | + """Initialize a model, tokenizer and config.""" | ||
73 | + super().__init__() | ||
74 | + # TODO: move to self.save_hyperparameters() | ||
75 | + # self.save_hyperparameters() | ||
76 | + # can also expand arguments into trainer signature for easier reading | ||
77 | + | ||
78 | + self.save_hyperparameters(hparams) | ||
79 | + self.step_count = 0 | ||
80 | + self.output_dir = Path(self.hparams.output_dir) | ||
81 | + cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None | ||
82 | + if config is None: | ||
83 | + self.config = AutoConfig.from_pretrained( | ||
84 | + self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path, | ||
85 | + **({"num_labels": num_labels} if num_labels is not None else {}), | ||
86 | + cache_dir=cache_dir, | ||
87 | + **config_kwargs, | ||
88 | + ) | ||
89 | + else: | ||
90 | + self.config: PretrainedConfig = config | ||
91 | + | ||
92 | + extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout") | ||
93 | + for p in extra_model_params: | ||
94 | + if getattr(self.hparams, p, None): | ||
95 | + assert hasattr(self.config, p), f"model config doesn't have a `{p}` attribute" | ||
96 | + setattr(self.config, p, getattr(self.hparams, p)) | ||
97 | + | ||
98 | + if tokenizer is None: | ||
99 | + self.tokenizer = AutoTokenizer.from_pretrained( | ||
100 | + self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path, | ||
101 | + cache_dir=cache_dir, | ||
102 | + ) | ||
103 | + else: | ||
104 | + self.tokenizer: PreTrainedTokenizer = tokenizer | ||
105 | + self.model_type = MODEL_MODES[mode] | ||
106 | + if model is None: | ||
107 | + self.model = self.model_type.from_pretrained( | ||
108 | + self.hparams.model_name_or_path, | ||
109 | + from_tf=bool(".ckpt" in self.hparams.model_name_or_path), | ||
110 | + config=self.config, | ||
111 | + cache_dir=cache_dir, | ||
112 | + ) | ||
113 | + else: | ||
114 | + self.model = model | ||
115 | + | ||
116 | + def load_hf_checkpoint(self, *args, **kwargs): | ||
117 | + self.model = self.model_type.from_pretrained(*args, **kwargs) | ||
118 | + | ||
119 | + def get_lr_scheduler(self): | ||
120 | + get_schedule_func = arg_to_scheduler[self.hparams.lr_scheduler] | ||
121 | + scheduler = get_schedule_func( | ||
122 | + self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps | ||
123 | + ) | ||
124 | + scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} | ||
125 | + return scheduler | ||
126 | + | ||
127 | + def configure_optimizers(self): | ||
128 | + """Prepare optimizer and schedule (linear warmup and decay)""" | ||
129 | + model = self.model | ||
130 | + no_decay = ["bias", "LayerNorm.weight"] | ||
131 | + optimizer_grouped_parameters = [ | ||
132 | + { | ||
133 | + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], | ||
134 | + "weight_decay": self.hparams.weight_decay, | ||
135 | + }, | ||
136 | + { | ||
137 | + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], | ||
138 | + "weight_decay": 0.0, | ||
139 | + }, | ||
140 | + ] | ||
141 | + if self.hparams.adafactor: | ||
142 | + optimizer = Adafactor( | ||
143 | + optimizer_grouped_parameters, lr=self.hparams.learning_rate, scale_parameter=False, relative_step=False | ||
144 | + ) | ||
145 | + | ||
146 | + else: | ||
147 | + optimizer = AdamW( | ||
148 | + optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon | ||
149 | + ) | ||
150 | + self.opt = optimizer | ||
151 | + | ||
152 | + scheduler = self.get_lr_scheduler() | ||
153 | + | ||
154 | + return [optimizer], [scheduler] | ||
155 | + | ||
156 | + def test_step(self, batch, batch_nb): | ||
157 | + return self.validation_step(batch, batch_nb) | ||
158 | + | ||
159 | + def test_epoch_end(self, outputs): | ||
160 | + return self.validation_end(outputs) | ||
161 | + | ||
162 | + @property | ||
163 | + def total_steps(self) -> int: | ||
164 | + """The number of total training steps that will be run. Used for lr scheduler purposes.""" | ||
165 | + num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores | ||
166 | + effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices | ||
167 | + dataset_size = len(self.train_loader.dataset) | ||
168 | + return (dataset_size / effective_batch_size) * self.hparams.max_epochs | ||
169 | + | ||
170 | + def setup(self, mode): | ||
171 | + if mode == "fit": | ||
172 | + self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True) | ||
173 | + | ||
174 | + def get_dataloader(self, type_path, batch_size, shuffle=False): | ||
175 | + raise NotImplementedError("You must implement this for your task") | ||
176 | + | ||
177 | + def train_dataloader(self): | ||
178 | + return self.train_loader | ||
179 | + | ||
180 | + def val_dataloader(self): | ||
181 | + return self.get_dataloader("dev", self.hparams.eval_batch_size, shuffle=False) | ||
182 | + | ||
183 | + def test_dataloader(self): | ||
184 | + return self.get_dataloader("test", self.hparams.eval_batch_size, shuffle=False) | ||
185 | + | ||
186 | + def _feature_file(self, mode): | ||
187 | + return os.path.join( | ||
188 | + self.hparams.data_dir, | ||
189 | + "cached_{}_{}_{}".format( | ||
190 | + mode, | ||
191 | + list(filter(None, self.hparams.model_name_or_path.split("/"))).pop(), | ||
192 | + str(self.hparams.max_seq_length), | ||
193 | + ), | ||
194 | + ) | ||
195 | + | ||
196 | + @pl.utilities.rank_zero_only | ||
197 | + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: | ||
198 | + save_path = self.output_dir.joinpath("best_tfmr") | ||
199 | + self.model.config.save_step = self.step_count | ||
200 | + self.model.save_pretrained(save_path) | ||
201 | + self.tokenizer.save_pretrained(save_path) | ||
202 | + | ||
203 | + @staticmethod | ||
204 | + def add_model_specific_args(parser, root_dir): | ||
205 | + parser.add_argument( | ||
206 | + "--model_name_or_path", | ||
207 | + default=None, | ||
208 | + type=str, | ||
209 | + required=True, | ||
210 | + help="Path to pretrained model or model identifier from huggingface.co/models", | ||
211 | + ) | ||
212 | + parser.add_argument( | ||
213 | + "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" | ||
214 | + ) | ||
215 | + parser.add_argument( | ||
216 | + "--tokenizer_name", | ||
217 | + default=None, | ||
218 | + type=str, | ||
219 | + help="Pretrained tokenizer name or path if not the same as model_name", | ||
220 | + ) | ||
221 | + parser.add_argument( | ||
222 | + "--cache_dir", | ||
223 | + default="", | ||
224 | + type=str, | ||
225 | + help="Where do you want to store the pre-trained models downloaded from s3", | ||
226 | + ) | ||
227 | + parser.add_argument( | ||
228 | + "--encoder_layerdrop", | ||
229 | + type=float, | ||
230 | + help="Encoder layer dropout probability (Optional). Goes into model.config", | ||
231 | + ) | ||
232 | + parser.add_argument( | ||
233 | + "--decoder_layerdrop", | ||
234 | + type=float, | ||
235 | + help="Decoder layer dropout probability (Optional). Goes into model.config", | ||
236 | + ) | ||
237 | + parser.add_argument( | ||
238 | + "--dropout", | ||
239 | + type=float, | ||
240 | + help="Dropout probability (Optional). Goes into model.config", | ||
241 | + ) | ||
242 | + parser.add_argument( | ||
243 | + "--attention_dropout", | ||
244 | + type=float, | ||
245 | + help="Attention dropout probability (Optional). Goes into model.config", | ||
246 | + ) | ||
247 | + parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") | ||
248 | + parser.add_argument( | ||
249 | + "--lr_scheduler", | ||
250 | + default="linear", | ||
251 | + choices=arg_to_scheduler_choices, | ||
252 | + metavar=arg_to_scheduler_metavar, | ||
253 | + type=str, | ||
254 | + help="Learning rate scheduler", | ||
255 | + ) | ||
256 | + parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") | ||
257 | + parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") | ||
258 | + parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") | ||
259 | + parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader") | ||
260 | + parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int) | ||
261 | + parser.add_argument("--train_batch_size", default=32, type=int) | ||
262 | + parser.add_argument("--eval_batch_size", default=32, type=int) | ||
263 | + parser.add_argument("--adafactor", action="store_true") | ||
264 | + | ||
265 | + | ||
266 | +class LoggingCallback(pl.Callback): | ||
267 | + def on_batch_end(self, trainer, pl_module): | ||
268 | + lr_scheduler = trainer.lr_schedulers[0]["scheduler"] | ||
269 | + lrs = {f"lr_group_{i}": lr for i, lr in enumerate(lr_scheduler.get_lr())} | ||
270 | + pl_module.logger.log_metrics(lrs) | ||
271 | + | ||
272 | + def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): | ||
273 | + rank_zero_info("***** Validation results *****") | ||
274 | + metrics = trainer.callback_metrics | ||
275 | + # Log results | ||
276 | + for key in sorted(metrics): | ||
277 | + if key not in ["log", "progress_bar"]: | ||
278 | + rank_zero_info("{} = {}\n".format(key, str(metrics[key]))) | ||
279 | + | ||
280 | + def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): | ||
281 | + rank_zero_info("***** Test results *****") | ||
282 | + metrics = trainer.callback_metrics | ||
283 | + # Log and save results to file | ||
284 | + output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt") | ||
285 | + with open(output_test_results_file, "w") as writer: | ||
286 | + for key in sorted(metrics): | ||
287 | + if key not in ["log", "progress_bar"]: | ||
288 | + rank_zero_info("{} = {}\n".format(key, str(metrics[key]))) | ||
289 | + writer.write("{} = {}\n".format(key, str(metrics[key]))) | ||
290 | + | ||
291 | + | ||
292 | +def add_generic_args(parser, root_dir) -> None: | ||
293 | + # TODO(SS): allow all pl args? parser = pl.Trainer.add_argparse_args(parser) | ||
294 | + parser.add_argument( | ||
295 | + "--output_dir", | ||
296 | + default=None, | ||
297 | + type=str, | ||
298 | + required=True, | ||
299 | + help="The output directory where the model predictions and checkpoints will be written.", | ||
300 | + ) | ||
301 | + parser.add_argument( | ||
302 | + "--fp16", | ||
303 | + action="store_true", | ||
304 | + help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", | ||
305 | + ) | ||
306 | + | ||
307 | + parser.add_argument( | ||
308 | + "--fp16_opt_level", | ||
309 | + type=str, | ||
310 | + default="O2", | ||
311 | + help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." | ||
312 | + "See details at https://nvidia.github.io/apex/amp.html", | ||
313 | + ) | ||
314 | + parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int) | ||
315 | + parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm") | ||
316 | + parser.add_argument("--do_train", action="store_true", help="Whether to run training.") | ||
317 | + parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.") | ||
318 | + parser.add_argument( | ||
319 | + "--gradient_accumulation_steps", | ||
320 | + dest="accumulate_grad_batches", | ||
321 | + type=int, | ||
322 | + default=1, | ||
323 | + help="Number of updates steps to accumulate before performing a backward/update pass.", | ||
324 | + ) | ||
325 | + parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") | ||
326 | + parser.add_argument( | ||
327 | + "--data_dir", | ||
328 | + default=None, | ||
329 | + type=str, | ||
330 | + required=True, | ||
331 | + help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.", | ||
332 | + ) | ||
333 | + | ||
334 | + | ||
335 | +def generic_train( | ||
336 | + model: BaseTransformer, | ||
337 | + args: argparse.Namespace, | ||
338 | + early_stopping_callback=False, | ||
339 | + logger=True, # can pass WandbLogger() here | ||
340 | + extra_callbacks=[], | ||
341 | + checkpoint_callback=None, | ||
342 | + logging_callback=None, | ||
343 | + **extra_train_kwargs | ||
344 | +): | ||
345 | + pl.seed_everything(args.seed) | ||
346 | + | ||
347 | + # init model | ||
348 | + odir = Path(model.hparams.output_dir) | ||
349 | + odir.mkdir(exist_ok=True) | ||
350 | + | ||
351 | + # add custom checkpoints | ||
352 | + if checkpoint_callback is None: | ||
353 | + checkpoint_callback = pl.callbacks.ModelCheckpoint( | ||
354 | + filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1 | ||
355 | + ) | ||
356 | + if logging_callback is None: | ||
357 | + logging_callback = LoggingCallback() | ||
358 | + | ||
359 | + train_params = {} | ||
360 | + | ||
361 | + # TODO: remove with PyTorch 1.6 since pl uses native amp | ||
362 | + if args.fp16: | ||
363 | + train_params["precision"] = 16 | ||
364 | + train_params["amp_level"] = args.fp16_opt_level | ||
365 | + | ||
366 | + if args.gpus > 1: | ||
367 | + train_params["distributed_backend"] = "ddp" | ||
368 | + | ||
369 | + trainer = pl.Trainer.from_argparse_args( | ||
370 | + args, | ||
371 | + weights_summary=None, | ||
372 | + callbacks=[logging_callback] + extra_callbacks, | ||
373 | + logger=logger, | ||
374 | + checkpoint_callback=checkpoint_callback, | ||
375 | + early_stop_callback=early_stopping_callback, | ||
376 | + **train_params, | ||
377 | + ) | ||
378 | + | ||
379 | + if args.do_train: | ||
380 | + trainer.fit(model) | ||
381 | + | ||
382 | + return trainer |
utils.py
0 → 100644
1 | +import itertools | ||
2 | +import json | ||
3 | +import linecache | ||
4 | +import os | ||
5 | +import pickle | ||
6 | +from logging import getLogger | ||
7 | +from pathlib import Path | ||
8 | +from typing import Callable, Dict, Iterable, List | ||
9 | + | ||
10 | +import git | ||
11 | +import numpy as np | ||
12 | +import torch | ||
13 | +from rouge_score import rouge_scorer, scoring | ||
14 | +from sacrebleu import corpus_bleu | ||
15 | +from torch import nn | ||
16 | +from torch.utils.data import Dataset, Sampler | ||
17 | + | ||
18 | +from transformers import BartTokenizer | ||
19 | + | ||
20 | + | ||
21 | +def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): | ||
22 | + """From fairseq""" | ||
23 | + if target.dim() == lprobs.dim() - 1: | ||
24 | + target = target.unsqueeze(-1) | ||
25 | + nll_loss = -lprobs.gather(dim=-1, index=target) | ||
26 | + smooth_loss = -lprobs.sum(dim=-1, keepdim=True) | ||
27 | + if ignore_index is not None: | ||
28 | + pad_mask = target.eq(ignore_index) | ||
29 | + nll_loss.masked_fill_(pad_mask, 0.0) | ||
30 | + smooth_loss.masked_fill_(pad_mask, 0.0) | ||
31 | + else: | ||
32 | + nll_loss = nll_loss.squeeze(-1) | ||
33 | + smooth_loss = smooth_loss.squeeze(-1) | ||
34 | + | ||
35 | + nll_loss = nll_loss.sum() # mean()? Scared to break other math. | ||
36 | + smooth_loss = smooth_loss.sum() | ||
37 | + eps_i = epsilon / lprobs.size(-1) | ||
38 | + loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss | ||
39 | + return loss, nll_loss | ||
40 | + | ||
41 | + | ||
42 | +def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"): | ||
43 | + """Only used by LegacyDataset""" | ||
44 | + extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {} | ||
45 | + return tokenizer( | ||
46 | + [line], | ||
47 | + max_length=max_length, | ||
48 | + padding="max_length" if pad_to_max_length else None, | ||
49 | + truncation=True, | ||
50 | + return_tensors=return_tensors, | ||
51 | + **extra_kw, | ||
52 | + ) | ||
53 | + | ||
54 | + | ||
55 | +def lmap(f: Callable, x: Iterable) -> List: | ||
56 | + """list(map(f, x))""" | ||
57 | + return list(map(f, x)) | ||
58 | + | ||
59 | + | ||
60 | +def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict: | ||
61 | + """Uses sacrebleu's corpus_bleu implementation.""" | ||
62 | + return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)} | ||
63 | + | ||
64 | + | ||
65 | +def trim_batch( | ||
66 | + input_ids, | ||
67 | + pad_token_id, | ||
68 | + attention_mask=None, | ||
69 | +): | ||
70 | + """Remove columns that are populated exclusively by pad_token_id""" | ||
71 | + keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) | ||
72 | + if attention_mask is None: | ||
73 | + return input_ids[:, keep_column_mask] | ||
74 | + else: | ||
75 | + return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask]) | ||
76 | + | ||
77 | + | ||
78 | +class AbstractSeq2SeqDataset(Dataset): | ||
79 | + def __init__( | ||
80 | + self, | ||
81 | + tokenizer, | ||
82 | + data_dir, | ||
83 | + max_source_length, | ||
84 | + max_target_length, | ||
85 | + type_path="train", | ||
86 | + n_obs=None, | ||
87 | + src_lang=None, | ||
88 | + tgt_lang=None, | ||
89 | + prefix="", | ||
90 | + ): | ||
91 | + super().__init__() | ||
92 | + self.src_file = Path(data_dir).joinpath(type_path + ".source") | ||
93 | + self.tgt_file = Path(data_dir).joinpath(type_path + ".target") | ||
94 | + self.src_lens = self.get_char_lens(self.src_file) | ||
95 | + self.max_source_length = max_source_length | ||
96 | + self.max_target_length = max_target_length | ||
97 | + assert min(self.src_lens) > 0, f"found empty line in {self.src_file}" | ||
98 | + self.tokenizer = tokenizer | ||
99 | + self.prefix = prefix | ||
100 | + if n_obs is not None: | ||
101 | + self.src_lens = self.src_lens[:n_obs] | ||
102 | + self.pad_token_id = self.tokenizer.pad_token_id | ||
103 | + self.src_lang = src_lang | ||
104 | + self.tgt_lang = tgt_lang | ||
105 | + self.add_prefix_space = isinstance(self.tokenizer, BartTokenizer) | ||
106 | + | ||
107 | + def __len__(self): | ||
108 | + return len(self.src_lens) | ||
109 | + | ||
110 | + @staticmethod | ||
111 | + def get_char_lens(data_file): | ||
112 | + return [len(x) for x in Path(data_file).open().readlines()] | ||
113 | + | ||
114 | + def make_sortish_sampler(self, batch_size): | ||
115 | + return SortishSampler(self.src_lens, batch_size) | ||
116 | + | ||
117 | + def __getitem__(self, item): | ||
118 | + raise NotImplementedError("You must implement this") | ||
119 | + | ||
120 | + def collate_fn(self, batch): | ||
121 | + raise NotImplementedError("You must implement this") | ||
122 | + | ||
123 | + | ||
124 | +class LegacySeq2SeqDataset(AbstractSeq2SeqDataset): | ||
125 | + def __getitem__(self, index) -> Dict[str, torch.Tensor]: | ||
126 | + """Call tokenizer on src and tgt_lines""" | ||
127 | + index = index + 1 # linecache starts at 1 | ||
128 | + source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") | ||
129 | + tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") | ||
130 | + assert source_line, f"empty source line for index {index}" | ||
131 | + assert tgt_line, f"empty tgt line for index {index}" | ||
132 | + source_inputs = encode_line(self.tokenizer, source_line, self.max_source_length) | ||
133 | + target_inputs = encode_line(self.tokenizer, tgt_line, self.max_target_length) | ||
134 | + | ||
135 | + source_ids = source_inputs["input_ids"].squeeze() | ||
136 | + target_ids = target_inputs["input_ids"].squeeze() | ||
137 | + src_mask = source_inputs["attention_mask"].squeeze() | ||
138 | + return { | ||
139 | + "input_ids": source_ids, | ||
140 | + "attention_mask": src_mask, | ||
141 | + "labels": target_ids, | ||
142 | + } | ||
143 | + | ||
144 | + def collate_fn(self, batch) -> Dict[str, torch.Tensor]: | ||
145 | + input_ids = torch.stack([x["input_ids"] for x in batch]) | ||
146 | + masks = torch.stack([x["attention_mask"] for x in batch]) | ||
147 | + target_ids = torch.stack([x["labels"] for x in batch]) | ||
148 | + pad_token_id = self.pad_token_id | ||
149 | + y = trim_batch(target_ids, pad_token_id) | ||
150 | + source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks) | ||
151 | + batch = { | ||
152 | + "input_ids": source_ids, | ||
153 | + "attention_mask": source_mask, | ||
154 | + "labels": y, | ||
155 | + } | ||
156 | + return batch | ||
157 | + | ||
158 | + | ||
159 | +class Seq2SeqDataset(AbstractSeq2SeqDataset): | ||
160 | + """A dataset that calls prepare_seq2seq_batch.""" | ||
161 | + | ||
162 | + def __getitem__(self, index) -> Dict[str, str]: | ||
163 | + index = index + 1 # linecache starts at 1 | ||
164 | + source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") | ||
165 | + tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") | ||
166 | + assert source_line, f"empty source line for index {index}" | ||
167 | + assert tgt_line, f"empty tgt line for index {index}" | ||
168 | + return { | ||
169 | + "tgt_texts": tgt_line, | ||
170 | + "src_texts": source_line, | ||
171 | + } | ||
172 | + | ||
173 | + def collate_fn(self, batch) -> Dict[str, torch.Tensor]: | ||
174 | + """Call prepare_seq2seq_batch.""" | ||
175 | + batch_encoding = self.tokenizer.prepare_seq2seq_batch( | ||
176 | + [x["src_texts"] for x in batch], | ||
177 | + src_lang=self.src_lang, | ||
178 | + tgt_texts=[x["tgt_texts"] for x in batch], | ||
179 | + tgt_lang=self.tgt_lang, | ||
180 | + max_length=self.max_source_length, | ||
181 | + max_target_length=self.max_target_length, | ||
182 | + return_tensors="pt", | ||
183 | + add_prefix_space=self.add_prefix_space, | ||
184 | + ) | ||
185 | + return batch_encoding.data | ||
186 | + | ||
187 | + | ||
188 | +class SortishSampler(Sampler): | ||
189 | + "Go through the text data by order of src length with a bit of randomness. From fastai repo." | ||
190 | + | ||
191 | + def __init__(self, data, batch_size): | ||
192 | + self.data, self.bs = data, batch_size | ||
193 | + | ||
194 | + def key(self, i): | ||
195 | + return self.data[i] | ||
196 | + | ||
197 | + def __len__(self) -> int: | ||
198 | + return len(self.data) | ||
199 | + | ||
200 | + def __iter__(self): | ||
201 | + idxs = np.random.permutation(len(self.data)) | ||
202 | + sz = self.bs * 50 | ||
203 | + ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)] | ||
204 | + sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx]) | ||
205 | + sz = self.bs | ||
206 | + ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)] | ||
207 | + max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx]) # find the chunk with the largest key, | ||
208 | + ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first. | ||
209 | + sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int) | ||
210 | + sort_idx = np.concatenate((ck_idx[0], sort_idx)) | ||
211 | + return iter(sort_idx) | ||
212 | + | ||
213 | + | ||
214 | +logger = getLogger(__name__) | ||
215 | + | ||
216 | + | ||
217 | +def use_task_specific_params(model, task): | ||
218 | + """Update config with summarization specific params.""" | ||
219 | + task_specific_params = model.config.task_specific_params | ||
220 | + | ||
221 | + if task_specific_params is not None: | ||
222 | + pars = task_specific_params.get(task, {}) | ||
223 | + logger.info(f"using task specific params for {task}: {pars}") | ||
224 | + model.config.update(pars) | ||
225 | + | ||
226 | + | ||
227 | +def pickle_load(path): | ||
228 | + """pickle.load(path)""" | ||
229 | + with open(path, "rb") as f: | ||
230 | + return pickle.load(f) | ||
231 | + | ||
232 | + | ||
233 | +def pickle_save(obj, path): | ||
234 | + """pickle.dump(obj, path)""" | ||
235 | + with open(path, "wb") as f: | ||
236 | + return pickle.dump(obj, f) | ||
237 | + | ||
238 | + | ||
239 | +def flatten_list(summary_ids: List[List]): | ||
240 | + return [x for x in itertools.chain.from_iterable(summary_ids)] | ||
241 | + | ||
242 | + | ||
243 | +def save_git_info(folder_path: str) -> None: | ||
244 | + """Save git information to output_dir/git_log.json""" | ||
245 | + repo_infos = get_git_info() | ||
246 | + save_json(repo_infos, os.path.join(folder_path, "git_log.json")) | ||
247 | + | ||
248 | + | ||
249 | +def save_json(content, path): | ||
250 | + with open(path, "w") as f: | ||
251 | + json.dump(content, f, indent=4) | ||
252 | + | ||
253 | + | ||
254 | +def load_json(path): | ||
255 | + with open(path) as f: | ||
256 | + return json.load(f) | ||
257 | + | ||
258 | + | ||
259 | +def get_git_info(): | ||
260 | + repo = git.Repo(search_parent_directories=True) | ||
261 | + repo_infos = { | ||
262 | + "repo_id": str(repo), | ||
263 | + "repo_sha": str(repo.head.object.hexsha), | ||
264 | + "repo_branch": str(repo.active_branch), | ||
265 | + } | ||
266 | + return repo_infos | ||
267 | + | ||
268 | + | ||
269 | +ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"] | ||
270 | + | ||
271 | + | ||
272 | +def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer=True) -> Dict: | ||
273 | + scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer) | ||
274 | + aggregator = scoring.BootstrapAggregator() | ||
275 | + | ||
276 | + for reference_ln, output_ln in zip(reference_lns, output_lns): | ||
277 | + scores = scorer.score(reference_ln, output_ln) | ||
278 | + aggregator.add_scores(scores) | ||
279 | + | ||
280 | + result = aggregator.aggregate() | ||
281 | + return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()} | ||
282 | + | ||
283 | + | ||
284 | +# Utilities for freezing parameters and checking whether they are frozen | ||
285 | + | ||
286 | + | ||
287 | +def freeze_params(model: nn.Module): | ||
288 | + """Set requires_grad=False for each of model.parameters()""" | ||
289 | + for par in model.parameters(): | ||
290 | + par.requires_grad = False | ||
291 | + | ||
292 | + | ||
293 | +def grad_status(model: nn.Module) -> Iterable: | ||
294 | + return (par.requires_grad for par in model.parameters()) | ||
295 | + | ||
296 | + | ||
297 | +def any_requires_grad(model: nn.Module) -> bool: | ||
298 | + return any(grad_status(model)) | ||
299 | + | ||
300 | + | ||
301 | +def assert_all_frozen(model): | ||
302 | + model_grads: List[bool] = list(grad_status(model)) | ||
303 | + n_require_grad = sum(lmap(int, model_grads)) | ||
304 | + npars = len(model_grads) | ||
305 | + assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad" | ||
306 | + | ||
307 | + | ||
308 | +def assert_not_all_frozen(model): | ||
309 | + model_grads: List[bool] = list(grad_status(model)) | ||
310 | + npars = len(model_grads) | ||
311 | + assert any(model_grads), f"none of {npars} weights require grad" |
-
Please register or login to post a comment