Showing
3 changed files
with
44 additions
and
65 deletions
... | @@ -16,6 +16,9 @@ from lightning_base import BaseTransformer, add_generic_args, generic_train | ... | @@ -16,6 +16,9 @@ from lightning_base import BaseTransformer, add_generic_args, generic_train |
16 | from transformers import MBartTokenizer, T5ForConditionalGeneration | 16 | from transformers import MBartTokenizer, T5ForConditionalGeneration |
17 | from transformers.modeling_bart import shift_tokens_right | 17 | from transformers.modeling_bart import shift_tokens_right |
18 | 18 | ||
19 | +from matorage import DataConfig | ||
20 | +from matorage.torch import Dataset | ||
21 | + | ||
19 | 22 | ||
20 | try: | 23 | try: |
21 | from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback | 24 | from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback |
... | @@ -75,18 +78,6 @@ class SummarizationModule(BaseTransformer): | ... | @@ -75,18 +78,6 @@ class SummarizationModule(BaseTransformer): |
75 | self.step_count = 0 | 78 | self.step_count = 0 |
76 | self.metrics = defaultdict(list) | 79 | self.metrics = defaultdict(list) |
77 | 80 | ||
78 | - self.dataset_kwargs: dict = dict( | ||
79 | - data_dir=self.hparams.data_dir, | ||
80 | - max_source_length=self.hparams.max_source_length, | ||
81 | - prefix=self.model.config.prefix or "", | ||
82 | - ) | ||
83 | - n_observations_per_split = { | ||
84 | - "train": self.hparams.n_train, | ||
85 | - "val": self.hparams.n_val, | ||
86 | - "test": self.hparams.n_test, | ||
87 | - } | ||
88 | - self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()} | ||
89 | - | ||
90 | self.target_lens = { | 81 | self.target_lens = { |
91 | "train": self.hparams.max_target_length, | 82 | "train": self.hparams.max_target_length, |
92 | "val": self.hparams.val_max_target_length, | 83 | "val": self.hparams.val_max_target_length, |
... | @@ -107,9 +98,7 @@ class SummarizationModule(BaseTransformer): | ... | @@ -107,9 +98,7 @@ class SummarizationModule(BaseTransformer): |
107 | if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer): | 98 | if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer): |
108 | self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang] | 99 | self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang] |
109 | self.model.config.decoder_start_token_id = self.decoder_start_token_id | 100 | self.model.config.decoder_start_token_id = self.decoder_start_token_id |
110 | - self.dataset_class = ( | 101 | + |
111 | - Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset | ||
112 | - ) | ||
113 | self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams | 102 | self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams |
114 | assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1" | 103 | assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1" |
115 | self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric | 104 | self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric |
... | @@ -137,8 +126,8 @@ class SummarizationModule(BaseTransformer): | ... | @@ -137,8 +126,8 @@ class SummarizationModule(BaseTransformer): |
137 | 126 | ||
138 | def _step(self, batch: dict) -> Tuple: | 127 | def _step(self, batch: dict) -> Tuple: |
139 | pad_token_id = self.tokenizer.pad_token_id | 128 | pad_token_id = self.tokenizer.pad_token_id |
140 | - src_ids, src_mask = batch["input_ids"], batch["attention_mask"] | 129 | + src_ids, src_mask, src_patch = batch[0].long(), batch[1].long(), batch[2].long() |
141 | - tgt_ids = batch["labels"] | 130 | + tgt_ids = batch[3].long() |
142 | if isinstance(self.model, T5ForConditionalGeneration): | 131 | if isinstance(self.model, T5ForConditionalGeneration): |
143 | decoder_input_ids = self.model._shift_right(tgt_ids) | 132 | decoder_input_ids = self.model._shift_right(tgt_ids) |
144 | else: | 133 | else: |
... | @@ -168,7 +157,7 @@ class SummarizationModule(BaseTransformer): | ... | @@ -168,7 +157,7 @@ class SummarizationModule(BaseTransformer): |
168 | 157 | ||
169 | logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} | 158 | logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} |
170 | # tokens per batch | 159 | # tokens per batch |
171 | - logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["labels"].ne(self.pad).sum() | 160 | + logs["tpb"] = batch[0].long().ne(self.pad).sum() + batch[3].long().ne(self.pad).sum() |
172 | return {"loss": loss_tensors[0], "log": logs} | 161 | return {"loss": loss_tensors[0], "log": logs} |
173 | 162 | ||
174 | def validation_step(self, batch, batch_idx) -> Dict: | 163 | def validation_step(self, batch, batch_idx) -> Dict: |
... | @@ -198,14 +187,15 @@ class SummarizationModule(BaseTransformer): | ... | @@ -198,14 +187,15 @@ class SummarizationModule(BaseTransformer): |
198 | def _generative_step(self, batch: dict) -> dict: | 187 | def _generative_step(self, batch: dict) -> dict: |
199 | t0 = time.time() | 188 | t0 = time.time() |
200 | generated_ids = self.model.generate( | 189 | generated_ids = self.model.generate( |
201 | - batch["input_ids"], | 190 | + batch[0].long(), |
202 | - attention_mask=batch["attention_mask"], | 191 | + attention_mask=batch[1].long(), |
192 | + # patch_ids=batch[2].long(), | ||
203 | use_cache=True, | 193 | use_cache=True, |
204 | decoder_start_token_id=self.decoder_start_token_id, | 194 | decoder_start_token_id=self.decoder_start_token_id, |
205 | ) | 195 | ) |
206 | - gen_time = (time.time() - t0) / batch["input_ids"].shape[0] | 196 | + gen_time = (time.time() - t0) / batch[0].shape[0] |
207 | preds: List[str] = self.ids_to_clean_text(generated_ids) | 197 | preds: List[str] = self.ids_to_clean_text(generated_ids) |
208 | - target: List[str] = self.ids_to_clean_text(batch["labels"]) | 198 | + target: List[str] = self.ids_to_clean_text(batch[3]) |
209 | loss_tensors = self._step(batch) | 199 | loss_tensors = self._step(batch) |
210 | base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} | 200 | base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} |
211 | rouge: Dict = self.calc_generative_metrics(preds, target) | 201 | rouge: Dict = self.calc_generative_metrics(preds, target) |
... | @@ -220,29 +210,34 @@ class SummarizationModule(BaseTransformer): | ... | @@ -220,29 +210,34 @@ class SummarizationModule(BaseTransformer): |
220 | return self.validation_epoch_end(outputs, prefix="test") | 210 | return self.validation_epoch_end(outputs, prefix="test") |
221 | 211 | ||
222 | def get_dataset(self, type_path) -> Seq2SeqDataset: | 212 | def get_dataset(self, type_path) -> Seq2SeqDataset: |
223 | - n_obs = self.n_obs[type_path] | ||
224 | max_target_length = self.target_lens[type_path] | 213 | max_target_length = self.target_lens[type_path] |
225 | - dataset = self.dataset_class( | 214 | + data_config = DataConfig( |
226 | - self.tokenizer, | 215 | + endpoint=args.matorage_dir, |
227 | - type_path=type_path, | 216 | + access_key=os.environ['access_key'], |
228 | - n_obs=n_obs, | 217 | + secret_key=os.environ['secret_key'], |
229 | - max_target_length=max_target_length, | 218 | + dataset_name='commit-autosuggestions', |
230 | - **self.dataset_kwargs, | 219 | + additional={ |
220 | + "mode": ("training" if type_path == "train" else "evaluation"), | ||
221 | + "max_source_length": self.hparams.max_source_length, | ||
222 | + "max_target_length": max_target_length, | ||
223 | + "url": args.url, | ||
224 | + }, | ||
225 | + attributes=[ | ||
226 | + ('input_ids', 'int32', (self.hparams.max_source_length,)), | ||
227 | + ('attention_masks', 'int32', (self.hparams.max_source_length,)), | ||
228 | + ('patch_ids', 'int32', (self.hparams.max_source_length,)), | ||
229 | + ('targets', 'int32', (max_target_length,)) | ||
230 | + ] | ||
231 | ) | 231 | ) |
232 | - return dataset | 232 | + return Dataset(config=data_config, clear=True) |
233 | 233 | ||
234 | def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader: | 234 | def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader: |
235 | dataset = self.get_dataset(type_path) | 235 | dataset = self.get_dataset(type_path) |
236 | sampler = None | 236 | sampler = None |
237 | - if self.hparams.sortish_sampler and type_path == "train": | ||
238 | - assert self.hparams.gpus <= 1 # TODO: assert earlier | ||
239 | - sampler = dataset.make_sortish_sampler(batch_size) | ||
240 | - shuffle = False | ||
241 | 237 | ||
242 | dataloader = DataLoader( | 238 | dataloader = DataLoader( |
243 | dataset, | 239 | dataset, |
244 | batch_size=batch_size, | 240 | batch_size=batch_size, |
245 | - collate_fn=dataset.collate_fn, | ||
246 | shuffle=shuffle, | 241 | shuffle=shuffle, |
247 | num_workers=self.num_workers, | 242 | num_workers=self.num_workers, |
248 | sampler=sampler, | 243 | sampler=sampler, |
... | @@ -264,6 +259,18 @@ class SummarizationModule(BaseTransformer): | ... | @@ -264,6 +259,18 @@ class SummarizationModule(BaseTransformer): |
264 | BaseTransformer.add_model_specific_args(parser, root_dir) | 259 | BaseTransformer.add_model_specific_args(parser, root_dir) |
265 | add_generic_args(parser, root_dir) | 260 | add_generic_args(parser, root_dir) |
266 | parser.add_argument( | 261 | parser.add_argument( |
262 | + "--url", | ||
263 | + type=str, | ||
264 | + required=True, | ||
265 | + help="github url" | ||
266 | + ) | ||
267 | + parser.add_argument( | ||
268 | + "--matorage_dir", | ||
269 | + type=str, | ||
270 | + required=True, | ||
271 | + help='matorage saved directory.' | ||
272 | + ) | ||
273 | + parser.add_argument( | ||
267 | "--max_source_length", | 274 | "--max_source_length", |
268 | default=1024, | 275 | default=1024, |
269 | type=int, | 276 | type=int, |
... | @@ -341,28 +348,7 @@ def main(args, model=None) -> SummarizationModule: | ... | @@ -341,28 +348,7 @@ def main(args, model=None) -> SummarizationModule: |
341 | else: | 348 | else: |
342 | model: SummarizationModule = TranslationModule(args) | 349 | model: SummarizationModule = TranslationModule(args) |
343 | 350 | ||
344 | - dataset = Path(args.data_dir).name | 351 | + logger = True |
345 | - if ( | ||
346 | - args.logger_name == "default" | ||
347 | - or args.fast_dev_run | ||
348 | - or str(args.output_dir).startswith("/tmp") | ||
349 | - or str(args.output_dir).startswith("/var") | ||
350 | - ): | ||
351 | - logger = True # don't pollute wandb logs unnecessarily | ||
352 | - elif args.logger_name == "wandb": | ||
353 | - from pytorch_lightning.loggers import WandbLogger | ||
354 | - | ||
355 | - project = os.environ.get("WANDB_PROJECT", dataset) | ||
356 | - logger = WandbLogger(name=model.output_dir.name, project=project) | ||
357 | - | ||
358 | - elif args.logger_name == "wandb_shared": | ||
359 | - from pytorch_lightning.loggers import WandbLogger | ||
360 | - | ||
361 | - logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}") | ||
362 | - | ||
363 | - if args.early_stopping_patience >= 0: | ||
364 | - es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience) | ||
365 | - else: | ||
366 | es_callback = False | 352 | es_callback = False |
367 | trainer: pl.Trainer = generic_train( | 353 | trainer: pl.Trainer = generic_train( |
368 | model, | 354 | model, | ... | ... |
... | @@ -323,13 +323,6 @@ def add_generic_args(parser, root_dir) -> None: | ... | @@ -323,13 +323,6 @@ def add_generic_args(parser, root_dir) -> None: |
323 | help="Number of updates steps to accumulate before performing a backward/update pass.", | 323 | help="Number of updates steps to accumulate before performing a backward/update pass.", |
324 | ) | 324 | ) |
325 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") | 325 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") |
326 | - parser.add_argument( | ||
327 | - "--data_dir", | ||
328 | - default=None, | ||
329 | - type=str, | ||
330 | - required=True, | ||
331 | - help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.", | ||
332 | - ) | ||
333 | 326 | ||
334 | 327 | ||
335 | def generic_train( | 328 | def generic_train( | ... | ... |
-
Please register or login to post a comment