Showing
11 changed files
with
219 additions
and
100 deletions
... | @@ -68,6 +68,7 @@ def main(args): | ... | @@ -68,6 +68,7 @@ def main(args): |
68 | ) | 68 | ) |
69 | print(commit_message) | 69 | print(commit_message) |
70 | 70 | ||
71 | + | ||
71 | if __name__ == "__main__": | 72 | if __name__ == "__main__": |
72 | parser = argparse.ArgumentParser(description="Code to collect commits on github") | 73 | parser = argparse.ArgumentParser(description="Code to collect commits on github") |
73 | parser.add_argument( | 74 | parser.add_argument( | ... | ... |
... | @@ -36,9 +36,11 @@ logging.basicConfig( | ... | @@ -36,9 +36,11 @@ logging.basicConfig( |
36 | level=logging.INFO, | 36 | level=logging.INFO, |
37 | ) | 37 | ) |
38 | 38 | ||
39 | + | ||
39 | class PATCH(enum.Enum): | 40 | class PATCH(enum.Enum): |
40 | - PLUS=1 | 41 | + PLUS = 1 |
41 | - MINUS=2 | 42 | + MINUS = 2 |
43 | + | ||
42 | 44 | ||
43 | def truncate(tuple, max_length, value=0): | 45 | def truncate(tuple, max_length, value=0): |
44 | ls = [] | 46 | ls = [] |
... | @@ -46,22 +48,20 @@ def truncate(tuple, max_length, value=0): | ... | @@ -46,22 +48,20 @@ def truncate(tuple, max_length, value=0): |
46 | if isinstance(t, int): | 48 | if isinstance(t, int): |
47 | t = [t] | 49 | t = [t] |
48 | ls.extend(t) | 50 | ls.extend(t) |
49 | - ls = ls[:max_length - 1] | 51 | + ls = ls[: max_length - 1] |
50 | ls.insert(0, value) | 52 | ls.insert(0, value) |
51 | if len(ls) < max_length: | 53 | if len(ls) < max_length: |
52 | ls.extend([0] * (max_length - len(ls))) | 54 | ls.extend([0] * (max_length - len(ls))) |
53 | assert len(ls) == max_length | 55 | assert len(ls) == max_length |
54 | return ls | 56 | return ls |
55 | 57 | ||
58 | + | ||
56 | def encode_line(tokenizer, line, patch): | 59 | def encode_line(tokenizer, line, patch): |
57 | - line = re.sub(r'[\u0100-\uFFFF\U00010000-\U0010FFFF]+', '', line).strip() | 60 | + line = re.sub(r"[\u0100-\uFFFF\U00010000-\U0010FFFF]+", "", line).strip() |
58 | tokens = tokenizer.tokenize(line) | 61 | tokens = tokenizer.tokenize(line) |
59 | tokens = tokenizer.convert_tokens_to_ids(tokens) | 62 | tokens = tokenizer.convert_tokens_to_ids(tokens) |
60 | - return ( | 63 | + return (tokens, [1] * len(tokens), len(tokens) * [patch.value]) |
61 | - tokens, | 64 | + |
62 | - [1] * len(tokens), | ||
63 | - len(tokens) * [patch.value] | ||
64 | - ) | ||
65 | 65 | ||
66 | def diff_parse(diff, tokenizer): | 66 | def diff_parse(diff, tokenizer): |
67 | chunks = [] | 67 | chunks = [] |
... | @@ -78,6 +78,7 @@ def diff_parse(diff, tokenizer): | ... | @@ -78,6 +78,7 @@ def diff_parse(diff, tokenizer): |
78 | chunks.append(encode_line(tokenizer, change.line, PATCH.MINUS)) | 78 | chunks.append(encode_line(tokenizer, change.line, PATCH.MINUS)) |
79 | return chunks | 79 | return chunks |
80 | 80 | ||
81 | + | ||
81 | def sha_parse(sha, tokenizer, max_length=1024): | 82 | def sha_parse(sha, tokenizer, max_length=1024): |
82 | 83 | ||
83 | chunks = diff_parse(diff=repo.git.show(sha), tokenizer=tokenizer) | 84 | chunks = diff_parse(diff=repo.git.show(sha), tokenizer=tokenizer) |
... | @@ -91,16 +92,18 @@ def sha_parse(sha, tokenizer, max_length=1024): | ... | @@ -91,16 +92,18 @@ def sha_parse(sha, tokenizer, max_length=1024): |
91 | 92 | ||
92 | return (input_ids, attention_masks, patch_ids) | 93 | return (input_ids, attention_masks, patch_ids) |
93 | 94 | ||
95 | + | ||
94 | def message_parse(msg, tokenizer, max_length=56): | 96 | def message_parse(msg, tokenizer, max_length=56): |
95 | - msg = re.sub(r'(\(|)#([0-9])+(\)|)', '', msg) | 97 | + msg = re.sub(r"(\(|)#([0-9])+(\)|)", "", msg) |
96 | 98 | ||
97 | - msg = re.sub(r'[\u0100-\uFFFF\U00010000-\U0010FFFF]+', '', msg).strip() | 99 | + msg = re.sub(r"[\u0100-\uFFFF\U00010000-\U0010FFFF]+", "", msg).strip() |
98 | msg = tokenizer.tokenize(msg) | 100 | msg = tokenizer.tokenize(msg) |
99 | msg = tokenizer.convert_tokens_to_ids(msg) | 101 | msg = tokenizer.convert_tokens_to_ids(msg) |
100 | msg = truncate(msg, max_length, value=0) | 102 | msg = truncate(msg, max_length, value=0) |
101 | 103 | ||
102 | return msg | 104 | return msg |
103 | 105 | ||
106 | + | ||
104 | def jobs(sha_msgs, args, data_config, train=True): | 107 | def jobs(sha_msgs, args, data_config, train=True): |
105 | 108 | ||
106 | input_ids, attention_masks, patch_ids, targets = [], [], [], [] | 109 | input_ids, attention_masks, patch_ids, targets = [], [], [], [] |
... | @@ -110,9 +113,7 @@ def jobs(sha_msgs, args, data_config, train=True): | ... | @@ -110,9 +113,7 @@ def jobs(sha_msgs, args, data_config, train=True): |
110 | sha, msg = sha_msg | 113 | sha, msg = sha_msg |
111 | 114 | ||
112 | source = sha_parse( | 115 | source = sha_parse( |
113 | - sha, | 116 | + sha, tokenizer=args.tokenizer, max_length=args.max_source_length |
114 | - tokenizer=args.tokenizer, | ||
115 | - max_length=args.max_source_length | ||
116 | ) | 117 | ) |
117 | if not source: | 118 | if not source: |
118 | continue | 119 | continue |
... | @@ -120,7 +121,9 @@ def jobs(sha_msgs, args, data_config, train=True): | ... | @@ -120,7 +121,9 @@ def jobs(sha_msgs, args, data_config, train=True): |
120 | target = message_parse( | 121 | target = message_parse( |
121 | msg, | 122 | msg, |
122 | tokenizer=args.tokenizer, | 123 | tokenizer=args.tokenizer, |
123 | - max_length=(args.max_target_length if train else args.val_max_target_length), | 124 | + max_length=( |
125 | + args.max_target_length if train else args.val_max_target_length | ||
126 | + ), | ||
124 | ) | 127 | ) |
125 | 128 | ||
126 | input_ids.append(input_id) | 129 | input_ids.append(input_id) |
... | @@ -128,14 +131,17 @@ def jobs(sha_msgs, args, data_config, train=True): | ... | @@ -128,14 +131,17 @@ def jobs(sha_msgs, args, data_config, train=True): |
128 | patch_ids.append(patch_id) | 131 | patch_ids.append(patch_id) |
129 | targets.append(target) | 132 | targets.append(target) |
130 | 133 | ||
131 | - data_saver({ | 134 | + data_saver( |
135 | + { | ||
132 | "input_ids": np.asarray(input_ids), | 136 | "input_ids": np.asarray(input_ids), |
133 | "attention_masks": np.asarray(attention_masks), | 137 | "attention_masks": np.asarray(attention_masks), |
134 | "patch_ids": np.asarray(patch_ids), | 138 | "patch_ids": np.asarray(patch_ids), |
135 | "targets": np.asarray(targets), | 139 | "targets": np.asarray(targets), |
136 | - }) | 140 | + } |
141 | + ) | ||
137 | data_saver.disconnect() | 142 | data_saver.disconnect() |
138 | 143 | ||
144 | + | ||
139 | def start(chunked_sha_msgs, train=True): | 145 | def start(chunked_sha_msgs, train=True): |
140 | 146 | ||
141 | logger.info(f"Start %s pre-processing" % ("training" if train else "evaluation")) | 147 | logger.info(f"Start %s pre-processing" % ("training" if train else "evaluation")) |
... | @@ -144,22 +150,22 @@ def start(chunked_sha_msgs, train=True): | ... | @@ -144,22 +150,22 @@ def start(chunked_sha_msgs, train=True): |
144 | 150 | ||
145 | data_config = DataConfig( | 151 | data_config = DataConfig( |
146 | endpoint=args.endpoint, | 152 | endpoint=args.endpoint, |
147 | - access_key=os.environ['access_key'], | 153 | + access_key=os.environ["access_key"], |
148 | - secret_key=os.environ['secret_key'], | 154 | + secret_key=os.environ["secret_key"], |
149 | region=args.region, | 155 | region=args.region, |
150 | - dataset_name='commit-autosuggestions', | 156 | + dataset_name="commit-autosuggestions", |
151 | additional={ | 157 | additional={ |
152 | - "mode" : ("training" if train else "evaluation"), | 158 | + "mode": ("training" if train else "evaluation"), |
153 | "max_source_length": args.max_source_length, | 159 | "max_source_length": args.max_source_length, |
154 | "max_target_length": max_target_length, | 160 | "max_target_length": max_target_length, |
155 | - "url" : args.url, | 161 | + "url": args.url, |
156 | }, | 162 | }, |
157 | attributes=[ | 163 | attributes=[ |
158 | - ('input_ids', 'int32', (args.max_source_length,)), | 164 | + ("input_ids", "int32", (args.max_source_length,)), |
159 | - ('attention_masks', 'int32', (args.max_source_length,)), | 165 | + ("attention_masks", "int32", (args.max_source_length,)), |
160 | - ('patch_ids', 'int32', (args.max_source_length,)), | 166 | + ("patch_ids", "int32", (args.max_source_length,)), |
161 | - ('targets', 'int32', (max_target_length,)) | 167 | + ("targets", "int32", (max_target_length,)), |
162 | - ] | 168 | + ], |
163 | ) | 169 | ) |
164 | 170 | ||
165 | func = partial(jobs, args=args, data_config=data_config, train=train) | 171 | func = partial(jobs, args=args, data_config=data_config, train=train) |
... | @@ -168,14 +174,15 @@ def start(chunked_sha_msgs, train=True): | ... | @@ -168,14 +174,15 @@ def start(chunked_sha_msgs, train=True): |
168 | for i, _ in tqdm(enumerate(pool.imap_unordered(func, chunked_sha_msgs))): | 174 | for i, _ in tqdm(enumerate(pool.imap_unordered(func, chunked_sha_msgs))): |
169 | pbar.update() | 175 | pbar.update() |
170 | 176 | ||
177 | + | ||
171 | def main(args): | 178 | def main(args): |
172 | - if 'access_key' not in os.environ or 'secret_key' not in os.environ: | 179 | + if "access_key" not in os.environ or "secret_key" not in os.environ: |
173 | raise OSError("access_key or secret_key are not found.") | 180 | raise OSError("access_key or secret_key are not found.") |
174 | 181 | ||
175 | sha_msgs = [(c.hexsha, c.summary) for c in repo.iter_commits()] | 182 | sha_msgs = [(c.hexsha, c.summary) for c in repo.iter_commits()] |
176 | random.shuffle(sha_msgs) | 183 | random.shuffle(sha_msgs) |
177 | chunked_sha_msgs = [ | 184 | chunked_sha_msgs = [ |
178 | - sha_msgs[x:x + args.matorage_batch] | 185 | + sha_msgs[x : x + args.matorage_batch] |
179 | for x in range(0, len(sha_msgs), args.matorage_batch) | 186 | for x in range(0, len(sha_msgs), args.matorage_batch) |
180 | ] | 187 | ] |
181 | 188 | ||
... | @@ -185,29 +192,25 @@ def main(args): | ... | @@ -185,29 +192,25 @@ def main(args): |
185 | if args.do_predict: | 192 | if args.do_predict: |
186 | start(chunked_sha_msgs[barrier:], train=False) | 193 | start(chunked_sha_msgs[barrier:], train=False) |
187 | 194 | ||
195 | + | ||
188 | if __name__ == "__main__": | 196 | if __name__ == "__main__": |
189 | parser = argparse.ArgumentParser(description="Code to collect commits on github") | 197 | parser = argparse.ArgumentParser(description="Code to collect commits on github") |
190 | - parser.add_argument( | 198 | + parser.add_argument("--url", type=str, required=True, help="github url") |
191 | - "--url", | ||
192 | - type=str, | ||
193 | - required=True, | ||
194 | - help="github url" | ||
195 | - ) | ||
196 | parser.add_argument( | 199 | parser.add_argument( |
197 | "--endpoint", | 200 | "--endpoint", |
198 | type=str, | 201 | type=str, |
199 | required=True, | 202 | required=True, |
200 | - help='matorage endpoint, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html' | 203 | + help="matorage endpoint, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html", |
201 | ) | 204 | ) |
202 | parser.add_argument( | 205 | parser.add_argument( |
203 | "--region", | 206 | "--region", |
204 | type=str, | 207 | type=str, |
205 | default=None, | 208 | default=None, |
206 | - help='matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html' | 209 | + help="matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html", |
207 | ) | 210 | ) |
208 | parser.add_argument( | 211 | parser.add_argument( |
209 | "--tokenizer_name", | 212 | "--tokenizer_name", |
210 | - default='sshleifer/distilbart-xsum-6-6', | 213 | + default="sshleifer/distilbart-xsum-6-6", |
211 | type=str, | 214 | type=str, |
212 | help="Pretrained tokenizer name or path if not the same as model_name", | 215 | help="Pretrained tokenizer name or path if not the same as model_name", |
213 | ) | 216 | ) |
... | @@ -215,13 +218,10 @@ if __name__ == "__main__": | ... | @@ -215,13 +218,10 @@ if __name__ == "__main__": |
215 | "--matorage_batch", | 218 | "--matorage_batch", |
216 | default=1024, | 219 | default=1024, |
217 | type=int, | 220 | type=int, |
218 | - help='The smallest batch size stored atomically in matorage.' | 221 | + help="The smallest batch size stored atomically in matorage.", |
219 | ) | 222 | ) |
220 | parser.add_argument( | 223 | parser.add_argument( |
221 | - "--num_workers", | 224 | + "--num_workers", default=4, type=int, help="number of process", |
222 | - default=4, | ||
223 | - type=int, | ||
224 | - help="number of process", | ||
225 | ) | 225 | ) |
226 | parser.add_argument( | 226 | parser.add_argument( |
227 | "--max_source_length", | 227 | "--max_source_length", |
... | @@ -244,12 +244,14 @@ if __name__ == "__main__": | ... | @@ -244,12 +244,14 @@ if __name__ == "__main__": |
244 | help="The maximum total input sequence length after tokenization. Sequences longer " | 244 | help="The maximum total input sequence length after tokenization. Sequences longer " |
245 | "than this will be truncated, sequences shorter will be padded.", | 245 | "than this will be truncated, sequences shorter will be padded.", |
246 | ) | 246 | ) |
247 | - parser.add_argument("--p_val", type=float, default=0.25, help="percent of validation dataset") | 247 | + parser.add_argument( |
248 | + "--p_val", type=float, default=0.25, help="percent of validation dataset" | ||
249 | + ) | ||
248 | parser.add_argument("--do_train", action="store_true", default=False) | 250 | parser.add_argument("--do_train", action="store_true", default=False) |
249 | parser.add_argument("--do_predict", action="store_true", default=False) | 251 | parser.add_argument("--do_predict", action="store_true", default=False) |
250 | args = parser.parse_args() | 252 | args = parser.parse_args() |
251 | 253 | ||
252 | - args.local_path = args.url.split('/')[-1] | 254 | + args.local_path = args.url.split("/")[-1] |
253 | logger.info(f"master branch of {args.url} will be downloaded to {args.local_path}") | 255 | logger.info(f"master branch of {args.url} will be downloaded to {args.local_path}") |
254 | repo = ( | 256 | repo = ( |
255 | Repo(args.local_path) | 257 | Repo(args.local_path) | ... | ... |
... | @@ -14,6 +14,4 @@ | ... | @@ -14,6 +14,4 @@ |
14 | 14 | ||
15 | from .modeling_bart import BartForConditionalGeneration | 15 | from .modeling_bart import BartForConditionalGeneration |
16 | 16 | ||
17 | -__all__ = [ | ||
18 | - 'BartForConditionalGeneration' | ||
19 | -] | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
17 | +__all__ = ["BartForConditionalGeneration"] | ... | ... |
... | @@ -20,16 +20,31 @@ logger = logging.getLogger(__name__) | ... | @@ -20,16 +20,31 @@ logger = logging.getLogger(__name__) |
20 | 20 | ||
21 | class Seq2SeqLoggingCallback(pl.Callback): | 21 | class Seq2SeqLoggingCallback(pl.Callback): |
22 | def on_batch_end(self, trainer, pl_module): | 22 | def on_batch_end(self, trainer, pl_module): |
23 | - lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)} | 23 | + lrs = { |
24 | + f"lr_group_{i}": param["lr"] | ||
25 | + for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups) | ||
26 | + } | ||
24 | pl_module.logger.log_metrics(lrs) | 27 | pl_module.logger.log_metrics(lrs) |
25 | 28 | ||
26 | @rank_zero_only | 29 | @rank_zero_only |
27 | def _write_logs( | 30 | def _write_logs( |
28 | - self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True | 31 | + self, |
32 | + trainer: pl.Trainer, | ||
33 | + pl_module: pl.LightningModule, | ||
34 | + type_path: str, | ||
35 | + save_generations=True, | ||
29 | ) -> None: | 36 | ) -> None: |
30 | - logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****") | 37 | + logger.info( |
38 | + f"***** {type_path} results at step {trainer.global_step:05d} *****" | ||
39 | + ) | ||
31 | metrics = trainer.callback_metrics | 40 | metrics = trainer.callback_metrics |
32 | - trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]}) | 41 | + trainer.logger.log_metrics( |
42 | + { | ||
43 | + k: v | ||
44 | + for k, v in metrics.items() | ||
45 | + if k not in ["log", "progress_bar", "preds"] | ||
46 | + } | ||
47 | + ) | ||
33 | # Log results | 48 | # Log results |
34 | od = Path(pl_module.hparams.output_dir) | 49 | od = Path(pl_module.hparams.output_dir) |
35 | if type_path == "test": | 50 | if type_path == "test": |
... | @@ -39,7 +54,9 @@ class Seq2SeqLoggingCallback(pl.Callback): | ... | @@ -39,7 +54,9 @@ class Seq2SeqLoggingCallback(pl.Callback): |
39 | # this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json | 54 | # this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json |
40 | # If people want this it will be easy enough to add back. | 55 | # If people want this it will be easy enough to add back. |
41 | results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt" | 56 | results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt" |
42 | - generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt" | 57 | + generations_file = ( |
58 | + od / f"{type_path}_generations/{trainer.global_step:05d}.txt" | ||
59 | + ) | ||
43 | results_file.parent.mkdir(exist_ok=True) | 60 | results_file.parent.mkdir(exist_ok=True) |
44 | generations_file.parent.mkdir(exist_ok=True) | 61 | generations_file.parent.mkdir(exist_ok=True) |
45 | with open(results_file, "a+") as writer: | 62 | with open(results_file, "a+") as writer: |
... | @@ -68,7 +85,9 @@ class Seq2SeqLoggingCallback(pl.Callback): | ... | @@ -68,7 +85,9 @@ class Seq2SeqLoggingCallback(pl.Callback): |
68 | 85 | ||
69 | n_trainable_pars = count_trainable_parameters(pl_module) | 86 | n_trainable_pars = count_trainable_parameters(pl_module) |
70 | # mp stands for million parameters | 87 | # mp stands for million parameters |
71 | - trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6}) | 88 | + trainer.logger.log_metrics( |
89 | + {"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6} | ||
90 | + ) | ||
72 | 91 | ||
73 | @rank_zero_only | 92 | @rank_zero_only |
74 | def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): | 93 | def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): |
... | @@ -98,8 +117,5 @@ def get_checkpoint_callback(output_dir, metric): | ... | @@ -98,8 +117,5 @@ def get_checkpoint_callback(output_dir, metric): |
98 | 117 | ||
99 | def get_early_stopping_callback(metric, patience): | 118 | def get_early_stopping_callback(metric, patience): |
100 | return EarlyStopping( | 119 | return EarlyStopping( |
101 | - monitor=f"val_{metric}", | 120 | + monitor=f"val_{metric}", mode="max", patience=patience, verbose=True, |
102 | - mode="max", | ||
103 | - patience=patience, | ||
104 | - verbose=True, | ||
105 | ) | 121 | ) | ... | ... |
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
... | @@ -69,7 +69,7 @@ class BaseTransformer(pl.LightningModule): | ... | @@ -69,7 +69,7 @@ class BaseTransformer(pl.LightningModule): |
69 | config=None, | 69 | config=None, |
70 | tokenizer=None, | 70 | tokenizer=None, |
71 | model=None, | 71 | model=None, |
72 | - **config_kwargs | 72 | + **config_kwargs, |
73 | ): | 73 | ): |
74 | """Initialize a model, tokenizer and config.""" | 74 | """Initialize a model, tokenizer and config.""" |
75 | super().__init__() | 75 | super().__init__() |
... | @@ -83,7 +83,9 @@ class BaseTransformer(pl.LightningModule): | ... | @@ -83,7 +83,9 @@ class BaseTransformer(pl.LightningModule): |
83 | cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None | 83 | cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None |
84 | if config is None: | 84 | if config is None: |
85 | self.config = AutoConfig.from_pretrained( | 85 | self.config = AutoConfig.from_pretrained( |
86 | - self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path, | 86 | + self.hparams.config_name |
87 | + if self.hparams.config_name | ||
88 | + else self.hparams.model_name_or_path, | ||
87 | **({"num_labels": num_labels} if num_labels is not None else {}), | 89 | **({"num_labels": num_labels} if num_labels is not None else {}), |
88 | cache_dir=cache_dir, | 90 | cache_dir=cache_dir, |
89 | **config_kwargs, | 91 | **config_kwargs, |
... | @@ -91,15 +93,24 @@ class BaseTransformer(pl.LightningModule): | ... | @@ -91,15 +93,24 @@ class BaseTransformer(pl.LightningModule): |
91 | else: | 93 | else: |
92 | self.config: PretrainedConfig = config | 94 | self.config: PretrainedConfig = config |
93 | 95 | ||
94 | - extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout") | 96 | + extra_model_params = ( |
97 | + "encoder_layerdrop", | ||
98 | + "decoder_layerdrop", | ||
99 | + "dropout", | ||
100 | + "attention_dropout", | ||
101 | + ) | ||
95 | for p in extra_model_params: | 102 | for p in extra_model_params: |
96 | if getattr(self.hparams, p, None): | 103 | if getattr(self.hparams, p, None): |
97 | - assert hasattr(self.config, p), f"model config doesn't have a `{p}` attribute" | 104 | + assert hasattr( |
105 | + self.config, p | ||
106 | + ), f"model config doesn't have a `{p}` attribute" | ||
98 | setattr(self.config, p, getattr(self.hparams, p)) | 107 | setattr(self.config, p, getattr(self.hparams, p)) |
99 | 108 | ||
100 | if tokenizer is None: | 109 | if tokenizer is None: |
101 | self.tokenizer = AutoTokenizer.from_pretrained( | 110 | self.tokenizer = AutoTokenizer.from_pretrained( |
102 | - self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path, | 111 | + self.hparams.tokenizer_name |
112 | + if self.hparams.tokenizer_name | ||
113 | + else self.hparams.model_name_or_path, | ||
103 | cache_dir=cache_dir, | 114 | cache_dir=cache_dir, |
104 | ) | 115 | ) |
105 | else: | 116 | else: |
... | @@ -121,7 +132,9 @@ class BaseTransformer(pl.LightningModule): | ... | @@ -121,7 +132,9 @@ class BaseTransformer(pl.LightningModule): |
121 | def get_lr_scheduler(self): | 132 | def get_lr_scheduler(self): |
122 | get_schedule_func = arg_to_scheduler[self.hparams.lr_scheduler] | 133 | get_schedule_func = arg_to_scheduler[self.hparams.lr_scheduler] |
123 | scheduler = get_schedule_func( | 134 | scheduler = get_schedule_func( |
124 | - self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps | 135 | + self.opt, |
136 | + num_warmup_steps=self.hparams.warmup_steps, | ||
137 | + num_training_steps=self.total_steps, | ||
125 | ) | 138 | ) |
126 | scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} | 139 | scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} |
127 | return scheduler | 140 | return scheduler |
... | @@ -132,22 +145,35 @@ class BaseTransformer(pl.LightningModule): | ... | @@ -132,22 +145,35 @@ class BaseTransformer(pl.LightningModule): |
132 | no_decay = ["bias", "LayerNorm.weight"] | 145 | no_decay = ["bias", "LayerNorm.weight"] |
133 | optimizer_grouped_parameters = [ | 146 | optimizer_grouped_parameters = [ |
134 | { | 147 | { |
135 | - "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], | 148 | + "params": [ |
149 | + p | ||
150 | + for n, p in model.named_parameters() | ||
151 | + if not any(nd in n for nd in no_decay) | ||
152 | + ], | ||
136 | "weight_decay": self.hparams.weight_decay, | 153 | "weight_decay": self.hparams.weight_decay, |
137 | }, | 154 | }, |
138 | { | 155 | { |
139 | - "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], | 156 | + "params": [ |
157 | + p | ||
158 | + for n, p in model.named_parameters() | ||
159 | + if any(nd in n for nd in no_decay) | ||
160 | + ], | ||
140 | "weight_decay": 0.0, | 161 | "weight_decay": 0.0, |
141 | }, | 162 | }, |
142 | ] | 163 | ] |
143 | if self.hparams.adafactor: | 164 | if self.hparams.adafactor: |
144 | optimizer = Adafactor( | 165 | optimizer = Adafactor( |
145 | - optimizer_grouped_parameters, lr=self.hparams.learning_rate, scale_parameter=False, relative_step=False | 166 | + optimizer_grouped_parameters, |
167 | + lr=self.hparams.learning_rate, | ||
168 | + scale_parameter=False, | ||
169 | + relative_step=False, | ||
146 | ) | 170 | ) |
147 | 171 | ||
148 | else: | 172 | else: |
149 | optimizer = AdamW( | 173 | optimizer = AdamW( |
150 | - optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon | 174 | + optimizer_grouped_parameters, |
175 | + lr=self.hparams.learning_rate, | ||
176 | + eps=self.hparams.adam_epsilon, | ||
151 | ) | 177 | ) |
152 | self.opt = optimizer | 178 | self.opt = optimizer |
153 | 179 | ||
... | @@ -165,13 +191,19 @@ class BaseTransformer(pl.LightningModule): | ... | @@ -165,13 +191,19 @@ class BaseTransformer(pl.LightningModule): |
165 | def total_steps(self) -> int: | 191 | def total_steps(self) -> int: |
166 | """The number of total training steps that will be run. Used for lr scheduler purposes.""" | 192 | """The number of total training steps that will be run. Used for lr scheduler purposes.""" |
167 | num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores | 193 | num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores |
168 | - effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices | 194 | + effective_batch_size = ( |
195 | + self.hparams.train_batch_size | ||
196 | + * self.hparams.accumulate_grad_batches | ||
197 | + * num_devices | ||
198 | + ) | ||
169 | dataset_size = len(self.train_loader.dataset) | 199 | dataset_size = len(self.train_loader.dataset) |
170 | return (dataset_size / effective_batch_size) * self.hparams.max_epochs | 200 | return (dataset_size / effective_batch_size) * self.hparams.max_epochs |
171 | 201 | ||
172 | def setup(self, mode): | 202 | def setup(self, mode): |
173 | if mode == "fit": | 203 | if mode == "fit": |
174 | - self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True) | 204 | + self.train_loader = self.get_dataloader( |
205 | + "train", self.hparams.train_batch_size, shuffle=True | ||
206 | + ) | ||
175 | 207 | ||
176 | def get_dataloader(self, type_path, batch_size, shuffle=False): | 208 | def get_dataloader(self, type_path, batch_size, shuffle=False): |
177 | raise NotImplementedError("You must implement this for your task") | 209 | raise NotImplementedError("You must implement this for your task") |
... | @@ -212,7 +244,10 @@ class BaseTransformer(pl.LightningModule): | ... | @@ -212,7 +244,10 @@ class BaseTransformer(pl.LightningModule): |
212 | help="Path to pretrained model or model identifier from huggingface.co/models", | 244 | help="Path to pretrained model or model identifier from huggingface.co/models", |
213 | ) | 245 | ) |
214 | parser.add_argument( | 246 | parser.add_argument( |
215 | - "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" | 247 | + "--config_name", |
248 | + default="", | ||
249 | + type=str, | ||
250 | + help="Pretrained config name or path if not the same as model_name", | ||
216 | ) | 251 | ) |
217 | parser.add_argument( | 252 | parser.add_argument( |
218 | "--tokenizer_name", | 253 | "--tokenizer_name", |
... | @@ -246,7 +281,12 @@ class BaseTransformer(pl.LightningModule): | ... | @@ -246,7 +281,12 @@ class BaseTransformer(pl.LightningModule): |
246 | type=float, | 281 | type=float, |
247 | help="Attention dropout probability (Optional). Goes into model.config", | 282 | help="Attention dropout probability (Optional). Goes into model.config", |
248 | ) | 283 | ) |
249 | - parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") | 284 | + parser.add_argument( |
285 | + "--learning_rate", | ||
286 | + default=5e-5, | ||
287 | + type=float, | ||
288 | + help="The initial learning rate for Adam.", | ||
289 | + ) | ||
250 | parser.add_argument( | 290 | parser.add_argument( |
251 | "--lr_scheduler", | 291 | "--lr_scheduler", |
252 | default="linear", | 292 | default="linear", |
... | @@ -255,11 +295,30 @@ class BaseTransformer(pl.LightningModule): | ... | @@ -255,11 +295,30 @@ class BaseTransformer(pl.LightningModule): |
255 | type=str, | 295 | type=str, |
256 | help="Learning rate scheduler", | 296 | help="Learning rate scheduler", |
257 | ) | 297 | ) |
258 | - parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") | 298 | + parser.add_argument( |
259 | - parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") | 299 | + "--weight_decay", |
260 | - parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") | 300 | + default=0.0, |
261 | - parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader") | 301 | + type=float, |
262 | - parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int) | 302 | + help="Weight decay if we apply some.", |
303 | + ) | ||
304 | + parser.add_argument( | ||
305 | + "--adam_epsilon", | ||
306 | + default=1e-8, | ||
307 | + type=float, | ||
308 | + help="Epsilon for Adam optimizer.", | ||
309 | + ) | ||
310 | + parser.add_argument( | ||
311 | + "--warmup_steps", | ||
312 | + default=0, | ||
313 | + type=int, | ||
314 | + help="Linear warmup over warmup_steps.", | ||
315 | + ) | ||
316 | + parser.add_argument( | ||
317 | + "--num_workers", default=4, type=int, help="kwarg passed to DataLoader" | ||
318 | + ) | ||
319 | + parser.add_argument( | ||
320 | + "--num_train_epochs", dest="max_epochs", default=3, type=int | ||
321 | + ) | ||
263 | parser.add_argument("--train_batch_size", default=32, type=int) | 322 | parser.add_argument("--train_batch_size", default=32, type=int) |
264 | parser.add_argument("--eval_batch_size", default=32, type=int) | 323 | parser.add_argument("--eval_batch_size", default=32, type=int) |
265 | parser.add_argument("--adafactor", action="store_true") | 324 | parser.add_argument("--adafactor", action="store_true") |
... | @@ -283,7 +342,9 @@ class LoggingCallback(pl.Callback): | ... | @@ -283,7 +342,9 @@ class LoggingCallback(pl.Callback): |
283 | rank_zero_info("***** Test results *****") | 342 | rank_zero_info("***** Test results *****") |
284 | metrics = trainer.callback_metrics | 343 | metrics = trainer.callback_metrics |
285 | # Log and save results to file | 344 | # Log and save results to file |
286 | - output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt") | 345 | + output_test_results_file = os.path.join( |
346 | + pl_module.hparams.output_dir, "test_results.txt" | ||
347 | + ) | ||
287 | with open(output_test_results_file, "w") as writer: | 348 | with open(output_test_results_file, "w") as writer: |
288 | for key in sorted(metrics): | 349 | for key in sorted(metrics): |
289 | if key not in ["log", "progress_bar"]: | 350 | if key not in ["log", "progress_bar"]: |
... | @@ -314,9 +375,21 @@ def add_generic_args(parser, root_dir) -> None: | ... | @@ -314,9 +375,21 @@ def add_generic_args(parser, root_dir) -> None: |
314 | "See details at https://nvidia.github.io/apex/amp.html", | 375 | "See details at https://nvidia.github.io/apex/amp.html", |
315 | ) | 376 | ) |
316 | parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int) | 377 | parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int) |
317 | - parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm") | 378 | + parser.add_argument( |
318 | - parser.add_argument("--do_train", action="store_true", help="Whether to run training.") | 379 | + "--max_grad_norm", |
319 | - parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.") | 380 | + dest="gradient_clip_val", |
381 | + default=1.0, | ||
382 | + type=float, | ||
383 | + help="Max gradient norm", | ||
384 | + ) | ||
385 | + parser.add_argument( | ||
386 | + "--do_train", action="store_true", help="Whether to run training." | ||
387 | + ) | ||
388 | + parser.add_argument( | ||
389 | + "--do_predict", | ||
390 | + action="store_true", | ||
391 | + help="Whether to run predictions on the test set.", | ||
392 | + ) | ||
320 | parser.add_argument( | 393 | parser.add_argument( |
321 | "--gradient_accumulation_steps", | 394 | "--gradient_accumulation_steps", |
322 | dest="accumulate_grad_batches", | 395 | dest="accumulate_grad_batches", |
... | @@ -324,7 +397,9 @@ def add_generic_args(parser, root_dir) -> None: | ... | @@ -324,7 +397,9 @@ def add_generic_args(parser, root_dir) -> None: |
324 | default=1, | 397 | default=1, |
325 | help="Number of updates steps to accumulate before performing a backward/update pass.", | 398 | help="Number of updates steps to accumulate before performing a backward/update pass.", |
326 | ) | 399 | ) |
327 | - parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") | 400 | + parser.add_argument( |
401 | + "--seed", type=int, default=42, help="random seed for initialization" | ||
402 | + ) | ||
328 | 403 | ||
329 | 404 | ||
330 | def generic_train( | 405 | def generic_train( |
... | @@ -335,7 +410,7 @@ def generic_train( | ... | @@ -335,7 +410,7 @@ def generic_train( |
335 | extra_callbacks=[], | 410 | extra_callbacks=[], |
336 | checkpoint_callback=None, | 411 | checkpoint_callback=None, |
337 | logging_callback=None, | 412 | logging_callback=None, |
338 | - **extra_train_kwargs | 413 | + **extra_train_kwargs, |
339 | ): | 414 | ): |
340 | pl.seed_everything(args.seed) | 415 | pl.seed_everything(args.seed) |
341 | 416 | ||
... | @@ -346,7 +421,11 @@ def generic_train( | ... | @@ -346,7 +421,11 @@ def generic_train( |
346 | # add custom checkpoints | 421 | # add custom checkpoints |
347 | if checkpoint_callback is None: | 422 | if checkpoint_callback is None: |
348 | checkpoint_callback = pl.callbacks.ModelCheckpoint( | 423 | checkpoint_callback = pl.callbacks.ModelCheckpoint( |
349 | - filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1 | 424 | + filepath=args.output_dir, |
425 | + prefix="checkpoint", | ||
426 | + monitor="val_loss", | ||
427 | + mode="min", | ||
428 | + save_top_k=1, | ||
350 | ) | 429 | ) |
351 | if logging_callback is None: | 430 | if logging_callback is None: |
352 | logging_callback = LoggingCallback() | 431 | logging_callback = LoggingCallback() | ... | ... |
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
... | @@ -39,9 +39,13 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): | ... | @@ -39,9 +39,13 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): |
39 | return loss, nll_loss | 39 | return loss, nll_loss |
40 | 40 | ||
41 | 41 | ||
42 | -def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"): | 42 | +def encode_line( |
43 | + tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt" | ||
44 | +): | ||
43 | """Only used by LegacyDataset""" | 45 | """Only used by LegacyDataset""" |
44 | - extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {} | 46 | + extra_kw = ( |
47 | + {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {} | ||
48 | + ) | ||
45 | return tokenizer( | 49 | return tokenizer( |
46 | [line], | 50 | [line], |
47 | max_length=max_length, | 51 | max_length=max_length, |
... | @@ -63,9 +67,7 @@ def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict: | ... | @@ -63,9 +67,7 @@ def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict: |
63 | 67 | ||
64 | 68 | ||
65 | def trim_batch( | 69 | def trim_batch( |
66 | - input_ids, | 70 | + input_ids, pad_token_id, attention_mask=None, |
67 | - pad_token_id, | ||
68 | - attention_mask=None, | ||
69 | ): | 71 | ): |
70 | """Remove columns that are populated exclusively by pad_token_id""" | 72 | """Remove columns that are populated exclusively by pad_token_id""" |
71 | keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) | 73 | keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) |
... | @@ -125,7 +127,9 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset): | ... | @@ -125,7 +127,9 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset): |
125 | def __getitem__(self, index) -> Dict[str, torch.Tensor]: | 127 | def __getitem__(self, index) -> Dict[str, torch.Tensor]: |
126 | """Call tokenizer on src and tgt_lines""" | 128 | """Call tokenizer on src and tgt_lines""" |
127 | index = index + 1 # linecache starts at 1 | 129 | index = index + 1 # linecache starts at 1 |
128 | - source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") | 130 | + source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip( |
131 | + "\n" | ||
132 | + ) | ||
129 | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") | 133 | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") |
130 | assert source_line, f"empty source line for index {index}" | 134 | assert source_line, f"empty source line for index {index}" |
131 | assert tgt_line, f"empty tgt line for index {index}" | 135 | assert tgt_line, f"empty tgt line for index {index}" |
... | @@ -147,7 +151,9 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset): | ... | @@ -147,7 +151,9 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset): |
147 | target_ids = torch.stack([x["labels"] for x in batch]) | 151 | target_ids = torch.stack([x["labels"] for x in batch]) |
148 | pad_token_id = self.pad_token_id | 152 | pad_token_id = self.pad_token_id |
149 | y = trim_batch(target_ids, pad_token_id) | 153 | y = trim_batch(target_ids, pad_token_id) |
150 | - source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks) | 154 | + source_ids, source_mask = trim_batch( |
155 | + input_ids, pad_token_id, attention_mask=masks | ||
156 | + ) | ||
151 | batch = { | 157 | batch = { |
152 | "input_ids": source_ids, | 158 | "input_ids": source_ids, |
153 | "attention_mask": source_mask, | 159 | "attention_mask": source_mask, |
... | @@ -161,7 +167,9 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset): | ... | @@ -161,7 +167,9 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset): |
161 | 167 | ||
162 | def __getitem__(self, index) -> Dict[str, str]: | 168 | def __getitem__(self, index) -> Dict[str, str]: |
163 | index = index + 1 # linecache starts at 1 | 169 | index = index + 1 # linecache starts at 1 |
164 | - source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") | 170 | + source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip( |
171 | + "\n" | ||
172 | + ) | ||
165 | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") | 173 | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") |
166 | assert source_line, f"empty source line for index {index}" | 174 | assert source_line, f"empty source line for index {index}" |
167 | assert tgt_line, f"empty tgt line for index {index}" | 175 | assert tgt_line, f"empty tgt line for index {index}" |
... | @@ -201,12 +209,23 @@ class SortishSampler(Sampler): | ... | @@ -201,12 +209,23 @@ class SortishSampler(Sampler): |
201 | idxs = np.random.permutation(len(self.data)) | 209 | idxs = np.random.permutation(len(self.data)) |
202 | sz = self.bs * 50 | 210 | sz = self.bs * 50 |
203 | ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)] | 211 | ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)] |
204 | - sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx]) | 212 | + sort_idx = np.concatenate( |
213 | + [sorted(s, key=self.key, reverse=True) for s in ck_idx] | ||
214 | + ) | ||
205 | sz = self.bs | 215 | sz = self.bs |
206 | ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)] | 216 | ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)] |
207 | - max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx]) # find the chunk with the largest key, | 217 | + max_ck = np.argmax( |
208 | - ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first. | 218 | + [self.key(ck[0]) for ck in ck_idx] |
209 | - sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int) | 219 | + ) # find the chunk with the largest key, |
220 | + ck_idx[0], ck_idx[max_ck] = ( | ||
221 | + ck_idx[max_ck], | ||
222 | + ck_idx[0], | ||
223 | + ) # then make sure it goes first. | ||
224 | + sort_idx = ( | ||
225 | + np.concatenate(np.random.permutation(ck_idx[1:])) | ||
226 | + if len(ck_idx) > 1 | ||
227 | + else np.array([], dtype=np.int) | ||
228 | + ) | ||
210 | sort_idx = np.concatenate((ck_idx[0], sort_idx)) | 229 | sort_idx = np.concatenate((ck_idx[0], sort_idx)) |
211 | return iter(sort_idx) | 230 | return iter(sort_idx) |
212 | 231 | ||
... | @@ -269,7 +288,9 @@ def get_git_info(): | ... | @@ -269,7 +288,9 @@ def get_git_info(): |
269 | ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"] | 288 | ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"] |
270 | 289 | ||
271 | 290 | ||
272 | -def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer=True) -> Dict: | 291 | +def calculate_rouge( |
292 | + output_lns: List[str], reference_lns: List[str], use_stemmer=True | ||
293 | +) -> Dict: | ||
273 | scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer) | 294 | scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer) |
274 | aggregator = scoring.BootstrapAggregator() | 295 | aggregator = scoring.BootstrapAggregator() |
275 | 296 | ||
... | @@ -302,7 +323,9 @@ def assert_all_frozen(model): | ... | @@ -302,7 +323,9 @@ def assert_all_frozen(model): |
302 | model_grads: List[bool] = list(grad_status(model)) | 323 | model_grads: List[bool] = list(grad_status(model)) |
303 | n_require_grad = sum(lmap(int, model_grads)) | 324 | n_require_grad = sum(lmap(int, model_grads)) |
304 | npars = len(model_grads) | 325 | npars = len(model_grads) |
305 | - assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad" | 326 | + assert not any( |
327 | + model_grads | ||
328 | + ), f"{n_require_grad/npars:.1%} of {npars} weights require grad" | ||
306 | 329 | ||
307 | 330 | ||
308 | def assert_not_all_frozen(model): | 331 | def assert_not_all_frozen(model): | ... | ... |
-
Please register or login to post a comment