graykode

(add) matorage runnable, edit cpyton to pandas in gitignore

...@@ -137,5 +137,5 @@ dmypy.json ...@@ -137,5 +137,5 @@ dmypy.json
137 # Cython debug symbols 137 # Cython debug symbols
138 cython_debug/ 138 cython_debug/
139 139
140 -cpython 140 +pandas
141 .idea/ 141 .idea/
...\ No newline at end of file ...\ No newline at end of file
......
...@@ -16,6 +16,9 @@ from lightning_base import BaseTransformer, add_generic_args, generic_train ...@@ -16,6 +16,9 @@ from lightning_base import BaseTransformer, add_generic_args, generic_train
16 from transformers import MBartTokenizer, T5ForConditionalGeneration 16 from transformers import MBartTokenizer, T5ForConditionalGeneration
17 from transformers.modeling_bart import shift_tokens_right 17 from transformers.modeling_bart import shift_tokens_right
18 18
19 +from matorage import DataConfig
20 +from matorage.torch import Dataset
21 +
19 22
20 try: 23 try:
21 from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback 24 from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
...@@ -75,18 +78,6 @@ class SummarizationModule(BaseTransformer): ...@@ -75,18 +78,6 @@ class SummarizationModule(BaseTransformer):
75 self.step_count = 0 78 self.step_count = 0
76 self.metrics = defaultdict(list) 79 self.metrics = defaultdict(list)
77 80
78 - self.dataset_kwargs: dict = dict(
79 - data_dir=self.hparams.data_dir,
80 - max_source_length=self.hparams.max_source_length,
81 - prefix=self.model.config.prefix or "",
82 - )
83 - n_observations_per_split = {
84 - "train": self.hparams.n_train,
85 - "val": self.hparams.n_val,
86 - "test": self.hparams.n_test,
87 - }
88 - self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}
89 -
90 self.target_lens = { 81 self.target_lens = {
91 "train": self.hparams.max_target_length, 82 "train": self.hparams.max_target_length,
92 "val": self.hparams.val_max_target_length, 83 "val": self.hparams.val_max_target_length,
...@@ -107,9 +98,7 @@ class SummarizationModule(BaseTransformer): ...@@ -107,9 +98,7 @@ class SummarizationModule(BaseTransformer):
107 if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer): 98 if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
108 self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang] 99 self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
109 self.model.config.decoder_start_token_id = self.decoder_start_token_id 100 self.model.config.decoder_start_token_id = self.decoder_start_token_id
110 - self.dataset_class = ( 101 +
111 - Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
112 - )
113 self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams 102 self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
114 assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1" 103 assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1"
115 self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric 104 self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
...@@ -137,8 +126,8 @@ class SummarizationModule(BaseTransformer): ...@@ -137,8 +126,8 @@ class SummarizationModule(BaseTransformer):
137 126
138 def _step(self, batch: dict) -> Tuple: 127 def _step(self, batch: dict) -> Tuple:
139 pad_token_id = self.tokenizer.pad_token_id 128 pad_token_id = self.tokenizer.pad_token_id
140 - src_ids, src_mask = batch["input_ids"], batch["attention_mask"] 129 + src_ids, src_mask, src_patch = batch[0].long(), batch[1].long(), batch[2].long()
141 - tgt_ids = batch["labels"] 130 + tgt_ids = batch[3].long()
142 if isinstance(self.model, T5ForConditionalGeneration): 131 if isinstance(self.model, T5ForConditionalGeneration):
143 decoder_input_ids = self.model._shift_right(tgt_ids) 132 decoder_input_ids = self.model._shift_right(tgt_ids)
144 else: 133 else:
...@@ -168,7 +157,7 @@ class SummarizationModule(BaseTransformer): ...@@ -168,7 +157,7 @@ class SummarizationModule(BaseTransformer):
168 157
169 logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} 158 logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
170 # tokens per batch 159 # tokens per batch
171 - logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["labels"].ne(self.pad).sum() 160 + logs["tpb"] = batch[0].long().ne(self.pad).sum() + batch[3].long().ne(self.pad).sum()
172 return {"loss": loss_tensors[0], "log": logs} 161 return {"loss": loss_tensors[0], "log": logs}
173 162
174 def validation_step(self, batch, batch_idx) -> Dict: 163 def validation_step(self, batch, batch_idx) -> Dict:
...@@ -198,14 +187,15 @@ class SummarizationModule(BaseTransformer): ...@@ -198,14 +187,15 @@ class SummarizationModule(BaseTransformer):
198 def _generative_step(self, batch: dict) -> dict: 187 def _generative_step(self, batch: dict) -> dict:
199 t0 = time.time() 188 t0 = time.time()
200 generated_ids = self.model.generate( 189 generated_ids = self.model.generate(
201 - batch["input_ids"], 190 + batch[0].long(),
202 - attention_mask=batch["attention_mask"], 191 + attention_mask=batch[1].long(),
192 + # patch_ids=batch[2].long(),
203 use_cache=True, 193 use_cache=True,
204 decoder_start_token_id=self.decoder_start_token_id, 194 decoder_start_token_id=self.decoder_start_token_id,
205 ) 195 )
206 - gen_time = (time.time() - t0) / batch["input_ids"].shape[0] 196 + gen_time = (time.time() - t0) / batch[0].shape[0]
207 preds: List[str] = self.ids_to_clean_text(generated_ids) 197 preds: List[str] = self.ids_to_clean_text(generated_ids)
208 - target: List[str] = self.ids_to_clean_text(batch["labels"]) 198 + target: List[str] = self.ids_to_clean_text(batch[3])
209 loss_tensors = self._step(batch) 199 loss_tensors = self._step(batch)
210 base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} 200 base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
211 rouge: Dict = self.calc_generative_metrics(preds, target) 201 rouge: Dict = self.calc_generative_metrics(preds, target)
...@@ -220,29 +210,34 @@ class SummarizationModule(BaseTransformer): ...@@ -220,29 +210,34 @@ class SummarizationModule(BaseTransformer):
220 return self.validation_epoch_end(outputs, prefix="test") 210 return self.validation_epoch_end(outputs, prefix="test")
221 211
222 def get_dataset(self, type_path) -> Seq2SeqDataset: 212 def get_dataset(self, type_path) -> Seq2SeqDataset:
223 - n_obs = self.n_obs[type_path]
224 max_target_length = self.target_lens[type_path] 213 max_target_length = self.target_lens[type_path]
225 - dataset = self.dataset_class( 214 + data_config = DataConfig(
226 - self.tokenizer, 215 + endpoint=args.matorage_dir,
227 - type_path=type_path, 216 + access_key=os.environ['access_key'],
228 - n_obs=n_obs, 217 + secret_key=os.environ['secret_key'],
229 - max_target_length=max_target_length, 218 + dataset_name='commit-autosuggestions',
230 - **self.dataset_kwargs, 219 + additional={
220 + "mode": ("training" if type_path == "train" else "evaluation"),
221 + "max_source_length": self.hparams.max_source_length,
222 + "max_target_length": max_target_length,
223 + "url": args.url,
224 + },
225 + attributes=[
226 + ('input_ids', 'int32', (self.hparams.max_source_length,)),
227 + ('attention_masks', 'int32', (self.hparams.max_source_length,)),
228 + ('patch_ids', 'int32', (self.hparams.max_source_length,)),
229 + ('targets', 'int32', (max_target_length,))
230 + ]
231 ) 231 )
232 - return dataset 232 + return Dataset(config=data_config, clear=True)
233 233
234 def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader: 234 def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
235 dataset = self.get_dataset(type_path) 235 dataset = self.get_dataset(type_path)
236 sampler = None 236 sampler = None
237 - if self.hparams.sortish_sampler and type_path == "train":
238 - assert self.hparams.gpus <= 1 # TODO: assert earlier
239 - sampler = dataset.make_sortish_sampler(batch_size)
240 - shuffle = False
241 237
242 dataloader = DataLoader( 238 dataloader = DataLoader(
243 dataset, 239 dataset,
244 batch_size=batch_size, 240 batch_size=batch_size,
245 - collate_fn=dataset.collate_fn,
246 shuffle=shuffle, 241 shuffle=shuffle,
247 num_workers=self.num_workers, 242 num_workers=self.num_workers,
248 sampler=sampler, 243 sampler=sampler,
...@@ -264,6 +259,18 @@ class SummarizationModule(BaseTransformer): ...@@ -264,6 +259,18 @@ class SummarizationModule(BaseTransformer):
264 BaseTransformer.add_model_specific_args(parser, root_dir) 259 BaseTransformer.add_model_specific_args(parser, root_dir)
265 add_generic_args(parser, root_dir) 260 add_generic_args(parser, root_dir)
266 parser.add_argument( 261 parser.add_argument(
262 + "--url",
263 + type=str,
264 + required=True,
265 + help="github url"
266 + )
267 + parser.add_argument(
268 + "--matorage_dir",
269 + type=str,
270 + required=True,
271 + help='matorage saved directory.'
272 + )
273 + parser.add_argument(
267 "--max_source_length", 274 "--max_source_length",
268 default=1024, 275 default=1024,
269 type=int, 276 type=int,
...@@ -341,28 +348,7 @@ def main(args, model=None) -> SummarizationModule: ...@@ -341,28 +348,7 @@ def main(args, model=None) -> SummarizationModule:
341 else: 348 else:
342 model: SummarizationModule = TranslationModule(args) 349 model: SummarizationModule = TranslationModule(args)
343 350
344 - dataset = Path(args.data_dir).name 351 + logger = True
345 - if (
346 - args.logger_name == "default"
347 - or args.fast_dev_run
348 - or str(args.output_dir).startswith("/tmp")
349 - or str(args.output_dir).startswith("/var")
350 - ):
351 - logger = True # don't pollute wandb logs unnecessarily
352 - elif args.logger_name == "wandb":
353 - from pytorch_lightning.loggers import WandbLogger
354 -
355 - project = os.environ.get("WANDB_PROJECT", dataset)
356 - logger = WandbLogger(name=model.output_dir.name, project=project)
357 -
358 - elif args.logger_name == "wandb_shared":
359 - from pytorch_lightning.loggers import WandbLogger
360 -
361 - logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
362 -
363 - if args.early_stopping_patience >= 0:
364 - es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
365 - else:
366 es_callback = False 352 es_callback = False
367 trainer: pl.Trainer = generic_train( 353 trainer: pl.Trainer = generic_train(
368 model, 354 model,
......
...@@ -323,13 +323,6 @@ def add_generic_args(parser, root_dir) -> None: ...@@ -323,13 +323,6 @@ def add_generic_args(parser, root_dir) -> None:
323 help="Number of updates steps to accumulate before performing a backward/update pass.", 323 help="Number of updates steps to accumulate before performing a backward/update pass.",
324 ) 324 )
325 parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 325 parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
326 - parser.add_argument(
327 - "--data_dir",
328 - default=None,
329 - type=str,
330 - required=True,
331 - help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
332 - )
333 326
334 327
335 def generic_train( 328 def generic_train(
......