graykode

(refactor) black style

...@@ -68,6 +68,7 @@ def main(args): ...@@ -68,6 +68,7 @@ def main(args):
68 ) 68 )
69 print(commit_message) 69 print(commit_message)
70 70
71 +
71 if __name__ == "__main__": 72 if __name__ == "__main__":
72 parser = argparse.ArgumentParser(description="Code to collect commits on github") 73 parser = argparse.ArgumentParser(description="Code to collect commits on github")
73 parser.add_argument( 74 parser.add_argument(
......
...@@ -15,6 +15,6 @@ ...@@ -15,6 +15,6 @@
15 from .gitcommit import diff_parse, truncate 15 from .gitcommit import diff_parse, truncate
16 16
17 __all__ = [ 17 __all__ = [
18 - 'diff_parse', 18 + "diff_parse",
19 - 'truncate', 19 + "truncate",
20 ] 20 ]
......
...@@ -36,9 +36,11 @@ logging.basicConfig( ...@@ -36,9 +36,11 @@ logging.basicConfig(
36 level=logging.INFO, 36 level=logging.INFO,
37 ) 37 )
38 38
39 +
39 class PATCH(enum.Enum): 40 class PATCH(enum.Enum):
40 - PLUS=1 41 + PLUS = 1
41 - MINUS=2 42 + MINUS = 2
43 +
42 44
43 def truncate(tuple, max_length, value=0): 45 def truncate(tuple, max_length, value=0):
44 ls = [] 46 ls = []
...@@ -46,22 +48,20 @@ def truncate(tuple, max_length, value=0): ...@@ -46,22 +48,20 @@ def truncate(tuple, max_length, value=0):
46 if isinstance(t, int): 48 if isinstance(t, int):
47 t = [t] 49 t = [t]
48 ls.extend(t) 50 ls.extend(t)
49 - ls = ls[:max_length - 1] 51 + ls = ls[: max_length - 1]
50 ls.insert(0, value) 52 ls.insert(0, value)
51 if len(ls) < max_length: 53 if len(ls) < max_length:
52 ls.extend([0] * (max_length - len(ls))) 54 ls.extend([0] * (max_length - len(ls)))
53 assert len(ls) == max_length 55 assert len(ls) == max_length
54 return ls 56 return ls
55 57
58 +
56 def encode_line(tokenizer, line, patch): 59 def encode_line(tokenizer, line, patch):
57 - line = re.sub(r'[\u0100-\uFFFF\U00010000-\U0010FFFF]+', '', line).strip() 60 + line = re.sub(r"[\u0100-\uFFFF\U00010000-\U0010FFFF]+", "", line).strip()
58 tokens = tokenizer.tokenize(line) 61 tokens = tokenizer.tokenize(line)
59 tokens = tokenizer.convert_tokens_to_ids(tokens) 62 tokens = tokenizer.convert_tokens_to_ids(tokens)
60 - return ( 63 + return (tokens, [1] * len(tokens), len(tokens) * [patch.value])
61 - tokens, 64 +
62 - [1] * len(tokens),
63 - len(tokens) * [patch.value]
64 - )
65 65
66 def diff_parse(diff, tokenizer): 66 def diff_parse(diff, tokenizer):
67 chunks = [] 67 chunks = []
...@@ -78,6 +78,7 @@ def diff_parse(diff, tokenizer): ...@@ -78,6 +78,7 @@ def diff_parse(diff, tokenizer):
78 chunks.append(encode_line(tokenizer, change.line, PATCH.MINUS)) 78 chunks.append(encode_line(tokenizer, change.line, PATCH.MINUS))
79 return chunks 79 return chunks
80 80
81 +
81 def sha_parse(sha, tokenizer, max_length=1024): 82 def sha_parse(sha, tokenizer, max_length=1024):
82 83
83 chunks = diff_parse(diff=repo.git.show(sha), tokenizer=tokenizer) 84 chunks = diff_parse(diff=repo.git.show(sha), tokenizer=tokenizer)
...@@ -91,16 +92,18 @@ def sha_parse(sha, tokenizer, max_length=1024): ...@@ -91,16 +92,18 @@ def sha_parse(sha, tokenizer, max_length=1024):
91 92
92 return (input_ids, attention_masks, patch_ids) 93 return (input_ids, attention_masks, patch_ids)
93 94
95 +
94 def message_parse(msg, tokenizer, max_length=56): 96 def message_parse(msg, tokenizer, max_length=56):
95 - msg = re.sub(r'(\(|)#([0-9])+(\)|)', '', msg) 97 + msg = re.sub(r"(\(|)#([0-9])+(\)|)", "", msg)
96 98
97 - msg = re.sub(r'[\u0100-\uFFFF\U00010000-\U0010FFFF]+', '', msg).strip() 99 + msg = re.sub(r"[\u0100-\uFFFF\U00010000-\U0010FFFF]+", "", msg).strip()
98 msg = tokenizer.tokenize(msg) 100 msg = tokenizer.tokenize(msg)
99 msg = tokenizer.convert_tokens_to_ids(msg) 101 msg = tokenizer.convert_tokens_to_ids(msg)
100 msg = truncate(msg, max_length, value=0) 102 msg = truncate(msg, max_length, value=0)
101 103
102 return msg 104 return msg
103 105
106 +
104 def jobs(sha_msgs, args, data_config, train=True): 107 def jobs(sha_msgs, args, data_config, train=True):
105 108
106 input_ids, attention_masks, patch_ids, targets = [], [], [], [] 109 input_ids, attention_masks, patch_ids, targets = [], [], [], []
...@@ -110,9 +113,7 @@ def jobs(sha_msgs, args, data_config, train=True): ...@@ -110,9 +113,7 @@ def jobs(sha_msgs, args, data_config, train=True):
110 sha, msg = sha_msg 113 sha, msg = sha_msg
111 114
112 source = sha_parse( 115 source = sha_parse(
113 - sha, 116 + sha, tokenizer=args.tokenizer, max_length=args.max_source_length
114 - tokenizer=args.tokenizer,
115 - max_length=args.max_source_length
116 ) 117 )
117 if not source: 118 if not source:
118 continue 119 continue
...@@ -120,7 +121,9 @@ def jobs(sha_msgs, args, data_config, train=True): ...@@ -120,7 +121,9 @@ def jobs(sha_msgs, args, data_config, train=True):
120 target = message_parse( 121 target = message_parse(
121 msg, 122 msg,
122 tokenizer=args.tokenizer, 123 tokenizer=args.tokenizer,
123 - max_length=(args.max_target_length if train else args.val_max_target_length), 124 + max_length=(
125 + args.max_target_length if train else args.val_max_target_length
126 + ),
124 ) 127 )
125 128
126 input_ids.append(input_id) 129 input_ids.append(input_id)
...@@ -128,14 +131,17 @@ def jobs(sha_msgs, args, data_config, train=True): ...@@ -128,14 +131,17 @@ def jobs(sha_msgs, args, data_config, train=True):
128 patch_ids.append(patch_id) 131 patch_ids.append(patch_id)
129 targets.append(target) 132 targets.append(target)
130 133
131 - data_saver({ 134 + data_saver(
135 + {
132 "input_ids": np.asarray(input_ids), 136 "input_ids": np.asarray(input_ids),
133 "attention_masks": np.asarray(attention_masks), 137 "attention_masks": np.asarray(attention_masks),
134 "patch_ids": np.asarray(patch_ids), 138 "patch_ids": np.asarray(patch_ids),
135 "targets": np.asarray(targets), 139 "targets": np.asarray(targets),
136 - }) 140 + }
141 + )
137 data_saver.disconnect() 142 data_saver.disconnect()
138 143
144 +
139 def start(chunked_sha_msgs, train=True): 145 def start(chunked_sha_msgs, train=True):
140 146
141 logger.info(f"Start %s pre-processing" % ("training" if train else "evaluation")) 147 logger.info(f"Start %s pre-processing" % ("training" if train else "evaluation"))
...@@ -144,22 +150,22 @@ def start(chunked_sha_msgs, train=True): ...@@ -144,22 +150,22 @@ def start(chunked_sha_msgs, train=True):
144 150
145 data_config = DataConfig( 151 data_config = DataConfig(
146 endpoint=args.endpoint, 152 endpoint=args.endpoint,
147 - access_key=os.environ['access_key'], 153 + access_key=os.environ["access_key"],
148 - secret_key=os.environ['secret_key'], 154 + secret_key=os.environ["secret_key"],
149 region=args.region, 155 region=args.region,
150 - dataset_name='commit-autosuggestions', 156 + dataset_name="commit-autosuggestions",
151 additional={ 157 additional={
152 - "mode" : ("training" if train else "evaluation"), 158 + "mode": ("training" if train else "evaluation"),
153 "max_source_length": args.max_source_length, 159 "max_source_length": args.max_source_length,
154 "max_target_length": max_target_length, 160 "max_target_length": max_target_length,
155 - "url" : args.url, 161 + "url": args.url,
156 }, 162 },
157 attributes=[ 163 attributes=[
158 - ('input_ids', 'int32', (args.max_source_length,)), 164 + ("input_ids", "int32", (args.max_source_length,)),
159 - ('attention_masks', 'int32', (args.max_source_length,)), 165 + ("attention_masks", "int32", (args.max_source_length,)),
160 - ('patch_ids', 'int32', (args.max_source_length,)), 166 + ("patch_ids", "int32", (args.max_source_length,)),
161 - ('targets', 'int32', (max_target_length,)) 167 + ("targets", "int32", (max_target_length,)),
162 - ] 168 + ],
163 ) 169 )
164 170
165 func = partial(jobs, args=args, data_config=data_config, train=train) 171 func = partial(jobs, args=args, data_config=data_config, train=train)
...@@ -168,14 +174,15 @@ def start(chunked_sha_msgs, train=True): ...@@ -168,14 +174,15 @@ def start(chunked_sha_msgs, train=True):
168 for i, _ in tqdm(enumerate(pool.imap_unordered(func, chunked_sha_msgs))): 174 for i, _ in tqdm(enumerate(pool.imap_unordered(func, chunked_sha_msgs))):
169 pbar.update() 175 pbar.update()
170 176
177 +
171 def main(args): 178 def main(args):
172 - if 'access_key' not in os.environ or 'secret_key' not in os.environ: 179 + if "access_key" not in os.environ or "secret_key" not in os.environ:
173 raise OSError("access_key or secret_key are not found.") 180 raise OSError("access_key or secret_key are not found.")
174 181
175 sha_msgs = [(c.hexsha, c.summary) for c in repo.iter_commits()] 182 sha_msgs = [(c.hexsha, c.summary) for c in repo.iter_commits()]
176 random.shuffle(sha_msgs) 183 random.shuffle(sha_msgs)
177 chunked_sha_msgs = [ 184 chunked_sha_msgs = [
178 - sha_msgs[x:x + args.matorage_batch] 185 + sha_msgs[x : x + args.matorage_batch]
179 for x in range(0, len(sha_msgs), args.matorage_batch) 186 for x in range(0, len(sha_msgs), args.matorage_batch)
180 ] 187 ]
181 188
...@@ -185,29 +192,25 @@ def main(args): ...@@ -185,29 +192,25 @@ def main(args):
185 if args.do_predict: 192 if args.do_predict:
186 start(chunked_sha_msgs[barrier:], train=False) 193 start(chunked_sha_msgs[barrier:], train=False)
187 194
195 +
188 if __name__ == "__main__": 196 if __name__ == "__main__":
189 parser = argparse.ArgumentParser(description="Code to collect commits on github") 197 parser = argparse.ArgumentParser(description="Code to collect commits on github")
190 - parser.add_argument( 198 + parser.add_argument("--url", type=str, required=True, help="github url")
191 - "--url",
192 - type=str,
193 - required=True,
194 - help="github url"
195 - )
196 parser.add_argument( 199 parser.add_argument(
197 "--endpoint", 200 "--endpoint",
198 type=str, 201 type=str,
199 required=True, 202 required=True,
200 - help='matorage endpoint, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html' 203 + help="matorage endpoint, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html",
201 ) 204 )
202 parser.add_argument( 205 parser.add_argument(
203 "--region", 206 "--region",
204 type=str, 207 type=str,
205 default=None, 208 default=None,
206 - help='matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html' 209 + help="matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html",
207 ) 210 )
208 parser.add_argument( 211 parser.add_argument(
209 "--tokenizer_name", 212 "--tokenizer_name",
210 - default='sshleifer/distilbart-xsum-6-6', 213 + default="sshleifer/distilbart-xsum-6-6",
211 type=str, 214 type=str,
212 help="Pretrained tokenizer name or path if not the same as model_name", 215 help="Pretrained tokenizer name or path if not the same as model_name",
213 ) 216 )
...@@ -215,13 +218,10 @@ if __name__ == "__main__": ...@@ -215,13 +218,10 @@ if __name__ == "__main__":
215 "--matorage_batch", 218 "--matorage_batch",
216 default=1024, 219 default=1024,
217 type=int, 220 type=int,
218 - help='The smallest batch size stored atomically in matorage.' 221 + help="The smallest batch size stored atomically in matorage.",
219 ) 222 )
220 parser.add_argument( 223 parser.add_argument(
221 - "--num_workers", 224 + "--num_workers", default=4, type=int, help="number of process",
222 - default=4,
223 - type=int,
224 - help="number of process",
225 ) 225 )
226 parser.add_argument( 226 parser.add_argument(
227 "--max_source_length", 227 "--max_source_length",
...@@ -244,12 +244,14 @@ if __name__ == "__main__": ...@@ -244,12 +244,14 @@ if __name__ == "__main__":
244 help="The maximum total input sequence length after tokenization. Sequences longer " 244 help="The maximum total input sequence length after tokenization. Sequences longer "
245 "than this will be truncated, sequences shorter will be padded.", 245 "than this will be truncated, sequences shorter will be padded.",
246 ) 246 )
247 - parser.add_argument("--p_val", type=float, default=0.25, help="percent of validation dataset") 247 + parser.add_argument(
248 + "--p_val", type=float, default=0.25, help="percent of validation dataset"
249 + )
248 parser.add_argument("--do_train", action="store_true", default=False) 250 parser.add_argument("--do_train", action="store_true", default=False)
249 parser.add_argument("--do_predict", action="store_true", default=False) 251 parser.add_argument("--do_predict", action="store_true", default=False)
250 args = parser.parse_args() 252 args = parser.parse_args()
251 253
252 - args.local_path = args.url.split('/')[-1] 254 + args.local_path = args.url.split("/")[-1]
253 logger.info(f"master branch of {args.url} will be downloaded to {args.local_path}") 255 logger.info(f"master branch of {args.url} will be downloaded to {args.local_path}")
254 repo = ( 256 repo = (
255 Repo(args.local_path) 257 Repo(args.local_path)
......
...@@ -14,6 +14,4 @@ ...@@ -14,6 +14,4 @@
14 14
15 from .modeling_bart import BartForConditionalGeneration 15 from .modeling_bart import BartForConditionalGeneration
16 16
17 -__all__ = [
18 - 'BartForConditionalGeneration'
19 -]
...\ No newline at end of file ...\ No newline at end of file
17 +__all__ = ["BartForConditionalGeneration"]
......
...@@ -20,16 +20,31 @@ logger = logging.getLogger(__name__) ...@@ -20,16 +20,31 @@ logger = logging.getLogger(__name__)
20 20
21 class Seq2SeqLoggingCallback(pl.Callback): 21 class Seq2SeqLoggingCallback(pl.Callback):
22 def on_batch_end(self, trainer, pl_module): 22 def on_batch_end(self, trainer, pl_module):
23 - lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)} 23 + lrs = {
24 + f"lr_group_{i}": param["lr"]
25 + for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)
26 + }
24 pl_module.logger.log_metrics(lrs) 27 pl_module.logger.log_metrics(lrs)
25 28
26 @rank_zero_only 29 @rank_zero_only
27 def _write_logs( 30 def _write_logs(
28 - self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True 31 + self,
32 + trainer: pl.Trainer,
33 + pl_module: pl.LightningModule,
34 + type_path: str,
35 + save_generations=True,
29 ) -> None: 36 ) -> None:
30 - logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****") 37 + logger.info(
38 + f"***** {type_path} results at step {trainer.global_step:05d} *****"
39 + )
31 metrics = trainer.callback_metrics 40 metrics = trainer.callback_metrics
32 - trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]}) 41 + trainer.logger.log_metrics(
42 + {
43 + k: v
44 + for k, v in metrics.items()
45 + if k not in ["log", "progress_bar", "preds"]
46 + }
47 + )
33 # Log results 48 # Log results
34 od = Path(pl_module.hparams.output_dir) 49 od = Path(pl_module.hparams.output_dir)
35 if type_path == "test": 50 if type_path == "test":
...@@ -39,7 +54,9 @@ class Seq2SeqLoggingCallback(pl.Callback): ...@@ -39,7 +54,9 @@ class Seq2SeqLoggingCallback(pl.Callback):
39 # this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json 54 # this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json
40 # If people want this it will be easy enough to add back. 55 # If people want this it will be easy enough to add back.
41 results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt" 56 results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt"
42 - generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt" 57 + generations_file = (
58 + od / f"{type_path}_generations/{trainer.global_step:05d}.txt"
59 + )
43 results_file.parent.mkdir(exist_ok=True) 60 results_file.parent.mkdir(exist_ok=True)
44 generations_file.parent.mkdir(exist_ok=True) 61 generations_file.parent.mkdir(exist_ok=True)
45 with open(results_file, "a+") as writer: 62 with open(results_file, "a+") as writer:
...@@ -68,7 +85,9 @@ class Seq2SeqLoggingCallback(pl.Callback): ...@@ -68,7 +85,9 @@ class Seq2SeqLoggingCallback(pl.Callback):
68 85
69 n_trainable_pars = count_trainable_parameters(pl_module) 86 n_trainable_pars = count_trainable_parameters(pl_module)
70 # mp stands for million parameters 87 # mp stands for million parameters
71 - trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6}) 88 + trainer.logger.log_metrics(
89 + {"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6}
90 + )
72 91
73 @rank_zero_only 92 @rank_zero_only
74 def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): 93 def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
...@@ -98,8 +117,5 @@ def get_checkpoint_callback(output_dir, metric): ...@@ -98,8 +117,5 @@ def get_checkpoint_callback(output_dir, metric):
98 117
99 def get_early_stopping_callback(metric, patience): 118 def get_early_stopping_callback(metric, patience):
100 return EarlyStopping( 119 return EarlyStopping(
101 - monitor=f"val_{metric}", 120 + monitor=f"val_{metric}", mode="max", patience=patience, verbose=True,
102 - mode="max",
103 - patience=patience,
104 - verbose=True,
105 ) 121 )
......
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
...@@ -69,7 +69,7 @@ class BaseTransformer(pl.LightningModule): ...@@ -69,7 +69,7 @@ class BaseTransformer(pl.LightningModule):
69 config=None, 69 config=None,
70 tokenizer=None, 70 tokenizer=None,
71 model=None, 71 model=None,
72 - **config_kwargs 72 + **config_kwargs,
73 ): 73 ):
74 """Initialize a model, tokenizer and config.""" 74 """Initialize a model, tokenizer and config."""
75 super().__init__() 75 super().__init__()
...@@ -83,7 +83,9 @@ class BaseTransformer(pl.LightningModule): ...@@ -83,7 +83,9 @@ class BaseTransformer(pl.LightningModule):
83 cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None 83 cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
84 if config is None: 84 if config is None:
85 self.config = AutoConfig.from_pretrained( 85 self.config = AutoConfig.from_pretrained(
86 - self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path, 86 + self.hparams.config_name
87 + if self.hparams.config_name
88 + else self.hparams.model_name_or_path,
87 **({"num_labels": num_labels} if num_labels is not None else {}), 89 **({"num_labels": num_labels} if num_labels is not None else {}),
88 cache_dir=cache_dir, 90 cache_dir=cache_dir,
89 **config_kwargs, 91 **config_kwargs,
...@@ -91,15 +93,24 @@ class BaseTransformer(pl.LightningModule): ...@@ -91,15 +93,24 @@ class BaseTransformer(pl.LightningModule):
91 else: 93 else:
92 self.config: PretrainedConfig = config 94 self.config: PretrainedConfig = config
93 95
94 - extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout") 96 + extra_model_params = (
97 + "encoder_layerdrop",
98 + "decoder_layerdrop",
99 + "dropout",
100 + "attention_dropout",
101 + )
95 for p in extra_model_params: 102 for p in extra_model_params:
96 if getattr(self.hparams, p, None): 103 if getattr(self.hparams, p, None):
97 - assert hasattr(self.config, p), f"model config doesn't have a `{p}` attribute" 104 + assert hasattr(
105 + self.config, p
106 + ), f"model config doesn't have a `{p}` attribute"
98 setattr(self.config, p, getattr(self.hparams, p)) 107 setattr(self.config, p, getattr(self.hparams, p))
99 108
100 if tokenizer is None: 109 if tokenizer is None:
101 self.tokenizer = AutoTokenizer.from_pretrained( 110 self.tokenizer = AutoTokenizer.from_pretrained(
102 - self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path, 111 + self.hparams.tokenizer_name
112 + if self.hparams.tokenizer_name
113 + else self.hparams.model_name_or_path,
103 cache_dir=cache_dir, 114 cache_dir=cache_dir,
104 ) 115 )
105 else: 116 else:
...@@ -121,7 +132,9 @@ class BaseTransformer(pl.LightningModule): ...@@ -121,7 +132,9 @@ class BaseTransformer(pl.LightningModule):
121 def get_lr_scheduler(self): 132 def get_lr_scheduler(self):
122 get_schedule_func = arg_to_scheduler[self.hparams.lr_scheduler] 133 get_schedule_func = arg_to_scheduler[self.hparams.lr_scheduler]
123 scheduler = get_schedule_func( 134 scheduler = get_schedule_func(
124 - self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps 135 + self.opt,
136 + num_warmup_steps=self.hparams.warmup_steps,
137 + num_training_steps=self.total_steps,
125 ) 138 )
126 scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} 139 scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
127 return scheduler 140 return scheduler
...@@ -132,22 +145,35 @@ class BaseTransformer(pl.LightningModule): ...@@ -132,22 +145,35 @@ class BaseTransformer(pl.LightningModule):
132 no_decay = ["bias", "LayerNorm.weight"] 145 no_decay = ["bias", "LayerNorm.weight"]
133 optimizer_grouped_parameters = [ 146 optimizer_grouped_parameters = [
134 { 147 {
135 - "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 148 + "params": [
149 + p
150 + for n, p in model.named_parameters()
151 + if not any(nd in n for nd in no_decay)
152 + ],
136 "weight_decay": self.hparams.weight_decay, 153 "weight_decay": self.hparams.weight_decay,
137 }, 154 },
138 { 155 {
139 - "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 156 + "params": [
157 + p
158 + for n, p in model.named_parameters()
159 + if any(nd in n for nd in no_decay)
160 + ],
140 "weight_decay": 0.0, 161 "weight_decay": 0.0,
141 }, 162 },
142 ] 163 ]
143 if self.hparams.adafactor: 164 if self.hparams.adafactor:
144 optimizer = Adafactor( 165 optimizer = Adafactor(
145 - optimizer_grouped_parameters, lr=self.hparams.learning_rate, scale_parameter=False, relative_step=False 166 + optimizer_grouped_parameters,
167 + lr=self.hparams.learning_rate,
168 + scale_parameter=False,
169 + relative_step=False,
146 ) 170 )
147 171
148 else: 172 else:
149 optimizer = AdamW( 173 optimizer = AdamW(
150 - optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon 174 + optimizer_grouped_parameters,
175 + lr=self.hparams.learning_rate,
176 + eps=self.hparams.adam_epsilon,
151 ) 177 )
152 self.opt = optimizer 178 self.opt = optimizer
153 179
...@@ -165,13 +191,19 @@ class BaseTransformer(pl.LightningModule): ...@@ -165,13 +191,19 @@ class BaseTransformer(pl.LightningModule):
165 def total_steps(self) -> int: 191 def total_steps(self) -> int:
166 """The number of total training steps that will be run. Used for lr scheduler purposes.""" 192 """The number of total training steps that will be run. Used for lr scheduler purposes."""
167 num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores 193 num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores
168 - effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices 194 + effective_batch_size = (
195 + self.hparams.train_batch_size
196 + * self.hparams.accumulate_grad_batches
197 + * num_devices
198 + )
169 dataset_size = len(self.train_loader.dataset) 199 dataset_size = len(self.train_loader.dataset)
170 return (dataset_size / effective_batch_size) * self.hparams.max_epochs 200 return (dataset_size / effective_batch_size) * self.hparams.max_epochs
171 201
172 def setup(self, mode): 202 def setup(self, mode):
173 if mode == "fit": 203 if mode == "fit":
174 - self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True) 204 + self.train_loader = self.get_dataloader(
205 + "train", self.hparams.train_batch_size, shuffle=True
206 + )
175 207
176 def get_dataloader(self, type_path, batch_size, shuffle=False): 208 def get_dataloader(self, type_path, batch_size, shuffle=False):
177 raise NotImplementedError("You must implement this for your task") 209 raise NotImplementedError("You must implement this for your task")
...@@ -212,7 +244,10 @@ class BaseTransformer(pl.LightningModule): ...@@ -212,7 +244,10 @@ class BaseTransformer(pl.LightningModule):
212 help="Path to pretrained model or model identifier from huggingface.co/models", 244 help="Path to pretrained model or model identifier from huggingface.co/models",
213 ) 245 )
214 parser.add_argument( 246 parser.add_argument(
215 - "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" 247 + "--config_name",
248 + default="",
249 + type=str,
250 + help="Pretrained config name or path if not the same as model_name",
216 ) 251 )
217 parser.add_argument( 252 parser.add_argument(
218 "--tokenizer_name", 253 "--tokenizer_name",
...@@ -246,7 +281,12 @@ class BaseTransformer(pl.LightningModule): ...@@ -246,7 +281,12 @@ class BaseTransformer(pl.LightningModule):
246 type=float, 281 type=float,
247 help="Attention dropout probability (Optional). Goes into model.config", 282 help="Attention dropout probability (Optional). Goes into model.config",
248 ) 283 )
249 - parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") 284 + parser.add_argument(
285 + "--learning_rate",
286 + default=5e-5,
287 + type=float,
288 + help="The initial learning rate for Adam.",
289 + )
250 parser.add_argument( 290 parser.add_argument(
251 "--lr_scheduler", 291 "--lr_scheduler",
252 default="linear", 292 default="linear",
...@@ -255,11 +295,30 @@ class BaseTransformer(pl.LightningModule): ...@@ -255,11 +295,30 @@ class BaseTransformer(pl.LightningModule):
255 type=str, 295 type=str,
256 help="Learning rate scheduler", 296 help="Learning rate scheduler",
257 ) 297 )
258 - parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") 298 + parser.add_argument(
259 - parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 299 + "--weight_decay",
260 - parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") 300 + default=0.0,
261 - parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader") 301 + type=float,
262 - parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int) 302 + help="Weight decay if we apply some.",
303 + )
304 + parser.add_argument(
305 + "--adam_epsilon",
306 + default=1e-8,
307 + type=float,
308 + help="Epsilon for Adam optimizer.",
309 + )
310 + parser.add_argument(
311 + "--warmup_steps",
312 + default=0,
313 + type=int,
314 + help="Linear warmup over warmup_steps.",
315 + )
316 + parser.add_argument(
317 + "--num_workers", default=4, type=int, help="kwarg passed to DataLoader"
318 + )
319 + parser.add_argument(
320 + "--num_train_epochs", dest="max_epochs", default=3, type=int
321 + )
263 parser.add_argument("--train_batch_size", default=32, type=int) 322 parser.add_argument("--train_batch_size", default=32, type=int)
264 parser.add_argument("--eval_batch_size", default=32, type=int) 323 parser.add_argument("--eval_batch_size", default=32, type=int)
265 parser.add_argument("--adafactor", action="store_true") 324 parser.add_argument("--adafactor", action="store_true")
...@@ -283,7 +342,9 @@ class LoggingCallback(pl.Callback): ...@@ -283,7 +342,9 @@ class LoggingCallback(pl.Callback):
283 rank_zero_info("***** Test results *****") 342 rank_zero_info("***** Test results *****")
284 metrics = trainer.callback_metrics 343 metrics = trainer.callback_metrics
285 # Log and save results to file 344 # Log and save results to file
286 - output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt") 345 + output_test_results_file = os.path.join(
346 + pl_module.hparams.output_dir, "test_results.txt"
347 + )
287 with open(output_test_results_file, "w") as writer: 348 with open(output_test_results_file, "w") as writer:
288 for key in sorted(metrics): 349 for key in sorted(metrics):
289 if key not in ["log", "progress_bar"]: 350 if key not in ["log", "progress_bar"]:
...@@ -314,9 +375,21 @@ def add_generic_args(parser, root_dir) -> None: ...@@ -314,9 +375,21 @@ def add_generic_args(parser, root_dir) -> None:
314 "See details at https://nvidia.github.io/apex/amp.html", 375 "See details at https://nvidia.github.io/apex/amp.html",
315 ) 376 )
316 parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int) 377 parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int)
317 - parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm") 378 + parser.add_argument(
318 - parser.add_argument("--do_train", action="store_true", help="Whether to run training.") 379 + "--max_grad_norm",
319 - parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.") 380 + dest="gradient_clip_val",
381 + default=1.0,
382 + type=float,
383 + help="Max gradient norm",
384 + )
385 + parser.add_argument(
386 + "--do_train", action="store_true", help="Whether to run training."
387 + )
388 + parser.add_argument(
389 + "--do_predict",
390 + action="store_true",
391 + help="Whether to run predictions on the test set.",
392 + )
320 parser.add_argument( 393 parser.add_argument(
321 "--gradient_accumulation_steps", 394 "--gradient_accumulation_steps",
322 dest="accumulate_grad_batches", 395 dest="accumulate_grad_batches",
...@@ -324,7 +397,9 @@ def add_generic_args(parser, root_dir) -> None: ...@@ -324,7 +397,9 @@ def add_generic_args(parser, root_dir) -> None:
324 default=1, 397 default=1,
325 help="Number of updates steps to accumulate before performing a backward/update pass.", 398 help="Number of updates steps to accumulate before performing a backward/update pass.",
326 ) 399 )
327 - parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 400 + parser.add_argument(
401 + "--seed", type=int, default=42, help="random seed for initialization"
402 + )
328 403
329 404
330 def generic_train( 405 def generic_train(
...@@ -335,7 +410,7 @@ def generic_train( ...@@ -335,7 +410,7 @@ def generic_train(
335 extra_callbacks=[], 410 extra_callbacks=[],
336 checkpoint_callback=None, 411 checkpoint_callback=None,
337 logging_callback=None, 412 logging_callback=None,
338 - **extra_train_kwargs 413 + **extra_train_kwargs,
339 ): 414 ):
340 pl.seed_everything(args.seed) 415 pl.seed_everything(args.seed)
341 416
...@@ -346,7 +421,11 @@ def generic_train( ...@@ -346,7 +421,11 @@ def generic_train(
346 # add custom checkpoints 421 # add custom checkpoints
347 if checkpoint_callback is None: 422 if checkpoint_callback is None:
348 checkpoint_callback = pl.callbacks.ModelCheckpoint( 423 checkpoint_callback = pl.callbacks.ModelCheckpoint(
349 - filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1 424 + filepath=args.output_dir,
425 + prefix="checkpoint",
426 + monitor="val_loss",
427 + mode="min",
428 + save_top_k=1,
350 ) 429 )
351 if logging_callback is None: 430 if logging_callback is None:
352 logging_callback = LoggingCallback() 431 logging_callback = LoggingCallback()
......
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
...@@ -39,9 +39,13 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): ...@@ -39,9 +39,13 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
39 return loss, nll_loss 39 return loss, nll_loss
40 40
41 41
42 -def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"): 42 +def encode_line(
43 + tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"
44 +):
43 """Only used by LegacyDataset""" 45 """Only used by LegacyDataset"""
44 - extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {} 46 + extra_kw = (
47 + {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
48 + )
45 return tokenizer( 49 return tokenizer(
46 [line], 50 [line],
47 max_length=max_length, 51 max_length=max_length,
...@@ -63,9 +67,7 @@ def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict: ...@@ -63,9 +67,7 @@ def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict:
63 67
64 68
65 def trim_batch( 69 def trim_batch(
66 - input_ids, 70 + input_ids, pad_token_id, attention_mask=None,
67 - pad_token_id,
68 - attention_mask=None,
69 ): 71 ):
70 """Remove columns that are populated exclusively by pad_token_id""" 72 """Remove columns that are populated exclusively by pad_token_id"""
71 keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) 73 keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
...@@ -125,7 +127,9 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset): ...@@ -125,7 +127,9 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
125 def __getitem__(self, index) -> Dict[str, torch.Tensor]: 127 def __getitem__(self, index) -> Dict[str, torch.Tensor]:
126 """Call tokenizer on src and tgt_lines""" 128 """Call tokenizer on src and tgt_lines"""
127 index = index + 1 # linecache starts at 1 129 index = index + 1 # linecache starts at 1
128 - source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") 130 + source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip(
131 + "\n"
132 + )
129 tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") 133 tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
130 assert source_line, f"empty source line for index {index}" 134 assert source_line, f"empty source line for index {index}"
131 assert tgt_line, f"empty tgt line for index {index}" 135 assert tgt_line, f"empty tgt line for index {index}"
...@@ -147,7 +151,9 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset): ...@@ -147,7 +151,9 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
147 target_ids = torch.stack([x["labels"] for x in batch]) 151 target_ids = torch.stack([x["labels"] for x in batch])
148 pad_token_id = self.pad_token_id 152 pad_token_id = self.pad_token_id
149 y = trim_batch(target_ids, pad_token_id) 153 y = trim_batch(target_ids, pad_token_id)
150 - source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks) 154 + source_ids, source_mask = trim_batch(
155 + input_ids, pad_token_id, attention_mask=masks
156 + )
151 batch = { 157 batch = {
152 "input_ids": source_ids, 158 "input_ids": source_ids,
153 "attention_mask": source_mask, 159 "attention_mask": source_mask,
...@@ -161,7 +167,9 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset): ...@@ -161,7 +167,9 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
161 167
162 def __getitem__(self, index) -> Dict[str, str]: 168 def __getitem__(self, index) -> Dict[str, str]:
163 index = index + 1 # linecache starts at 1 169 index = index + 1 # linecache starts at 1
164 - source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") 170 + source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip(
171 + "\n"
172 + )
165 tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") 173 tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
166 assert source_line, f"empty source line for index {index}" 174 assert source_line, f"empty source line for index {index}"
167 assert tgt_line, f"empty tgt line for index {index}" 175 assert tgt_line, f"empty tgt line for index {index}"
...@@ -201,12 +209,23 @@ class SortishSampler(Sampler): ...@@ -201,12 +209,23 @@ class SortishSampler(Sampler):
201 idxs = np.random.permutation(len(self.data)) 209 idxs = np.random.permutation(len(self.data))
202 sz = self.bs * 50 210 sz = self.bs * 50
203 ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)] 211 ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
204 - sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx]) 212 + sort_idx = np.concatenate(
213 + [sorted(s, key=self.key, reverse=True) for s in ck_idx]
214 + )
205 sz = self.bs 215 sz = self.bs
206 ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)] 216 ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
207 - max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx]) # find the chunk with the largest key, 217 + max_ck = np.argmax(
208 - ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first. 218 + [self.key(ck[0]) for ck in ck_idx]
209 - sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int) 219 + ) # find the chunk with the largest key,
220 + ck_idx[0], ck_idx[max_ck] = (
221 + ck_idx[max_ck],
222 + ck_idx[0],
223 + ) # then make sure it goes first.
224 + sort_idx = (
225 + np.concatenate(np.random.permutation(ck_idx[1:]))
226 + if len(ck_idx) > 1
227 + else np.array([], dtype=np.int)
228 + )
210 sort_idx = np.concatenate((ck_idx[0], sort_idx)) 229 sort_idx = np.concatenate((ck_idx[0], sort_idx))
211 return iter(sort_idx) 230 return iter(sort_idx)
212 231
...@@ -269,7 +288,9 @@ def get_git_info(): ...@@ -269,7 +288,9 @@ def get_git_info():
269 ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"] 288 ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"]
270 289
271 290
272 -def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer=True) -> Dict: 291 +def calculate_rouge(
292 + output_lns: List[str], reference_lns: List[str], use_stemmer=True
293 +) -> Dict:
273 scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer) 294 scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
274 aggregator = scoring.BootstrapAggregator() 295 aggregator = scoring.BootstrapAggregator()
275 296
...@@ -302,7 +323,9 @@ def assert_all_frozen(model): ...@@ -302,7 +323,9 @@ def assert_all_frozen(model):
302 model_grads: List[bool] = list(grad_status(model)) 323 model_grads: List[bool] = list(grad_status(model))
303 n_require_grad = sum(lmap(int, model_grads)) 324 n_require_grad = sum(lmap(int, model_grads))
304 npars = len(model_grads) 325 npars = len(model_grads)
305 - assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad" 326 + assert not any(
327 + model_grads
328 + ), f"{n_require_grad/npars:.1%} of {npars} weights require grad"
306 329
307 330
308 def assert_not_all_frozen(model): 331 def assert_not_all_frozen(model):
......