Showing
11 changed files
with
1028 additions
and
381 deletions
... | @@ -68,6 +68,7 @@ def main(args): | ... | @@ -68,6 +68,7 @@ def main(args): |
68 | ) | 68 | ) |
69 | print(commit_message) | 69 | print(commit_message) |
70 | 70 | ||
71 | + | ||
71 | if __name__ == "__main__": | 72 | if __name__ == "__main__": |
72 | parser = argparse.ArgumentParser(description="Code to collect commits on github") | 73 | parser = argparse.ArgumentParser(description="Code to collect commits on github") |
73 | parser.add_argument( | 74 | parser.add_argument( | ... | ... |
1 | # Copyright 2020-present Tae Hwan Jung | 1 | # Copyright 2020-present Tae Hwan Jung |
2 | -# | 2 | +# |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | # you may not use this file except in compliance with the License. | 4 | # you may not use this file except in compliance with the License. |
5 | # You may obtain a copy of the License at | 5 | # You may obtain a copy of the License at |
6 | -# | 6 | +# |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 |
8 | -# | 8 | +# |
9 | # Unless required by applicable law or agreed to in writing, software | 9 | # Unless required by applicable law or agreed to in writing, software |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | 10 | # distributed under the License is distributed on an "AS IS" BASIS, |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
... | @@ -15,6 +15,6 @@ | ... | @@ -15,6 +15,6 @@ |
15 | from .gitcommit import diff_parse, truncate | 15 | from .gitcommit import diff_parse, truncate |
16 | 16 | ||
17 | __all__ = [ | 17 | __all__ = [ |
18 | - 'diff_parse', | ||
19 | - 'truncate', | ||
20 | -] | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
18 | + "diff_parse", | ||
19 | + "truncate", | ||
20 | +] | ... | ... |
... | @@ -36,9 +36,11 @@ logging.basicConfig( | ... | @@ -36,9 +36,11 @@ logging.basicConfig( |
36 | level=logging.INFO, | 36 | level=logging.INFO, |
37 | ) | 37 | ) |
38 | 38 | ||
39 | + | ||
39 | class PATCH(enum.Enum): | 40 | class PATCH(enum.Enum): |
40 | - PLUS=1 | 41 | + PLUS = 1 |
41 | - MINUS=2 | 42 | + MINUS = 2 |
43 | + | ||
42 | 44 | ||
43 | def truncate(tuple, max_length, value=0): | 45 | def truncate(tuple, max_length, value=0): |
44 | ls = [] | 46 | ls = [] |
... | @@ -46,22 +48,20 @@ def truncate(tuple, max_length, value=0): | ... | @@ -46,22 +48,20 @@ def truncate(tuple, max_length, value=0): |
46 | if isinstance(t, int): | 48 | if isinstance(t, int): |
47 | t = [t] | 49 | t = [t] |
48 | ls.extend(t) | 50 | ls.extend(t) |
49 | - ls = ls[:max_length - 1] | 51 | + ls = ls[: max_length - 1] |
50 | ls.insert(0, value) | 52 | ls.insert(0, value) |
51 | if len(ls) < max_length: | 53 | if len(ls) < max_length: |
52 | ls.extend([0] * (max_length - len(ls))) | 54 | ls.extend([0] * (max_length - len(ls))) |
53 | assert len(ls) == max_length | 55 | assert len(ls) == max_length |
54 | return ls | 56 | return ls |
55 | 57 | ||
58 | + | ||
56 | def encode_line(tokenizer, line, patch): | 59 | def encode_line(tokenizer, line, patch): |
57 | - line = re.sub(r'[\u0100-\uFFFF\U00010000-\U0010FFFF]+', '', line).strip() | 60 | + line = re.sub(r"[\u0100-\uFFFF\U00010000-\U0010FFFF]+", "", line).strip() |
58 | tokens = tokenizer.tokenize(line) | 61 | tokens = tokenizer.tokenize(line) |
59 | tokens = tokenizer.convert_tokens_to_ids(tokens) | 62 | tokens = tokenizer.convert_tokens_to_ids(tokens) |
60 | - return ( | 63 | + return (tokens, [1] * len(tokens), len(tokens) * [patch.value]) |
61 | - tokens, | 64 | + |
62 | - [1] * len(tokens), | ||
63 | - len(tokens) * [patch.value] | ||
64 | - ) | ||
65 | 65 | ||
66 | def diff_parse(diff, tokenizer): | 66 | def diff_parse(diff, tokenizer): |
67 | chunks = [] | 67 | chunks = [] |
... | @@ -78,6 +78,7 @@ def diff_parse(diff, tokenizer): | ... | @@ -78,6 +78,7 @@ def diff_parse(diff, tokenizer): |
78 | chunks.append(encode_line(tokenizer, change.line, PATCH.MINUS)) | 78 | chunks.append(encode_line(tokenizer, change.line, PATCH.MINUS)) |
79 | return chunks | 79 | return chunks |
80 | 80 | ||
81 | + | ||
81 | def sha_parse(sha, tokenizer, max_length=1024): | 82 | def sha_parse(sha, tokenizer, max_length=1024): |
82 | 83 | ||
83 | chunks = diff_parse(diff=repo.git.show(sha), tokenizer=tokenizer) | 84 | chunks = diff_parse(diff=repo.git.show(sha), tokenizer=tokenizer) |
... | @@ -91,16 +92,18 @@ def sha_parse(sha, tokenizer, max_length=1024): | ... | @@ -91,16 +92,18 @@ def sha_parse(sha, tokenizer, max_length=1024): |
91 | 92 | ||
92 | return (input_ids, attention_masks, patch_ids) | 93 | return (input_ids, attention_masks, patch_ids) |
93 | 94 | ||
95 | + | ||
94 | def message_parse(msg, tokenizer, max_length=56): | 96 | def message_parse(msg, tokenizer, max_length=56): |
95 | - msg = re.sub(r'(\(|)#([0-9])+(\)|)', '', msg) | 97 | + msg = re.sub(r"(\(|)#([0-9])+(\)|)", "", msg) |
96 | 98 | ||
97 | - msg = re.sub(r'[\u0100-\uFFFF\U00010000-\U0010FFFF]+', '', msg).strip() | 99 | + msg = re.sub(r"[\u0100-\uFFFF\U00010000-\U0010FFFF]+", "", msg).strip() |
98 | msg = tokenizer.tokenize(msg) | 100 | msg = tokenizer.tokenize(msg) |
99 | msg = tokenizer.convert_tokens_to_ids(msg) | 101 | msg = tokenizer.convert_tokens_to_ids(msg) |
100 | msg = truncate(msg, max_length, value=0) | 102 | msg = truncate(msg, max_length, value=0) |
101 | 103 | ||
102 | return msg | 104 | return msg |
103 | 105 | ||
106 | + | ||
104 | def jobs(sha_msgs, args, data_config, train=True): | 107 | def jobs(sha_msgs, args, data_config, train=True): |
105 | 108 | ||
106 | input_ids, attention_masks, patch_ids, targets = [], [], [], [] | 109 | input_ids, attention_masks, patch_ids, targets = [], [], [], [] |
... | @@ -110,9 +113,7 @@ def jobs(sha_msgs, args, data_config, train=True): | ... | @@ -110,9 +113,7 @@ def jobs(sha_msgs, args, data_config, train=True): |
110 | sha, msg = sha_msg | 113 | sha, msg = sha_msg |
111 | 114 | ||
112 | source = sha_parse( | 115 | source = sha_parse( |
113 | - sha, | 116 | + sha, tokenizer=args.tokenizer, max_length=args.max_source_length |
114 | - tokenizer=args.tokenizer, | ||
115 | - max_length=args.max_source_length | ||
116 | ) | 117 | ) |
117 | if not source: | 118 | if not source: |
118 | continue | 119 | continue |
... | @@ -120,7 +121,9 @@ def jobs(sha_msgs, args, data_config, train=True): | ... | @@ -120,7 +121,9 @@ def jobs(sha_msgs, args, data_config, train=True): |
120 | target = message_parse( | 121 | target = message_parse( |
121 | msg, | 122 | msg, |
122 | tokenizer=args.tokenizer, | 123 | tokenizer=args.tokenizer, |
123 | - max_length=(args.max_target_length if train else args.val_max_target_length), | 124 | + max_length=( |
125 | + args.max_target_length if train else args.val_max_target_length | ||
126 | + ), | ||
124 | ) | 127 | ) |
125 | 128 | ||
126 | input_ids.append(input_id) | 129 | input_ids.append(input_id) |
... | @@ -128,14 +131,17 @@ def jobs(sha_msgs, args, data_config, train=True): | ... | @@ -128,14 +131,17 @@ def jobs(sha_msgs, args, data_config, train=True): |
128 | patch_ids.append(patch_id) | 131 | patch_ids.append(patch_id) |
129 | targets.append(target) | 132 | targets.append(target) |
130 | 133 | ||
131 | - data_saver({ | 134 | + data_saver( |
132 | - "input_ids": np.asarray(input_ids), | 135 | + { |
133 | - "attention_masks": np.asarray(attention_masks), | 136 | + "input_ids": np.asarray(input_ids), |
134 | - "patch_ids": np.asarray(patch_ids), | 137 | + "attention_masks": np.asarray(attention_masks), |
135 | - "targets": np.asarray(targets), | 138 | + "patch_ids": np.asarray(patch_ids), |
136 | - }) | 139 | + "targets": np.asarray(targets), |
140 | + } | ||
141 | + ) | ||
137 | data_saver.disconnect() | 142 | data_saver.disconnect() |
138 | 143 | ||
144 | + | ||
139 | def start(chunked_sha_msgs, train=True): | 145 | def start(chunked_sha_msgs, train=True): |
140 | 146 | ||
141 | logger.info(f"Start %s pre-processing" % ("training" if train else "evaluation")) | 147 | logger.info(f"Start %s pre-processing" % ("training" if train else "evaluation")) |
... | @@ -144,22 +150,22 @@ def start(chunked_sha_msgs, train=True): | ... | @@ -144,22 +150,22 @@ def start(chunked_sha_msgs, train=True): |
144 | 150 | ||
145 | data_config = DataConfig( | 151 | data_config = DataConfig( |
146 | endpoint=args.endpoint, | 152 | endpoint=args.endpoint, |
147 | - access_key=os.environ['access_key'], | 153 | + access_key=os.environ["access_key"], |
148 | - secret_key=os.environ['secret_key'], | 154 | + secret_key=os.environ["secret_key"], |
149 | region=args.region, | 155 | region=args.region, |
150 | - dataset_name='commit-autosuggestions', | 156 | + dataset_name="commit-autosuggestions", |
151 | additional={ | 157 | additional={ |
152 | - "mode" : ("training" if train else "evaluation"), | 158 | + "mode": ("training" if train else "evaluation"), |
153 | "max_source_length": args.max_source_length, | 159 | "max_source_length": args.max_source_length, |
154 | "max_target_length": max_target_length, | 160 | "max_target_length": max_target_length, |
155 | - "url" : args.url, | 161 | + "url": args.url, |
156 | }, | 162 | }, |
157 | attributes=[ | 163 | attributes=[ |
158 | - ('input_ids', 'int32', (args.max_source_length,)), | 164 | + ("input_ids", "int32", (args.max_source_length,)), |
159 | - ('attention_masks', 'int32', (args.max_source_length,)), | 165 | + ("attention_masks", "int32", (args.max_source_length,)), |
160 | - ('patch_ids', 'int32', (args.max_source_length,)), | 166 | + ("patch_ids", "int32", (args.max_source_length,)), |
161 | - ('targets', 'int32', (max_target_length,)) | 167 | + ("targets", "int32", (max_target_length,)), |
162 | - ] | 168 | + ], |
163 | ) | 169 | ) |
164 | 170 | ||
165 | func = partial(jobs, args=args, data_config=data_config, train=train) | 171 | func = partial(jobs, args=args, data_config=data_config, train=train) |
... | @@ -168,14 +174,15 @@ def start(chunked_sha_msgs, train=True): | ... | @@ -168,14 +174,15 @@ def start(chunked_sha_msgs, train=True): |
168 | for i, _ in tqdm(enumerate(pool.imap_unordered(func, chunked_sha_msgs))): | 174 | for i, _ in tqdm(enumerate(pool.imap_unordered(func, chunked_sha_msgs))): |
169 | pbar.update() | 175 | pbar.update() |
170 | 176 | ||
177 | + | ||
171 | def main(args): | 178 | def main(args): |
172 | - if 'access_key' not in os.environ or 'secret_key' not in os.environ: | 179 | + if "access_key" not in os.environ or "secret_key" not in os.environ: |
173 | raise OSError("access_key or secret_key are not found.") | 180 | raise OSError("access_key or secret_key are not found.") |
174 | 181 | ||
175 | sha_msgs = [(c.hexsha, c.summary) for c in repo.iter_commits()] | 182 | sha_msgs = [(c.hexsha, c.summary) for c in repo.iter_commits()] |
176 | random.shuffle(sha_msgs) | 183 | random.shuffle(sha_msgs) |
177 | chunked_sha_msgs = [ | 184 | chunked_sha_msgs = [ |
178 | - sha_msgs[x:x + args.matorage_batch] | 185 | + sha_msgs[x : x + args.matorage_batch] |
179 | for x in range(0, len(sha_msgs), args.matorage_batch) | 186 | for x in range(0, len(sha_msgs), args.matorage_batch) |
180 | ] | 187 | ] |
181 | 188 | ||
... | @@ -185,29 +192,25 @@ def main(args): | ... | @@ -185,29 +192,25 @@ def main(args): |
185 | if args.do_predict: | 192 | if args.do_predict: |
186 | start(chunked_sha_msgs[barrier:], train=False) | 193 | start(chunked_sha_msgs[barrier:], train=False) |
187 | 194 | ||
195 | + | ||
188 | if __name__ == "__main__": | 196 | if __name__ == "__main__": |
189 | parser = argparse.ArgumentParser(description="Code to collect commits on github") | 197 | parser = argparse.ArgumentParser(description="Code to collect commits on github") |
190 | - parser.add_argument( | 198 | + parser.add_argument("--url", type=str, required=True, help="github url") |
191 | - "--url", | ||
192 | - type=str, | ||
193 | - required=True, | ||
194 | - help="github url" | ||
195 | - ) | ||
196 | parser.add_argument( | 199 | parser.add_argument( |
197 | "--endpoint", | 200 | "--endpoint", |
198 | type=str, | 201 | type=str, |
199 | required=True, | 202 | required=True, |
200 | - help='matorage endpoint, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html' | 203 | + help="matorage endpoint, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html", |
201 | ) | 204 | ) |
202 | parser.add_argument( | 205 | parser.add_argument( |
203 | "--region", | 206 | "--region", |
204 | type=str, | 207 | type=str, |
205 | default=None, | 208 | default=None, |
206 | - help='matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html' | 209 | + help="matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html", |
207 | ) | 210 | ) |
208 | parser.add_argument( | 211 | parser.add_argument( |
209 | "--tokenizer_name", | 212 | "--tokenizer_name", |
210 | - default='sshleifer/distilbart-xsum-6-6', | 213 | + default="sshleifer/distilbart-xsum-6-6", |
211 | type=str, | 214 | type=str, |
212 | help="Pretrained tokenizer name or path if not the same as model_name", | 215 | help="Pretrained tokenizer name or path if not the same as model_name", |
213 | ) | 216 | ) |
... | @@ -215,41 +218,40 @@ if __name__ == "__main__": | ... | @@ -215,41 +218,40 @@ if __name__ == "__main__": |
215 | "--matorage_batch", | 218 | "--matorage_batch", |
216 | default=1024, | 219 | default=1024, |
217 | type=int, | 220 | type=int, |
218 | - help='The smallest batch size stored atomically in matorage.' | 221 | + help="The smallest batch size stored atomically in matorage.", |
219 | ) | 222 | ) |
220 | parser.add_argument( | 223 | parser.add_argument( |
221 | - "--num_workers", | 224 | + "--num_workers", default=4, type=int, help="number of process", |
222 | - default=4, | ||
223 | - type=int, | ||
224 | - help="number of process", | ||
225 | ) | 225 | ) |
226 | parser.add_argument( | 226 | parser.add_argument( |
227 | "--max_source_length", | 227 | "--max_source_length", |
228 | default=1024, | 228 | default=1024, |
229 | type=int, | 229 | type=int, |
230 | help="The maximum total input sequence length after tokenization. Sequences longer " | 230 | help="The maximum total input sequence length after tokenization. Sequences longer " |
231 | - "than this will be truncated, sequences shorter will be padded.", | 231 | + "than this will be truncated, sequences shorter will be padded.", |
232 | ) | 232 | ) |
233 | parser.add_argument( | 233 | parser.add_argument( |
234 | "--max_target_length", | 234 | "--max_target_length", |
235 | default=56, | 235 | default=56, |
236 | type=int, | 236 | type=int, |
237 | help="The maximum total input sequence length after tokenization. Sequences longer " | 237 | help="The maximum total input sequence length after tokenization. Sequences longer " |
238 | - "than this will be truncated, sequences shorter will be padded.", | 238 | + "than this will be truncated, sequences shorter will be padded.", |
239 | ) | 239 | ) |
240 | parser.add_argument( | 240 | parser.add_argument( |
241 | "--val_max_target_length", | 241 | "--val_max_target_length", |
242 | default=142, # these defaults are optimized for CNNDM. For xsum, see README.md. | 242 | default=142, # these defaults are optimized for CNNDM. For xsum, see README.md. |
243 | type=int, | 243 | type=int, |
244 | help="The maximum total input sequence length after tokenization. Sequences longer " | 244 | help="The maximum total input sequence length after tokenization. Sequences longer " |
245 | - "than this will be truncated, sequences shorter will be padded.", | 245 | + "than this will be truncated, sequences shorter will be padded.", |
246 | + ) | ||
247 | + parser.add_argument( | ||
248 | + "--p_val", type=float, default=0.25, help="percent of validation dataset" | ||
246 | ) | 249 | ) |
247 | - parser.add_argument("--p_val", type=float, default=0.25, help="percent of validation dataset") | ||
248 | parser.add_argument("--do_train", action="store_true", default=False) | 250 | parser.add_argument("--do_train", action="store_true", default=False) |
249 | parser.add_argument("--do_predict", action="store_true", default=False) | 251 | parser.add_argument("--do_predict", action="store_true", default=False) |
250 | args = parser.parse_args() | 252 | args = parser.parse_args() |
251 | 253 | ||
252 | - args.local_path = args.url.split('/')[-1] | 254 | + args.local_path = args.url.split("/")[-1] |
253 | logger.info(f"master branch of {args.url} will be downloaded to {args.local_path}") | 255 | logger.info(f"master branch of {args.url} will be downloaded to {args.local_path}") |
254 | repo = ( | 256 | repo = ( |
255 | Repo(args.local_path) | 257 | Repo(args.local_path) | ... | ... |
1 | # Copyright 2020-present Tae Hwan Jung | 1 | # Copyright 2020-present Tae Hwan Jung |
2 | -# | 2 | +# |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | # you may not use this file except in compliance with the License. | 4 | # you may not use this file except in compliance with the License. |
5 | # You may obtain a copy of the License at | 5 | # You may obtain a copy of the License at |
6 | -# | 6 | +# |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 |
8 | -# | 8 | +# |
9 | # Unless required by applicable law or agreed to in writing, software | 9 | # Unless required by applicable law or agreed to in writing, software |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | 10 | # distributed under the License is distributed on an "AS IS" BASIS, |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
... | @@ -14,6 +14,4 @@ | ... | @@ -14,6 +14,4 @@ |
14 | 14 | ||
15 | from .modeling_bart import BartForConditionalGeneration | 15 | from .modeling_bart import BartForConditionalGeneration |
16 | 16 | ||
17 | -__all__ = [ | ||
18 | - 'BartForConditionalGeneration' | ||
19 | -] | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
17 | +__all__ = ["BartForConditionalGeneration"] | ... | ... |
... | @@ -20,16 +20,31 @@ logger = logging.getLogger(__name__) | ... | @@ -20,16 +20,31 @@ logger = logging.getLogger(__name__) |
20 | 20 | ||
21 | class Seq2SeqLoggingCallback(pl.Callback): | 21 | class Seq2SeqLoggingCallback(pl.Callback): |
22 | def on_batch_end(self, trainer, pl_module): | 22 | def on_batch_end(self, trainer, pl_module): |
23 | - lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)} | 23 | + lrs = { |
24 | + f"lr_group_{i}": param["lr"] | ||
25 | + for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups) | ||
26 | + } | ||
24 | pl_module.logger.log_metrics(lrs) | 27 | pl_module.logger.log_metrics(lrs) |
25 | 28 | ||
26 | @rank_zero_only | 29 | @rank_zero_only |
27 | def _write_logs( | 30 | def _write_logs( |
28 | - self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True | 31 | + self, |
32 | + trainer: pl.Trainer, | ||
33 | + pl_module: pl.LightningModule, | ||
34 | + type_path: str, | ||
35 | + save_generations=True, | ||
29 | ) -> None: | 36 | ) -> None: |
30 | - logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****") | 37 | + logger.info( |
38 | + f"***** {type_path} results at step {trainer.global_step:05d} *****" | ||
39 | + ) | ||
31 | metrics = trainer.callback_metrics | 40 | metrics = trainer.callback_metrics |
32 | - trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]}) | 41 | + trainer.logger.log_metrics( |
42 | + { | ||
43 | + k: v | ||
44 | + for k, v in metrics.items() | ||
45 | + if k not in ["log", "progress_bar", "preds"] | ||
46 | + } | ||
47 | + ) | ||
33 | # Log results | 48 | # Log results |
34 | od = Path(pl_module.hparams.output_dir) | 49 | od = Path(pl_module.hparams.output_dir) |
35 | if type_path == "test": | 50 | if type_path == "test": |
... | @@ -39,7 +54,9 @@ class Seq2SeqLoggingCallback(pl.Callback): | ... | @@ -39,7 +54,9 @@ class Seq2SeqLoggingCallback(pl.Callback): |
39 | # this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json | 54 | # this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json |
40 | # If people want this it will be easy enough to add back. | 55 | # If people want this it will be easy enough to add back. |
41 | results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt" | 56 | results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt" |
42 | - generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt" | 57 | + generations_file = ( |
58 | + od / f"{type_path}_generations/{trainer.global_step:05d}.txt" | ||
59 | + ) | ||
43 | results_file.parent.mkdir(exist_ok=True) | 60 | results_file.parent.mkdir(exist_ok=True) |
44 | generations_file.parent.mkdir(exist_ok=True) | 61 | generations_file.parent.mkdir(exist_ok=True) |
45 | with open(results_file, "a+") as writer: | 62 | with open(results_file, "a+") as writer: |
... | @@ -68,7 +85,9 @@ class Seq2SeqLoggingCallback(pl.Callback): | ... | @@ -68,7 +85,9 @@ class Seq2SeqLoggingCallback(pl.Callback): |
68 | 85 | ||
69 | n_trainable_pars = count_trainable_parameters(pl_module) | 86 | n_trainable_pars = count_trainable_parameters(pl_module) |
70 | # mp stands for million parameters | 87 | # mp stands for million parameters |
71 | - trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6}) | 88 | + trainer.logger.log_metrics( |
89 | + {"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6} | ||
90 | + ) | ||
72 | 91 | ||
73 | @rank_zero_only | 92 | @rank_zero_only |
74 | def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): | 93 | def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): |
... | @@ -98,8 +117,5 @@ def get_checkpoint_callback(output_dir, metric): | ... | @@ -98,8 +117,5 @@ def get_checkpoint_callback(output_dir, metric): |
98 | 117 | ||
99 | def get_early_stopping_callback(metric, patience): | 118 | def get_early_stopping_callback(metric, patience): |
100 | return EarlyStopping( | 119 | return EarlyStopping( |
101 | - monitor=f"val_{metric}", | 120 | + monitor=f"val_{metric}", mode="max", patience=patience, verbose=True, |
102 | - mode="max", | ||
103 | - patience=patience, | ||
104 | - verbose=True, | ||
105 | ) | 121 | ) | ... | ... |
... | @@ -21,7 +21,11 @@ from matorage.torch import Dataset | ... | @@ -21,7 +21,11 @@ from matorage.torch import Dataset |
21 | 21 | ||
22 | 22 | ||
23 | try: | 23 | try: |
24 | - from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback | 24 | + from .callbacks import ( |
25 | + Seq2SeqLoggingCallback, | ||
26 | + get_checkpoint_callback, | ||
27 | + get_early_stopping_callback, | ||
28 | + ) | ||
25 | from .utils import ( | 29 | from .utils import ( |
26 | ROUGE_KEYS, | 30 | ROUGE_KEYS, |
27 | LegacySeq2SeqDataset, | 31 | LegacySeq2SeqDataset, |
... | @@ -40,7 +44,11 @@ try: | ... | @@ -40,7 +44,11 @@ try: |
40 | use_task_specific_params, | 44 | use_task_specific_params, |
41 | ) | 45 | ) |
42 | except ImportError: | 46 | except ImportError: |
43 | - from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback | 47 | + from callbacks import ( |
48 | + Seq2SeqLoggingCallback, | ||
49 | + get_checkpoint_callback, | ||
50 | + get_early_stopping_callback, | ||
51 | + ) | ||
44 | from utils import ( | 52 | from utils import ( |
45 | ROUGE_KEYS, | 53 | ROUGE_KEYS, |
46 | LegacySeq2SeqDataset, | 54 | LegacySeq2SeqDataset, |
... | @@ -83,8 +91,12 @@ class SummarizationModule(BaseTransformer): | ... | @@ -83,8 +91,12 @@ class SummarizationModule(BaseTransformer): |
83 | "val": self.hparams.val_max_target_length, | 91 | "val": self.hparams.val_max_target_length, |
84 | "test": self.hparams.test_max_target_length, | 92 | "test": self.hparams.test_max_target_length, |
85 | } | 93 | } |
86 | - assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}" | 94 | + assert ( |
87 | - assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}" | 95 | + self.target_lens["train"] <= self.target_lens["val"] |
96 | + ), f"target_lens: {self.target_lens}" | ||
97 | + assert ( | ||
98 | + self.target_lens["train"] <= self.target_lens["test"] | ||
99 | + ), f"target_lens: {self.target_lens}" | ||
88 | 100 | ||
89 | if self.hparams.freeze_embeds: | 101 | if self.hparams.freeze_embeds: |
90 | self.freeze_embeds() | 102 | self.freeze_embeds() |
... | @@ -95,13 +107,27 @@ class SummarizationModule(BaseTransformer): | ... | @@ -95,13 +107,27 @@ class SummarizationModule(BaseTransformer): |
95 | self.hparams.git_sha = get_git_info()["repo_sha"] | 107 | self.hparams.git_sha = get_git_info()["repo_sha"] |
96 | self.num_workers = hparams.num_workers | 108 | self.num_workers = hparams.num_workers |
97 | self.decoder_start_token_id = None # default to config | 109 | self.decoder_start_token_id = None # default to config |
98 | - if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer): | 110 | + if self.model.config.decoder_start_token_id is None and isinstance( |
99 | - self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang] | 111 | + self.tokenizer, MBartTokenizer |
112 | + ): | ||
113 | + self.decoder_start_token_id = self.tokenizer.lang_code_to_id[ | ||
114 | + hparams.tgt_lang | ||
115 | + ] | ||
100 | self.model.config.decoder_start_token_id = self.decoder_start_token_id | 116 | self.model.config.decoder_start_token_id = self.decoder_start_token_id |
101 | 117 | ||
102 | - self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams | 118 | + self.eval_beams = ( |
103 | - assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1" | 119 | + self.model.config.num_beams |
104 | - self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric | 120 | + if self.hparams.eval_beams is None |
121 | + else self.hparams.eval_beams | ||
122 | + ) | ||
123 | + assert ( | ||
124 | + self.eval_beams >= 1 | ||
125 | + ), f"got self.eval_beams={self.eval_beams}. Need an integer > 1" | ||
126 | + self.val_metric = ( | ||
127 | + self.default_val_metric | ||
128 | + if self.hparams.val_metric is None | ||
129 | + else self.hparams.val_metric | ||
130 | + ) | ||
105 | 131 | ||
106 | def freeze_embeds(self): | 132 | def freeze_embeds(self): |
107 | """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" | 133 | """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" |
... | @@ -133,7 +159,13 @@ class SummarizationModule(BaseTransformer): | ... | @@ -133,7 +159,13 @@ class SummarizationModule(BaseTransformer): |
133 | else: | 159 | else: |
134 | decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id) | 160 | decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id) |
135 | 161 | ||
136 | - outputs = self(src_ids, src_patch, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False) | 162 | + outputs = self( |
163 | + src_ids, | ||
164 | + src_patch, | ||
165 | + attention_mask=src_mask, | ||
166 | + decoder_input_ids=decoder_input_ids, | ||
167 | + use_cache=False, | ||
168 | + ) | ||
137 | lm_logits = outputs[0] | 169 | lm_logits = outputs[0] |
138 | if self.hparams.label_smoothing == 0: | 170 | if self.hparams.label_smoothing == 0: |
139 | # Same behavior as modeling_bart.py, besides ignoring pad_token_id | 171 | # Same behavior as modeling_bart.py, besides ignoring pad_token_id |
... | @@ -157,7 +189,9 @@ class SummarizationModule(BaseTransformer): | ... | @@ -157,7 +189,9 @@ class SummarizationModule(BaseTransformer): |
157 | 189 | ||
158 | logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} | 190 | logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} |
159 | # tokens per batch | 191 | # tokens per batch |
160 | - logs["tpb"] = batch[0].long().ne(self.pad).sum() + batch[3].long().ne(self.pad).sum() | 192 | + logs["tpb"] = ( |
193 | + batch[0].long().ne(self.pad).sum() + batch[3].long().ne(self.pad).sum() | ||
194 | + ) | ||
161 | return {"loss": loss_tensors[0], "log": logs} | 195 | return {"loss": loss_tensors[0], "log": logs} |
162 | 196 | ||
163 | def validation_step(self, batch, batch_idx) -> Dict: | 197 | def validation_step(self, batch, batch_idx) -> Dict: |
... | @@ -165,17 +199,29 @@ class SummarizationModule(BaseTransformer): | ... | @@ -165,17 +199,29 @@ class SummarizationModule(BaseTransformer): |
165 | 199 | ||
166 | def validation_epoch_end(self, outputs, prefix="val") -> Dict: | 200 | def validation_epoch_end(self, outputs, prefix="val") -> Dict: |
167 | self.step_count += 1 | 201 | self.step_count += 1 |
168 | - losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names} | 202 | + losses = { |
203 | + k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names | ||
204 | + } | ||
169 | loss = losses["loss"] | 205 | loss = losses["loss"] |
170 | - rouges = {k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]} | 206 | + rouges = { |
171 | - rouge_tensor: torch.FloatTensor = torch.tensor(rouges[self.val_metric]).type_as(loss) | 207 | + k: np.array([x[k] for x in outputs]).mean() |
208 | + for k in self.metric_names + ["gen_time", "gen_len"] | ||
209 | + } | ||
210 | + rouge_tensor: torch.FloatTensor = torch.tensor(rouges[self.val_metric]).type_as( | ||
211 | + loss | ||
212 | + ) | ||
172 | rouges.update({k: v.item() for k, v in losses.items()}) | 213 | rouges.update({k: v.item() for k, v in losses.items()}) |
173 | losses.update(rouges) | 214 | losses.update(rouges) |
174 | metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()} | 215 | metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()} |
175 | metrics["step_count"] = self.step_count | 216 | metrics["step_count"] = self.step_count |
176 | self.save_metrics(metrics, prefix) # writes to self.metrics_save_path | 217 | self.save_metrics(metrics, prefix) # writes to self.metrics_save_path |
177 | preds = flatten_list([x["preds"] for x in outputs]) | 218 | preds = flatten_list([x["preds"] for x in outputs]) |
178 | - return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_{self.val_metric}": rouge_tensor} | 219 | + return { |
220 | + "log": metrics, | ||
221 | + "preds": preds, | ||
222 | + f"{prefix}_loss": loss, | ||
223 | + f"{prefix}_{self.val_metric}": rouge_tensor, | ||
224 | + } | ||
179 | 225 | ||
180 | def save_metrics(self, latest_metrics, type_path) -> None: | 226 | def save_metrics(self, latest_metrics, type_path) -> None: |
181 | self.metrics[type_path].append(latest_metrics) | 227 | self.metrics[type_path].append(latest_metrics) |
... | @@ -200,7 +246,9 @@ class SummarizationModule(BaseTransformer): | ... | @@ -200,7 +246,9 @@ class SummarizationModule(BaseTransformer): |
200 | base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} | 246 | base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} |
201 | rouge: Dict = self.calc_generative_metrics(preds, target) | 247 | rouge: Dict = self.calc_generative_metrics(preds, target) |
202 | summ_len = np.mean(lmap(len, generated_ids)) | 248 | summ_len = np.mean(lmap(len, generated_ids)) |
203 | - base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge) | 249 | + base_metrics.update( |
250 | + gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge | ||
251 | + ) | ||
204 | return base_metrics | 252 | return base_metrics |
205 | 253 | ||
206 | def test_step(self, batch, batch_idx): | 254 | def test_step(self, batch, batch_idx): |
... | @@ -213,10 +261,10 @@ class SummarizationModule(BaseTransformer): | ... | @@ -213,10 +261,10 @@ class SummarizationModule(BaseTransformer): |
213 | max_target_length = self.target_lens[type_path] | 261 | max_target_length = self.target_lens[type_path] |
214 | data_config = DataConfig( | 262 | data_config = DataConfig( |
215 | endpoint=args.endpoint, | 263 | endpoint=args.endpoint, |
216 | - access_key=os.environ['access_key'], | 264 | + access_key=os.environ["access_key"], |
217 | - secret_key=os.environ['secret_key'], | 265 | + secret_key=os.environ["secret_key"], |
218 | region=args.region, | 266 | region=args.region, |
219 | - dataset_name='commit-autosuggestions', | 267 | + dataset_name="commit-autosuggestions", |
220 | additional={ | 268 | additional={ |
221 | "mode": ("training" if type_path == "train" else "evaluation"), | 269 | "mode": ("training" if type_path == "train" else "evaluation"), |
222 | "max_source_length": self.hparams.max_source_length, | 270 | "max_source_length": self.hparams.max_source_length, |
... | @@ -224,15 +272,17 @@ class SummarizationModule(BaseTransformer): | ... | @@ -224,15 +272,17 @@ class SummarizationModule(BaseTransformer): |
224 | "url": args.url, | 272 | "url": args.url, |
225 | }, | 273 | }, |
226 | attributes=[ | 274 | attributes=[ |
227 | - ('input_ids', 'int32', (self.hparams.max_source_length,)), | 275 | + ("input_ids", "int32", (self.hparams.max_source_length,)), |
228 | - ('attention_masks', 'int32', (self.hparams.max_source_length,)), | 276 | + ("attention_masks", "int32", (self.hparams.max_source_length,)), |
229 | - ('patch_ids', 'int32', (self.hparams.max_source_length,)), | 277 | + ("patch_ids", "int32", (self.hparams.max_source_length,)), |
230 | - ('targets', 'int32', (max_target_length,)) | 278 | + ("targets", "int32", (max_target_length,)), |
231 | - ] | 279 | + ], |
232 | ) | 280 | ) |
233 | return Dataset(config=data_config, clear=True) | 281 | return Dataset(config=data_config, clear=True) |
234 | 282 | ||
235 | - def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader: | 283 | + def get_dataloader( |
284 | + self, type_path: str, batch_size: int, shuffle: bool = False | ||
285 | + ) -> DataLoader: | ||
236 | dataset = self.get_dataset(type_path) | 286 | dataset = self.get_dataset(type_path) |
237 | sampler = None | 287 | sampler = None |
238 | 288 | ||
... | @@ -246,7 +296,9 @@ class SummarizationModule(BaseTransformer): | ... | @@ -246,7 +296,9 @@ class SummarizationModule(BaseTransformer): |
246 | return dataloader | 296 | return dataloader |
247 | 297 | ||
248 | def train_dataloader(self) -> DataLoader: | 298 | def train_dataloader(self) -> DataLoader: |
249 | - dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True) | 299 | + dataloader = self.get_dataloader( |
300 | + "train", batch_size=self.hparams.train_batch_size, shuffle=True | ||
301 | + ) | ||
250 | return dataloader | 302 | return dataloader |
251 | 303 | ||
252 | def val_dataloader(self) -> DataLoader: | 304 | def val_dataloader(self) -> DataLoader: |
... | @@ -259,23 +311,18 @@ class SummarizationModule(BaseTransformer): | ... | @@ -259,23 +311,18 @@ class SummarizationModule(BaseTransformer): |
259 | def add_model_specific_args(parser, root_dir): | 311 | def add_model_specific_args(parser, root_dir): |
260 | BaseTransformer.add_model_specific_args(parser, root_dir) | 312 | BaseTransformer.add_model_specific_args(parser, root_dir) |
261 | add_generic_args(parser, root_dir) | 313 | add_generic_args(parser, root_dir) |
262 | - parser.add_argument( | 314 | + parser.add_argument("--url", type=str, required=True, help="github url") |
263 | - "--url", | ||
264 | - type=str, | ||
265 | - required=True, | ||
266 | - help="github url" | ||
267 | - ) | ||
268 | parser.add_argument( | 315 | parser.add_argument( |
269 | "--endpoint", | 316 | "--endpoint", |
270 | type=str, | 317 | type=str, |
271 | required=True, | 318 | required=True, |
272 | - help='matorage endpoint, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html' | 319 | + help="matorage endpoint, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html", |
273 | ) | 320 | ) |
274 | parser.add_argument( | 321 | parser.add_argument( |
275 | "--region", | 322 | "--region", |
276 | type=str, | 323 | type=str, |
277 | default=None, | 324 | default=None, |
278 | - help='matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html' | 325 | + help="matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html", |
279 | ) | 326 | ) |
280 | parser.add_argument( | 327 | parser.add_argument( |
281 | "--max_source_length", | 328 | "--max_source_length", |
... | @@ -308,14 +355,43 @@ class SummarizationModule(BaseTransformer): | ... | @@ -308,14 +355,43 @@ class SummarizationModule(BaseTransformer): |
308 | parser.add_argument("--freeze_encoder", action="store_true") | 355 | parser.add_argument("--freeze_encoder", action="store_true") |
309 | parser.add_argument("--freeze_embeds", action="store_true") | 356 | parser.add_argument("--freeze_embeds", action="store_true") |
310 | parser.add_argument("--sortish_sampler", action="store_true", default=False) | 357 | parser.add_argument("--sortish_sampler", action="store_true", default=False) |
311 | - parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default") | ||
312 | - parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.") | ||
313 | - parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.") | ||
314 | - parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.") | ||
315 | parser.add_argument( | 358 | parser.add_argument( |
316 | - "--task", type=str, default="summarization", required=False, help="# examples. -1 means use all." | 359 | + "--logger_name", |
360 | + type=str, | ||
361 | + choices=["default", "wandb", "wandb_shared"], | ||
362 | + default="default", | ||
363 | + ) | ||
364 | + parser.add_argument( | ||
365 | + "--n_train", | ||
366 | + type=int, | ||
367 | + default=-1, | ||
368 | + required=False, | ||
369 | + help="# examples. -1 means use all.", | ||
370 | + ) | ||
371 | + parser.add_argument( | ||
372 | + "--n_val", | ||
373 | + type=int, | ||
374 | + default=500, | ||
375 | + required=False, | ||
376 | + help="# examples. -1 means use all.", | ||
377 | + ) | ||
378 | + parser.add_argument( | ||
379 | + "--n_test", | ||
380 | + type=int, | ||
381 | + default=-1, | ||
382 | + required=False, | ||
383 | + help="# examples. -1 means use all.", | ||
384 | + ) | ||
385 | + parser.add_argument( | ||
386 | + "--task", | ||
387 | + type=str, | ||
388 | + default="summarization", | ||
389 | + required=False, | ||
390 | + help="# examples. -1 means use all.", | ||
391 | + ) | ||
392 | + parser.add_argument( | ||
393 | + "--label_smoothing", type=float, default=0.0, required=False | ||
317 | ) | 394 | ) |
318 | - parser.add_argument("--label_smoothing", type=float, default=0.0, required=False) | ||
319 | parser.add_argument("--src_lang", type=str, default="", required=False) | 395 | parser.add_argument("--src_lang", type=str, default="", required=False) |
320 | parser.add_argument("--tgt_lang", type=str, default="", required=False) | 396 | parser.add_argument("--tgt_lang", type=str, default="", required=False) |
321 | parser.add_argument("--eval_beams", type=int, default=None, required=False) | 397 | parser.add_argument("--eval_beams", type=int, default=None, required=False) |
... | @@ -348,7 +424,11 @@ class TranslationModule(SummarizationModule): | ... | @@ -348,7 +424,11 @@ class TranslationModule(SummarizationModule): |
348 | def main(args, model=None) -> SummarizationModule: | 424 | def main(args, model=None) -> SummarizationModule: |
349 | Path(args.output_dir).mkdir(exist_ok=True) | 425 | Path(args.output_dir).mkdir(exist_ok=True) |
350 | if len(os.listdir(args.output_dir)) > 3 and args.do_train: | 426 | if len(os.listdir(args.output_dir)) > 3 and args.do_train: |
351 | - raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) | 427 | + raise ValueError( |
428 | + "Output directory ({}) already exists and is not empty.".format( | ||
429 | + args.output_dir | ||
430 | + ) | ||
431 | + ) | ||
352 | if model is None: | 432 | if model is None: |
353 | if args.task == "summarization": | 433 | if args.task == "summarization": |
354 | model: SummarizationModule = SummarizationModule(args) | 434 | model: SummarizationModule = SummarizationModule(args) |
... | @@ -371,7 +451,9 @@ def main(args, model=None) -> SummarizationModule: | ... | @@ -371,7 +451,9 @@ def main(args, model=None) -> SummarizationModule: |
371 | return model | 451 | return model |
372 | 452 | ||
373 | model.hparams.test_checkpoint = "" | 453 | model.hparams.test_checkpoint = "" |
374 | - checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True))) | 454 | + checkpoints = list( |
455 | + sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)) | ||
456 | + ) | ||
375 | if checkpoints: | 457 | if checkpoints: |
376 | model.hparams.test_checkpoint = checkpoints[-1] | 458 | model.hparams.test_checkpoint = checkpoints[-1] |
377 | trainer.resume_from_checkpoint = checkpoints[-1] | 459 | trainer.resume_from_checkpoint = checkpoints[-1] | ... | ... |
... | @@ -30,6 +30,7 @@ logging.basicConfig( | ... | @@ -30,6 +30,7 @@ logging.basicConfig( |
30 | level=logging.INFO, | 30 | level=logging.INFO, |
31 | ) | 31 | ) |
32 | 32 | ||
33 | + | ||
33 | class GenerationMixin: | 34 | class GenerationMixin: |
34 | """ | 35 | """ |
35 | A class contraining all of the functions supporting generation, to be used as a mixin in | 36 | A class contraining all of the functions supporting generation, to be used as a mixin in |
... | @@ -50,7 +51,9 @@ class GenerationMixin: | ... | @@ -50,7 +51,9 @@ class GenerationMixin: |
50 | """ | 51 | """ |
51 | return logits | 52 | return logits |
52 | 53 | ||
53 | - def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty): | 54 | + def enforce_repetition_penalty_( |
55 | + self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty | ||
56 | + ): | ||
54 | """ | 57 | """ |
55 | Enforce the repetition penalty (from the `CTRL paper <https://arxiv.org/abs/1909.05858>`__). | 58 | Enforce the repetition penalty (from the `CTRL paper <https://arxiv.org/abs/1909.05858>`__). |
56 | """ | 59 | """ |
... | @@ -79,11 +82,7 @@ class GenerationMixin: | ... | @@ -79,11 +82,7 @@ class GenerationMixin: |
79 | # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) | 82 | # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) |
80 | if repetition_penalty != 1.0: | 83 | if repetition_penalty != 1.0: |
81 | self.enforce_repetition_penalty_( | 84 | self.enforce_repetition_penalty_( |
82 | - scores, | 85 | + scores, batch_size, num_beams, input_ids, repetition_penalty, |
83 | - batch_size, | ||
84 | - num_beams, | ||
85 | - input_ids, | ||
86 | - repetition_penalty, | ||
87 | ) | 86 | ) |
88 | 87 | ||
89 | # set eos token prob to zero if min_length is not reached | 88 | # set eos token prob to zero if min_length is not reached |
... | @@ -102,7 +101,11 @@ class GenerationMixin: | ... | @@ -102,7 +101,11 @@ class GenerationMixin: |
102 | 101 | ||
103 | if bad_words_ids is not None: | 102 | if bad_words_ids is not None: |
104 | # Exclude EOS token (already processed) | 103 | # Exclude EOS token (already processed) |
105 | - bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids)) | 104 | + bad_words_ids = list( |
105 | + filter( | ||
106 | + lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids | ||
107 | + ) | ||
108 | + ) | ||
106 | # calculate a list of banned tokens according to bad words | 109 | # calculate a list of banned tokens according to bad words |
107 | banned_tokens = calc_banned_bad_words_ids(input_ids.tolist(), bad_words_ids) | 110 | banned_tokens = calc_banned_bad_words_ids(input_ids.tolist(), bad_words_ids) |
108 | # Modify the scores in place by setting the banned tokens logits to `-inf` | 111 | # Modify the scores in place by setting the banned tokens logits to `-inf` |
... | @@ -134,7 +137,7 @@ class GenerationMixin: | ... | @@ -134,7 +137,7 @@ class GenerationMixin: |
134 | attention_mask: Optional[torch.LongTensor] = None, | 137 | attention_mask: Optional[torch.LongTensor] = None, |
135 | decoder_start_token_id: Optional[int] = None, | 138 | decoder_start_token_id: Optional[int] = None, |
136 | use_cache: Optional[bool] = None, | 139 | use_cache: Optional[bool] = None, |
137 | - **model_kwargs | 140 | + **model_kwargs, |
138 | ) -> torch.LongTensor: | 141 | ) -> torch.LongTensor: |
139 | r""" | 142 | r""" |
140 | Generates sequences for models with a language modeling head. The method currently supports greedy decoding, | 143 | Generates sequences for models with a language modeling head. The method currently supports greedy decoding, |
... | @@ -262,26 +265,50 @@ class GenerationMixin: | ... | @@ -262,26 +265,50 @@ class GenerationMixin: |
262 | max_length = max_length if max_length is not None else self.config.max_length | 265 | max_length = max_length if max_length is not None else self.config.max_length |
263 | min_length = min_length if min_length is not None else self.config.min_length | 266 | min_length = min_length if min_length is not None else self.config.min_length |
264 | do_sample = do_sample if do_sample is not None else self.config.do_sample | 267 | do_sample = do_sample if do_sample is not None else self.config.do_sample |
265 | - early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping | 268 | + early_stopping = ( |
269 | + early_stopping if early_stopping is not None else self.config.early_stopping | ||
270 | + ) | ||
266 | use_cache = use_cache if use_cache is not None else self.config.use_cache | 271 | use_cache = use_cache if use_cache is not None else self.config.use_cache |
267 | num_beams = num_beams if num_beams is not None else self.config.num_beams | 272 | num_beams = num_beams if num_beams is not None else self.config.num_beams |
268 | - temperature = temperature if temperature is not None else self.config.temperature | 273 | + temperature = ( |
274 | + temperature if temperature is not None else self.config.temperature | ||
275 | + ) | ||
269 | top_k = top_k if top_k is not None else self.config.top_k | 276 | top_k = top_k if top_k is not None else self.config.top_k |
270 | top_p = top_p if top_p is not None else self.config.top_p | 277 | top_p = top_p if top_p is not None else self.config.top_p |
271 | - repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty | 278 | + repetition_penalty = ( |
272 | - bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id | 279 | + repetition_penalty |
273 | - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id | 280 | + if repetition_penalty is not None |
274 | - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id | 281 | + else self.config.repetition_penalty |
275 | - length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty | 282 | + ) |
283 | + bos_token_id = ( | ||
284 | + bos_token_id if bos_token_id is not None else self.config.bos_token_id | ||
285 | + ) | ||
286 | + pad_token_id = ( | ||
287 | + pad_token_id if pad_token_id is not None else self.config.pad_token_id | ||
288 | + ) | ||
289 | + eos_token_id = ( | ||
290 | + eos_token_id if eos_token_id is not None else self.config.eos_token_id | ||
291 | + ) | ||
292 | + length_penalty = ( | ||
293 | + length_penalty if length_penalty is not None else self.config.length_penalty | ||
294 | + ) | ||
276 | no_repeat_ngram_size = ( | 295 | no_repeat_ngram_size = ( |
277 | - no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size | 296 | + no_repeat_ngram_size |
297 | + if no_repeat_ngram_size is not None | ||
298 | + else self.config.no_repeat_ngram_size | ||
299 | + ) | ||
300 | + bad_words_ids = ( | ||
301 | + bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids | ||
278 | ) | 302 | ) |
279 | - bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids | ||
280 | num_return_sequences = ( | 303 | num_return_sequences = ( |
281 | - num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences | 304 | + num_return_sequences |
305 | + if num_return_sequences is not None | ||
306 | + else self.config.num_return_sequences | ||
282 | ) | 307 | ) |
283 | decoder_start_token_id = ( | 308 | decoder_start_token_id = ( |
284 | - decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id | 309 | + decoder_start_token_id |
310 | + if decoder_start_token_id is not None | ||
311 | + else self.config.decoder_start_token_id | ||
285 | ) | 312 | ) |
286 | 313 | ||
287 | if input_ids is not None: | 314 | if input_ids is not None: |
... | @@ -289,14 +316,22 @@ class GenerationMixin: | ... | @@ -289,14 +316,22 @@ class GenerationMixin: |
289 | else: | 316 | else: |
290 | batch_size = 1 | 317 | batch_size = 1 |
291 | 318 | ||
292 | - assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer." | 319 | + assert ( |
293 | - assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer." | 320 | + isinstance(max_length, int) and max_length > 0 |
321 | + ), "`max_length` should be a strictly positive integer." | ||
322 | + assert ( | ||
323 | + isinstance(min_length, int) and min_length >= 0 | ||
324 | + ), "`min_length` should be a positive integer." | ||
294 | assert isinstance(do_sample, bool), "`do_sample` should be a boolean." | 325 | assert isinstance(do_sample, bool), "`do_sample` should be a boolean." |
295 | assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean." | 326 | assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean." |
296 | assert isinstance(use_cache, bool), "`use_cache` should be a boolean." | 327 | assert isinstance(use_cache, bool), "`use_cache` should be a boolean." |
297 | - assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer." | 328 | + assert ( |
329 | + isinstance(num_beams, int) and num_beams > 0 | ||
330 | + ), "`num_beams` should be a strictly positive integer." | ||
298 | assert temperature > 0, "`temperature` should be strictly positive." | 331 | assert temperature > 0, "`temperature` should be strictly positive." |
299 | - assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer." | 332 | + assert ( |
333 | + isinstance(top_k, int) and top_k >= 0 | ||
334 | + ), "`top_k` should be a positive integer." | ||
300 | assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1." | 335 | assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1." |
301 | assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1." | 336 | assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1." |
302 | assert input_ids is not None or ( | 337 | assert input_ids is not None or ( |
... | @@ -316,7 +351,9 @@ class GenerationMixin: | ... | @@ -316,7 +351,9 @@ class GenerationMixin: |
316 | isinstance(num_return_sequences, int) and num_return_sequences > 0 | 351 | isinstance(num_return_sequences, int) and num_return_sequences > 0 |
317 | ), "`num_return_sequences` should be a strictly positive integer." | 352 | ), "`num_return_sequences` should be a strictly positive integer." |
318 | assert ( | 353 | assert ( |
319 | - bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list) | 354 | + bad_words_ids is None |
355 | + or isinstance(bad_words_ids, list) | ||
356 | + and isinstance(bad_words_ids[0], list) | ||
320 | ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated" | 357 | ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated" |
321 | 358 | ||
322 | if input_ids is None: | 359 | if input_ids is None: |
... | @@ -331,7 +368,9 @@ class GenerationMixin: | ... | @@ -331,7 +368,9 @@ class GenerationMixin: |
331 | device=next(self.parameters()).device, | 368 | device=next(self.parameters()).device, |
332 | ) | 369 | ) |
333 | else: | 370 | else: |
334 | - assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)." | 371 | + assert ( |
372 | + input_ids.dim() == 2 | ||
373 | + ), "Input prompt should be of shape (batch_size, sequence length)." | ||
335 | 374 | ||
336 | # not allow to duplicate outputs when greedy decoding | 375 | # not allow to duplicate outputs when greedy decoding |
337 | if do_sample is False: | 376 | if do_sample is False: |
... | @@ -349,7 +388,11 @@ class GenerationMixin: | ... | @@ -349,7 +388,11 @@ class GenerationMixin: |
349 | 388 | ||
350 | # create attention mask if necessary | 389 | # create attention mask if necessary |
351 | # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140 | 390 | # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140 |
352 | - if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids): | 391 | + if ( |
392 | + (attention_mask is None) | ||
393 | + and (pad_token_id is not None) | ||
394 | + and (pad_token_id in input_ids) | ||
395 | + ): | ||
353 | attention_mask = input_ids.ne(pad_token_id).long() | 396 | attention_mask = input_ids.ne(pad_token_id).long() |
354 | elif attention_mask is None: | 397 | elif attention_mask is None: |
355 | attention_mask = input_ids.new_ones(input_ids.shape) | 398 | attention_mask = input_ids.new_ones(input_ids.shape) |
... | @@ -358,7 +401,9 @@ class GenerationMixin: | ... | @@ -358,7 +401,9 @@ class GenerationMixin: |
358 | # attention_mask is created | 401 | # attention_mask is created |
359 | if pad_token_id is None and eos_token_id is not None: | 402 | if pad_token_id is None and eos_token_id is not None: |
360 | logger.warning( | 403 | logger.warning( |
361 | - "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id) | 404 | + "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format( |
405 | + eos_token_id | ||
406 | + ) | ||
362 | ) | 407 | ) |
363 | pad_token_id = eos_token_id | 408 | pad_token_id = eos_token_id |
364 | 409 | ||
... | @@ -385,25 +430,37 @@ class GenerationMixin: | ... | @@ -385,25 +430,37 @@ class GenerationMixin: |
385 | # see if BOS token can be used for decoder_start_token_id | 430 | # see if BOS token can be used for decoder_start_token_id |
386 | if bos_token_id is not None: | 431 | if bos_token_id is not None: |
387 | decoder_start_token_id = bos_token_id | 432 | decoder_start_token_id = bos_token_id |
388 | - elif hasattr(self.config, "decoder") and hasattr(self.config.decoder, "bos_token_id"): | 433 | + elif hasattr(self.config, "decoder") and hasattr( |
434 | + self.config.decoder, "bos_token_id" | ||
435 | + ): | ||
389 | decoder_start_token_id = self.config.decoder.bos_token_id | 436 | decoder_start_token_id = self.config.decoder.bos_token_id |
390 | else: | 437 | else: |
391 | raise ValueError( | 438 | raise ValueError( |
392 | "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation" | 439 | "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation" |
393 | ) | 440 | ) |
394 | 441 | ||
395 | - assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self) | 442 | + assert hasattr( |
396 | - assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder) | 443 | + self, "get_encoder" |
444 | + ), "{} should have a 'get_encoder' function defined".format(self) | ||
445 | + assert callable(self.get_encoder), "{} should be a method".format( | ||
446 | + self.get_encoder | ||
447 | + ) | ||
397 | 448 | ||
398 | # get encoder and store encoder outputs | 449 | # get encoder and store encoder outputs |
399 | encoder = self.get_encoder() | 450 | encoder = self.get_encoder() |
400 | - encoder_outputs: ModelOutput = encoder(input_ids, patch_ids, attention_mask=attention_mask, return_dict=True) | 451 | + encoder_outputs: ModelOutput = encoder( |
452 | + input_ids, patch_ids, attention_mask=attention_mask, return_dict=True | ||
453 | + ) | ||
401 | 454 | ||
402 | # Expand input ids if num_beams > 1 or num_return_sequences > 1 | 455 | # Expand input ids if num_beams > 1 or num_return_sequences > 1 |
403 | if num_return_sequences > 1 or num_beams > 1: | 456 | if num_return_sequences > 1 or num_beams > 1: |
404 | input_ids_len = input_ids.shape[-1] | 457 | input_ids_len = input_ids.shape[-1] |
405 | - input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len) | 458 | + input_ids = input_ids.unsqueeze(1).expand( |
406 | - patch_ids = patch_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len) | 459 | + batch_size, effective_batch_mult * num_beams, input_ids_len |
460 | + ) | ||
461 | + patch_ids = patch_ids.unsqueeze(1).expand( | ||
462 | + batch_size, effective_batch_mult * num_beams, input_ids_len | ||
463 | + ) | ||
407 | attention_mask = attention_mask.unsqueeze(1).expand( | 464 | attention_mask = attention_mask.unsqueeze(1).expand( |
408 | batch_size, effective_batch_mult * num_beams, input_ids_len | 465 | batch_size, effective_batch_mult * num_beams, input_ids_len |
409 | ) | 466 | ) |
... | @@ -442,9 +499,9 @@ class GenerationMixin: | ... | @@ -442,9 +499,9 @@ class GenerationMixin: |
442 | ) | 499 | ) |
443 | 500 | ||
444 | # expand encoder_outputs | 501 | # expand encoder_outputs |
445 | - encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( | 502 | + encoder_outputs[ |
446 | - 0, expanded_batch_idxs | 503 | + "last_hidden_state" |
447 | - ) | 504 | + ] = encoder_outputs.last_hidden_state.index_select(0, expanded_batch_idxs) |
448 | 505 | ||
449 | # save encoder_outputs in `model_kwargs` | 506 | # save encoder_outputs in `model_kwargs` |
450 | model_kwargs["encoder_outputs"] = encoder_outputs | 507 | model_kwargs["encoder_outputs"] = encoder_outputs |
... | @@ -534,7 +591,11 @@ class GenerationMixin: | ... | @@ -534,7 +591,11 @@ class GenerationMixin: |
534 | past = None | 591 | past = None |
535 | while cur_len < max_length: | 592 | while cur_len < max_length: |
536 | model_inputs = self.prepare_inputs_for_generation( | 593 | model_inputs = self.prepare_inputs_for_generation( |
537 | - input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs | 594 | + input_ids, |
595 | + past=past, | ||
596 | + attention_mask=attention_mask, | ||
597 | + use_cache=use_cache, | ||
598 | + **model_kwargs, | ||
538 | ) | 599 | ) |
539 | 600 | ||
540 | outputs = self(**model_inputs, return_dict=True) | 601 | outputs = self(**model_inputs, return_dict=True) |
... | @@ -565,7 +626,9 @@ class GenerationMixin: | ... | @@ -565,7 +626,9 @@ class GenerationMixin: |
565 | if temperature != 1.0: | 626 | if temperature != 1.0: |
566 | scores = scores / temperature | 627 | scores = scores / temperature |
567 | # Top-p/top-k filtering | 628 | # Top-p/top-k filtering |
568 | - next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p) | 629 | + next_token_logscores = top_k_top_p_filtering( |
630 | + scores, top_k=top_k, top_p=top_p | ||
631 | + ) | ||
569 | # Sample | 632 | # Sample |
570 | probs = F.softmax(next_token_logscores, dim=-1) | 633 | probs = F.softmax(next_token_logscores, dim=-1) |
571 | next_token = torch.multinomial(probs, num_samples=1).squeeze(1) | 634 | next_token = torch.multinomial(probs, num_samples=1).squeeze(1) |
... | @@ -576,7 +639,9 @@ class GenerationMixin: | ... | @@ -576,7 +639,9 @@ class GenerationMixin: |
576 | # update generations and finished sentences | 639 | # update generations and finished sentences |
577 | if eos_token_id is not None: | 640 | if eos_token_id is not None: |
578 | # pad finished sentences if eos_token_id exist | 641 | # pad finished sentences if eos_token_id exist |
579 | - tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents) | 642 | + tokens_to_add = next_token * unfinished_sents + (pad_token_id) * ( |
643 | + 1 - unfinished_sents | ||
644 | + ) | ||
580 | else: | 645 | else: |
581 | tokens_to_add = next_token | 646 | tokens_to_add = next_token |
582 | 647 | ||
... | @@ -587,8 +652,12 @@ class GenerationMixin: | ... | @@ -587,8 +652,12 @@ class GenerationMixin: |
587 | if eos_token_id is not None: | 652 | if eos_token_id is not None: |
588 | eos_in_sents = tokens_to_add == eos_token_id | 653 | eos_in_sents = tokens_to_add == eos_token_id |
589 | # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length | 654 | # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length |
590 | - is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool() | 655 | + is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul( |
591 | - sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len) | 656 | + eos_in_sents.long() |
657 | + ).bool() | ||
658 | + sent_lengths.masked_fill_( | ||
659 | + is_sents_unfinished_and_token_to_add_is_eos, cur_len | ||
660 | + ) | ||
592 | # unfinished_sents is set to zero if eos in sentence | 661 | # unfinished_sents is set to zero if eos in sentence |
593 | unfinished_sents.mul_((~eos_in_sents).long()) | 662 | unfinished_sents.mul_((~eos_in_sents).long()) |
594 | 663 | ||
... | @@ -599,7 +668,11 @@ class GenerationMixin: | ... | @@ -599,7 +668,11 @@ class GenerationMixin: |
599 | # extend attention_mask for new generated input if only decoder | 668 | # extend attention_mask for new generated input if only decoder |
600 | if self.config.is_encoder_decoder is False: | 669 | if self.config.is_encoder_decoder is False: |
601 | attention_mask = torch.cat( | 670 | attention_mask = torch.cat( |
602 | - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 | 671 | + [ |
672 | + attention_mask, | ||
673 | + attention_mask.new_ones((attention_mask.shape[0], 1)), | ||
674 | + ], | ||
675 | + dim=-1, | ||
603 | ) | 676 | ) |
604 | 677 | ||
605 | return input_ids | 678 | return input_ids |
... | @@ -633,12 +706,16 @@ class GenerationMixin: | ... | @@ -633,12 +706,16 @@ class GenerationMixin: |
633 | 706 | ||
634 | # generated hypotheses | 707 | # generated hypotheses |
635 | generated_hyps = [ | 708 | generated_hyps = [ |
636 | - BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping) | 709 | + BeamHypotheses( |
710 | + num_beams, max_length, length_penalty, early_stopping=early_stopping | ||
711 | + ) | ||
637 | for _ in range(batch_size) | 712 | for _ in range(batch_size) |
638 | ] | 713 | ] |
639 | 714 | ||
640 | # scores for each sentence in the beam | 715 | # scores for each sentence in the beam |
641 | - beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) | 716 | + beam_scores = torch.zeros( |
717 | + (batch_size, num_beams), dtype=torch.float, device=input_ids.device | ||
718 | + ) | ||
642 | 719 | ||
643 | # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times | 720 | # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times |
644 | if do_sample is False: | 721 | if do_sample is False: |
... | @@ -653,10 +730,18 @@ class GenerationMixin: | ... | @@ -653,10 +730,18 @@ class GenerationMixin: |
653 | 730 | ||
654 | while cur_len < max_length: | 731 | while cur_len < max_length: |
655 | model_inputs = self.prepare_inputs_for_generation( | 732 | model_inputs = self.prepare_inputs_for_generation( |
656 | - input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs | 733 | + input_ids, |
734 | + past=past, | ||
735 | + attention_mask=attention_mask, | ||
736 | + use_cache=use_cache, | ||
737 | + **model_kwargs, | ||
657 | ) | 738 | ) |
658 | - outputs = self(**model_inputs, return_dict=True) # (batch_size * num_beams, cur_len, vocab_size) | 739 | + outputs = self( |
659 | - next_token_logits = outputs.logits[:, -1, :] # (batch_size * num_beams, vocab_size) | 740 | + **model_inputs, return_dict=True |
741 | + ) # (batch_size * num_beams, cur_len, vocab_size) | ||
742 | + next_token_logits = outputs.logits[ | ||
743 | + :, -1, : | ||
744 | + ] # (batch_size * num_beams, vocab_size) | ||
660 | 745 | ||
661 | # if model has past, then set the past variable to speed up decoding | 746 | # if model has past, then set the past variable to speed up decoding |
662 | if "past_key_values" in outputs: | 747 | if "past_key_values" in outputs: |
... | @@ -670,7 +755,9 @@ class GenerationMixin: | ... | @@ -670,7 +755,9 @@ class GenerationMixin: |
670 | next_token_logits, cur_len=cur_len, max_length=max_length | 755 | next_token_logits, cur_len=cur_len, max_length=max_length |
671 | ) | 756 | ) |
672 | 757 | ||
673 | - scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) | 758 | + scores = F.log_softmax( |
759 | + next_token_logits, dim=-1 | ||
760 | + ) # (batch_size * num_beams, vocab_size) | ||
674 | 761 | ||
675 | scores = self.postprocess_next_token_scores( | 762 | scores = self.postprocess_next_token_scores( |
676 | scores=scores, | 763 | scores=scores, |
... | @@ -686,12 +773,17 @@ class GenerationMixin: | ... | @@ -686,12 +773,17 @@ class GenerationMixin: |
686 | num_beams=num_beams, | 773 | num_beams=num_beams, |
687 | ) | 774 | ) |
688 | 775 | ||
689 | - assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format( | 776 | + assert scores.shape == ( |
777 | + batch_size * num_beams, | ||
778 | + vocab_size, | ||
779 | + ), "Shapes of scores: {} != {}".format( | ||
690 | scores.shape, (batch_size * num_beams, vocab_size) | 780 | scores.shape, (batch_size * num_beams, vocab_size) |
691 | ) | 781 | ) |
692 | 782 | ||
693 | if do_sample: | 783 | if do_sample: |
694 | - _scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size) | 784 | + _scores = scores + beam_scores[:, None].expand_as( |
785 | + scores | ||
786 | + ) # (batch_size * num_beams, vocab_size) | ||
695 | # Temperature | 787 | # Temperature |
696 | if temperature != 1.0: | 788 | if temperature != 1.0: |
697 | _scores = _scores / temperature | 789 | _scores = _scores / temperature |
... | @@ -706,24 +798,38 @@ class GenerationMixin: | ... | @@ -706,24 +798,38 @@ class GenerationMixin: |
706 | 798 | ||
707 | # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search) | 799 | # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search) |
708 | probs = F.softmax(_scores, dim=-1) | 800 | probs = F.softmax(_scores, dim=-1) |
709 | - next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) # (batch_size, num_beams * 2) | 801 | + next_tokens = torch.multinomial( |
802 | + probs, num_samples=2 * num_beams | ||
803 | + ) # (batch_size, num_beams * 2) | ||
710 | # Compute next scores | 804 | # Compute next scores |
711 | - next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2) | 805 | + next_scores = torch.gather( |
806 | + _scores, -1, next_tokens | ||
807 | + ) # (batch_size, num_beams * 2) | ||
712 | # sort the sampled vector to make sure that the first num_beams samples are the best | 808 | # sort the sampled vector to make sure that the first num_beams samples are the best |
713 | - next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1) | 809 | + next_scores, next_scores_indices = torch.sort( |
714 | - next_tokens = torch.gather(next_tokens, -1, next_scores_indices) # (batch_size, num_beams * 2) | 810 | + next_scores, descending=True, dim=1 |
811 | + ) | ||
812 | + next_tokens = torch.gather( | ||
813 | + next_tokens, -1, next_scores_indices | ||
814 | + ) # (batch_size, num_beams * 2) | ||
715 | 815 | ||
716 | else: | 816 | else: |
717 | - next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size) | 817 | + next_scores = scores + beam_scores[:, None].expand_as( |
818 | + scores | ||
819 | + ) # (batch_size * num_beams, vocab_size) | ||
718 | 820 | ||
719 | # re-organize to group the beam together (we are keeping top hypothesis accross beams) | 821 | # re-organize to group the beam together (we are keeping top hypothesis accross beams) |
720 | next_scores = next_scores.view( | 822 | next_scores = next_scores.view( |
721 | batch_size, num_beams * vocab_size | 823 | batch_size, num_beams * vocab_size |
722 | ) # (batch_size, num_beams * vocab_size) | 824 | ) # (batch_size, num_beams * vocab_size) |
723 | 825 | ||
724 | - next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True) | 826 | + next_scores, next_tokens = torch.topk( |
827 | + next_scores, 2 * num_beams, dim=1, largest=True, sorted=True | ||
828 | + ) | ||
725 | 829 | ||
726 | - assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams) | 830 | + assert ( |
831 | + next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams) | ||
832 | + ) | ||
727 | 833 | ||
728 | # next batch beam content | 834 | # next batch beam content |
729 | next_batch_beam = [] | 835 | next_batch_beam = [] |
... | @@ -735,11 +841,15 @@ class GenerationMixin: | ... | @@ -735,11 +841,15 @@ class GenerationMixin: |
735 | if done[batch_idx]: | 841 | if done[batch_idx]: |
736 | assert ( | 842 | assert ( |
737 | len(generated_hyps[batch_idx]) >= num_beams | 843 | len(generated_hyps[batch_idx]) >= num_beams |
738 | - ), "Batch can only be done if at least {} beams have been generated".format(num_beams) | 844 | + ), "Batch can only be done if at least {} beams have been generated".format( |
845 | + num_beams | ||
846 | + ) | ||
739 | assert ( | 847 | assert ( |
740 | eos_token_id is not None and pad_token_id is not None | 848 | eos_token_id is not None and pad_token_id is not None |
741 | ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined" | 849 | ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined" |
742 | - next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch | 850 | + next_batch_beam.extend( |
851 | + [(0, pad_token_id, 0)] * num_beams | ||
852 | + ) # pad the batch | ||
743 | continue | 853 | continue |
744 | 854 | ||
745 | # next sentence beam content, this will get added to next_batch_beam | 855 | # next sentence beam content, this will get added to next_batch_beam |
... | @@ -757,7 +867,9 @@ class GenerationMixin: | ... | @@ -757,7 +867,9 @@ class GenerationMixin: |
757 | # add to generated hypotheses if end of sentence | 867 | # add to generated hypotheses if end of sentence |
758 | if (eos_token_id is not None) and (token_id.item() == eos_token_id): | 868 | if (eos_token_id is not None) and (token_id.item() == eos_token_id): |
759 | # if beam_token does not belong to top num_beams tokens, it should not be added | 869 | # if beam_token does not belong to top num_beams tokens, it should not be added |
760 | - is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams | 870 | + is_beam_token_worse_than_top_num_beams = ( |
871 | + beam_token_rank >= num_beams | ||
872 | + ) | ||
761 | if is_beam_token_worse_than_top_num_beams: | 873 | if is_beam_token_worse_than_top_num_beams: |
762 | continue | 874 | continue |
763 | generated_hyps[batch_idx].add( | 875 | generated_hyps[batch_idx].add( |
... | @@ -766,7 +878,9 @@ class GenerationMixin: | ... | @@ -766,7 +878,9 @@ class GenerationMixin: |
766 | ) | 878 | ) |
767 | else: | 879 | else: |
768 | # add next predicted token since it is not eos_token | 880 | # add next predicted token since it is not eos_token |
769 | - next_sent_beam.append((beam_token_score, token_id, effective_beam_id)) | 881 | + next_sent_beam.append( |
882 | + (beam_token_score, token_id, effective_beam_id) | ||
883 | + ) | ||
770 | 884 | ||
771 | # once the beam for next step is full, don't add more tokens to it. | 885 | # once the beam for next step is full, don't add more tokens to it. |
772 | if len(next_sent_beam) == num_beams: | 886 | if len(next_sent_beam) == num_beams: |
... | @@ -780,7 +894,9 @@ class GenerationMixin: | ... | @@ -780,7 +894,9 @@ class GenerationMixin: |
780 | # update next beam content | 894 | # update next beam content |
781 | assert len(next_sent_beam) == num_beams, "Beam should always be full" | 895 | assert len(next_sent_beam) == num_beams, "Beam should always be full" |
782 | next_batch_beam.extend(next_sent_beam) | 896 | next_batch_beam.extend(next_sent_beam) |
783 | - assert len(next_batch_beam) == num_beams * (batch_idx + 1), "We should have added num_beams each step" | 897 | + assert len(next_batch_beam) == num_beams * ( |
898 | + batch_idx + 1 | ||
899 | + ), "We should have added num_beams each step" | ||
784 | 900 | ||
785 | # stop when we are done with each sentence | 901 | # stop when we are done with each sentence |
786 | if all(done): | 902 | if all(done): |
... | @@ -804,7 +920,11 @@ class GenerationMixin: | ... | @@ -804,7 +920,11 @@ class GenerationMixin: |
804 | # extend attention_mask for new generated input if only decoder | 920 | # extend attention_mask for new generated input if only decoder |
805 | if self.config.is_encoder_decoder is False: | 921 | if self.config.is_encoder_decoder is False: |
806 | attention_mask = torch.cat( | 922 | attention_mask = torch.cat( |
807 | - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 | 923 | + [ |
924 | + attention_mask, | ||
925 | + attention_mask.new_ones((attention_mask.shape[0], 1)), | ||
926 | + ], | ||
927 | + dim=-1, | ||
808 | ) | 928 | ) |
809 | 929 | ||
810 | # finalize all open beam hypotheses and add to generated hypotheses | 930 | # finalize all open beam hypotheses and add to generated hypotheses |
... | @@ -814,10 +934,12 @@ class GenerationMixin: | ... | @@ -814,10 +934,12 @@ class GenerationMixin: |
814 | 934 | ||
815 | # test that beam scores match previously calculated scores if not eos and batch_idx not done | 935 | # test that beam scores match previously calculated scores if not eos and batch_idx not done |
816 | if eos_token_id is not None and all( | 936 | if eos_token_id is not None and all( |
817 | - (token_id % vocab_size).item() != eos_token_id for token_id in next_tokens[batch_idx] | 937 | + (token_id % vocab_size).item() != eos_token_id |
938 | + for token_id in next_tokens[batch_idx] | ||
818 | ): | 939 | ): |
819 | assert torch.all( | 940 | assert torch.all( |
820 | - next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx] | 941 | + next_scores[batch_idx, :num_beams] |
942 | + == beam_scores.view(batch_size, num_beams)[batch_idx] | ||
821 | ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format( | 943 | ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format( |
822 | next_scores[:, :num_beams][batch_idx], | 944 | next_scores[:, :num_beams][batch_idx], |
823 | beam_scores.view(batch_size, num_beams)[batch_idx], | 945 | beam_scores.view(batch_size, num_beams)[batch_idx], |
... | @@ -831,7 +953,9 @@ class GenerationMixin: | ... | @@ -831,7 +953,9 @@ class GenerationMixin: |
831 | generated_hyps[batch_idx].add(final_tokens, final_score) | 953 | generated_hyps[batch_idx].add(final_tokens, final_score) |
832 | 954 | ||
833 | # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch | 955 | # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch |
834 | - output_batch_size = batch_size if do_sample else batch_size * num_return_sequences | 956 | + output_batch_size = ( |
957 | + batch_size if do_sample else batch_size * num_return_sequences | ||
958 | + ) | ||
835 | output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences | 959 | output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences |
836 | 960 | ||
837 | # select the best hypotheses | 961 | # select the best hypotheses |
... | @@ -861,7 +985,9 @@ class GenerationMixin: | ... | @@ -861,7 +985,9 @@ class GenerationMixin: |
861 | else: | 985 | else: |
862 | # none of the hypotheses have an eos_token | 986 | # none of the hypotheses have an eos_token |
863 | assert (len(hypo) == max_length for hypo in best) | 987 | assert (len(hypo) == max_length for hypo in best) |
864 | - decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device) | 988 | + decoded = ( |
989 | + torch.stack(best).type(torch.long).to(next(self.parameters()).device) | ||
990 | + ) | ||
865 | 991 | ||
866 | return decoded | 992 | return decoded |
867 | 993 | ||
... | @@ -870,7 +996,9 @@ class GenerationMixin: | ... | @@ -870,7 +996,9 @@ class GenerationMixin: |
870 | return tuple(layer_past.index_select(1, beam_idx) for layer_past in past) | 996 | return tuple(layer_past.index_select(1, beam_idx) for layer_past in past) |
871 | 997 | ||
872 | 998 | ||
873 | -def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int) -> None: | 999 | +def calc_banned_ngram_tokens( |
1000 | + prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int | ||
1001 | +) -> None: | ||
874 | """Copied from fairseq for no_repeat_ngram in beam_search""" | 1002 | """Copied from fairseq for no_repeat_ngram in beam_search""" |
875 | if cur_len + 1 < no_repeat_ngram_size: | 1003 | if cur_len + 1 < no_repeat_ngram_size: |
876 | # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet | 1004 | # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet |
... | @@ -881,7 +1009,9 @@ def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_n | ... | @@ -881,7 +1009,9 @@ def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_n |
881 | generated_ngram = generated_ngrams[idx] | 1009 | generated_ngram = generated_ngrams[idx] |
882 | for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]): | 1010 | for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]): |
883 | prev_ngram_tuple = tuple(ngram[:-1]) | 1011 | prev_ngram_tuple = tuple(ngram[:-1]) |
884 | - generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]] | 1012 | + generated_ngram[prev_ngram_tuple] = generated_ngram.get( |
1013 | + prev_ngram_tuple, [] | ||
1014 | + ) + [ngram[-1]] | ||
885 | 1015 | ||
886 | def _get_generated_ngrams(hypo_idx): | 1016 | def _get_generated_ngrams(hypo_idx): |
887 | # Before decoding the next token, prevent decoding of ngrams that have already appeared | 1017 | # Before decoding the next token, prevent decoding of ngrams that have already appeared |
... | @@ -893,7 +1023,9 @@ def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_n | ... | @@ -893,7 +1023,9 @@ def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_n |
893 | return banned_tokens | 1023 | return banned_tokens |
894 | 1024 | ||
895 | 1025 | ||
896 | -def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]) -> Iterable[int]: | 1026 | +def calc_banned_bad_words_ids( |
1027 | + prev_input_ids: Iterable[int], bad_words_ids: Iterable[int] | ||
1028 | +) -> Iterable[int]: | ||
897 | banned_tokens = [] | 1029 | banned_tokens = [] |
898 | 1030 | ||
899 | def _tokens_match(prev_tokens, tokens): | 1031 | def _tokens_match(prev_tokens, tokens): |
... | @@ -914,7 +1046,9 @@ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iter | ... | @@ -914,7 +1046,9 @@ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iter |
914 | banned_tokens_slice = [] | 1046 | banned_tokens_slice = [] |
915 | 1047 | ||
916 | for banned_token_seq in bad_words_ids: | 1048 | for banned_token_seq in bad_words_ids: |
917 | - assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format( | 1049 | + assert ( |
1050 | + len(banned_token_seq) > 0 | ||
1051 | + ), "Banned words token sequences {} cannot have an empty list".format( | ||
918 | bad_words_ids | 1052 | bad_words_ids |
919 | ) | 1053 | ) |
920 | 1054 | ||
... | @@ -929,7 +1063,9 @@ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iter | ... | @@ -929,7 +1063,9 @@ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iter |
929 | return banned_tokens | 1063 | return banned_tokens |
930 | 1064 | ||
931 | 1065 | ||
932 | -def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: List[List[int]]) -> None: | 1066 | +def set_scores_to_inf_for_banned_tokens( |
1067 | + scores: torch.Tensor, banned_tokens: List[List[int]] | ||
1068 | +) -> None: | ||
933 | """Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be | 1069 | """Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be |
934 | a list of list of banned tokens to ban in the format [[batch index, vocabulary position],...] | 1070 | a list of list of banned tokens to ban in the format [[batch index, vocabulary position],...] |
935 | Args: | 1071 | Args: |
... | @@ -949,7 +1085,12 @@ def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: Lis | ... | @@ -949,7 +1085,12 @@ def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: Lis |
949 | # [ 0 0 0 ] | 1085 | # [ 0 0 0 ] |
950 | # [ 1 0 0 ] | 1086 | # [ 1 0 0 ] |
951 | 1087 | ||
952 | - banned_mask = torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool() | 1088 | + banned_mask = ( |
1089 | + torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()) | ||
1090 | + .to(scores.device) | ||
1091 | + .to_dense() | ||
1092 | + .bool() | ||
1093 | + ) | ||
953 | scores.masked_fill_(banned_mask, -float("inf")) | 1094 | scores.masked_fill_(banned_mask, -float("inf")) |
954 | 1095 | ||
955 | 1096 | ||
... | @@ -989,7 +1130,9 @@ def top_k_top_p_filtering( | ... | @@ -989,7 +1130,9 @@ def top_k_top_p_filtering( |
989 | sorted_indices_to_remove[..., 0] = 0 | 1130 | sorted_indices_to_remove[..., 0] = 0 |
990 | 1131 | ||
991 | # scatter sorted tensors to original indexing | 1132 | # scatter sorted tensors to original indexing |
992 | - indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) | 1133 | + indices_to_remove = sorted_indices_to_remove.scatter( |
1134 | + 1, sorted_indices, sorted_indices_to_remove | ||
1135 | + ) | ||
993 | logits[indices_to_remove] = filter_value | 1136 | logits[indices_to_remove] = filter_value |
994 | return logits | 1137 | return logits |
995 | 1138 | ||
... | @@ -1020,7 +1163,9 @@ class BeamHypotheses(object): | ... | @@ -1020,7 +1163,9 @@ class BeamHypotheses(object): |
1020 | if len(self) < self.num_beams or score > self.worst_score: | 1163 | if len(self) < self.num_beams or score > self.worst_score: |
1021 | self.beams.append((score, hyp)) | 1164 | self.beams.append((score, hyp)) |
1022 | if len(self) > self.num_beams: | 1165 | if len(self) > self.num_beams: |
1023 | - sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)]) | 1166 | + sorted_scores = sorted( |
1167 | + [(s, idx) for idx, (s, _) in enumerate(self.beams)] | ||
1168 | + ) | ||
1024 | del self.beams[sorted_scores[0][1]] | 1169 | del self.beams[sorted_scores[0][1]] |
1025 | self.worst_score = sorted_scores[1][0] | 1170 | self.worst_score = sorted_scores[1][0] |
1026 | else: | 1171 | else: | ... | ... |
... | @@ -69,7 +69,7 @@ class BaseTransformer(pl.LightningModule): | ... | @@ -69,7 +69,7 @@ class BaseTransformer(pl.LightningModule): |
69 | config=None, | 69 | config=None, |
70 | tokenizer=None, | 70 | tokenizer=None, |
71 | model=None, | 71 | model=None, |
72 | - **config_kwargs | 72 | + **config_kwargs, |
73 | ): | 73 | ): |
74 | """Initialize a model, tokenizer and config.""" | 74 | """Initialize a model, tokenizer and config.""" |
75 | super().__init__() | 75 | super().__init__() |
... | @@ -83,7 +83,9 @@ class BaseTransformer(pl.LightningModule): | ... | @@ -83,7 +83,9 @@ class BaseTransformer(pl.LightningModule): |
83 | cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None | 83 | cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None |
84 | if config is None: | 84 | if config is None: |
85 | self.config = AutoConfig.from_pretrained( | 85 | self.config = AutoConfig.from_pretrained( |
86 | - self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path, | 86 | + self.hparams.config_name |
87 | + if self.hparams.config_name | ||
88 | + else self.hparams.model_name_or_path, | ||
87 | **({"num_labels": num_labels} if num_labels is not None else {}), | 89 | **({"num_labels": num_labels} if num_labels is not None else {}), |
88 | cache_dir=cache_dir, | 90 | cache_dir=cache_dir, |
89 | **config_kwargs, | 91 | **config_kwargs, |
... | @@ -91,15 +93,24 @@ class BaseTransformer(pl.LightningModule): | ... | @@ -91,15 +93,24 @@ class BaseTransformer(pl.LightningModule): |
91 | else: | 93 | else: |
92 | self.config: PretrainedConfig = config | 94 | self.config: PretrainedConfig = config |
93 | 95 | ||
94 | - extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout") | 96 | + extra_model_params = ( |
97 | + "encoder_layerdrop", | ||
98 | + "decoder_layerdrop", | ||
99 | + "dropout", | ||
100 | + "attention_dropout", | ||
101 | + ) | ||
95 | for p in extra_model_params: | 102 | for p in extra_model_params: |
96 | if getattr(self.hparams, p, None): | 103 | if getattr(self.hparams, p, None): |
97 | - assert hasattr(self.config, p), f"model config doesn't have a `{p}` attribute" | 104 | + assert hasattr( |
105 | + self.config, p | ||
106 | + ), f"model config doesn't have a `{p}` attribute" | ||
98 | setattr(self.config, p, getattr(self.hparams, p)) | 107 | setattr(self.config, p, getattr(self.hparams, p)) |
99 | 108 | ||
100 | if tokenizer is None: | 109 | if tokenizer is None: |
101 | self.tokenizer = AutoTokenizer.from_pretrained( | 110 | self.tokenizer = AutoTokenizer.from_pretrained( |
102 | - self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path, | 111 | + self.hparams.tokenizer_name |
112 | + if self.hparams.tokenizer_name | ||
113 | + else self.hparams.model_name_or_path, | ||
103 | cache_dir=cache_dir, | 114 | cache_dir=cache_dir, |
104 | ) | 115 | ) |
105 | else: | 116 | else: |
... | @@ -121,7 +132,9 @@ class BaseTransformer(pl.LightningModule): | ... | @@ -121,7 +132,9 @@ class BaseTransformer(pl.LightningModule): |
121 | def get_lr_scheduler(self): | 132 | def get_lr_scheduler(self): |
122 | get_schedule_func = arg_to_scheduler[self.hparams.lr_scheduler] | 133 | get_schedule_func = arg_to_scheduler[self.hparams.lr_scheduler] |
123 | scheduler = get_schedule_func( | 134 | scheduler = get_schedule_func( |
124 | - self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps | 135 | + self.opt, |
136 | + num_warmup_steps=self.hparams.warmup_steps, | ||
137 | + num_training_steps=self.total_steps, | ||
125 | ) | 138 | ) |
126 | scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} | 139 | scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} |
127 | return scheduler | 140 | return scheduler |
... | @@ -132,22 +145,35 @@ class BaseTransformer(pl.LightningModule): | ... | @@ -132,22 +145,35 @@ class BaseTransformer(pl.LightningModule): |
132 | no_decay = ["bias", "LayerNorm.weight"] | 145 | no_decay = ["bias", "LayerNorm.weight"] |
133 | optimizer_grouped_parameters = [ | 146 | optimizer_grouped_parameters = [ |
134 | { | 147 | { |
135 | - "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], | 148 | + "params": [ |
149 | + p | ||
150 | + for n, p in model.named_parameters() | ||
151 | + if not any(nd in n for nd in no_decay) | ||
152 | + ], | ||
136 | "weight_decay": self.hparams.weight_decay, | 153 | "weight_decay": self.hparams.weight_decay, |
137 | }, | 154 | }, |
138 | { | 155 | { |
139 | - "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], | 156 | + "params": [ |
157 | + p | ||
158 | + for n, p in model.named_parameters() | ||
159 | + if any(nd in n for nd in no_decay) | ||
160 | + ], | ||
140 | "weight_decay": 0.0, | 161 | "weight_decay": 0.0, |
141 | }, | 162 | }, |
142 | ] | 163 | ] |
143 | if self.hparams.adafactor: | 164 | if self.hparams.adafactor: |
144 | optimizer = Adafactor( | 165 | optimizer = Adafactor( |
145 | - optimizer_grouped_parameters, lr=self.hparams.learning_rate, scale_parameter=False, relative_step=False | 166 | + optimizer_grouped_parameters, |
167 | + lr=self.hparams.learning_rate, | ||
168 | + scale_parameter=False, | ||
169 | + relative_step=False, | ||
146 | ) | 170 | ) |
147 | 171 | ||
148 | else: | 172 | else: |
149 | optimizer = AdamW( | 173 | optimizer = AdamW( |
150 | - optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon | 174 | + optimizer_grouped_parameters, |
175 | + lr=self.hparams.learning_rate, | ||
176 | + eps=self.hparams.adam_epsilon, | ||
151 | ) | 177 | ) |
152 | self.opt = optimizer | 178 | self.opt = optimizer |
153 | 179 | ||
... | @@ -165,13 +191,19 @@ class BaseTransformer(pl.LightningModule): | ... | @@ -165,13 +191,19 @@ class BaseTransformer(pl.LightningModule): |
165 | def total_steps(self) -> int: | 191 | def total_steps(self) -> int: |
166 | """The number of total training steps that will be run. Used for lr scheduler purposes.""" | 192 | """The number of total training steps that will be run. Used for lr scheduler purposes.""" |
167 | num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores | 193 | num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores |
168 | - effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices | 194 | + effective_batch_size = ( |
195 | + self.hparams.train_batch_size | ||
196 | + * self.hparams.accumulate_grad_batches | ||
197 | + * num_devices | ||
198 | + ) | ||
169 | dataset_size = len(self.train_loader.dataset) | 199 | dataset_size = len(self.train_loader.dataset) |
170 | return (dataset_size / effective_batch_size) * self.hparams.max_epochs | 200 | return (dataset_size / effective_batch_size) * self.hparams.max_epochs |
171 | 201 | ||
172 | def setup(self, mode): | 202 | def setup(self, mode): |
173 | if mode == "fit": | 203 | if mode == "fit": |
174 | - self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True) | 204 | + self.train_loader = self.get_dataloader( |
205 | + "train", self.hparams.train_batch_size, shuffle=True | ||
206 | + ) | ||
175 | 207 | ||
176 | def get_dataloader(self, type_path, batch_size, shuffle=False): | 208 | def get_dataloader(self, type_path, batch_size, shuffle=False): |
177 | raise NotImplementedError("You must implement this for your task") | 209 | raise NotImplementedError("You must implement this for your task") |
... | @@ -212,7 +244,10 @@ class BaseTransformer(pl.LightningModule): | ... | @@ -212,7 +244,10 @@ class BaseTransformer(pl.LightningModule): |
212 | help="Path to pretrained model or model identifier from huggingface.co/models", | 244 | help="Path to pretrained model or model identifier from huggingface.co/models", |
213 | ) | 245 | ) |
214 | parser.add_argument( | 246 | parser.add_argument( |
215 | - "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" | 247 | + "--config_name", |
248 | + default="", | ||
249 | + type=str, | ||
250 | + help="Pretrained config name or path if not the same as model_name", | ||
216 | ) | 251 | ) |
217 | parser.add_argument( | 252 | parser.add_argument( |
218 | "--tokenizer_name", | 253 | "--tokenizer_name", |
... | @@ -246,7 +281,12 @@ class BaseTransformer(pl.LightningModule): | ... | @@ -246,7 +281,12 @@ class BaseTransformer(pl.LightningModule): |
246 | type=float, | 281 | type=float, |
247 | help="Attention dropout probability (Optional). Goes into model.config", | 282 | help="Attention dropout probability (Optional). Goes into model.config", |
248 | ) | 283 | ) |
249 | - parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") | 284 | + parser.add_argument( |
285 | + "--learning_rate", | ||
286 | + default=5e-5, | ||
287 | + type=float, | ||
288 | + help="The initial learning rate for Adam.", | ||
289 | + ) | ||
250 | parser.add_argument( | 290 | parser.add_argument( |
251 | "--lr_scheduler", | 291 | "--lr_scheduler", |
252 | default="linear", | 292 | default="linear", |
... | @@ -255,11 +295,30 @@ class BaseTransformer(pl.LightningModule): | ... | @@ -255,11 +295,30 @@ class BaseTransformer(pl.LightningModule): |
255 | type=str, | 295 | type=str, |
256 | help="Learning rate scheduler", | 296 | help="Learning rate scheduler", |
257 | ) | 297 | ) |
258 | - parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") | 298 | + parser.add_argument( |
259 | - parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") | 299 | + "--weight_decay", |
260 | - parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") | 300 | + default=0.0, |
261 | - parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader") | 301 | + type=float, |
262 | - parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int) | 302 | + help="Weight decay if we apply some.", |
303 | + ) | ||
304 | + parser.add_argument( | ||
305 | + "--adam_epsilon", | ||
306 | + default=1e-8, | ||
307 | + type=float, | ||
308 | + help="Epsilon for Adam optimizer.", | ||
309 | + ) | ||
310 | + parser.add_argument( | ||
311 | + "--warmup_steps", | ||
312 | + default=0, | ||
313 | + type=int, | ||
314 | + help="Linear warmup over warmup_steps.", | ||
315 | + ) | ||
316 | + parser.add_argument( | ||
317 | + "--num_workers", default=4, type=int, help="kwarg passed to DataLoader" | ||
318 | + ) | ||
319 | + parser.add_argument( | ||
320 | + "--num_train_epochs", dest="max_epochs", default=3, type=int | ||
321 | + ) | ||
263 | parser.add_argument("--train_batch_size", default=32, type=int) | 322 | parser.add_argument("--train_batch_size", default=32, type=int) |
264 | parser.add_argument("--eval_batch_size", default=32, type=int) | 323 | parser.add_argument("--eval_batch_size", default=32, type=int) |
265 | parser.add_argument("--adafactor", action="store_true") | 324 | parser.add_argument("--adafactor", action="store_true") |
... | @@ -283,7 +342,9 @@ class LoggingCallback(pl.Callback): | ... | @@ -283,7 +342,9 @@ class LoggingCallback(pl.Callback): |
283 | rank_zero_info("***** Test results *****") | 342 | rank_zero_info("***** Test results *****") |
284 | metrics = trainer.callback_metrics | 343 | metrics = trainer.callback_metrics |
285 | # Log and save results to file | 344 | # Log and save results to file |
286 | - output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt") | 345 | + output_test_results_file = os.path.join( |
346 | + pl_module.hparams.output_dir, "test_results.txt" | ||
347 | + ) | ||
287 | with open(output_test_results_file, "w") as writer: | 348 | with open(output_test_results_file, "w") as writer: |
288 | for key in sorted(metrics): | 349 | for key in sorted(metrics): |
289 | if key not in ["log", "progress_bar"]: | 350 | if key not in ["log", "progress_bar"]: |
... | @@ -314,9 +375,21 @@ def add_generic_args(parser, root_dir) -> None: | ... | @@ -314,9 +375,21 @@ def add_generic_args(parser, root_dir) -> None: |
314 | "See details at https://nvidia.github.io/apex/amp.html", | 375 | "See details at https://nvidia.github.io/apex/amp.html", |
315 | ) | 376 | ) |
316 | parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int) | 377 | parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int) |
317 | - parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm") | 378 | + parser.add_argument( |
318 | - parser.add_argument("--do_train", action="store_true", help="Whether to run training.") | 379 | + "--max_grad_norm", |
319 | - parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.") | 380 | + dest="gradient_clip_val", |
381 | + default=1.0, | ||
382 | + type=float, | ||
383 | + help="Max gradient norm", | ||
384 | + ) | ||
385 | + parser.add_argument( | ||
386 | + "--do_train", action="store_true", help="Whether to run training." | ||
387 | + ) | ||
388 | + parser.add_argument( | ||
389 | + "--do_predict", | ||
390 | + action="store_true", | ||
391 | + help="Whether to run predictions on the test set.", | ||
392 | + ) | ||
320 | parser.add_argument( | 393 | parser.add_argument( |
321 | "--gradient_accumulation_steps", | 394 | "--gradient_accumulation_steps", |
322 | dest="accumulate_grad_batches", | 395 | dest="accumulate_grad_batches", |
... | @@ -324,7 +397,9 @@ def add_generic_args(parser, root_dir) -> None: | ... | @@ -324,7 +397,9 @@ def add_generic_args(parser, root_dir) -> None: |
324 | default=1, | 397 | default=1, |
325 | help="Number of updates steps to accumulate before performing a backward/update pass.", | 398 | help="Number of updates steps to accumulate before performing a backward/update pass.", |
326 | ) | 399 | ) |
327 | - parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") | 400 | + parser.add_argument( |
401 | + "--seed", type=int, default=42, help="random seed for initialization" | ||
402 | + ) | ||
328 | 403 | ||
329 | 404 | ||
330 | def generic_train( | 405 | def generic_train( |
... | @@ -335,7 +410,7 @@ def generic_train( | ... | @@ -335,7 +410,7 @@ def generic_train( |
335 | extra_callbacks=[], | 410 | extra_callbacks=[], |
336 | checkpoint_callback=None, | 411 | checkpoint_callback=None, |
337 | logging_callback=None, | 412 | logging_callback=None, |
338 | - **extra_train_kwargs | 413 | + **extra_train_kwargs, |
339 | ): | 414 | ): |
340 | pl.seed_everything(args.seed) | 415 | pl.seed_everything(args.seed) |
341 | 416 | ||
... | @@ -346,7 +421,11 @@ def generic_train( | ... | @@ -346,7 +421,11 @@ def generic_train( |
346 | # add custom checkpoints | 421 | # add custom checkpoints |
347 | if checkpoint_callback is None: | 422 | if checkpoint_callback is None: |
348 | checkpoint_callback = pl.callbacks.ModelCheckpoint( | 423 | checkpoint_callback = pl.callbacks.ModelCheckpoint( |
349 | - filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1 | 424 | + filepath=args.output_dir, |
425 | + prefix="checkpoint", | ||
426 | + monitor="val_loss", | ||
427 | + mode="min", | ||
428 | + save_top_k=1, | ||
350 | ) | 429 | ) |
351 | if logging_callback is None: | 430 | if logging_callback is None: |
352 | logging_callback = LoggingCallback() | 431 | logging_callback = LoggingCallback() | ... | ... |
... | @@ -141,7 +141,11 @@ def invert_mask(attention_mask): | ... | @@ -141,7 +141,11 @@ def invert_mask(attention_mask): |
141 | 141 | ||
142 | 142 | ||
143 | def _prepare_bart_decoder_inputs( | 143 | def _prepare_bart_decoder_inputs( |
144 | - config, input_ids, decoder_input_ids=None, decoder_padding_mask=None, causal_mask_dtype=torch.float32 | 144 | + config, |
145 | + input_ids, | ||
146 | + decoder_input_ids=None, | ||
147 | + decoder_padding_mask=None, | ||
148 | + causal_mask_dtype=torch.float32, | ||
145 | ): | 149 | ): |
146 | """Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if | 150 | """Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if |
147 | none are provided. This mimics the default behavior in fairseq. To override it pass in masks. | 151 | none are provided. This mimics the default behavior in fairseq. To override it pass in masks. |
... | @@ -184,7 +188,9 @@ class PretrainedBartModel(PreTrainedModel): | ... | @@ -184,7 +188,9 @@ class PretrainedBartModel(PreTrainedModel): |
184 | @property | 188 | @property |
185 | def dummy_inputs(self): | 189 | def dummy_inputs(self): |
186 | pad_token = self.config.pad_token_id | 190 | pad_token = self.config.pad_token_id |
187 | - input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) | 191 | + input_ids = torch.tensor( |
192 | + [[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device | ||
193 | + ) | ||
188 | dummy_inputs = { | 194 | dummy_inputs = { |
189 | "attention_mask": input_ids.ne(pad_token), | 195 | "attention_mask": input_ids.ne(pad_token), |
190 | "input_ids": input_ids, | 196 | "input_ids": input_ids, |
... | @@ -229,7 +235,11 @@ class EncoderLayer(nn.Module): | ... | @@ -229,7 +235,11 @@ class EncoderLayer(nn.Module): |
229 | def __init__(self, config: BartConfig): | 235 | def __init__(self, config: BartConfig): |
230 | super().__init__() | 236 | super().__init__() |
231 | self.embed_dim = config.d_model | 237 | self.embed_dim = config.d_model |
232 | - self.self_attn = Attention(self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout) | 238 | + self.self_attn = Attention( |
239 | + self.embed_dim, | ||
240 | + config.encoder_attention_heads, | ||
241 | + dropout=config.attention_dropout, | ||
242 | + ) | ||
233 | self.normalize_before = config.normalize_before | 243 | self.normalize_before = config.normalize_before |
234 | self.self_attn_layer_norm = LayerNorm(self.embed_dim) | 244 | self.self_attn_layer_norm = LayerNorm(self.embed_dim) |
235 | self.dropout = config.dropout | 245 | self.dropout = config.dropout |
... | @@ -255,7 +265,10 @@ class EncoderLayer(nn.Module): | ... | @@ -255,7 +265,10 @@ class EncoderLayer(nn.Module): |
255 | if self.normalize_before: | 265 | if self.normalize_before: |
256 | x = self.self_attn_layer_norm(x) | 266 | x = self.self_attn_layer_norm(x) |
257 | x, attn_weights = self.self_attn( | 267 | x, attn_weights = self.self_attn( |
258 | - query=x, key=x, key_padding_mask=encoder_padding_mask, output_attentions=output_attentions | 268 | + query=x, |
269 | + key=x, | ||
270 | + key_padding_mask=encoder_padding_mask, | ||
271 | + output_attentions=output_attentions, | ||
259 | ) | 272 | ) |
260 | x = F.dropout(x, p=self.dropout, training=self.training) | 273 | x = F.dropout(x, p=self.dropout, training=self.training) |
261 | x = residual + x | 274 | x = residual + x |
... | @@ -308,13 +321,23 @@ class BartEncoder(nn.Module): | ... | @@ -308,13 +321,23 @@ class BartEncoder(nn.Module): |
308 | config.extra_pos_embeddings, | 321 | config.extra_pos_embeddings, |
309 | ) | 322 | ) |
310 | self.embed_patches = nn.Embedding(3, config.d_model, padding_idx=0) | 323 | self.embed_patches = nn.Embedding(3, config.d_model, padding_idx=0) |
311 | - self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)]) | 324 | + self.layers = nn.ModuleList( |
312 | - self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity() | 325 | + [EncoderLayer(config) for _ in range(config.encoder_layers)] |
326 | + ) | ||
327 | + self.layernorm_embedding = ( | ||
328 | + LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity() | ||
329 | + ) | ||
313 | # mbart has one extra layer_norm | 330 | # mbart has one extra layer_norm |
314 | self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None | 331 | self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None |
315 | 332 | ||
316 | def forward( | 333 | def forward( |
317 | - self, input_ids, patch_ids, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=False | 334 | + self, |
335 | + input_ids, | ||
336 | + patch_ids, | ||
337 | + attention_mask=None, | ||
338 | + output_attentions=False, | ||
339 | + output_hidden_states=False, | ||
340 | + return_dict=False, | ||
318 | ): | 341 | ): |
319 | """ | 342 | """ |
320 | Args: | 343 | Args: |
... | @@ -352,10 +375,14 @@ class BartEncoder(nn.Module): | ... | @@ -352,10 +375,14 @@ class BartEncoder(nn.Module): |
352 | encoder_states.append(x) | 375 | encoder_states.append(x) |
353 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) | 376 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) |
354 | dropout_probability = random.uniform(0, 1) | 377 | dropout_probability = random.uniform(0, 1) |
355 | - if self.training and (dropout_probability < self.layerdrop): # skip the layer | 378 | + if self.training and ( |
379 | + dropout_probability < self.layerdrop | ||
380 | + ): # skip the layer | ||
356 | attn = None | 381 | attn = None |
357 | else: | 382 | else: |
358 | - x, attn = encoder_layer(x, attention_mask, output_attentions=output_attentions) | 383 | + x, attn = encoder_layer( |
384 | + x, attention_mask, output_attentions=output_attentions | ||
385 | + ) | ||
359 | 386 | ||
360 | if output_attentions: | 387 | if output_attentions: |
361 | all_attentions = all_attentions + (attn,) | 388 | all_attentions = all_attentions + (attn,) |
... | @@ -365,14 +392,20 @@ class BartEncoder(nn.Module): | ... | @@ -365,14 +392,20 @@ class BartEncoder(nn.Module): |
365 | if output_hidden_states: | 392 | if output_hidden_states: |
366 | encoder_states.append(x) | 393 | encoder_states.append(x) |
367 | # T x B x C -> B x T x C | 394 | # T x B x C -> B x T x C |
368 | - encoder_states = tuple(hidden_state.transpose(0, 1) for hidden_state in encoder_states) | 395 | + encoder_states = tuple( |
396 | + hidden_state.transpose(0, 1) for hidden_state in encoder_states | ||
397 | + ) | ||
369 | 398 | ||
370 | # T x B x C -> B x T x C | 399 | # T x B x C -> B x T x C |
371 | x = x.transpose(0, 1) | 400 | x = x.transpose(0, 1) |
372 | 401 | ||
373 | if not return_dict: | 402 | if not return_dict: |
374 | - return tuple(v for v in [x, encoder_states, all_attentions] if v is not None) | 403 | + return tuple( |
375 | - return BaseModelOutput(last_hidden_state=x, hidden_states=encoder_states, attentions=all_attentions) | 404 | + v for v in [x, encoder_states, all_attentions] if v is not None |
405 | + ) | ||
406 | + return BaseModelOutput( | ||
407 | + last_hidden_state=x, hidden_states=encoder_states, attentions=all_attentions | ||
408 | + ) | ||
376 | 409 | ||
377 | 410 | ||
378 | class DecoderLayer(nn.Module): | 411 | class DecoderLayer(nn.Module): |
... | @@ -498,8 +531,12 @@ class BartDecoder(nn.Module): | ... | @@ -498,8 +531,12 @@ class BartDecoder(nn.Module): |
498 | self.layers = nn.ModuleList( | 531 | self.layers = nn.ModuleList( |
499 | [DecoderLayer(config) for _ in range(config.decoder_layers)] | 532 | [DecoderLayer(config) for _ in range(config.decoder_layers)] |
500 | ) # type: List[DecoderLayer] | 533 | ) # type: List[DecoderLayer] |
501 | - self.layernorm_embedding = LayerNorm(config.d_model) if config.normalize_embedding else nn.Identity() | 534 | + self.layernorm_embedding = ( |
502 | - self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None | 535 | + LayerNorm(config.d_model) if config.normalize_embedding else nn.Identity() |
536 | + ) | ||
537 | + self.layer_norm = ( | ||
538 | + LayerNorm(config.d_model) if config.add_final_layer_norm else None | ||
539 | + ) | ||
503 | 540 | ||
504 | def forward( | 541 | def forward( |
505 | self, | 542 | self, |
... | @@ -595,23 +632,34 @@ class BartDecoder(nn.Module): | ... | @@ -595,23 +632,34 @@ class BartDecoder(nn.Module): |
595 | if use_cache: | 632 | if use_cache: |
596 | next_decoder_cache.append(layer_past.copy()) | 633 | next_decoder_cache.append(layer_past.copy()) |
597 | 634 | ||
598 | - if self.layer_norm and (idx == len(self.layers) - 1): # if config.add_final_layer_norm (mBART) | 635 | + if self.layer_norm and ( |
636 | + idx == len(self.layers) - 1 | ||
637 | + ): # if config.add_final_layer_norm (mBART) | ||
599 | x = self.layer_norm(x) | 638 | x = self.layer_norm(x) |
600 | if output_attentions: | 639 | if output_attentions: |
601 | all_self_attns += (layer_self_attn,) | 640 | all_self_attns += (layer_self_attn,) |
602 | 641 | ||
603 | # Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim) | 642 | # Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim) |
604 | if output_hidden_states: | 643 | if output_hidden_states: |
605 | - all_hidden_states = tuple(hidden_state.transpose(0, 1) for hidden_state in all_hidden_states) | 644 | + all_hidden_states = tuple( |
645 | + hidden_state.transpose(0, 1) for hidden_state in all_hidden_states | ||
646 | + ) | ||
606 | x = x.transpose(0, 1) | 647 | x = x.transpose(0, 1) |
607 | encoder_hidden_states = encoder_hidden_states.transpose(0, 1) | 648 | encoder_hidden_states = encoder_hidden_states.transpose(0, 1) |
608 | 649 | ||
609 | next_cache = next_decoder_cache if use_cache else None | 650 | next_cache = next_decoder_cache if use_cache else None |
610 | 651 | ||
611 | if not return_dict: | 652 | if not return_dict: |
612 | - return tuple(v for v in [x, next_cache, all_hidden_states, all_self_attns] if v is not None) | 653 | + return tuple( |
654 | + v | ||
655 | + for v in [x, next_cache, all_hidden_states, all_self_attns] | ||
656 | + if v is not None | ||
657 | + ) | ||
613 | return BaseModelOutputWithPast( | 658 | return BaseModelOutputWithPast( |
614 | - last_hidden_state=x, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns | 659 | + last_hidden_state=x, |
660 | + past_key_values=next_cache, | ||
661 | + hidden_states=all_hidden_states, | ||
662 | + attentions=all_self_attns, | ||
615 | ) | 663 | ) |
616 | 664 | ||
617 | 665 | ||
... | @@ -638,7 +686,9 @@ class Attention(nn.Module): | ... | @@ -638,7 +686,9 @@ class Attention(nn.Module): |
638 | self.num_heads = num_heads | 686 | self.num_heads = num_heads |
639 | self.dropout = dropout | 687 | self.dropout = dropout |
640 | self.head_dim = embed_dim // num_heads | 688 | self.head_dim = embed_dim // num_heads |
641 | - assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" | 689 | + assert ( |
690 | + self.head_dim * num_heads == self.embed_dim | ||
691 | + ), "embed_dim must be divisible by num_heads" | ||
642 | self.scaling = self.head_dim ** -0.5 | 692 | self.scaling = self.head_dim ** -0.5 |
643 | 693 | ||
644 | self.encoder_decoder_attention = encoder_decoder_attention | 694 | self.encoder_decoder_attention = encoder_decoder_attention |
... | @@ -649,7 +699,11 @@ class Attention(nn.Module): | ... | @@ -649,7 +699,11 @@ class Attention(nn.Module): |
649 | self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self" | 699 | self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self" |
650 | 700 | ||
651 | def _shape(self, tensor, seq_len, bsz): | 701 | def _shape(self, tensor, seq_len, bsz): |
652 | - return tensor.contiguous().view(seq_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) | 702 | + return ( |
703 | + tensor.contiguous() | ||
704 | + .view(seq_len, bsz * self.num_heads, self.head_dim) | ||
705 | + .transpose(0, 1) | ||
706 | + ) | ||
653 | 707 | ||
654 | def forward( | 708 | def forward( |
655 | self, | 709 | self, |
... | @@ -693,7 +747,9 @@ class Attention(nn.Module): | ... | @@ -693,7 +747,9 @@ class Attention(nn.Module): |
693 | v = self._shape(v, -1, bsz) | 747 | v = self._shape(v, -1, bsz) |
694 | 748 | ||
695 | if saved_state is not None: | 749 | if saved_state is not None: |
696 | - k, v, key_padding_mask = self._use_saved_state(k, v, saved_state, key_padding_mask, static_kv, bsz) | 750 | + k, v, key_padding_mask = self._use_saved_state( |
751 | + k, v, saved_state, key_padding_mask, static_kv, bsz | ||
752 | + ) | ||
697 | 753 | ||
698 | # Update cache | 754 | # Update cache |
699 | layer_state[self.cache_key] = { | 755 | layer_state[self.cache_key] = { |
... | @@ -708,7 +764,9 @@ class Attention(nn.Module): | ... | @@ -708,7 +764,9 @@ class Attention(nn.Module): |
708 | assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len) | 764 | assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len) |
709 | 765 | ||
710 | if attn_mask is not None: | 766 | if attn_mask is not None: |
711 | - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask | 767 | + attn_weights = ( |
768 | + attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask | ||
769 | + ) | ||
712 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) | 770 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
713 | 771 | ||
714 | # This is part of a workaround to get around fork/join parallelism not supporting Optional types. | 772 | # This is part of a workaround to get around fork/join parallelism not supporting Optional types. |
... | @@ -725,16 +783,14 @@ class Attention(nn.Module): | ... | @@ -725,16 +783,14 @@ class Attention(nn.Module): |
725 | attn_weights = attn_weights.masked_fill(reshaped, float("-inf")) | 783 | attn_weights = attn_weights.masked_fill(reshaped, float("-inf")) |
726 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) | 784 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
727 | attn_weights = F.softmax(attn_weights, dim=-1) | 785 | attn_weights = F.softmax(attn_weights, dim=-1) |
728 | - attn_probs = F.dropout( | 786 | + attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training,) |
729 | - attn_weights, | ||
730 | - p=self.dropout, | ||
731 | - training=self.training, | ||
732 | - ) | ||
733 | 787 | ||
734 | assert v is not None | 788 | assert v is not None |
735 | attn_output = torch.bmm(attn_probs, v) | 789 | attn_output = torch.bmm(attn_probs, v) |
736 | assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim) | 790 | assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim) |
737 | - attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) | 791 | + attn_output = ( |
792 | + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) | ||
793 | + ) | ||
738 | attn_output = self.out_proj(attn_output) | 794 | attn_output = self.out_proj(attn_output) |
739 | if output_attentions: | 795 | if output_attentions: |
740 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) | 796 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) |
... | @@ -763,12 +819,16 @@ class Attention(nn.Module): | ... | @@ -763,12 +819,16 @@ class Attention(nn.Module): |
763 | assert v is not None | 819 | assert v is not None |
764 | v = torch.cat([prev_value, v], dim=1) | 820 | v = torch.cat([prev_value, v], dim=1) |
765 | assert k is not None and v is not None | 821 | assert k is not None and v is not None |
766 | - prev_key_padding_mask: Optional[Tensor] = saved_state.get("prev_key_padding_mask", None) | 822 | + prev_key_padding_mask: Optional[Tensor] = saved_state.get( |
823 | + "prev_key_padding_mask", None | ||
824 | + ) | ||
767 | if prev_key_padding_mask is not None: | 825 | if prev_key_padding_mask is not None: |
768 | if static_kv: | 826 | if static_kv: |
769 | new_key_padding_mask = prev_key_padding_mask | 827 | new_key_padding_mask = prev_key_padding_mask |
770 | else: | 828 | else: |
771 | - new_key_padding_mask = torch.cat([prev_key_padding_mask, key_padding_mask], dim=1) | 829 | + new_key_padding_mask = torch.cat( |
830 | + [prev_key_padding_mask, key_padding_mask], dim=1 | ||
831 | + ) | ||
772 | else: | 832 | else: |
773 | new_key_padding_mask = key_padding_mask | 833 | new_key_padding_mask = key_padding_mask |
774 | return k, v, new_key_padding_mask | 834 | return k, v, new_key_padding_mask |
... | @@ -780,11 +840,7 @@ class BartClassificationHead(nn.Module): | ... | @@ -780,11 +840,7 @@ class BartClassificationHead(nn.Module): |
780 | # This can trivially be shared with RobertaClassificationHead | 840 | # This can trivially be shared with RobertaClassificationHead |
781 | 841 | ||
782 | def __init__( | 842 | def __init__( |
783 | - self, | 843 | + self, input_dim, inner_dim, num_classes, pooler_dropout, |
784 | - input_dim, | ||
785 | - inner_dim, | ||
786 | - num_classes, | ||
787 | - pooler_dropout, | ||
788 | ): | 844 | ): |
789 | super().__init__() | 845 | super().__init__() |
790 | self.dense = nn.Linear(input_dim, inner_dim) | 846 | self.dense = nn.Linear(input_dim, inner_dim) |
... | @@ -808,7 +864,9 @@ class LearnedPositionalEmbedding(nn.Embedding): | ... | @@ -808,7 +864,9 @@ class LearnedPositionalEmbedding(nn.Embedding): |
808 | position ids are passed to the forward function. | 864 | position ids are passed to the forward function. |
809 | """ | 865 | """ |
810 | 866 | ||
811 | - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, offset): | 867 | + def __init__( |
868 | + self, num_embeddings: int, embedding_dim: int, padding_idx: int, offset | ||
869 | + ): | ||
812 | # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 | 870 | # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 |
813 | # and adjust num_embeddings appropriately. Other models dont have this hack | 871 | # and adjust num_embeddings appropriately. Other models dont have this hack |
814 | self.offset = offset | 872 | self.offset = offset |
... | @@ -820,10 +878,14 @@ class LearnedPositionalEmbedding(nn.Embedding): | ... | @@ -820,10 +878,14 @@ class LearnedPositionalEmbedding(nn.Embedding): |
820 | """Input is expected to be of size [bsz x seqlen].""" | 878 | """Input is expected to be of size [bsz x seqlen].""" |
821 | bsz, seq_len = input_ids.shape[:2] | 879 | bsz, seq_len = input_ids.shape[:2] |
822 | if use_cache: | 880 | if use_cache: |
823 | - positions = input_ids.data.new(1, 1).fill_(seq_len - 1) # called before slicing | 881 | + positions = input_ids.data.new(1, 1).fill_( |
882 | + seq_len - 1 | ||
883 | + ) # called before slicing | ||
824 | else: | 884 | else: |
825 | # starts at 0, ends at 1-seq_len | 885 | # starts at 0, ends at 1-seq_len |
826 | - positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device) | 886 | + positions = torch.arange( |
887 | + seq_len, dtype=torch.long, device=self.weight.device | ||
888 | + ) | ||
827 | return super().forward(positions + self.offset) | 889 | return super().forward(positions + self.offset) |
828 | 890 | ||
829 | 891 | ||
... | @@ -896,16 +958,28 @@ class BartModel(PretrainedBartModel): | ... | @@ -896,16 +958,28 @@ class BartModel(PretrainedBartModel): |
896 | if decoder_input_ids is None: | 958 | if decoder_input_ids is None: |
897 | use_cache = False | 959 | use_cache = False |
898 | 960 | ||
899 | - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | 961 | + output_attentions = ( |
962 | + output_attentions | ||
963 | + if output_attentions is not None | ||
964 | + else self.config.output_attentions | ||
965 | + ) | ||
900 | output_hidden_states = ( | 966 | output_hidden_states = ( |
901 | - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | 967 | + output_hidden_states |
968 | + if output_hidden_states is not None | ||
969 | + else self.config.output_hidden_states | ||
902 | ) | 970 | ) |
903 | use_cache = use_cache if use_cache is not None else self.config.use_cache | 971 | use_cache = use_cache if use_cache is not None else self.config.use_cache |
904 | - return_dict = return_dict if return_dict is not None else self.config.use_return_dict | 972 | + return_dict = ( |
973 | + return_dict if return_dict is not None else self.config.use_return_dict | ||
974 | + ) | ||
905 | 975 | ||
906 | # make masks if user doesn't supply | 976 | # make masks if user doesn't supply |
907 | if not use_cache: | 977 | if not use_cache: |
908 | - decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs( | 978 | + ( |
979 | + decoder_input_ids, | ||
980 | + decoder_padding_mask, | ||
981 | + causal_mask, | ||
982 | + ) = _prepare_bart_decoder_inputs( | ||
909 | self.config, | 983 | self.config, |
910 | input_ids, | 984 | input_ids, |
911 | decoder_input_ids=decoder_input_ids, | 985 | decoder_input_ids=decoder_input_ids, |
... | @@ -974,17 +1048,24 @@ class BartModel(PretrainedBartModel): | ... | @@ -974,17 +1048,24 @@ class BartModel(PretrainedBartModel): |
974 | 1048 | ||
975 | 1049 | ||
976 | @add_start_docstrings( | 1050 | @add_start_docstrings( |
977 | - "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING | 1051 | + "The BART Model with a language modeling head. Can be used for summarization.", |
1052 | + BART_START_DOCSTRING, | ||
978 | ) | 1053 | ) |
979 | class BartForConditionalGeneration(PretrainedBartModel): | 1054 | class BartForConditionalGeneration(PretrainedBartModel): |
980 | base_model_prefix = "model" | 1055 | base_model_prefix = "model" |
981 | - authorized_missing_keys = [r"final_logits_bias", r"encoder\.version", r"decoder\.version"] | 1056 | + authorized_missing_keys = [ |
1057 | + r"final_logits_bias", | ||
1058 | + r"encoder\.version", | ||
1059 | + r"decoder\.version", | ||
1060 | + ] | ||
982 | 1061 | ||
983 | def __init__(self, config: BartConfig): | 1062 | def __init__(self, config: BartConfig): |
984 | super().__init__(config) | 1063 | super().__init__(config) |
985 | base_model = BartModel(config) | 1064 | base_model = BartModel(config) |
986 | self.model = base_model | 1065 | self.model = base_model |
987 | - self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) | 1066 | + self.register_buffer( |
1067 | + "final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)) | ||
1068 | + ) | ||
988 | 1069 | ||
989 | def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: | 1070 | def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: |
990 | old_num_tokens = self.model.shared.num_embeddings | 1071 | old_num_tokens = self.model.shared.num_embeddings |
... | @@ -993,16 +1074,23 @@ class BartForConditionalGeneration(PretrainedBartModel): | ... | @@ -993,16 +1074,23 @@ class BartForConditionalGeneration(PretrainedBartModel): |
993 | self._resize_final_logits_bias(new_num_tokens, old_num_tokens) | 1074 | self._resize_final_logits_bias(new_num_tokens, old_num_tokens) |
994 | return new_embeddings | 1075 | return new_embeddings |
995 | 1076 | ||
996 | - def _resize_final_logits_bias(self, new_num_tokens: int, old_num_tokens: int) -> None: | 1077 | + def _resize_final_logits_bias( |
1078 | + self, new_num_tokens: int, old_num_tokens: int | ||
1079 | + ) -> None: | ||
997 | if new_num_tokens <= old_num_tokens: | 1080 | if new_num_tokens <= old_num_tokens: |
998 | new_bias = self.final_logits_bias[:, :new_num_tokens] | 1081 | new_bias = self.final_logits_bias[:, :new_num_tokens] |
999 | else: | 1082 | else: |
1000 | - extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) | 1083 | + extra_bias = torch.zeros( |
1084 | + (1, new_num_tokens - old_num_tokens), | ||
1085 | + device=self.final_logits_bias.device, | ||
1086 | + ) | ||
1001 | new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) | 1087 | new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) |
1002 | self.register_buffer("final_logits_bias", new_bias) | 1088 | self.register_buffer("final_logits_bias", new_bias) |
1003 | 1089 | ||
1004 | @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING) | 1090 | @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING) |
1005 | - @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) | 1091 | + @replace_return_docstrings( |
1092 | + output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC | ||
1093 | + ) | ||
1006 | @add_end_docstrings(BART_GENERATION_EXAMPLE) | 1094 | @add_end_docstrings(BART_GENERATION_EXAMPLE) |
1007 | def forward( | 1095 | def forward( |
1008 | self, | 1096 | self, |
... | @@ -1065,7 +1153,9 @@ class BartForConditionalGeneration(PretrainedBartModel): | ... | @@ -1065,7 +1153,9 @@ class BartForConditionalGeneration(PretrainedBartModel): |
1065 | FutureWarning, | 1153 | FutureWarning, |
1066 | ) | 1154 | ) |
1067 | past_key_values = unused.pop("decoder_past_key_values") | 1155 | past_key_values = unused.pop("decoder_past_key_values") |
1068 | - return_dict = return_dict if return_dict is not None else self.config.use_return_dict | 1156 | + return_dict = ( |
1157 | + return_dict if return_dict is not None else self.config.use_return_dict | ||
1158 | + ) | ||
1069 | 1159 | ||
1070 | if labels is not None: | 1160 | if labels is not None: |
1071 | use_cache = False | 1161 | use_cache = False |
... | @@ -1085,17 +1175,23 @@ class BartForConditionalGeneration(PretrainedBartModel): | ... | @@ -1085,17 +1175,23 @@ class BartForConditionalGeneration(PretrainedBartModel): |
1085 | output_hidden_states=output_hidden_states, | 1175 | output_hidden_states=output_hidden_states, |
1086 | return_dict=return_dict, | 1176 | return_dict=return_dict, |
1087 | ) | 1177 | ) |
1088 | - lm_logits = F.linear(outputs[0], self.model.shared.weight, bias=self.final_logits_bias) | 1178 | + lm_logits = F.linear( |
1179 | + outputs[0], self.model.shared.weight, bias=self.final_logits_bias | ||
1180 | + ) | ||
1089 | 1181 | ||
1090 | masked_lm_loss = None | 1182 | masked_lm_loss = None |
1091 | if labels is not None: | 1183 | if labels is not None: |
1092 | loss_fct = CrossEntropyLoss() | 1184 | loss_fct = CrossEntropyLoss() |
1093 | # TODO(SS): do we need to ignore pad tokens in labels? | 1185 | # TODO(SS): do we need to ignore pad tokens in labels? |
1094 | - masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) | 1186 | + masked_lm_loss = loss_fct( |
1187 | + lm_logits.view(-1, self.config.vocab_size), labels.view(-1) | ||
1188 | + ) | ||
1095 | 1189 | ||
1096 | if not return_dict: | 1190 | if not return_dict: |
1097 | output = (lm_logits,) + outputs[1:] | 1191 | output = (lm_logits,) + outputs[1:] |
1098 | - return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output | 1192 | + return ( |
1193 | + ((masked_lm_loss,) + output) if masked_lm_loss is not None else output | ||
1194 | + ) | ||
1099 | 1195 | ||
1100 | return Seq2SeqLMOutput( | 1196 | return Seq2SeqLMOutput( |
1101 | loss=masked_lm_loss, | 1197 | loss=masked_lm_loss, |
... | @@ -1109,7 +1205,13 @@ class BartForConditionalGeneration(PretrainedBartModel): | ... | @@ -1109,7 +1205,13 @@ class BartForConditionalGeneration(PretrainedBartModel): |
1109 | ) | 1205 | ) |
1110 | 1206 | ||
1111 | def prepare_inputs_for_generation( | 1207 | def prepare_inputs_for_generation( |
1112 | - self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs | 1208 | + self, |
1209 | + decoder_input_ids, | ||
1210 | + past, | ||
1211 | + attention_mask, | ||
1212 | + use_cache, | ||
1213 | + encoder_outputs, | ||
1214 | + **kwargs, | ||
1113 | ): | 1215 | ): |
1114 | return { | 1216 | return { |
1115 | "input_ids": None, # encoder_outputs is defined. input_ids not needed | 1217 | "input_ids": None, # encoder_outputs is defined. input_ids not needed |
... | @@ -1130,7 +1232,9 @@ class BartForConditionalGeneration(PretrainedBartModel): | ... | @@ -1130,7 +1232,9 @@ class BartForConditionalGeneration(PretrainedBartModel): |
1130 | 1232 | ||
1131 | def _force_token_ids_generation(self, scores, token_id) -> None: | 1233 | def _force_token_ids_generation(self, scores, token_id) -> None: |
1132 | """force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))""" | 1234 | """force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))""" |
1133 | - scores[:, [x for x in range(self.config.vocab_size) if x != token_id]] = -float("inf") | 1235 | + scores[:, [x for x in range(self.config.vocab_size) if x != token_id]] = -float( |
1236 | + "inf" | ||
1237 | + ) | ||
1134 | 1238 | ||
1135 | @staticmethod | 1239 | @staticmethod |
1136 | def _reorder_cache(past, beam_idx): | 1240 | def _reorder_cache(past, beam_idx): |
... | @@ -1138,7 +1242,8 @@ class BartForConditionalGeneration(PretrainedBartModel): | ... | @@ -1138,7 +1242,8 @@ class BartForConditionalGeneration(PretrainedBartModel): |
1138 | for layer_past in past: | 1242 | for layer_past in past: |
1139 | # get the correct batch idx from decoder layer's batch dim for cross and self-attn | 1243 | # get the correct batch idx from decoder layer's batch dim for cross and self-attn |
1140 | layer_past_new = { | 1244 | layer_past_new = { |
1141 | - attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items() | 1245 | + attn_key: _reorder_buffer(attn_cache, beam_idx) |
1246 | + for attn_key, attn_cache in layer_past.items() | ||
1142 | } | 1247 | } |
1143 | reordered_past.append(layer_past_new) | 1248 | reordered_past.append(layer_past_new) |
1144 | return reordered_past | 1249 | return reordered_past |
... | @@ -1159,10 +1264,7 @@ class BartForSequenceClassification(PretrainedBartModel): | ... | @@ -1159,10 +1264,7 @@ class BartForSequenceClassification(PretrainedBartModel): |
1159 | super().__init__(config, **kwargs) | 1264 | super().__init__(config, **kwargs) |
1160 | self.model = BartModel(config) | 1265 | self.model = BartModel(config) |
1161 | self.classification_head = BartClassificationHead( | 1266 | self.classification_head = BartClassificationHead( |
1162 | - config.d_model, | 1267 | + config.d_model, config.d_model, config.num_labels, config.classif_dropout, |
1163 | - config.d_model, | ||
1164 | - config.num_labels, | ||
1165 | - config.classif_dropout, | ||
1166 | ) | 1268 | ) |
1167 | self.model._init_weights(self.classification_head.dense) | 1269 | self.model._init_weights(self.classification_head.dense) |
1168 | self.model._init_weights(self.classification_head.out_proj) | 1270 | self.model._init_weights(self.classification_head.out_proj) |
... | @@ -1193,7 +1295,9 @@ class BartForSequenceClassification(PretrainedBartModel): | ... | @@ -1193,7 +1295,9 @@ class BartForSequenceClassification(PretrainedBartModel): |
1193 | Indices should be in :obj:`[0, ..., config.num_labels - 1]`. | 1295 | Indices should be in :obj:`[0, ..., config.num_labels - 1]`. |
1194 | If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). | 1296 | If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
1195 | """ | 1297 | """ |
1196 | - return_dict = return_dict if return_dict is not None else self.config.use_return_dict | 1298 | + return_dict = ( |
1299 | + return_dict if return_dict is not None else self.config.use_return_dict | ||
1300 | + ) | ||
1197 | if labels is not None: | 1301 | if labels is not None: |
1198 | use_cache = False | 1302 | use_cache = False |
1199 | 1303 | ||
... | @@ -1212,7 +1316,9 @@ class BartForSequenceClassification(PretrainedBartModel): | ... | @@ -1212,7 +1316,9 @@ class BartForSequenceClassification(PretrainedBartModel): |
1212 | eos_mask = input_ids.eq(self.config.eos_token_id) | 1316 | eos_mask = input_ids.eq(self.config.eos_token_id) |
1213 | if len(torch.unique(eos_mask.sum(1))) > 1: | 1317 | if len(torch.unique(eos_mask.sum(1))) > 1: |
1214 | raise ValueError("All examples must have the same number of <eos> tokens.") | 1318 | raise ValueError("All examples must have the same number of <eos> tokens.") |
1215 | - sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :] | 1319 | + sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[ |
1320 | + :, -1, : | ||
1321 | + ] | ||
1216 | logits = self.classification_head(sentence_representation) | 1322 | logits = self.classification_head(sentence_representation) |
1217 | 1323 | ||
1218 | loss = None | 1324 | loss = None |
... | @@ -1284,7 +1390,9 @@ class BartForQuestionAnswering(PretrainedBartModel): | ... | @@ -1284,7 +1390,9 @@ class BartForQuestionAnswering(PretrainedBartModel): |
1284 | Positions are clamped to the length of the sequence (`sequence_length`). | 1390 | Positions are clamped to the length of the sequence (`sequence_length`). |
1285 | Position outside of the sequence are not taken into account for computing the loss. | 1391 | Position outside of the sequence are not taken into account for computing the loss. |
1286 | """ | 1392 | """ |
1287 | - return_dict = return_dict if return_dict is not None else self.config.use_return_dict | 1393 | + return_dict = ( |
1394 | + return_dict if return_dict is not None else self.config.use_return_dict | ||
1395 | + ) | ||
1288 | if start_positions is not None and end_positions is not None: | 1396 | if start_positions is not None and end_positions is not None: |
1289 | use_cache = False | 1397 | use_cache = False |
1290 | 1398 | ||
... | @@ -1325,10 +1433,7 @@ class BartForQuestionAnswering(PretrainedBartModel): | ... | @@ -1325,10 +1433,7 @@ class BartForQuestionAnswering(PretrainedBartModel): |
1325 | total_loss = (start_loss + end_loss) / 2 | 1433 | total_loss = (start_loss + end_loss) / 2 |
1326 | 1434 | ||
1327 | if not return_dict: | 1435 | if not return_dict: |
1328 | - output = ( | 1436 | + output = (start_logits, end_logits,) + outputs[1:] |
1329 | - start_logits, | ||
1330 | - end_logits, | ||
1331 | - ) + outputs[1:] | ||
1332 | return ((total_loss,) + output) if total_loss is not None else output | 1437 | return ((total_loss,) + output) if total_loss is not None else output |
1333 | 1438 | ||
1334 | return Seq2SeqQuestionAnsweringModelOutput( | 1439 | return Seq2SeqQuestionAnsweringModelOutput( |
... | @@ -1350,7 +1455,9 @@ class SinusoidalPositionalEmbedding(nn.Embedding): | ... | @@ -1350,7 +1455,9 @@ class SinusoidalPositionalEmbedding(nn.Embedding): |
1350 | def __init__(self, num_positions, embedding_dim, padding_idx=None): | 1455 | def __init__(self, num_positions, embedding_dim, padding_idx=None): |
1351 | super().__init__(num_positions, embedding_dim) | 1456 | super().__init__(num_positions, embedding_dim) |
1352 | if embedding_dim % 2 != 0: | 1457 | if embedding_dim % 2 != 0: |
1353 | - raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported") | 1458 | + raise NotImplementedError( |
1459 | + f"odd embedding_dim {embedding_dim} not supported" | ||
1460 | + ) | ||
1354 | self.weight = self._init_weight(self.weight) | 1461 | self.weight = self._init_weight(self.weight) |
1355 | 1462 | ||
1356 | @staticmethod | 1463 | @staticmethod |
... | @@ -1360,9 +1467,14 @@ class SinusoidalPositionalEmbedding(nn.Embedding): | ... | @@ -1360,9 +1467,14 @@ class SinusoidalPositionalEmbedding(nn.Embedding): |
1360 | """ | 1467 | """ |
1361 | n_pos, dim = out.shape | 1468 | n_pos, dim = out.shape |
1362 | position_enc = np.array( | 1469 | position_enc = np.array( |
1363 | - [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)] | 1470 | + [ |
1471 | + [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] | ||
1472 | + for pos in range(n_pos) | ||
1473 | + ] | ||
1364 | ) | 1474 | ) |
1365 | - out[:, 0 : dim // 2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) # This line breaks for odd n_pos | 1475 | + out[:, 0 : dim // 2] = torch.FloatTensor( |
1476 | + np.sin(position_enc[:, 0::2]) | ||
1477 | + ) # This line breaks for odd n_pos | ||
1366 | out[:, dim // 2 :] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) | 1478 | out[:, dim // 2 :] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) |
1367 | out.detach_() | 1479 | out.detach_() |
1368 | out.requires_grad = False | 1480 | out.requires_grad = False |
... | @@ -1373,8 +1485,12 @@ class SinusoidalPositionalEmbedding(nn.Embedding): | ... | @@ -1373,8 +1485,12 @@ class SinusoidalPositionalEmbedding(nn.Embedding): |
1373 | """Input is expected to be of size [bsz x seqlen].""" | 1485 | """Input is expected to be of size [bsz x seqlen].""" |
1374 | bsz, seq_len = input_ids.shape[:2] | 1486 | bsz, seq_len = input_ids.shape[:2] |
1375 | if use_cache: | 1487 | if use_cache: |
1376 | - positions = input_ids.data.new(1, 1).fill_(seq_len - 1) # called before slicing | 1488 | + positions = input_ids.data.new(1, 1).fill_( |
1489 | + seq_len - 1 | ||
1490 | + ) # called before slicing | ||
1377 | else: | 1491 | else: |
1378 | # starts at 0, ends at 1-seq_len | 1492 | # starts at 0, ends at 1-seq_len |
1379 | - positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device) | 1493 | + positions = torch.arange( |
1494 | + seq_len, dtype=torch.long, device=self.weight.device | ||
1495 | + ) | ||
1380 | return super().forward(positions) | 1496 | return super().forward(positions) | ... | ... |
... | @@ -80,7 +80,9 @@ def find_pruneable_heads_and_indices( | ... | @@ -80,7 +80,9 @@ def find_pruneable_heads_and_indices( |
80 | :obj:`Tuple[Set[int], torch.LongTensor]`: A tuple with the remaining heads and their corresponding indices. | 80 | :obj:`Tuple[Set[int], torch.LongTensor]`: A tuple with the remaining heads and their corresponding indices. |
81 | """ | 81 | """ |
82 | mask = torch.ones(n_heads, head_size) | 82 | mask = torch.ones(n_heads, head_size) |
83 | - heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads | 83 | + heads = ( |
84 | + set(heads) - already_pruned_heads | ||
85 | + ) # Convert to set and remove already pruned heads | ||
84 | for head in heads: | 86 | for head in heads: |
85 | # Compute how many pruned heads are before the head and move the index accordingly | 87 | # Compute how many pruned heads are before the head and move the index accordingly |
86 | head = head - sum(1 if h < head else 0 for h in already_pruned_heads) | 88 | head = head - sum(1 if h < head else 0 for h in already_pruned_heads) |
... | @@ -106,7 +108,11 @@ class ModuleUtilsMixin: | ... | @@ -106,7 +108,11 @@ class ModuleUtilsMixin: |
106 | Returns: | 108 | Returns: |
107 | :obj:`int`: The number of parameters. | 109 | :obj:`int`: The number of parameters. |
108 | """ | 110 | """ |
109 | - params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters() | 111 | + params = ( |
112 | + filter(lambda x: x.requires_grad, self.parameters()) | ||
113 | + if only_trainable | ||
114 | + else self.parameters() | ||
115 | + ) | ||
110 | return sum(p.numel() for p in params) | 116 | return sum(p.numel() for p in params) |
111 | 117 | ||
112 | @staticmethod | 118 | @staticmethod |
... | @@ -114,7 +120,9 @@ class ModuleUtilsMixin: | ... | @@ -114,7 +120,9 @@ class ModuleUtilsMixin: |
114 | try: | 120 | try: |
115 | import psutil | 121 | import psutil |
116 | except (ImportError): | 122 | except (ImportError): |
117 | - raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.") | 123 | + raise ImportError( |
124 | + "You need to install psutil (pip install psutil) to use memory tracing." | ||
125 | + ) | ||
118 | 126 | ||
119 | process = psutil.Process(os.getpid()) | 127 | process = psutil.Process(os.getpid()) |
120 | mem = process.memory_info() | 128 | mem = process.memory_info() |
... | @@ -126,13 +134,17 @@ class ModuleUtilsMixin: | ... | @@ -126,13 +134,17 @@ class ModuleUtilsMixin: |
126 | try: | 134 | try: |
127 | import psutil | 135 | import psutil |
128 | except (ImportError): | 136 | except (ImportError): |
129 | - raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.") | 137 | + raise ImportError( |
138 | + "You need to install psutil (pip install psutil) to use memory tracing." | ||
139 | + ) | ||
130 | 140 | ||
131 | process = psutil.Process(os.getpid()) | 141 | process = psutil.Process(os.getpid()) |
132 | mem = process.memory_info() | 142 | mem = process.memory_info() |
133 | module.mem_rss_post_forward = mem.rss | 143 | module.mem_rss_post_forward = mem.rss |
134 | mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward | 144 | mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward |
135 | - module.mem_rss_diff = mem_rss_diff + (module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0) | 145 | + module.mem_rss_diff = mem_rss_diff + ( |
146 | + module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0 | ||
147 | + ) | ||
136 | return None | 148 | return None |
137 | 149 | ||
138 | def add_memory_hooks(self): | 150 | def add_memory_hooks(self): |
... | @@ -169,7 +181,9 @@ class ModuleUtilsMixin: | ... | @@ -169,7 +181,9 @@ class ModuleUtilsMixin: |
169 | # For nn.DataParallel compatibility in PyTorch 1.5 | 181 | # For nn.DataParallel compatibility in PyTorch 1.5 |
170 | 182 | ||
171 | def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: | 183 | def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: |
172 | - tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] | 184 | + tuples = [ |
185 | + (k, v) for k, v in module.__dict__.items() if torch.is_tensor(v) | ||
186 | + ] | ||
173 | return tuples | 187 | return tuples |
174 | 188 | ||
175 | gen = self._named_members(get_members_fn=find_tensor_attributes) | 189 | gen = self._named_members(get_members_fn=find_tensor_attributes) |
... | @@ -187,7 +201,9 @@ class ModuleUtilsMixin: | ... | @@ -187,7 +201,9 @@ class ModuleUtilsMixin: |
187 | # For nn.DataParallel compatibility in PyTorch 1.5 | 201 | # For nn.DataParallel compatibility in PyTorch 1.5 |
188 | 202 | ||
189 | def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: | 203 | def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: |
190 | - tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] | 204 | + tuples = [ |
205 | + (k, v) for k, v in module.__dict__.items() if torch.is_tensor(v) | ||
206 | + ] | ||
191 | return tuples | 207 | return tuples |
192 | 208 | ||
193 | gen = self._named_members(get_members_fn=find_tensor_attributes) | 209 | gen = self._named_members(get_members_fn=find_tensor_attributes) |
... | @@ -213,12 +229,18 @@ class ModuleUtilsMixin: | ... | @@ -213,12 +229,18 @@ class ModuleUtilsMixin: |
213 | # /transformer/transformer_layers.py#L270 | 229 | # /transformer/transformer_layers.py#L270 |
214 | # encoder_extended_attention_mask = (encoder_extended_attention_mask == | 230 | # encoder_extended_attention_mask = (encoder_extended_attention_mask == |
215 | # encoder_extended_attention_mask.transpose(-1, -2)) | 231 | # encoder_extended_attention_mask.transpose(-1, -2)) |
216 | - encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility | 232 | + encoder_extended_attention_mask = encoder_extended_attention_mask.to( |
233 | + dtype=self.dtype | ||
234 | + ) # fp16 compatibility | ||
217 | 235 | ||
218 | if self.dtype == torch.float16: | 236 | if self.dtype == torch.float16: |
219 | - encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e4 | 237 | + encoder_extended_attention_mask = ( |
238 | + 1.0 - encoder_extended_attention_mask | ||
239 | + ) * -1e4 | ||
220 | elif self.dtype == torch.float32: | 240 | elif self.dtype == torch.float32: |
221 | - encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9 | 241 | + encoder_extended_attention_mask = ( |
242 | + 1.0 - encoder_extended_attention_mask | ||
243 | + ) * -1e9 | ||
222 | else: | 244 | else: |
223 | raise ValueError( | 245 | raise ValueError( |
224 | "{} not recognized. `dtype` should be set to either `torch.float32` or `torch.float16`".format( | 246 | "{} not recognized. `dtype` should be set to either `torch.float32` or `torch.float16`".format( |
... | @@ -228,7 +250,9 @@ class ModuleUtilsMixin: | ... | @@ -228,7 +250,9 @@ class ModuleUtilsMixin: |
228 | 250 | ||
229 | return encoder_extended_attention_mask | 251 | return encoder_extended_attention_mask |
230 | 252 | ||
231 | - def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device) -> Tensor: | 253 | + def get_extended_attention_mask( |
254 | + self, attention_mask: Tensor, input_shape: Tuple[int], device: device | ||
255 | + ) -> Tensor: | ||
232 | """ | 256 | """ |
233 | Makes broadcastable attention and causal masks so that future and masked tokens are ignored. | 257 | Makes broadcastable attention and causal masks so that future and masked tokens are ignored. |
234 | 258 | ||
... | @@ -254,10 +278,15 @@ class ModuleUtilsMixin: | ... | @@ -254,10 +278,15 @@ class ModuleUtilsMixin: |
254 | if self.config.is_decoder: | 278 | if self.config.is_decoder: |
255 | batch_size, seq_length = input_shape | 279 | batch_size, seq_length = input_shape |
256 | seq_ids = torch.arange(seq_length, device=device) | 280 | seq_ids = torch.arange(seq_length, device=device) |
257 | - causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] | 281 | + causal_mask = ( |
282 | + seq_ids[None, None, :].repeat(batch_size, seq_length, 1) | ||
283 | + <= seq_ids[None, :, None] | ||
284 | + ) | ||
258 | # causal and attention masks must have same type with pytorch version < 1.3 | 285 | # causal and attention masks must have same type with pytorch version < 1.3 |
259 | causal_mask = causal_mask.to(attention_mask.dtype) | 286 | causal_mask = causal_mask.to(attention_mask.dtype) |
260 | - extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] | 287 | + extended_attention_mask = ( |
288 | + causal_mask[:, None, :, :] * attention_mask[:, None, None, :] | ||
289 | + ) | ||
261 | else: | 290 | else: |
262 | extended_attention_mask = attention_mask[:, None, None, :] | 291 | extended_attention_mask = attention_mask[:, None, None, :] |
263 | else: | 292 | else: |
... | @@ -272,12 +301,17 @@ class ModuleUtilsMixin: | ... | @@ -272,12 +301,17 @@ class ModuleUtilsMixin: |
272 | # positions we want to attend and -10000.0 for masked positions. | 301 | # positions we want to attend and -10000.0 for masked positions. |
273 | # Since we are adding it to the raw scores before the softmax, this is | 302 | # Since we are adding it to the raw scores before the softmax, this is |
274 | # effectively the same as removing these entirely. | 303 | # effectively the same as removing these entirely. |
275 | - extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility | 304 | + extended_attention_mask = extended_attention_mask.to( |
305 | + dtype=self.dtype | ||
306 | + ) # fp16 compatibility | ||
276 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 | 307 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 |
277 | return extended_attention_mask | 308 | return extended_attention_mask |
278 | 309 | ||
279 | def get_head_mask( | 310 | def get_head_mask( |
280 | - self, head_mask: Optional[Tensor], num_hidden_layers: int, is_attention_chunked: bool = False | 311 | + self, |
312 | + head_mask: Optional[Tensor], | ||
313 | + num_hidden_layers: int, | ||
314 | + is_attention_chunked: bool = False, | ||
281 | ) -> Tensor: | 315 | ) -> Tensor: |
282 | """ | 316 | """ |
283 | Prepare the head mask if needed. | 317 | Prepare the head mask if needed. |
... | @@ -309,9 +343,13 @@ class ModuleUtilsMixin: | ... | @@ -309,9 +343,13 @@ class ModuleUtilsMixin: |
309 | head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) | 343 | head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) |
310 | head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1) | 344 | head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1) |
311 | elif head_mask.dim() == 2: | 345 | elif head_mask.dim() == 2: |
312 | - head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer | 346 | + head_mask = ( |
347 | + head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) | ||
348 | + ) # We can specify head_mask for each layer | ||
313 | assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}" | 349 | assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}" |
314 | - head_mask = head_mask.to(dtype=self.dtype) # switch to fload if need + fp16 compatibility | 350 | + head_mask = head_mask.to( |
351 | + dtype=self.dtype | ||
352 | + ) # switch to fload if need + fp16 compatibility | ||
315 | return head_mask | 353 | return head_mask |
316 | 354 | ||
317 | 355 | ||
... | @@ -420,12 +458,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): | ... | @@ -420,12 +458,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): |
420 | self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) | 458 | self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) |
421 | 459 | ||
422 | if self.config.is_encoder_decoder and self.config.tie_encoder_decoder: | 460 | if self.config.is_encoder_decoder and self.config.tie_encoder_decoder: |
423 | - self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix) | 461 | + self._tie_encoder_decoder_weights( |
462 | + self.encoder, self.decoder, self.base_model_prefix | ||
463 | + ) | ||
424 | 464 | ||
425 | @staticmethod | 465 | @staticmethod |
426 | - def _tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str): | 466 | + def _tie_encoder_decoder_weights( |
467 | + encoder: nn.Module, decoder: nn.Module, base_model_prefix: str | ||
468 | + ): | ||
427 | uninitialized_encoder_weights: List[str] = [] | 469 | uninitialized_encoder_weights: List[str] = [] |
428 | - assert decoder.__class__ == encoder.__class__, f"{decoder.__class__} and {encoder.__class__} have to be equal." | 470 | + assert ( |
471 | + decoder.__class__ == encoder.__class__ | ||
472 | + ), f"{decoder.__class__} and {encoder.__class__} have to be equal." | ||
429 | 473 | ||
430 | def tie_encoder_to_decoder_recursively( | 474 | def tie_encoder_to_decoder_recursively( |
431 | decoder_pointer: nn.Module, | 475 | decoder_pointer: nn.Module, |
... | @@ -452,13 +496,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): | ... | @@ -452,13 +496,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): |
452 | len(encoder_modules) > 0 | 496 | len(encoder_modules) > 0 |
453 | ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" | 497 | ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" |
454 | 498 | ||
455 | - all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()]) | 499 | + all_encoder_weights = set( |
500 | + [ | ||
501 | + module_name + "/" + sub_name | ||
502 | + for sub_name in encoder_modules.keys() | ||
503 | + ] | ||
504 | + ) | ||
456 | encoder_layer_pos = 0 | 505 | encoder_layer_pos = 0 |
457 | for name, module in decoder_modules.items(): | 506 | for name, module in decoder_modules.items(): |
458 | if name.isdigit(): | 507 | if name.isdigit(): |
459 | encoder_name = str(int(name) + encoder_layer_pos) | 508 | encoder_name = str(int(name) + encoder_layer_pos) |
460 | decoder_name = name | 509 | decoder_name = name |
461 | - if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])): | 510 | + if not isinstance( |
511 | + decoder_modules[decoder_name], | ||
512 | + type(encoder_modules[encoder_name]), | ||
513 | + ): | ||
462 | # this can happen if the name corresponds to the position in a list module list of layers | 514 | # this can happen if the name corresponds to the position in a list module list of layers |
463 | # in this case the decoder has added a cross-attention that the encoder does not have | 515 | # in this case the decoder has added a cross-attention that the encoder does not have |
464 | # thus skip this step and substract one layer pos from encoder | 516 | # thus skip this step and substract one layer pos from encoder |
... | @@ -484,7 +536,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): | ... | @@ -484,7 +536,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): |
484 | uninitialized_encoder_weights += list(all_encoder_weights) | 536 | uninitialized_encoder_weights += list(all_encoder_weights) |
485 | 537 | ||
486 | # tie weights recursively | 538 | # tie weights recursively |
487 | - tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights) | 539 | + tie_encoder_to_decoder_recursively( |
540 | + decoder, encoder, base_model_prefix, uninitialized_encoder_weights | ||
541 | + ) | ||
488 | if len(uninitialized_encoder_weights) > 0: | 542 | if len(uninitialized_encoder_weights) > 0: |
489 | logger.warning( | 543 | logger.warning( |
490 | f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}" | 544 | f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}" |
... | @@ -507,10 +561,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): | ... | @@ -507,10 +561,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): |
507 | "constant", | 561 | "constant", |
508 | 0, | 562 | 0, |
509 | ) | 563 | ) |
510 | - if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"): | 564 | + if hasattr(output_embeddings, "out_features") and hasattr( |
565 | + input_embeddings, "num_embeddings" | ||
566 | + ): | ||
511 | output_embeddings.out_features = input_embeddings.num_embeddings | 567 | output_embeddings.out_features = input_embeddings.num_embeddings |
512 | 568 | ||
513 | - def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> torch.nn.Embedding: | 569 | + def resize_token_embeddings( |
570 | + self, new_num_tokens: Optional[int] = None | ||
571 | + ) -> torch.nn.Embedding: | ||
514 | """ | 572 | """ |
515 | Resizes input token embeddings matrix of the model if :obj:`new_num_tokens != config.vocab_size`. | 573 | Resizes input token embeddings matrix of the model if :obj:`new_num_tokens != config.vocab_size`. |
516 | 574 | ||
... | @@ -526,7 +584,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): | ... | @@ -526,7 +584,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): |
526 | Return: | 584 | Return: |
527 | :obj:`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model. | 585 | :obj:`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model. |
528 | """ | 586 | """ |
529 | - base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed | 587 | + base_model = getattr( |
588 | + self, self.base_model_prefix, self | ||
589 | + ) # get the base model if needed | ||
530 | model_embeds = base_model._resize_token_embeddings(new_num_tokens) | 590 | model_embeds = base_model._resize_token_embeddings(new_num_tokens) |
531 | if new_num_tokens is None: | 591 | if new_num_tokens is None: |
532 | return model_embeds | 592 | return model_embeds |
... | @@ -583,7 +643,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): | ... | @@ -583,7 +643,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): |
583 | 643 | ||
584 | # Copy token embeddings from the previous weights | 644 | # Copy token embeddings from the previous weights |
585 | num_tokens_to_copy = min(old_num_tokens, new_num_tokens) | 645 | num_tokens_to_copy = min(old_num_tokens, new_num_tokens) |
586 | - new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :] | 646 | + new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[ |
647 | + :num_tokens_to_copy, : | ||
648 | + ] | ||
587 | 649 | ||
588 | return new_embeddings | 650 | return new_embeddings |
589 | 651 | ||
... | @@ -614,7 +676,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): | ... | @@ -614,7 +676,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): |
614 | # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads | 676 | # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads |
615 | for layer, heads in heads_to_prune.items(): | 677 | for layer, heads in heads_to_prune.items(): |
616 | union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads) | 678 | union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads) |
617 | - self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON | 679 | + self.config.pruned_heads[layer] = list( |
680 | + union_heads | ||
681 | + ) # Unfortunately we have to store it as list for JSON | ||
618 | 682 | ||
619 | self.base_model._prune_heads(heads_to_prune) | 683 | self.base_model._prune_heads(heads_to_prune) |
620 | 684 | ||
... | @@ -628,7 +692,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): | ... | @@ -628,7 +692,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): |
628 | Directory to which to save. Will be created if it doesn't exist. | 692 | Directory to which to save. Will be created if it doesn't exist. |
629 | """ | 693 | """ |
630 | if os.path.isfile(save_directory): | 694 | if os.path.isfile(save_directory): |
631 | - logger.error("Provided path ({}) should be a directory, not a file".format(save_directory)) | 695 | + logger.error( |
696 | + "Provided path ({}) should be a directory, not a file".format( | ||
697 | + save_directory | ||
698 | + ) | ||
699 | + ) | ||
632 | return | 700 | return |
633 | os.makedirs(save_directory, exist_ok=True) | 701 | os.makedirs(save_directory, exist_ok=True) |
634 | 702 | ||
... | @@ -775,7 +843,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): | ... | @@ -775,7 +843,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): |
775 | 843 | ||
776 | # Load config if we don't provide a configuration | 844 | # Load config if we don't provide a configuration |
777 | if not isinstance(config, PretrainedConfig): | 845 | if not isinstance(config, PretrainedConfig): |
778 | - config_path = config if config is not None else pretrained_model_name_or_path | 846 | + config_path = ( |
847 | + config if config is not None else pretrained_model_name_or_path | ||
848 | + ) | ||
779 | config, model_kwargs = cls.config_class.from_pretrained( | 849 | config, model_kwargs = cls.config_class.from_pretrained( |
780 | config_path, | 850 | config_path, |
781 | *model_args, | 851 | *model_args, |
... | @@ -793,23 +863,43 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): | ... | @@ -793,23 +863,43 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): |
793 | # Load model | 863 | # Load model |
794 | if pretrained_model_name_or_path is not None: | 864 | if pretrained_model_name_or_path is not None: |
795 | if os.path.isdir(pretrained_model_name_or_path): | 865 | if os.path.isdir(pretrained_model_name_or_path): |
796 | - if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")): | 866 | + if from_tf and os.path.isfile( |
867 | + os.path.join( | ||
868 | + pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index" | ||
869 | + ) | ||
870 | + ): | ||
797 | # Load from a TF 1.0 checkpoint | 871 | # Load from a TF 1.0 checkpoint |
798 | - archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index") | 872 | + archive_file = os.path.join( |
799 | - elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): | 873 | + pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index" |
874 | + ) | ||
875 | + elif from_tf and os.path.isfile( | ||
876 | + os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) | ||
877 | + ): | ||
800 | # Load from a TF 2.0 checkpoint | 878 | # Load from a TF 2.0 checkpoint |
801 | - archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) | 879 | + archive_file = os.path.join( |
802 | - elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): | 880 | + pretrained_model_name_or_path, TF2_WEIGHTS_NAME |
881 | + ) | ||
882 | + elif os.path.isfile( | ||
883 | + os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) | ||
884 | + ): | ||
803 | # Load from a PyTorch checkpoint | 885 | # Load from a PyTorch checkpoint |
804 | - archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) | 886 | + archive_file = os.path.join( |
887 | + pretrained_model_name_or_path, WEIGHTS_NAME | ||
888 | + ) | ||
805 | else: | 889 | else: |
806 | raise EnvironmentError( | 890 | raise EnvironmentError( |
807 | "Error no file named {} found in directory {} or `from_tf` set to False".format( | 891 | "Error no file named {} found in directory {} or `from_tf` set to False".format( |
808 | - [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"], | 892 | + [ |
893 | + WEIGHTS_NAME, | ||
894 | + TF2_WEIGHTS_NAME, | ||
895 | + TF_WEIGHTS_NAME + ".index", | ||
896 | + ], | ||
809 | pretrained_model_name_or_path, | 897 | pretrained_model_name_or_path, |
810 | ) | 898 | ) |
811 | ) | 899 | ) |
812 | - elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): | 900 | + elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url( |
901 | + pretrained_model_name_or_path | ||
902 | + ): | ||
813 | archive_file = pretrained_model_name_or_path | 903 | archive_file = pretrained_model_name_or_path |
814 | elif os.path.isfile(pretrained_model_name_or_path + ".index"): | 904 | elif os.path.isfile(pretrained_model_name_or_path + ".index"): |
815 | assert ( | 905 | assert ( |
... | @@ -848,7 +938,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): | ... | @@ -848,7 +938,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): |
848 | if resolved_archive_file == archive_file: | 938 | if resolved_archive_file == archive_file: |
849 | logger.info("loading weights file {}".format(archive_file)) | 939 | logger.info("loading weights file {}".format(archive_file)) |
850 | else: | 940 | else: |
851 | - logger.info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file)) | 941 | + logger.info( |
942 | + "loading weights file {} from cache at {}".format( | ||
943 | + archive_file, resolved_archive_file | ||
944 | + ) | ||
945 | + ) | ||
852 | else: | 946 | else: |
853 | resolved_archive_file = None | 947 | resolved_archive_file = None |
854 | 948 | ||
... | @@ -871,13 +965,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): | ... | @@ -871,13 +965,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): |
871 | if from_tf: | 965 | if from_tf: |
872 | if resolved_archive_file.endswith(".index"): | 966 | if resolved_archive_file.endswith(".index"): |
873 | # Load from a TensorFlow 1.X checkpoint - provided by original authors | 967 | # Load from a TensorFlow 1.X checkpoint - provided by original authors |
874 | - model = cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index' | 968 | + model = cls.load_tf_weights( |
969 | + model, config, resolved_archive_file[:-6] | ||
970 | + ) # Remove the '.index' | ||
875 | else: | 971 | else: |
876 | # Load from our TensorFlow 2.0 checkpoints | 972 | # Load from our TensorFlow 2.0 checkpoints |
877 | try: | 973 | try: |
878 | from transformers import load_tf2_checkpoint_in_pytorch_model | 974 | from transformers import load_tf2_checkpoint_in_pytorch_model |
879 | 975 | ||
880 | - model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True) | 976 | + model = load_tf2_checkpoint_in_pytorch_model( |
977 | + model, resolved_archive_file, allow_missing_keys=True | ||
978 | + ) | ||
881 | except ImportError: | 979 | except ImportError: |
882 | logger.error( | 980 | logger.error( |
883 | "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see " | 981 | "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see " |
... | @@ -909,7 +1007,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): | ... | @@ -909,7 +1007,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): |
909 | # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants | 1007 | # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants |
910 | # so we need to apply the function recursively. | 1008 | # so we need to apply the function recursively. |
911 | def load(module: nn.Module, prefix=""): | 1009 | def load(module: nn.Module, prefix=""): |
912 | - local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) | 1010 | + local_metadata = ( |
1011 | + {} if metadata is None else metadata.get(prefix[:-1], {}) | ||
1012 | + ) | ||
913 | module._load_from_state_dict( | 1013 | module._load_from_state_dict( |
914 | state_dict, | 1014 | state_dict, |
915 | prefix, | 1015 | prefix, |
... | @@ -926,7 +1026,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): | ... | @@ -926,7 +1026,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): |
926 | # Make sure we are able to load base models as well as derived models (with heads) | 1026 | # Make sure we are able to load base models as well as derived models (with heads) |
927 | start_prefix = "" | 1027 | start_prefix = "" |
928 | model_to_load = model | 1028 | model_to_load = model |
929 | - has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()) | 1029 | + has_prefix_module = any( |
1030 | + s.startswith(cls.base_model_prefix) for s in state_dict.keys() | ||
1031 | + ) | ||
930 | if not hasattr(model, cls.base_model_prefix) and has_prefix_module: | 1032 | if not hasattr(model, cls.base_model_prefix) and has_prefix_module: |
931 | start_prefix = cls.base_model_prefix + "." | 1033 | start_prefix = cls.base_model_prefix + "." |
932 | if hasattr(model, cls.base_model_prefix) and not has_prefix_module: | 1034 | if hasattr(model, cls.base_model_prefix) and not has_prefix_module: |
... | @@ -937,15 +1039,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): | ... | @@ -937,15 +1039,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): |
937 | if model.__class__.__name__ != model_to_load.__class__.__name__: | 1039 | if model.__class__.__name__ != model_to_load.__class__.__name__: |
938 | base_model_state_dict = model_to_load.state_dict().keys() | 1040 | base_model_state_dict = model_to_load.state_dict().keys() |
939 | head_model_state_dict_without_base_prefix = [ | 1041 | head_model_state_dict_without_base_prefix = [ |
940 | - key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys() | 1042 | + key.split(cls.base_model_prefix + ".")[-1] |
1043 | + for key in model.state_dict().keys() | ||
941 | ] | 1044 | ] |
942 | - missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict) | 1045 | + missing_keys.extend( |
1046 | + head_model_state_dict_without_base_prefix - base_model_state_dict | ||
1047 | + ) | ||
943 | 1048 | ||
944 | # Some models may have keys that are not in the state by design, removing them before needlessly warning | 1049 | # Some models may have keys that are not in the state by design, removing them before needlessly warning |
945 | # the user. | 1050 | # the user. |
946 | if cls.authorized_missing_keys is not None: | 1051 | if cls.authorized_missing_keys is not None: |
947 | for pat in cls.authorized_missing_keys: | 1052 | for pat in cls.authorized_missing_keys: |
948 | - missing_keys = [k for k in missing_keys if re.search(pat, k) is None] | 1053 | + missing_keys = [ |
1054 | + k for k in missing_keys if re.search(pat, k) is None | ||
1055 | + ] | ||
949 | 1056 | ||
950 | if len(unexpected_keys) > 0: | 1057 | if len(unexpected_keys) > 0: |
951 | logger.warning( | 1058 | logger.warning( |
... | @@ -957,7 +1064,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): | ... | @@ -957,7 +1064,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): |
957 | f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." | 1064 | f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." |
958 | ) | 1065 | ) |
959 | else: | 1066 | else: |
960 | - logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") | 1067 | + logger.info( |
1068 | + f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n" | ||
1069 | + ) | ||
961 | if len(missing_keys) > 0: | 1070 | if len(missing_keys) > 0: |
962 | logger.warning( | 1071 | logger.warning( |
963 | f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " | 1072 | f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " |
... | @@ -990,7 +1099,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): | ... | @@ -990,7 +1099,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): |
990 | } | 1099 | } |
991 | return model, loading_info | 1100 | return model, loading_info |
992 | 1101 | ||
993 | - if hasattr(config, "xla_device") and config.xla_device and is_torch_tpu_available(): | 1102 | + if ( |
1103 | + hasattr(config, "xla_device") | ||
1104 | + and config.xla_device | ||
1105 | + and is_torch_tpu_available() | ||
1106 | + ): | ||
994 | import torch_xla.core.xla_model as xm | 1107 | import torch_xla.core.xla_model as xm |
995 | 1108 | ||
996 | model = xm.send_cpu_data_to_device(model, xm.xla_device()) | 1109 | model = xm.send_cpu_data_to_device(model, xm.xla_device()) |
... | @@ -1039,7 +1152,9 @@ class PoolerStartLogits(nn.Module): | ... | @@ -1039,7 +1152,9 @@ class PoolerStartLogits(nn.Module): |
1039 | self.dense = nn.Linear(config.hidden_size, 1) | 1152 | self.dense = nn.Linear(config.hidden_size, 1) |
1040 | 1153 | ||
1041 | def forward( | 1154 | def forward( |
1042 | - self, hidden_states: torch.FloatTensor, p_mask: Optional[torch.FloatTensor] = None | 1155 | + self, |
1156 | + hidden_states: torch.FloatTensor, | ||
1157 | + p_mask: Optional[torch.FloatTensor] = None, | ||
1043 | ) -> torch.FloatTensor: | 1158 | ) -> torch.FloatTensor: |
1044 | """ | 1159 | """ |
1045 | Args: | 1160 | Args: |
... | @@ -1112,8 +1227,12 @@ class PoolerEndLogits(nn.Module): | ... | @@ -1112,8 +1227,12 @@ class PoolerEndLogits(nn.Module): |
1112 | ), "One of start_states, start_positions should be not None" | 1227 | ), "One of start_states, start_positions should be not None" |
1113 | if start_positions is not None: | 1228 | if start_positions is not None: |
1114 | slen, hsz = hidden_states.shape[-2:] | 1229 | slen, hsz = hidden_states.shape[-2:] |
1115 | - start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) | 1230 | + start_positions = start_positions[:, None, None].expand( |
1116 | - start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz) | 1231 | + -1, -1, hsz |
1232 | + ) # shape (bsz, 1, hsz) | ||
1233 | + start_states = hidden_states.gather( | ||
1234 | + -2, start_positions | ||
1235 | + ) # shape (bsz, 1, hsz) | ||
1117 | start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz) | 1236 | start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz) |
1118 | 1237 | ||
1119 | x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1)) | 1238 | x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1)) |
... | @@ -1177,12 +1296,20 @@ class PoolerAnswerClass(nn.Module): | ... | @@ -1177,12 +1296,20 @@ class PoolerAnswerClass(nn.Module): |
1177 | start_states is not None or start_positions is not None | 1296 | start_states is not None or start_positions is not None |
1178 | ), "One of start_states, start_positions should be not None" | 1297 | ), "One of start_states, start_positions should be not None" |
1179 | if start_positions is not None: | 1298 | if start_positions is not None: |
1180 | - start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) | 1299 | + start_positions = start_positions[:, None, None].expand( |
1181 | - start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz) | 1300 | + -1, -1, hsz |
1301 | + ) # shape (bsz, 1, hsz) | ||
1302 | + start_states = hidden_states.gather(-2, start_positions).squeeze( | ||
1303 | + -2 | ||
1304 | + ) # shape (bsz, hsz) | ||
1182 | 1305 | ||
1183 | if cls_index is not None: | 1306 | if cls_index is not None: |
1184 | - cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) | 1307 | + cls_index = cls_index[:, None, None].expand( |
1185 | - cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz) | 1308 | + -1, -1, hsz |
1309 | + ) # shape (bsz, 1, hsz) | ||
1310 | + cls_token_state = hidden_states.gather(-2, cls_index).squeeze( | ||
1311 | + -2 | ||
1312 | + ) # shape (bsz, hsz) | ||
1186 | else: | 1313 | else: |
1187 | cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz) | 1314 | cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz) |
1188 | 1315 | ||
... | @@ -1241,7 +1368,9 @@ class SQuADHead(nn.Module): | ... | @@ -1241,7 +1368,9 @@ class SQuADHead(nn.Module): |
1241 | self.end_logits = PoolerEndLogits(config) | 1368 | self.end_logits = PoolerEndLogits(config) |
1242 | self.answer_class = PoolerAnswerClass(config) | 1369 | self.answer_class = PoolerAnswerClass(config) |
1243 | 1370 | ||
1244 | - @replace_return_docstrings(output_type=SquadHeadOutput, config_class=PretrainedConfig) | 1371 | + @replace_return_docstrings( |
1372 | + output_type=SquadHeadOutput, config_class=PretrainedConfig | ||
1373 | + ) | ||
1245 | def forward( | 1374 | def forward( |
1246 | self, | 1375 | self, |
1247 | hidden_states: torch.FloatTensor, | 1376 | hidden_states: torch.FloatTensor, |
... | @@ -1281,7 +1410,9 @@ class SQuADHead(nn.Module): | ... | @@ -1281,7 +1410,9 @@ class SQuADHead(nn.Module): |
1281 | x.squeeze_(-1) | 1410 | x.squeeze_(-1) |
1282 | 1411 | ||
1283 | # during training, compute the end logits based on the ground truth of the start position | 1412 | # during training, compute the end logits based on the ground truth of the start position |
1284 | - end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask) | 1413 | + end_logits = self.end_logits( |
1414 | + hidden_states, start_positions=start_positions, p_mask=p_mask | ||
1415 | + ) | ||
1285 | 1416 | ||
1286 | loss_fct = CrossEntropyLoss() | 1417 | loss_fct = CrossEntropyLoss() |
1287 | start_loss = loss_fct(start_logits, start_positions) | 1418 | start_loss = loss_fct(start_logits, start_positions) |
... | @@ -1290,7 +1421,9 @@ class SQuADHead(nn.Module): | ... | @@ -1290,7 +1421,9 @@ class SQuADHead(nn.Module): |
1290 | 1421 | ||
1291 | if cls_index is not None and is_impossible is not None: | 1422 | if cls_index is not None and is_impossible is not None: |
1292 | # Predict answerability from the representation of CLS and START | 1423 | # Predict answerability from the representation of CLS and START |
1293 | - cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index) | 1424 | + cls_logits = self.answer_class( |
1425 | + hidden_states, start_positions=start_positions, cls_index=cls_index | ||
1426 | + ) | ||
1294 | loss_fct_cls = nn.BCEWithLogitsLoss() | 1427 | loss_fct_cls = nn.BCEWithLogitsLoss() |
1295 | cls_loss = loss_fct_cls(cls_logits, is_impossible) | 1428 | cls_loss = loss_fct_cls(cls_logits, is_impossible) |
1296 | 1429 | ||
... | @@ -1307,28 +1440,48 @@ class SQuADHead(nn.Module): | ... | @@ -1307,28 +1440,48 @@ class SQuADHead(nn.Module): |
1307 | start_top_log_probs, start_top_index = torch.topk( | 1440 | start_top_log_probs, start_top_index = torch.topk( |
1308 | start_log_probs, self.start_n_top, dim=-1 | 1441 | start_log_probs, self.start_n_top, dim=-1 |
1309 | ) # shape (bsz, start_n_top) | 1442 | ) # shape (bsz, start_n_top) |
1310 | - start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz) | 1443 | + start_top_index_exp = start_top_index.unsqueeze(-1).expand( |
1311 | - start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz) | 1444 | + -1, -1, hsz |
1312 | - start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz) | 1445 | + ) # shape (bsz, start_n_top, hsz) |
1446 | + start_states = torch.gather( | ||
1447 | + hidden_states, -2, start_top_index_exp | ||
1448 | + ) # shape (bsz, start_n_top, hsz) | ||
1449 | + start_states = start_states.unsqueeze(1).expand( | ||
1450 | + -1, slen, -1, -1 | ||
1451 | + ) # shape (bsz, slen, start_n_top, hsz) | ||
1313 | 1452 | ||
1314 | hidden_states_expanded = hidden_states.unsqueeze(2).expand_as( | 1453 | hidden_states_expanded = hidden_states.unsqueeze(2).expand_as( |
1315 | start_states | 1454 | start_states |
1316 | ) # shape (bsz, slen, start_n_top, hsz) | 1455 | ) # shape (bsz, slen, start_n_top, hsz) |
1317 | p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None | 1456 | p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None |
1318 | - end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask) | 1457 | + end_logits = self.end_logits( |
1319 | - end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top) | 1458 | + hidden_states_expanded, start_states=start_states, p_mask=p_mask |
1459 | + ) | ||
1460 | + end_log_probs = F.softmax( | ||
1461 | + end_logits, dim=1 | ||
1462 | + ) # shape (bsz, slen, start_n_top) | ||
1320 | 1463 | ||
1321 | end_top_log_probs, end_top_index = torch.topk( | 1464 | end_top_log_probs, end_top_index = torch.topk( |
1322 | end_log_probs, self.end_n_top, dim=1 | 1465 | end_log_probs, self.end_n_top, dim=1 |
1323 | ) # shape (bsz, end_n_top, start_n_top) | 1466 | ) # shape (bsz, end_n_top, start_n_top) |
1324 | - end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top) | 1467 | + end_top_log_probs = end_top_log_probs.view( |
1468 | + -1, self.start_n_top * self.end_n_top | ||
1469 | + ) | ||
1325 | end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top) | 1470 | end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top) |
1326 | 1471 | ||
1327 | start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) | 1472 | start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) |
1328 | - cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) | 1473 | + cls_logits = self.answer_class( |
1474 | + hidden_states, start_states=start_states, cls_index=cls_index | ||
1475 | + ) | ||
1329 | 1476 | ||
1330 | if not return_dict: | 1477 | if not return_dict: |
1331 | - return (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) | 1478 | + return ( |
1479 | + start_top_log_probs, | ||
1480 | + start_top_index, | ||
1481 | + end_top_log_probs, | ||
1482 | + end_top_index, | ||
1483 | + cls_logits, | ||
1484 | + ) | ||
1332 | else: | 1485 | else: |
1333 | return SquadHeadOutput( | 1486 | return SquadHeadOutput( |
1334 | start_top_log_probs=start_top_log_probs, | 1487 | start_top_log_probs=start_top_log_probs, |
... | @@ -1379,17 +1532,26 @@ class SequenceSummary(nn.Module): | ... | @@ -1379,17 +1532,26 @@ class SequenceSummary(nn.Module): |
1379 | 1532 | ||
1380 | self.summary = Identity() | 1533 | self.summary = Identity() |
1381 | if hasattr(config, "summary_use_proj") and config.summary_use_proj: | 1534 | if hasattr(config, "summary_use_proj") and config.summary_use_proj: |
1382 | - if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: | 1535 | + if ( |
1536 | + hasattr(config, "summary_proj_to_labels") | ||
1537 | + and config.summary_proj_to_labels | ||
1538 | + and config.num_labels > 0 | ||
1539 | + ): | ||
1383 | num_classes = config.num_labels | 1540 | num_classes = config.num_labels |
1384 | else: | 1541 | else: |
1385 | num_classes = config.hidden_size | 1542 | num_classes = config.hidden_size |
1386 | self.summary = nn.Linear(config.hidden_size, num_classes) | 1543 | self.summary = nn.Linear(config.hidden_size, num_classes) |
1387 | 1544 | ||
1388 | activation_string = getattr(config, "summary_activation", None) | 1545 | activation_string = getattr(config, "summary_activation", None) |
1389 | - self.activation: Callable = get_activation(activation_string) if activation_string else Identity() | 1546 | + self.activation: Callable = get_activation( |
1547 | + activation_string | ||
1548 | + ) if activation_string else Identity() | ||
1390 | 1549 | ||
1391 | self.first_dropout = Identity() | 1550 | self.first_dropout = Identity() |
1392 | - if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: | 1551 | + if ( |
1552 | + hasattr(config, "summary_first_dropout") | ||
1553 | + and config.summary_first_dropout > 0 | ||
1554 | + ): | ||
1393 | self.first_dropout = nn.Dropout(config.summary_first_dropout) | 1555 | self.first_dropout = nn.Dropout(config.summary_first_dropout) |
1394 | 1556 | ||
1395 | self.last_dropout = Identity() | 1557 | self.last_dropout = Identity() |
... | @@ -1397,7 +1559,9 @@ class SequenceSummary(nn.Module): | ... | @@ -1397,7 +1559,9 @@ class SequenceSummary(nn.Module): |
1397 | self.last_dropout = nn.Dropout(config.summary_last_dropout) | 1559 | self.last_dropout = nn.Dropout(config.summary_last_dropout) |
1398 | 1560 | ||
1399 | def forward( | 1561 | def forward( |
1400 | - self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None | 1562 | + self, |
1563 | + hidden_states: torch.FloatTensor, | ||
1564 | + cls_index: Optional[torch.LongTensor] = None, | ||
1401 | ) -> torch.FloatTensor: | 1565 | ) -> torch.FloatTensor: |
1402 | """ | 1566 | """ |
1403 | Compute a single vector summary of a sequence hidden states. | 1567 | Compute a single vector summary of a sequence hidden states. |
... | @@ -1427,9 +1591,13 @@ class SequenceSummary(nn.Module): | ... | @@ -1427,9 +1591,13 @@ class SequenceSummary(nn.Module): |
1427 | ) | 1591 | ) |
1428 | else: | 1592 | else: |
1429 | cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) | 1593 | cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) |
1430 | - cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) | 1594 | + cls_index = cls_index.expand( |
1595 | + (-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),) | ||
1596 | + ) | ||
1431 | # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states | 1597 | # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states |
1432 | - output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) | 1598 | + output = hidden_states.gather(-2, cls_index).squeeze( |
1599 | + -2 | ||
1600 | + ) # shape (bsz, XX, hidden_size) | ||
1433 | elif self.summary_type == "attn": | 1601 | elif self.summary_type == "attn": |
1434 | raise NotImplementedError | 1602 | raise NotImplementedError |
1435 | 1603 | ||
... | @@ -1441,7 +1609,9 @@ class SequenceSummary(nn.Module): | ... | @@ -1441,7 +1609,9 @@ class SequenceSummary(nn.Module): |
1441 | return output | 1609 | return output |
1442 | 1610 | ||
1443 | 1611 | ||
1444 | -def prune_linear_layer(layer: torch.nn.Linear, index: torch.LongTensor, dim: int = 0) -> torch.nn.Linear: | 1612 | +def prune_linear_layer( |
1613 | + layer: torch.nn.Linear, index: torch.LongTensor, dim: int = 0 | ||
1614 | +) -> torch.nn.Linear: | ||
1445 | """ | 1615 | """ |
1446 | Prune a linear layer to keep only entries in index. | 1616 | Prune a linear layer to keep only entries in index. |
1447 | 1617 | ||
... | @@ -1464,7 +1634,9 @@ def prune_linear_layer(layer: torch.nn.Linear, index: torch.LongTensor, dim: int | ... | @@ -1464,7 +1634,9 @@ def prune_linear_layer(layer: torch.nn.Linear, index: torch.LongTensor, dim: int |
1464 | b = layer.bias[index].clone().detach() | 1634 | b = layer.bias[index].clone().detach() |
1465 | new_size = list(layer.weight.size()) | 1635 | new_size = list(layer.weight.size()) |
1466 | new_size[dim] = len(index) | 1636 | new_size[dim] = len(index) |
1467 | - new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device) | 1637 | + new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to( |
1638 | + layer.weight.device | ||
1639 | + ) | ||
1468 | new_layer.weight.requires_grad = False | 1640 | new_layer.weight.requires_grad = False |
1469 | new_layer.weight.copy_(W.contiguous()) | 1641 | new_layer.weight.copy_(W.contiguous()) |
1470 | new_layer.weight.requires_grad = True | 1642 | new_layer.weight.requires_grad = True |
... | @@ -1509,7 +1681,9 @@ def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) -> | ... | @@ -1509,7 +1681,9 @@ def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) -> |
1509 | 1681 | ||
1510 | 1682 | ||
1511 | def prune_layer( | 1683 | def prune_layer( |
1512 | - layer: Union[torch.nn.Linear, Conv1D], index: torch.LongTensor, dim: Optional[int] = None | 1684 | + layer: Union[torch.nn.Linear, Conv1D], |
1685 | + index: torch.LongTensor, | ||
1686 | + dim: Optional[int] = None, | ||
1513 | ) -> Union[torch.nn.Linear, Conv1D]: | 1687 | ) -> Union[torch.nn.Linear, Conv1D]: |
1514 | """ | 1688 | """ |
1515 | Prune a Conv1D or linear layer to keep only entries in index. | 1689 | Prune a Conv1D or linear layer to keep only entries in index. |
... | @@ -1534,7 +1708,10 @@ def prune_layer( | ... | @@ -1534,7 +1708,10 @@ def prune_layer( |
1534 | 1708 | ||
1535 | 1709 | ||
1536 | def apply_chunking_to_forward( | 1710 | def apply_chunking_to_forward( |
1537 | - forward_fn: Callable[..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors | 1711 | + forward_fn: Callable[..., torch.Tensor], |
1712 | + chunk_size: int, | ||
1713 | + chunk_dim: int, | ||
1714 | + *input_tensors, | ||
1538 | ) -> torch.Tensor: | 1715 | ) -> torch.Tensor: |
1539 | """ | 1716 | """ |
1540 | This function chunks the :obj:`input_tensors` into smaller input tensor parts of size :obj:`chunk_size` over the | 1717 | This function chunks the :obj:`input_tensors` into smaller input tensor parts of size :obj:`chunk_size` over the |
... | @@ -1568,7 +1745,9 @@ def apply_chunking_to_forward( | ... | @@ -1568,7 +1745,9 @@ def apply_chunking_to_forward( |
1568 | return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states) | 1745 | return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states) |
1569 | """ | 1746 | """ |
1570 | 1747 | ||
1571 | - assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors) | 1748 | + assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format( |
1749 | + input_tensors | ||
1750 | + ) | ||
1572 | tensor_shape = input_tensors[0].shape | 1751 | tensor_shape = input_tensors[0].shape |
1573 | assert all( | 1752 | assert all( |
1574 | input_tensor.shape == tensor_shape for input_tensor in input_tensors | 1753 | input_tensor.shape == tensor_shape for input_tensor in input_tensors |
... | @@ -1592,9 +1771,15 @@ def apply_chunking_to_forward( | ... | @@ -1592,9 +1771,15 @@ def apply_chunking_to_forward( |
1592 | num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size | 1771 | num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size |
1593 | 1772 | ||
1594 | # chunk input tensor into tuples | 1773 | # chunk input tensor into tuples |
1595 | - input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors) | 1774 | + input_tensors_chunks = tuple( |
1775 | + input_tensor.chunk(num_chunks, dim=chunk_dim) | ||
1776 | + for input_tensor in input_tensors | ||
1777 | + ) | ||
1596 | # apply forward fn to every tuple | 1778 | # apply forward fn to every tuple |
1597 | - output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks)) | 1779 | + output_chunks = tuple( |
1780 | + forward_fn(*input_tensors_chunk) | ||
1781 | + for input_tensors_chunk in zip(*input_tensors_chunks) | ||
1782 | + ) | ||
1598 | # concatenate output at same dimension | 1783 | # concatenate output at same dimension |
1599 | return torch.cat(output_chunks, dim=chunk_dim) | 1784 | return torch.cat(output_chunks, dim=chunk_dim) |
1600 | 1785 | ... | ... |
... | @@ -39,9 +39,13 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): | ... | @@ -39,9 +39,13 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): |
39 | return loss, nll_loss | 39 | return loss, nll_loss |
40 | 40 | ||
41 | 41 | ||
42 | -def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"): | 42 | +def encode_line( |
43 | + tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt" | ||
44 | +): | ||
43 | """Only used by LegacyDataset""" | 45 | """Only used by LegacyDataset""" |
44 | - extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {} | 46 | + extra_kw = ( |
47 | + {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {} | ||
48 | + ) | ||
45 | return tokenizer( | 49 | return tokenizer( |
46 | [line], | 50 | [line], |
47 | max_length=max_length, | 51 | max_length=max_length, |
... | @@ -63,9 +67,7 @@ def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict: | ... | @@ -63,9 +67,7 @@ def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict: |
63 | 67 | ||
64 | 68 | ||
65 | def trim_batch( | 69 | def trim_batch( |
66 | - input_ids, | 70 | + input_ids, pad_token_id, attention_mask=None, |
67 | - pad_token_id, | ||
68 | - attention_mask=None, | ||
69 | ): | 71 | ): |
70 | """Remove columns that are populated exclusively by pad_token_id""" | 72 | """Remove columns that are populated exclusively by pad_token_id""" |
71 | keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) | 73 | keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) |
... | @@ -125,7 +127,9 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset): | ... | @@ -125,7 +127,9 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset): |
125 | def __getitem__(self, index) -> Dict[str, torch.Tensor]: | 127 | def __getitem__(self, index) -> Dict[str, torch.Tensor]: |
126 | """Call tokenizer on src and tgt_lines""" | 128 | """Call tokenizer on src and tgt_lines""" |
127 | index = index + 1 # linecache starts at 1 | 129 | index = index + 1 # linecache starts at 1 |
128 | - source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") | 130 | + source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip( |
131 | + "\n" | ||
132 | + ) | ||
129 | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") | 133 | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") |
130 | assert source_line, f"empty source line for index {index}" | 134 | assert source_line, f"empty source line for index {index}" |
131 | assert tgt_line, f"empty tgt line for index {index}" | 135 | assert tgt_line, f"empty tgt line for index {index}" |
... | @@ -147,7 +151,9 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset): | ... | @@ -147,7 +151,9 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset): |
147 | target_ids = torch.stack([x["labels"] for x in batch]) | 151 | target_ids = torch.stack([x["labels"] for x in batch]) |
148 | pad_token_id = self.pad_token_id | 152 | pad_token_id = self.pad_token_id |
149 | y = trim_batch(target_ids, pad_token_id) | 153 | y = trim_batch(target_ids, pad_token_id) |
150 | - source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks) | 154 | + source_ids, source_mask = trim_batch( |
155 | + input_ids, pad_token_id, attention_mask=masks | ||
156 | + ) | ||
151 | batch = { | 157 | batch = { |
152 | "input_ids": source_ids, | 158 | "input_ids": source_ids, |
153 | "attention_mask": source_mask, | 159 | "attention_mask": source_mask, |
... | @@ -161,7 +167,9 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset): | ... | @@ -161,7 +167,9 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset): |
161 | 167 | ||
162 | def __getitem__(self, index) -> Dict[str, str]: | 168 | def __getitem__(self, index) -> Dict[str, str]: |
163 | index = index + 1 # linecache starts at 1 | 169 | index = index + 1 # linecache starts at 1 |
164 | - source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") | 170 | + source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip( |
171 | + "\n" | ||
172 | + ) | ||
165 | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") | 173 | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") |
166 | assert source_line, f"empty source line for index {index}" | 174 | assert source_line, f"empty source line for index {index}" |
167 | assert tgt_line, f"empty tgt line for index {index}" | 175 | assert tgt_line, f"empty tgt line for index {index}" |
... | @@ -201,12 +209,23 @@ class SortishSampler(Sampler): | ... | @@ -201,12 +209,23 @@ class SortishSampler(Sampler): |
201 | idxs = np.random.permutation(len(self.data)) | 209 | idxs = np.random.permutation(len(self.data)) |
202 | sz = self.bs * 50 | 210 | sz = self.bs * 50 |
203 | ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)] | 211 | ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)] |
204 | - sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx]) | 212 | + sort_idx = np.concatenate( |
213 | + [sorted(s, key=self.key, reverse=True) for s in ck_idx] | ||
214 | + ) | ||
205 | sz = self.bs | 215 | sz = self.bs |
206 | ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)] | 216 | ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)] |
207 | - max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx]) # find the chunk with the largest key, | 217 | + max_ck = np.argmax( |
208 | - ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first. | 218 | + [self.key(ck[0]) for ck in ck_idx] |
209 | - sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int) | 219 | + ) # find the chunk with the largest key, |
220 | + ck_idx[0], ck_idx[max_ck] = ( | ||
221 | + ck_idx[max_ck], | ||
222 | + ck_idx[0], | ||
223 | + ) # then make sure it goes first. | ||
224 | + sort_idx = ( | ||
225 | + np.concatenate(np.random.permutation(ck_idx[1:])) | ||
226 | + if len(ck_idx) > 1 | ||
227 | + else np.array([], dtype=np.int) | ||
228 | + ) | ||
210 | sort_idx = np.concatenate((ck_idx[0], sort_idx)) | 229 | sort_idx = np.concatenate((ck_idx[0], sort_idx)) |
211 | return iter(sort_idx) | 230 | return iter(sort_idx) |
212 | 231 | ||
... | @@ -269,7 +288,9 @@ def get_git_info(): | ... | @@ -269,7 +288,9 @@ def get_git_info(): |
269 | ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"] | 288 | ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"] |
270 | 289 | ||
271 | 290 | ||
272 | -def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer=True) -> Dict: | 291 | +def calculate_rouge( |
292 | + output_lns: List[str], reference_lns: List[str], use_stemmer=True | ||
293 | +) -> Dict: | ||
273 | scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer) | 294 | scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer) |
274 | aggregator = scoring.BootstrapAggregator() | 295 | aggregator = scoring.BootstrapAggregator() |
275 | 296 | ||
... | @@ -302,7 +323,9 @@ def assert_all_frozen(model): | ... | @@ -302,7 +323,9 @@ def assert_all_frozen(model): |
302 | model_grads: List[bool] = list(grad_status(model)) | 323 | model_grads: List[bool] = list(grad_status(model)) |
303 | n_require_grad = sum(lmap(int, model_grads)) | 324 | n_require_grad = sum(lmap(int, model_grads)) |
304 | npars = len(model_grads) | 325 | npars = len(model_grads) |
305 | - assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad" | 326 | + assert not any( |
327 | + model_grads | ||
328 | + ), f"{n_require_grad/npars:.1%} of {npars} weights require grad" | ||
306 | 329 | ||
307 | 330 | ||
308 | def assert_not_all_frozen(model): | 331 | def assert_not_all_frozen(model): | ... | ... |
-
Please register or login to post a comment