graykode

(refactor) black style

...@@ -68,6 +68,7 @@ def main(args): ...@@ -68,6 +68,7 @@ def main(args):
68 ) 68 )
69 print(commit_message) 69 print(commit_message)
70 70
71 +
71 if __name__ == "__main__": 72 if __name__ == "__main__":
72 parser = argparse.ArgumentParser(description="Code to collect commits on github") 73 parser = argparse.ArgumentParser(description="Code to collect commits on github")
73 parser.add_argument( 74 parser.add_argument(
......
1 # Copyright 2020-present Tae Hwan Jung 1 # Copyright 2020-present Tae Hwan Jung
2 -# 2 +#
3 # Licensed under the Apache License, Version 2.0 (the "License"); 3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License. 4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at 5 # You may obtain a copy of the License at
6 -# 6 +#
7 # http://www.apache.org/licenses/LICENSE-2.0 7 # http://www.apache.org/licenses/LICENSE-2.0
8 -# 8 +#
9 # Unless required by applicable law or agreed to in writing, software 9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS, 10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -15,6 +15,6 @@ ...@@ -15,6 +15,6 @@
15 from .gitcommit import diff_parse, truncate 15 from .gitcommit import diff_parse, truncate
16 16
17 __all__ = [ 17 __all__ = [
18 - 'diff_parse',
19 - 'truncate',
20 -]
...\ No newline at end of file ...\ No newline at end of file
18 + "diff_parse",
19 + "truncate",
20 +]
......
...@@ -36,9 +36,11 @@ logging.basicConfig( ...@@ -36,9 +36,11 @@ logging.basicConfig(
36 level=logging.INFO, 36 level=logging.INFO,
37 ) 37 )
38 38
39 +
39 class PATCH(enum.Enum): 40 class PATCH(enum.Enum):
40 - PLUS=1 41 + PLUS = 1
41 - MINUS=2 42 + MINUS = 2
43 +
42 44
43 def truncate(tuple, max_length, value=0): 45 def truncate(tuple, max_length, value=0):
44 ls = [] 46 ls = []
...@@ -46,22 +48,20 @@ def truncate(tuple, max_length, value=0): ...@@ -46,22 +48,20 @@ def truncate(tuple, max_length, value=0):
46 if isinstance(t, int): 48 if isinstance(t, int):
47 t = [t] 49 t = [t]
48 ls.extend(t) 50 ls.extend(t)
49 - ls = ls[:max_length - 1] 51 + ls = ls[: max_length - 1]
50 ls.insert(0, value) 52 ls.insert(0, value)
51 if len(ls) < max_length: 53 if len(ls) < max_length:
52 ls.extend([0] * (max_length - len(ls))) 54 ls.extend([0] * (max_length - len(ls)))
53 assert len(ls) == max_length 55 assert len(ls) == max_length
54 return ls 56 return ls
55 57
58 +
56 def encode_line(tokenizer, line, patch): 59 def encode_line(tokenizer, line, patch):
57 - line = re.sub(r'[\u0100-\uFFFF\U00010000-\U0010FFFF]+', '', line).strip() 60 + line = re.sub(r"[\u0100-\uFFFF\U00010000-\U0010FFFF]+", "", line).strip()
58 tokens = tokenizer.tokenize(line) 61 tokens = tokenizer.tokenize(line)
59 tokens = tokenizer.convert_tokens_to_ids(tokens) 62 tokens = tokenizer.convert_tokens_to_ids(tokens)
60 - return ( 63 + return (tokens, [1] * len(tokens), len(tokens) * [patch.value])
61 - tokens, 64 +
62 - [1] * len(tokens),
63 - len(tokens) * [patch.value]
64 - )
65 65
66 def diff_parse(diff, tokenizer): 66 def diff_parse(diff, tokenizer):
67 chunks = [] 67 chunks = []
...@@ -78,6 +78,7 @@ def diff_parse(diff, tokenizer): ...@@ -78,6 +78,7 @@ def diff_parse(diff, tokenizer):
78 chunks.append(encode_line(tokenizer, change.line, PATCH.MINUS)) 78 chunks.append(encode_line(tokenizer, change.line, PATCH.MINUS))
79 return chunks 79 return chunks
80 80
81 +
81 def sha_parse(sha, tokenizer, max_length=1024): 82 def sha_parse(sha, tokenizer, max_length=1024):
82 83
83 chunks = diff_parse(diff=repo.git.show(sha), tokenizer=tokenizer) 84 chunks = diff_parse(diff=repo.git.show(sha), tokenizer=tokenizer)
...@@ -91,16 +92,18 @@ def sha_parse(sha, tokenizer, max_length=1024): ...@@ -91,16 +92,18 @@ def sha_parse(sha, tokenizer, max_length=1024):
91 92
92 return (input_ids, attention_masks, patch_ids) 93 return (input_ids, attention_masks, patch_ids)
93 94
95 +
94 def message_parse(msg, tokenizer, max_length=56): 96 def message_parse(msg, tokenizer, max_length=56):
95 - msg = re.sub(r'(\(|)#([0-9])+(\)|)', '', msg) 97 + msg = re.sub(r"(\(|)#([0-9])+(\)|)", "", msg)
96 98
97 - msg = re.sub(r'[\u0100-\uFFFF\U00010000-\U0010FFFF]+', '', msg).strip() 99 + msg = re.sub(r"[\u0100-\uFFFF\U00010000-\U0010FFFF]+", "", msg).strip()
98 msg = tokenizer.tokenize(msg) 100 msg = tokenizer.tokenize(msg)
99 msg = tokenizer.convert_tokens_to_ids(msg) 101 msg = tokenizer.convert_tokens_to_ids(msg)
100 msg = truncate(msg, max_length, value=0) 102 msg = truncate(msg, max_length, value=0)
101 103
102 return msg 104 return msg
103 105
106 +
104 def jobs(sha_msgs, args, data_config, train=True): 107 def jobs(sha_msgs, args, data_config, train=True):
105 108
106 input_ids, attention_masks, patch_ids, targets = [], [], [], [] 109 input_ids, attention_masks, patch_ids, targets = [], [], [], []
...@@ -110,9 +113,7 @@ def jobs(sha_msgs, args, data_config, train=True): ...@@ -110,9 +113,7 @@ def jobs(sha_msgs, args, data_config, train=True):
110 sha, msg = sha_msg 113 sha, msg = sha_msg
111 114
112 source = sha_parse( 115 source = sha_parse(
113 - sha, 116 + sha, tokenizer=args.tokenizer, max_length=args.max_source_length
114 - tokenizer=args.tokenizer,
115 - max_length=args.max_source_length
116 ) 117 )
117 if not source: 118 if not source:
118 continue 119 continue
...@@ -120,7 +121,9 @@ def jobs(sha_msgs, args, data_config, train=True): ...@@ -120,7 +121,9 @@ def jobs(sha_msgs, args, data_config, train=True):
120 target = message_parse( 121 target = message_parse(
121 msg, 122 msg,
122 tokenizer=args.tokenizer, 123 tokenizer=args.tokenizer,
123 - max_length=(args.max_target_length if train else args.val_max_target_length), 124 + max_length=(
125 + args.max_target_length if train else args.val_max_target_length
126 + ),
124 ) 127 )
125 128
126 input_ids.append(input_id) 129 input_ids.append(input_id)
...@@ -128,14 +131,17 @@ def jobs(sha_msgs, args, data_config, train=True): ...@@ -128,14 +131,17 @@ def jobs(sha_msgs, args, data_config, train=True):
128 patch_ids.append(patch_id) 131 patch_ids.append(patch_id)
129 targets.append(target) 132 targets.append(target)
130 133
131 - data_saver({ 134 + data_saver(
132 - "input_ids": np.asarray(input_ids), 135 + {
133 - "attention_masks": np.asarray(attention_masks), 136 + "input_ids": np.asarray(input_ids),
134 - "patch_ids": np.asarray(patch_ids), 137 + "attention_masks": np.asarray(attention_masks),
135 - "targets": np.asarray(targets), 138 + "patch_ids": np.asarray(patch_ids),
136 - }) 139 + "targets": np.asarray(targets),
140 + }
141 + )
137 data_saver.disconnect() 142 data_saver.disconnect()
138 143
144 +
139 def start(chunked_sha_msgs, train=True): 145 def start(chunked_sha_msgs, train=True):
140 146
141 logger.info(f"Start %s pre-processing" % ("training" if train else "evaluation")) 147 logger.info(f"Start %s pre-processing" % ("training" if train else "evaluation"))
...@@ -144,22 +150,22 @@ def start(chunked_sha_msgs, train=True): ...@@ -144,22 +150,22 @@ def start(chunked_sha_msgs, train=True):
144 150
145 data_config = DataConfig( 151 data_config = DataConfig(
146 endpoint=args.endpoint, 152 endpoint=args.endpoint,
147 - access_key=os.environ['access_key'], 153 + access_key=os.environ["access_key"],
148 - secret_key=os.environ['secret_key'], 154 + secret_key=os.environ["secret_key"],
149 region=args.region, 155 region=args.region,
150 - dataset_name='commit-autosuggestions', 156 + dataset_name="commit-autosuggestions",
151 additional={ 157 additional={
152 - "mode" : ("training" if train else "evaluation"), 158 + "mode": ("training" if train else "evaluation"),
153 "max_source_length": args.max_source_length, 159 "max_source_length": args.max_source_length,
154 "max_target_length": max_target_length, 160 "max_target_length": max_target_length,
155 - "url" : args.url, 161 + "url": args.url,
156 }, 162 },
157 attributes=[ 163 attributes=[
158 - ('input_ids', 'int32', (args.max_source_length,)), 164 + ("input_ids", "int32", (args.max_source_length,)),
159 - ('attention_masks', 'int32', (args.max_source_length,)), 165 + ("attention_masks", "int32", (args.max_source_length,)),
160 - ('patch_ids', 'int32', (args.max_source_length,)), 166 + ("patch_ids", "int32", (args.max_source_length,)),
161 - ('targets', 'int32', (max_target_length,)) 167 + ("targets", "int32", (max_target_length,)),
162 - ] 168 + ],
163 ) 169 )
164 170
165 func = partial(jobs, args=args, data_config=data_config, train=train) 171 func = partial(jobs, args=args, data_config=data_config, train=train)
...@@ -168,14 +174,15 @@ def start(chunked_sha_msgs, train=True): ...@@ -168,14 +174,15 @@ def start(chunked_sha_msgs, train=True):
168 for i, _ in tqdm(enumerate(pool.imap_unordered(func, chunked_sha_msgs))): 174 for i, _ in tqdm(enumerate(pool.imap_unordered(func, chunked_sha_msgs))):
169 pbar.update() 175 pbar.update()
170 176
177 +
171 def main(args): 178 def main(args):
172 - if 'access_key' not in os.environ or 'secret_key' not in os.environ: 179 + if "access_key" not in os.environ or "secret_key" not in os.environ:
173 raise OSError("access_key or secret_key are not found.") 180 raise OSError("access_key or secret_key are not found.")
174 181
175 sha_msgs = [(c.hexsha, c.summary) for c in repo.iter_commits()] 182 sha_msgs = [(c.hexsha, c.summary) for c in repo.iter_commits()]
176 random.shuffle(sha_msgs) 183 random.shuffle(sha_msgs)
177 chunked_sha_msgs = [ 184 chunked_sha_msgs = [
178 - sha_msgs[x:x + args.matorage_batch] 185 + sha_msgs[x : x + args.matorage_batch]
179 for x in range(0, len(sha_msgs), args.matorage_batch) 186 for x in range(0, len(sha_msgs), args.matorage_batch)
180 ] 187 ]
181 188
...@@ -185,29 +192,25 @@ def main(args): ...@@ -185,29 +192,25 @@ def main(args):
185 if args.do_predict: 192 if args.do_predict:
186 start(chunked_sha_msgs[barrier:], train=False) 193 start(chunked_sha_msgs[barrier:], train=False)
187 194
195 +
188 if __name__ == "__main__": 196 if __name__ == "__main__":
189 parser = argparse.ArgumentParser(description="Code to collect commits on github") 197 parser = argparse.ArgumentParser(description="Code to collect commits on github")
190 - parser.add_argument( 198 + parser.add_argument("--url", type=str, required=True, help="github url")
191 - "--url",
192 - type=str,
193 - required=True,
194 - help="github url"
195 - )
196 parser.add_argument( 199 parser.add_argument(
197 "--endpoint", 200 "--endpoint",
198 type=str, 201 type=str,
199 required=True, 202 required=True,
200 - help='matorage endpoint, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html' 203 + help="matorage endpoint, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html",
201 ) 204 )
202 parser.add_argument( 205 parser.add_argument(
203 "--region", 206 "--region",
204 type=str, 207 type=str,
205 default=None, 208 default=None,
206 - help='matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html' 209 + help="matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html",
207 ) 210 )
208 parser.add_argument( 211 parser.add_argument(
209 "--tokenizer_name", 212 "--tokenizer_name",
210 - default='sshleifer/distilbart-xsum-6-6', 213 + default="sshleifer/distilbart-xsum-6-6",
211 type=str, 214 type=str,
212 help="Pretrained tokenizer name or path if not the same as model_name", 215 help="Pretrained tokenizer name or path if not the same as model_name",
213 ) 216 )
...@@ -215,41 +218,40 @@ if __name__ == "__main__": ...@@ -215,41 +218,40 @@ if __name__ == "__main__":
215 "--matorage_batch", 218 "--matorage_batch",
216 default=1024, 219 default=1024,
217 type=int, 220 type=int,
218 - help='The smallest batch size stored atomically in matorage.' 221 + help="The smallest batch size stored atomically in matorage.",
219 ) 222 )
220 parser.add_argument( 223 parser.add_argument(
221 - "--num_workers", 224 + "--num_workers", default=4, type=int, help="number of process",
222 - default=4,
223 - type=int,
224 - help="number of process",
225 ) 225 )
226 parser.add_argument( 226 parser.add_argument(
227 "--max_source_length", 227 "--max_source_length",
228 default=1024, 228 default=1024,
229 type=int, 229 type=int,
230 help="The maximum total input sequence length after tokenization. Sequences longer " 230 help="The maximum total input sequence length after tokenization. Sequences longer "
231 - "than this will be truncated, sequences shorter will be padded.", 231 + "than this will be truncated, sequences shorter will be padded.",
232 ) 232 )
233 parser.add_argument( 233 parser.add_argument(
234 "--max_target_length", 234 "--max_target_length",
235 default=56, 235 default=56,
236 type=int, 236 type=int,
237 help="The maximum total input sequence length after tokenization. Sequences longer " 237 help="The maximum total input sequence length after tokenization. Sequences longer "
238 - "than this will be truncated, sequences shorter will be padded.", 238 + "than this will be truncated, sequences shorter will be padded.",
239 ) 239 )
240 parser.add_argument( 240 parser.add_argument(
241 "--val_max_target_length", 241 "--val_max_target_length",
242 default=142, # these defaults are optimized for CNNDM. For xsum, see README.md. 242 default=142, # these defaults are optimized for CNNDM. For xsum, see README.md.
243 type=int, 243 type=int,
244 help="The maximum total input sequence length after tokenization. Sequences longer " 244 help="The maximum total input sequence length after tokenization. Sequences longer "
245 - "than this will be truncated, sequences shorter will be padded.", 245 + "than this will be truncated, sequences shorter will be padded.",
246 + )
247 + parser.add_argument(
248 + "--p_val", type=float, default=0.25, help="percent of validation dataset"
246 ) 249 )
247 - parser.add_argument("--p_val", type=float, default=0.25, help="percent of validation dataset")
248 parser.add_argument("--do_train", action="store_true", default=False) 250 parser.add_argument("--do_train", action="store_true", default=False)
249 parser.add_argument("--do_predict", action="store_true", default=False) 251 parser.add_argument("--do_predict", action="store_true", default=False)
250 args = parser.parse_args() 252 args = parser.parse_args()
251 253
252 - args.local_path = args.url.split('/')[-1] 254 + args.local_path = args.url.split("/")[-1]
253 logger.info(f"master branch of {args.url} will be downloaded to {args.local_path}") 255 logger.info(f"master branch of {args.url} will be downloaded to {args.local_path}")
254 repo = ( 256 repo = (
255 Repo(args.local_path) 257 Repo(args.local_path)
......
1 # Copyright 2020-present Tae Hwan Jung 1 # Copyright 2020-present Tae Hwan Jung
2 -# 2 +#
3 # Licensed under the Apache License, Version 2.0 (the "License"); 3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License. 4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at 5 # You may obtain a copy of the License at
6 -# 6 +#
7 # http://www.apache.org/licenses/LICENSE-2.0 7 # http://www.apache.org/licenses/LICENSE-2.0
8 -# 8 +#
9 # Unless required by applicable law or agreed to in writing, software 9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS, 10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -14,6 +14,4 @@ ...@@ -14,6 +14,4 @@
14 14
15 from .modeling_bart import BartForConditionalGeneration 15 from .modeling_bart import BartForConditionalGeneration
16 16
17 -__all__ = [
18 - 'BartForConditionalGeneration'
19 -]
...\ No newline at end of file ...\ No newline at end of file
17 +__all__ = ["BartForConditionalGeneration"]
......
...@@ -20,16 +20,31 @@ logger = logging.getLogger(__name__) ...@@ -20,16 +20,31 @@ logger = logging.getLogger(__name__)
20 20
21 class Seq2SeqLoggingCallback(pl.Callback): 21 class Seq2SeqLoggingCallback(pl.Callback):
22 def on_batch_end(self, trainer, pl_module): 22 def on_batch_end(self, trainer, pl_module):
23 - lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)} 23 + lrs = {
24 + f"lr_group_{i}": param["lr"]
25 + for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)
26 + }
24 pl_module.logger.log_metrics(lrs) 27 pl_module.logger.log_metrics(lrs)
25 28
26 @rank_zero_only 29 @rank_zero_only
27 def _write_logs( 30 def _write_logs(
28 - self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True 31 + self,
32 + trainer: pl.Trainer,
33 + pl_module: pl.LightningModule,
34 + type_path: str,
35 + save_generations=True,
29 ) -> None: 36 ) -> None:
30 - logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****") 37 + logger.info(
38 + f"***** {type_path} results at step {trainer.global_step:05d} *****"
39 + )
31 metrics = trainer.callback_metrics 40 metrics = trainer.callback_metrics
32 - trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]}) 41 + trainer.logger.log_metrics(
42 + {
43 + k: v
44 + for k, v in metrics.items()
45 + if k not in ["log", "progress_bar", "preds"]
46 + }
47 + )
33 # Log results 48 # Log results
34 od = Path(pl_module.hparams.output_dir) 49 od = Path(pl_module.hparams.output_dir)
35 if type_path == "test": 50 if type_path == "test":
...@@ -39,7 +54,9 @@ class Seq2SeqLoggingCallback(pl.Callback): ...@@ -39,7 +54,9 @@ class Seq2SeqLoggingCallback(pl.Callback):
39 # this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json 54 # this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json
40 # If people want this it will be easy enough to add back. 55 # If people want this it will be easy enough to add back.
41 results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt" 56 results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt"
42 - generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt" 57 + generations_file = (
58 + od / f"{type_path}_generations/{trainer.global_step:05d}.txt"
59 + )
43 results_file.parent.mkdir(exist_ok=True) 60 results_file.parent.mkdir(exist_ok=True)
44 generations_file.parent.mkdir(exist_ok=True) 61 generations_file.parent.mkdir(exist_ok=True)
45 with open(results_file, "a+") as writer: 62 with open(results_file, "a+") as writer:
...@@ -68,7 +85,9 @@ class Seq2SeqLoggingCallback(pl.Callback): ...@@ -68,7 +85,9 @@ class Seq2SeqLoggingCallback(pl.Callback):
68 85
69 n_trainable_pars = count_trainable_parameters(pl_module) 86 n_trainable_pars = count_trainable_parameters(pl_module)
70 # mp stands for million parameters 87 # mp stands for million parameters
71 - trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6}) 88 + trainer.logger.log_metrics(
89 + {"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6}
90 + )
72 91
73 @rank_zero_only 92 @rank_zero_only
74 def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): 93 def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
...@@ -98,8 +117,5 @@ def get_checkpoint_callback(output_dir, metric): ...@@ -98,8 +117,5 @@ def get_checkpoint_callback(output_dir, metric):
98 117
99 def get_early_stopping_callback(metric, patience): 118 def get_early_stopping_callback(metric, patience):
100 return EarlyStopping( 119 return EarlyStopping(
101 - monitor=f"val_{metric}", 120 + monitor=f"val_{metric}", mode="max", patience=patience, verbose=True,
102 - mode="max",
103 - patience=patience,
104 - verbose=True,
105 ) 121 )
......
...@@ -21,7 +21,11 @@ from matorage.torch import Dataset ...@@ -21,7 +21,11 @@ from matorage.torch import Dataset
21 21
22 22
23 try: 23 try:
24 - from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback 24 + from .callbacks import (
25 + Seq2SeqLoggingCallback,
26 + get_checkpoint_callback,
27 + get_early_stopping_callback,
28 + )
25 from .utils import ( 29 from .utils import (
26 ROUGE_KEYS, 30 ROUGE_KEYS,
27 LegacySeq2SeqDataset, 31 LegacySeq2SeqDataset,
...@@ -40,7 +44,11 @@ try: ...@@ -40,7 +44,11 @@ try:
40 use_task_specific_params, 44 use_task_specific_params,
41 ) 45 )
42 except ImportError: 46 except ImportError:
43 - from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback 47 + from callbacks import (
48 + Seq2SeqLoggingCallback,
49 + get_checkpoint_callback,
50 + get_early_stopping_callback,
51 + )
44 from utils import ( 52 from utils import (
45 ROUGE_KEYS, 53 ROUGE_KEYS,
46 LegacySeq2SeqDataset, 54 LegacySeq2SeqDataset,
...@@ -83,8 +91,12 @@ class SummarizationModule(BaseTransformer): ...@@ -83,8 +91,12 @@ class SummarizationModule(BaseTransformer):
83 "val": self.hparams.val_max_target_length, 91 "val": self.hparams.val_max_target_length,
84 "test": self.hparams.test_max_target_length, 92 "test": self.hparams.test_max_target_length,
85 } 93 }
86 - assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}" 94 + assert (
87 - assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}" 95 + self.target_lens["train"] <= self.target_lens["val"]
96 + ), f"target_lens: {self.target_lens}"
97 + assert (
98 + self.target_lens["train"] <= self.target_lens["test"]
99 + ), f"target_lens: {self.target_lens}"
88 100
89 if self.hparams.freeze_embeds: 101 if self.hparams.freeze_embeds:
90 self.freeze_embeds() 102 self.freeze_embeds()
...@@ -95,13 +107,27 @@ class SummarizationModule(BaseTransformer): ...@@ -95,13 +107,27 @@ class SummarizationModule(BaseTransformer):
95 self.hparams.git_sha = get_git_info()["repo_sha"] 107 self.hparams.git_sha = get_git_info()["repo_sha"]
96 self.num_workers = hparams.num_workers 108 self.num_workers = hparams.num_workers
97 self.decoder_start_token_id = None # default to config 109 self.decoder_start_token_id = None # default to config
98 - if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer): 110 + if self.model.config.decoder_start_token_id is None and isinstance(
99 - self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang] 111 + self.tokenizer, MBartTokenizer
112 + ):
113 + self.decoder_start_token_id = self.tokenizer.lang_code_to_id[
114 + hparams.tgt_lang
115 + ]
100 self.model.config.decoder_start_token_id = self.decoder_start_token_id 116 self.model.config.decoder_start_token_id = self.decoder_start_token_id
101 117
102 - self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams 118 + self.eval_beams = (
103 - assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1" 119 + self.model.config.num_beams
104 - self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric 120 + if self.hparams.eval_beams is None
121 + else self.hparams.eval_beams
122 + )
123 + assert (
124 + self.eval_beams >= 1
125 + ), f"got self.eval_beams={self.eval_beams}. Need an integer > 1"
126 + self.val_metric = (
127 + self.default_val_metric
128 + if self.hparams.val_metric is None
129 + else self.hparams.val_metric
130 + )
105 131
106 def freeze_embeds(self): 132 def freeze_embeds(self):
107 """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" 133 """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
...@@ -133,7 +159,13 @@ class SummarizationModule(BaseTransformer): ...@@ -133,7 +159,13 @@ class SummarizationModule(BaseTransformer):
133 else: 159 else:
134 decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id) 160 decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
135 161
136 - outputs = self(src_ids, src_patch, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False) 162 + outputs = self(
163 + src_ids,
164 + src_patch,
165 + attention_mask=src_mask,
166 + decoder_input_ids=decoder_input_ids,
167 + use_cache=False,
168 + )
137 lm_logits = outputs[0] 169 lm_logits = outputs[0]
138 if self.hparams.label_smoothing == 0: 170 if self.hparams.label_smoothing == 0:
139 # Same behavior as modeling_bart.py, besides ignoring pad_token_id 171 # Same behavior as modeling_bart.py, besides ignoring pad_token_id
...@@ -157,7 +189,9 @@ class SummarizationModule(BaseTransformer): ...@@ -157,7 +189,9 @@ class SummarizationModule(BaseTransformer):
157 189
158 logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} 190 logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
159 # tokens per batch 191 # tokens per batch
160 - logs["tpb"] = batch[0].long().ne(self.pad).sum() + batch[3].long().ne(self.pad).sum() 192 + logs["tpb"] = (
193 + batch[0].long().ne(self.pad).sum() + batch[3].long().ne(self.pad).sum()
194 + )
161 return {"loss": loss_tensors[0], "log": logs} 195 return {"loss": loss_tensors[0], "log": logs}
162 196
163 def validation_step(self, batch, batch_idx) -> Dict: 197 def validation_step(self, batch, batch_idx) -> Dict:
...@@ -165,17 +199,29 @@ class SummarizationModule(BaseTransformer): ...@@ -165,17 +199,29 @@ class SummarizationModule(BaseTransformer):
165 199
166 def validation_epoch_end(self, outputs, prefix="val") -> Dict: 200 def validation_epoch_end(self, outputs, prefix="val") -> Dict:
167 self.step_count += 1 201 self.step_count += 1
168 - losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names} 202 + losses = {
203 + k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names
204 + }
169 loss = losses["loss"] 205 loss = losses["loss"]
170 - rouges = {k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]} 206 + rouges = {
171 - rouge_tensor: torch.FloatTensor = torch.tensor(rouges[self.val_metric]).type_as(loss) 207 + k: np.array([x[k] for x in outputs]).mean()
208 + for k in self.metric_names + ["gen_time", "gen_len"]
209 + }
210 + rouge_tensor: torch.FloatTensor = torch.tensor(rouges[self.val_metric]).type_as(
211 + loss
212 + )
172 rouges.update({k: v.item() for k, v in losses.items()}) 213 rouges.update({k: v.item() for k, v in losses.items()})
173 losses.update(rouges) 214 losses.update(rouges)
174 metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()} 215 metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
175 metrics["step_count"] = self.step_count 216 metrics["step_count"] = self.step_count
176 self.save_metrics(metrics, prefix) # writes to self.metrics_save_path 217 self.save_metrics(metrics, prefix) # writes to self.metrics_save_path
177 preds = flatten_list([x["preds"] for x in outputs]) 218 preds = flatten_list([x["preds"] for x in outputs])
178 - return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_{self.val_metric}": rouge_tensor} 219 + return {
220 + "log": metrics,
221 + "preds": preds,
222 + f"{prefix}_loss": loss,
223 + f"{prefix}_{self.val_metric}": rouge_tensor,
224 + }
179 225
180 def save_metrics(self, latest_metrics, type_path) -> None: 226 def save_metrics(self, latest_metrics, type_path) -> None:
181 self.metrics[type_path].append(latest_metrics) 227 self.metrics[type_path].append(latest_metrics)
...@@ -200,7 +246,9 @@ class SummarizationModule(BaseTransformer): ...@@ -200,7 +246,9 @@ class SummarizationModule(BaseTransformer):
200 base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} 246 base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
201 rouge: Dict = self.calc_generative_metrics(preds, target) 247 rouge: Dict = self.calc_generative_metrics(preds, target)
202 summ_len = np.mean(lmap(len, generated_ids)) 248 summ_len = np.mean(lmap(len, generated_ids))
203 - base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge) 249 + base_metrics.update(
250 + gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge
251 + )
204 return base_metrics 252 return base_metrics
205 253
206 def test_step(self, batch, batch_idx): 254 def test_step(self, batch, batch_idx):
...@@ -213,10 +261,10 @@ class SummarizationModule(BaseTransformer): ...@@ -213,10 +261,10 @@ class SummarizationModule(BaseTransformer):
213 max_target_length = self.target_lens[type_path] 261 max_target_length = self.target_lens[type_path]
214 data_config = DataConfig( 262 data_config = DataConfig(
215 endpoint=args.endpoint, 263 endpoint=args.endpoint,
216 - access_key=os.environ['access_key'], 264 + access_key=os.environ["access_key"],
217 - secret_key=os.environ['secret_key'], 265 + secret_key=os.environ["secret_key"],
218 region=args.region, 266 region=args.region,
219 - dataset_name='commit-autosuggestions', 267 + dataset_name="commit-autosuggestions",
220 additional={ 268 additional={
221 "mode": ("training" if type_path == "train" else "evaluation"), 269 "mode": ("training" if type_path == "train" else "evaluation"),
222 "max_source_length": self.hparams.max_source_length, 270 "max_source_length": self.hparams.max_source_length,
...@@ -224,15 +272,17 @@ class SummarizationModule(BaseTransformer): ...@@ -224,15 +272,17 @@ class SummarizationModule(BaseTransformer):
224 "url": args.url, 272 "url": args.url,
225 }, 273 },
226 attributes=[ 274 attributes=[
227 - ('input_ids', 'int32', (self.hparams.max_source_length,)), 275 + ("input_ids", "int32", (self.hparams.max_source_length,)),
228 - ('attention_masks', 'int32', (self.hparams.max_source_length,)), 276 + ("attention_masks", "int32", (self.hparams.max_source_length,)),
229 - ('patch_ids', 'int32', (self.hparams.max_source_length,)), 277 + ("patch_ids", "int32", (self.hparams.max_source_length,)),
230 - ('targets', 'int32', (max_target_length,)) 278 + ("targets", "int32", (max_target_length,)),
231 - ] 279 + ],
232 ) 280 )
233 return Dataset(config=data_config, clear=True) 281 return Dataset(config=data_config, clear=True)
234 282
235 - def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader: 283 + def get_dataloader(
284 + self, type_path: str, batch_size: int, shuffle: bool = False
285 + ) -> DataLoader:
236 dataset = self.get_dataset(type_path) 286 dataset = self.get_dataset(type_path)
237 sampler = None 287 sampler = None
238 288
...@@ -246,7 +296,9 @@ class SummarizationModule(BaseTransformer): ...@@ -246,7 +296,9 @@ class SummarizationModule(BaseTransformer):
246 return dataloader 296 return dataloader
247 297
248 def train_dataloader(self) -> DataLoader: 298 def train_dataloader(self) -> DataLoader:
249 - dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True) 299 + dataloader = self.get_dataloader(
300 + "train", batch_size=self.hparams.train_batch_size, shuffle=True
301 + )
250 return dataloader 302 return dataloader
251 303
252 def val_dataloader(self) -> DataLoader: 304 def val_dataloader(self) -> DataLoader:
...@@ -259,23 +311,18 @@ class SummarizationModule(BaseTransformer): ...@@ -259,23 +311,18 @@ class SummarizationModule(BaseTransformer):
259 def add_model_specific_args(parser, root_dir): 311 def add_model_specific_args(parser, root_dir):
260 BaseTransformer.add_model_specific_args(parser, root_dir) 312 BaseTransformer.add_model_specific_args(parser, root_dir)
261 add_generic_args(parser, root_dir) 313 add_generic_args(parser, root_dir)
262 - parser.add_argument( 314 + parser.add_argument("--url", type=str, required=True, help="github url")
263 - "--url",
264 - type=str,
265 - required=True,
266 - help="github url"
267 - )
268 parser.add_argument( 315 parser.add_argument(
269 "--endpoint", 316 "--endpoint",
270 type=str, 317 type=str,
271 required=True, 318 required=True,
272 - help='matorage endpoint, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html' 319 + help="matorage endpoint, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html",
273 ) 320 )
274 parser.add_argument( 321 parser.add_argument(
275 "--region", 322 "--region",
276 type=str, 323 type=str,
277 default=None, 324 default=None,
278 - help='matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html' 325 + help="matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html",
279 ) 326 )
280 parser.add_argument( 327 parser.add_argument(
281 "--max_source_length", 328 "--max_source_length",
...@@ -308,14 +355,43 @@ class SummarizationModule(BaseTransformer): ...@@ -308,14 +355,43 @@ class SummarizationModule(BaseTransformer):
308 parser.add_argument("--freeze_encoder", action="store_true") 355 parser.add_argument("--freeze_encoder", action="store_true")
309 parser.add_argument("--freeze_embeds", action="store_true") 356 parser.add_argument("--freeze_embeds", action="store_true")
310 parser.add_argument("--sortish_sampler", action="store_true", default=False) 357 parser.add_argument("--sortish_sampler", action="store_true", default=False)
311 - parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
312 - parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
313 - parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
314 - parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
315 parser.add_argument( 358 parser.add_argument(
316 - "--task", type=str, default="summarization", required=False, help="# examples. -1 means use all." 359 + "--logger_name",
360 + type=str,
361 + choices=["default", "wandb", "wandb_shared"],
362 + default="default",
363 + )
364 + parser.add_argument(
365 + "--n_train",
366 + type=int,
367 + default=-1,
368 + required=False,
369 + help="# examples. -1 means use all.",
370 + )
371 + parser.add_argument(
372 + "--n_val",
373 + type=int,
374 + default=500,
375 + required=False,
376 + help="# examples. -1 means use all.",
377 + )
378 + parser.add_argument(
379 + "--n_test",
380 + type=int,
381 + default=-1,
382 + required=False,
383 + help="# examples. -1 means use all.",
384 + )
385 + parser.add_argument(
386 + "--task",
387 + type=str,
388 + default="summarization",
389 + required=False,
390 + help="# examples. -1 means use all.",
391 + )
392 + parser.add_argument(
393 + "--label_smoothing", type=float, default=0.0, required=False
317 ) 394 )
318 - parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
319 parser.add_argument("--src_lang", type=str, default="", required=False) 395 parser.add_argument("--src_lang", type=str, default="", required=False)
320 parser.add_argument("--tgt_lang", type=str, default="", required=False) 396 parser.add_argument("--tgt_lang", type=str, default="", required=False)
321 parser.add_argument("--eval_beams", type=int, default=None, required=False) 397 parser.add_argument("--eval_beams", type=int, default=None, required=False)
...@@ -348,7 +424,11 @@ class TranslationModule(SummarizationModule): ...@@ -348,7 +424,11 @@ class TranslationModule(SummarizationModule):
348 def main(args, model=None) -> SummarizationModule: 424 def main(args, model=None) -> SummarizationModule:
349 Path(args.output_dir).mkdir(exist_ok=True) 425 Path(args.output_dir).mkdir(exist_ok=True)
350 if len(os.listdir(args.output_dir)) > 3 and args.do_train: 426 if len(os.listdir(args.output_dir)) > 3 and args.do_train:
351 - raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) 427 + raise ValueError(
428 + "Output directory ({}) already exists and is not empty.".format(
429 + args.output_dir
430 + )
431 + )
352 if model is None: 432 if model is None:
353 if args.task == "summarization": 433 if args.task == "summarization":
354 model: SummarizationModule = SummarizationModule(args) 434 model: SummarizationModule = SummarizationModule(args)
...@@ -371,7 +451,9 @@ def main(args, model=None) -> SummarizationModule: ...@@ -371,7 +451,9 @@ def main(args, model=None) -> SummarizationModule:
371 return model 451 return model
372 452
373 model.hparams.test_checkpoint = "" 453 model.hparams.test_checkpoint = ""
374 - checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True))) 454 + checkpoints = list(
455 + sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True))
456 + )
375 if checkpoints: 457 if checkpoints:
376 model.hparams.test_checkpoint = checkpoints[-1] 458 model.hparams.test_checkpoint = checkpoints[-1]
377 trainer.resume_from_checkpoint = checkpoints[-1] 459 trainer.resume_from_checkpoint = checkpoints[-1]
......
...@@ -30,6 +30,7 @@ logging.basicConfig( ...@@ -30,6 +30,7 @@ logging.basicConfig(
30 level=logging.INFO, 30 level=logging.INFO,
31 ) 31 )
32 32
33 +
33 class GenerationMixin: 34 class GenerationMixin:
34 """ 35 """
35 A class contraining all of the functions supporting generation, to be used as a mixin in 36 A class contraining all of the functions supporting generation, to be used as a mixin in
...@@ -50,7 +51,9 @@ class GenerationMixin: ...@@ -50,7 +51,9 @@ class GenerationMixin:
50 """ 51 """
51 return logits 52 return logits
52 53
53 - def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty): 54 + def enforce_repetition_penalty_(
55 + self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty
56 + ):
54 """ 57 """
55 Enforce the repetition penalty (from the `CTRL paper <https://arxiv.org/abs/1909.05858>`__). 58 Enforce the repetition penalty (from the `CTRL paper <https://arxiv.org/abs/1909.05858>`__).
56 """ 59 """
...@@ -79,11 +82,7 @@ class GenerationMixin: ...@@ -79,11 +82,7 @@ class GenerationMixin:
79 # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) 82 # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
80 if repetition_penalty != 1.0: 83 if repetition_penalty != 1.0:
81 self.enforce_repetition_penalty_( 84 self.enforce_repetition_penalty_(
82 - scores, 85 + scores, batch_size, num_beams, input_ids, repetition_penalty,
83 - batch_size,
84 - num_beams,
85 - input_ids,
86 - repetition_penalty,
87 ) 86 )
88 87
89 # set eos token prob to zero if min_length is not reached 88 # set eos token prob to zero if min_length is not reached
...@@ -102,7 +101,11 @@ class GenerationMixin: ...@@ -102,7 +101,11 @@ class GenerationMixin:
102 101
103 if bad_words_ids is not None: 102 if bad_words_ids is not None:
104 # Exclude EOS token (already processed) 103 # Exclude EOS token (already processed)
105 - bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids)) 104 + bad_words_ids = list(
105 + filter(
106 + lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids
107 + )
108 + )
106 # calculate a list of banned tokens according to bad words 109 # calculate a list of banned tokens according to bad words
107 banned_tokens = calc_banned_bad_words_ids(input_ids.tolist(), bad_words_ids) 110 banned_tokens = calc_banned_bad_words_ids(input_ids.tolist(), bad_words_ids)
108 # Modify the scores in place by setting the banned tokens logits to `-inf` 111 # Modify the scores in place by setting the banned tokens logits to `-inf`
...@@ -134,7 +137,7 @@ class GenerationMixin: ...@@ -134,7 +137,7 @@ class GenerationMixin:
134 attention_mask: Optional[torch.LongTensor] = None, 137 attention_mask: Optional[torch.LongTensor] = None,
135 decoder_start_token_id: Optional[int] = None, 138 decoder_start_token_id: Optional[int] = None,
136 use_cache: Optional[bool] = None, 139 use_cache: Optional[bool] = None,
137 - **model_kwargs 140 + **model_kwargs,
138 ) -> torch.LongTensor: 141 ) -> torch.LongTensor:
139 r""" 142 r"""
140 Generates sequences for models with a language modeling head. The method currently supports greedy decoding, 143 Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
...@@ -262,26 +265,50 @@ class GenerationMixin: ...@@ -262,26 +265,50 @@ class GenerationMixin:
262 max_length = max_length if max_length is not None else self.config.max_length 265 max_length = max_length if max_length is not None else self.config.max_length
263 min_length = min_length if min_length is not None else self.config.min_length 266 min_length = min_length if min_length is not None else self.config.min_length
264 do_sample = do_sample if do_sample is not None else self.config.do_sample 267 do_sample = do_sample if do_sample is not None else self.config.do_sample
265 - early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping 268 + early_stopping = (
269 + early_stopping if early_stopping is not None else self.config.early_stopping
270 + )
266 use_cache = use_cache if use_cache is not None else self.config.use_cache 271 use_cache = use_cache if use_cache is not None else self.config.use_cache
267 num_beams = num_beams if num_beams is not None else self.config.num_beams 272 num_beams = num_beams if num_beams is not None else self.config.num_beams
268 - temperature = temperature if temperature is not None else self.config.temperature 273 + temperature = (
274 + temperature if temperature is not None else self.config.temperature
275 + )
269 top_k = top_k if top_k is not None else self.config.top_k 276 top_k = top_k if top_k is not None else self.config.top_k
270 top_p = top_p if top_p is not None else self.config.top_p 277 top_p = top_p if top_p is not None else self.config.top_p
271 - repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty 278 + repetition_penalty = (
272 - bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id 279 + repetition_penalty
273 - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id 280 + if repetition_penalty is not None
274 - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id 281 + else self.config.repetition_penalty
275 - length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty 282 + )
283 + bos_token_id = (
284 + bos_token_id if bos_token_id is not None else self.config.bos_token_id
285 + )
286 + pad_token_id = (
287 + pad_token_id if pad_token_id is not None else self.config.pad_token_id
288 + )
289 + eos_token_id = (
290 + eos_token_id if eos_token_id is not None else self.config.eos_token_id
291 + )
292 + length_penalty = (
293 + length_penalty if length_penalty is not None else self.config.length_penalty
294 + )
276 no_repeat_ngram_size = ( 295 no_repeat_ngram_size = (
277 - no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size 296 + no_repeat_ngram_size
297 + if no_repeat_ngram_size is not None
298 + else self.config.no_repeat_ngram_size
299 + )
300 + bad_words_ids = (
301 + bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
278 ) 302 )
279 - bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
280 num_return_sequences = ( 303 num_return_sequences = (
281 - num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences 304 + num_return_sequences
305 + if num_return_sequences is not None
306 + else self.config.num_return_sequences
282 ) 307 )
283 decoder_start_token_id = ( 308 decoder_start_token_id = (
284 - decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id 309 + decoder_start_token_id
310 + if decoder_start_token_id is not None
311 + else self.config.decoder_start_token_id
285 ) 312 )
286 313
287 if input_ids is not None: 314 if input_ids is not None:
...@@ -289,14 +316,22 @@ class GenerationMixin: ...@@ -289,14 +316,22 @@ class GenerationMixin:
289 else: 316 else:
290 batch_size = 1 317 batch_size = 1
291 318
292 - assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer." 319 + assert (
293 - assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer." 320 + isinstance(max_length, int) and max_length > 0
321 + ), "`max_length` should be a strictly positive integer."
322 + assert (
323 + isinstance(min_length, int) and min_length >= 0
324 + ), "`min_length` should be a positive integer."
294 assert isinstance(do_sample, bool), "`do_sample` should be a boolean." 325 assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
295 assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean." 326 assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
296 assert isinstance(use_cache, bool), "`use_cache` should be a boolean." 327 assert isinstance(use_cache, bool), "`use_cache` should be a boolean."
297 - assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer." 328 + assert (
329 + isinstance(num_beams, int) and num_beams > 0
330 + ), "`num_beams` should be a strictly positive integer."
298 assert temperature > 0, "`temperature` should be strictly positive." 331 assert temperature > 0, "`temperature` should be strictly positive."
299 - assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer." 332 + assert (
333 + isinstance(top_k, int) and top_k >= 0
334 + ), "`top_k` should be a positive integer."
300 assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1." 335 assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
301 assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1." 336 assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
302 assert input_ids is not None or ( 337 assert input_ids is not None or (
...@@ -316,7 +351,9 @@ class GenerationMixin: ...@@ -316,7 +351,9 @@ class GenerationMixin:
316 isinstance(num_return_sequences, int) and num_return_sequences > 0 351 isinstance(num_return_sequences, int) and num_return_sequences > 0
317 ), "`num_return_sequences` should be a strictly positive integer." 352 ), "`num_return_sequences` should be a strictly positive integer."
318 assert ( 353 assert (
319 - bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list) 354 + bad_words_ids is None
355 + or isinstance(bad_words_ids, list)
356 + and isinstance(bad_words_ids[0], list)
320 ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated" 357 ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
321 358
322 if input_ids is None: 359 if input_ids is None:
...@@ -331,7 +368,9 @@ class GenerationMixin: ...@@ -331,7 +368,9 @@ class GenerationMixin:
331 device=next(self.parameters()).device, 368 device=next(self.parameters()).device,
332 ) 369 )
333 else: 370 else:
334 - assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)." 371 + assert (
372 + input_ids.dim() == 2
373 + ), "Input prompt should be of shape (batch_size, sequence length)."
335 374
336 # not allow to duplicate outputs when greedy decoding 375 # not allow to duplicate outputs when greedy decoding
337 if do_sample is False: 376 if do_sample is False:
...@@ -349,7 +388,11 @@ class GenerationMixin: ...@@ -349,7 +388,11 @@ class GenerationMixin:
349 388
350 # create attention mask if necessary 389 # create attention mask if necessary
351 # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140 390 # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
352 - if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids): 391 + if (
392 + (attention_mask is None)
393 + and (pad_token_id is not None)
394 + and (pad_token_id in input_ids)
395 + ):
353 attention_mask = input_ids.ne(pad_token_id).long() 396 attention_mask = input_ids.ne(pad_token_id).long()
354 elif attention_mask is None: 397 elif attention_mask is None:
355 attention_mask = input_ids.new_ones(input_ids.shape) 398 attention_mask = input_ids.new_ones(input_ids.shape)
...@@ -358,7 +401,9 @@ class GenerationMixin: ...@@ -358,7 +401,9 @@ class GenerationMixin:
358 # attention_mask is created 401 # attention_mask is created
359 if pad_token_id is None and eos_token_id is not None: 402 if pad_token_id is None and eos_token_id is not None:
360 logger.warning( 403 logger.warning(
361 - "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id) 404 + "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(
405 + eos_token_id
406 + )
362 ) 407 )
363 pad_token_id = eos_token_id 408 pad_token_id = eos_token_id
364 409
...@@ -385,25 +430,37 @@ class GenerationMixin: ...@@ -385,25 +430,37 @@ class GenerationMixin:
385 # see if BOS token can be used for decoder_start_token_id 430 # see if BOS token can be used for decoder_start_token_id
386 if bos_token_id is not None: 431 if bos_token_id is not None:
387 decoder_start_token_id = bos_token_id 432 decoder_start_token_id = bos_token_id
388 - elif hasattr(self.config, "decoder") and hasattr(self.config.decoder, "bos_token_id"): 433 + elif hasattr(self.config, "decoder") and hasattr(
434 + self.config.decoder, "bos_token_id"
435 + ):
389 decoder_start_token_id = self.config.decoder.bos_token_id 436 decoder_start_token_id = self.config.decoder.bos_token_id
390 else: 437 else:
391 raise ValueError( 438 raise ValueError(
392 "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation" 439 "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
393 ) 440 )
394 441
395 - assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self) 442 + assert hasattr(
396 - assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder) 443 + self, "get_encoder"
444 + ), "{} should have a 'get_encoder' function defined".format(self)
445 + assert callable(self.get_encoder), "{} should be a method".format(
446 + self.get_encoder
447 + )
397 448
398 # get encoder and store encoder outputs 449 # get encoder and store encoder outputs
399 encoder = self.get_encoder() 450 encoder = self.get_encoder()
400 - encoder_outputs: ModelOutput = encoder(input_ids, patch_ids, attention_mask=attention_mask, return_dict=True) 451 + encoder_outputs: ModelOutput = encoder(
452 + input_ids, patch_ids, attention_mask=attention_mask, return_dict=True
453 + )
401 454
402 # Expand input ids if num_beams > 1 or num_return_sequences > 1 455 # Expand input ids if num_beams > 1 or num_return_sequences > 1
403 if num_return_sequences > 1 or num_beams > 1: 456 if num_return_sequences > 1 or num_beams > 1:
404 input_ids_len = input_ids.shape[-1] 457 input_ids_len = input_ids.shape[-1]
405 - input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len) 458 + input_ids = input_ids.unsqueeze(1).expand(
406 - patch_ids = patch_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len) 459 + batch_size, effective_batch_mult * num_beams, input_ids_len
460 + )
461 + patch_ids = patch_ids.unsqueeze(1).expand(
462 + batch_size, effective_batch_mult * num_beams, input_ids_len
463 + )
407 attention_mask = attention_mask.unsqueeze(1).expand( 464 attention_mask = attention_mask.unsqueeze(1).expand(
408 batch_size, effective_batch_mult * num_beams, input_ids_len 465 batch_size, effective_batch_mult * num_beams, input_ids_len
409 ) 466 )
...@@ -442,9 +499,9 @@ class GenerationMixin: ...@@ -442,9 +499,9 @@ class GenerationMixin:
442 ) 499 )
443 500
444 # expand encoder_outputs 501 # expand encoder_outputs
445 - encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( 502 + encoder_outputs[
446 - 0, expanded_batch_idxs 503 + "last_hidden_state"
447 - ) 504 + ] = encoder_outputs.last_hidden_state.index_select(0, expanded_batch_idxs)
448 505
449 # save encoder_outputs in `model_kwargs` 506 # save encoder_outputs in `model_kwargs`
450 model_kwargs["encoder_outputs"] = encoder_outputs 507 model_kwargs["encoder_outputs"] = encoder_outputs
...@@ -534,7 +591,11 @@ class GenerationMixin: ...@@ -534,7 +591,11 @@ class GenerationMixin:
534 past = None 591 past = None
535 while cur_len < max_length: 592 while cur_len < max_length:
536 model_inputs = self.prepare_inputs_for_generation( 593 model_inputs = self.prepare_inputs_for_generation(
537 - input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs 594 + input_ids,
595 + past=past,
596 + attention_mask=attention_mask,
597 + use_cache=use_cache,
598 + **model_kwargs,
538 ) 599 )
539 600
540 outputs = self(**model_inputs, return_dict=True) 601 outputs = self(**model_inputs, return_dict=True)
...@@ -565,7 +626,9 @@ class GenerationMixin: ...@@ -565,7 +626,9 @@ class GenerationMixin:
565 if temperature != 1.0: 626 if temperature != 1.0:
566 scores = scores / temperature 627 scores = scores / temperature
567 # Top-p/top-k filtering 628 # Top-p/top-k filtering
568 - next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p) 629 + next_token_logscores = top_k_top_p_filtering(
630 + scores, top_k=top_k, top_p=top_p
631 + )
569 # Sample 632 # Sample
570 probs = F.softmax(next_token_logscores, dim=-1) 633 probs = F.softmax(next_token_logscores, dim=-1)
571 next_token = torch.multinomial(probs, num_samples=1).squeeze(1) 634 next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
...@@ -576,7 +639,9 @@ class GenerationMixin: ...@@ -576,7 +639,9 @@ class GenerationMixin:
576 # update generations and finished sentences 639 # update generations and finished sentences
577 if eos_token_id is not None: 640 if eos_token_id is not None:
578 # pad finished sentences if eos_token_id exist 641 # pad finished sentences if eos_token_id exist
579 - tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents) 642 + tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (
643 + 1 - unfinished_sents
644 + )
580 else: 645 else:
581 tokens_to_add = next_token 646 tokens_to_add = next_token
582 647
...@@ -587,8 +652,12 @@ class GenerationMixin: ...@@ -587,8 +652,12 @@ class GenerationMixin:
587 if eos_token_id is not None: 652 if eos_token_id is not None:
588 eos_in_sents = tokens_to_add == eos_token_id 653 eos_in_sents = tokens_to_add == eos_token_id
589 # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length 654 # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
590 - is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool() 655 + is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(
591 - sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len) 656 + eos_in_sents.long()
657 + ).bool()
658 + sent_lengths.masked_fill_(
659 + is_sents_unfinished_and_token_to_add_is_eos, cur_len
660 + )
592 # unfinished_sents is set to zero if eos in sentence 661 # unfinished_sents is set to zero if eos in sentence
593 unfinished_sents.mul_((~eos_in_sents).long()) 662 unfinished_sents.mul_((~eos_in_sents).long())
594 663
...@@ -599,7 +668,11 @@ class GenerationMixin: ...@@ -599,7 +668,11 @@ class GenerationMixin:
599 # extend attention_mask for new generated input if only decoder 668 # extend attention_mask for new generated input if only decoder
600 if self.config.is_encoder_decoder is False: 669 if self.config.is_encoder_decoder is False:
601 attention_mask = torch.cat( 670 attention_mask = torch.cat(
602 - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 671 + [
672 + attention_mask,
673 + attention_mask.new_ones((attention_mask.shape[0], 1)),
674 + ],
675 + dim=-1,
603 ) 676 )
604 677
605 return input_ids 678 return input_ids
...@@ -633,12 +706,16 @@ class GenerationMixin: ...@@ -633,12 +706,16 @@ class GenerationMixin:
633 706
634 # generated hypotheses 707 # generated hypotheses
635 generated_hyps = [ 708 generated_hyps = [
636 - BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping) 709 + BeamHypotheses(
710 + num_beams, max_length, length_penalty, early_stopping=early_stopping
711 + )
637 for _ in range(batch_size) 712 for _ in range(batch_size)
638 ] 713 ]
639 714
640 # scores for each sentence in the beam 715 # scores for each sentence in the beam
641 - beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) 716 + beam_scores = torch.zeros(
717 + (batch_size, num_beams), dtype=torch.float, device=input_ids.device
718 + )
642 719
643 # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times 720 # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
644 if do_sample is False: 721 if do_sample is False:
...@@ -653,10 +730,18 @@ class GenerationMixin: ...@@ -653,10 +730,18 @@ class GenerationMixin:
653 730
654 while cur_len < max_length: 731 while cur_len < max_length:
655 model_inputs = self.prepare_inputs_for_generation( 732 model_inputs = self.prepare_inputs_for_generation(
656 - input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs 733 + input_ids,
734 + past=past,
735 + attention_mask=attention_mask,
736 + use_cache=use_cache,
737 + **model_kwargs,
657 ) 738 )
658 - outputs = self(**model_inputs, return_dict=True) # (batch_size * num_beams, cur_len, vocab_size) 739 + outputs = self(
659 - next_token_logits = outputs.logits[:, -1, :] # (batch_size * num_beams, vocab_size) 740 + **model_inputs, return_dict=True
741 + ) # (batch_size * num_beams, cur_len, vocab_size)
742 + next_token_logits = outputs.logits[
743 + :, -1, :
744 + ] # (batch_size * num_beams, vocab_size)
660 745
661 # if model has past, then set the past variable to speed up decoding 746 # if model has past, then set the past variable to speed up decoding
662 if "past_key_values" in outputs: 747 if "past_key_values" in outputs:
...@@ -670,7 +755,9 @@ class GenerationMixin: ...@@ -670,7 +755,9 @@ class GenerationMixin:
670 next_token_logits, cur_len=cur_len, max_length=max_length 755 next_token_logits, cur_len=cur_len, max_length=max_length
671 ) 756 )
672 757
673 - scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) 758 + scores = F.log_softmax(
759 + next_token_logits, dim=-1
760 + ) # (batch_size * num_beams, vocab_size)
674 761
675 scores = self.postprocess_next_token_scores( 762 scores = self.postprocess_next_token_scores(
676 scores=scores, 763 scores=scores,
...@@ -686,12 +773,17 @@ class GenerationMixin: ...@@ -686,12 +773,17 @@ class GenerationMixin:
686 num_beams=num_beams, 773 num_beams=num_beams,
687 ) 774 )
688 775
689 - assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format( 776 + assert scores.shape == (
777 + batch_size * num_beams,
778 + vocab_size,
779 + ), "Shapes of scores: {} != {}".format(
690 scores.shape, (batch_size * num_beams, vocab_size) 780 scores.shape, (batch_size * num_beams, vocab_size)
691 ) 781 )
692 782
693 if do_sample: 783 if do_sample:
694 - _scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size) 784 + _scores = scores + beam_scores[:, None].expand_as(
785 + scores
786 + ) # (batch_size * num_beams, vocab_size)
695 # Temperature 787 # Temperature
696 if temperature != 1.0: 788 if temperature != 1.0:
697 _scores = _scores / temperature 789 _scores = _scores / temperature
...@@ -706,24 +798,38 @@ class GenerationMixin: ...@@ -706,24 +798,38 @@ class GenerationMixin:
706 798
707 # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search) 799 # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
708 probs = F.softmax(_scores, dim=-1) 800 probs = F.softmax(_scores, dim=-1)
709 - next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) # (batch_size, num_beams * 2) 801 + next_tokens = torch.multinomial(
802 + probs, num_samples=2 * num_beams
803 + ) # (batch_size, num_beams * 2)
710 # Compute next scores 804 # Compute next scores
711 - next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2) 805 + next_scores = torch.gather(
806 + _scores, -1, next_tokens
807 + ) # (batch_size, num_beams * 2)
712 # sort the sampled vector to make sure that the first num_beams samples are the best 808 # sort the sampled vector to make sure that the first num_beams samples are the best
713 - next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1) 809 + next_scores, next_scores_indices = torch.sort(
714 - next_tokens = torch.gather(next_tokens, -1, next_scores_indices) # (batch_size, num_beams * 2) 810 + next_scores, descending=True, dim=1
811 + )
812 + next_tokens = torch.gather(
813 + next_tokens, -1, next_scores_indices
814 + ) # (batch_size, num_beams * 2)
715 815
716 else: 816 else:
717 - next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size) 817 + next_scores = scores + beam_scores[:, None].expand_as(
818 + scores
819 + ) # (batch_size * num_beams, vocab_size)
718 820
719 # re-organize to group the beam together (we are keeping top hypothesis accross beams) 821 # re-organize to group the beam together (we are keeping top hypothesis accross beams)
720 next_scores = next_scores.view( 822 next_scores = next_scores.view(
721 batch_size, num_beams * vocab_size 823 batch_size, num_beams * vocab_size
722 ) # (batch_size, num_beams * vocab_size) 824 ) # (batch_size, num_beams * vocab_size)
723 825
724 - next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True) 826 + next_scores, next_tokens = torch.topk(
827 + next_scores, 2 * num_beams, dim=1, largest=True, sorted=True
828 + )
725 829
726 - assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams) 830 + assert (
831 + next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)
832 + )
727 833
728 # next batch beam content 834 # next batch beam content
729 next_batch_beam = [] 835 next_batch_beam = []
...@@ -735,11 +841,15 @@ class GenerationMixin: ...@@ -735,11 +841,15 @@ class GenerationMixin:
735 if done[batch_idx]: 841 if done[batch_idx]:
736 assert ( 842 assert (
737 len(generated_hyps[batch_idx]) >= num_beams 843 len(generated_hyps[batch_idx]) >= num_beams
738 - ), "Batch can only be done if at least {} beams have been generated".format(num_beams) 844 + ), "Batch can only be done if at least {} beams have been generated".format(
845 + num_beams
846 + )
739 assert ( 847 assert (
740 eos_token_id is not None and pad_token_id is not None 848 eos_token_id is not None and pad_token_id is not None
741 ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined" 849 ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
742 - next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch 850 + next_batch_beam.extend(
851 + [(0, pad_token_id, 0)] * num_beams
852 + ) # pad the batch
743 continue 853 continue
744 854
745 # next sentence beam content, this will get added to next_batch_beam 855 # next sentence beam content, this will get added to next_batch_beam
...@@ -757,7 +867,9 @@ class GenerationMixin: ...@@ -757,7 +867,9 @@ class GenerationMixin:
757 # add to generated hypotheses if end of sentence 867 # add to generated hypotheses if end of sentence
758 if (eos_token_id is not None) and (token_id.item() == eos_token_id): 868 if (eos_token_id is not None) and (token_id.item() == eos_token_id):
759 # if beam_token does not belong to top num_beams tokens, it should not be added 869 # if beam_token does not belong to top num_beams tokens, it should not be added
760 - is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams 870 + is_beam_token_worse_than_top_num_beams = (
871 + beam_token_rank >= num_beams
872 + )
761 if is_beam_token_worse_than_top_num_beams: 873 if is_beam_token_worse_than_top_num_beams:
762 continue 874 continue
763 generated_hyps[batch_idx].add( 875 generated_hyps[batch_idx].add(
...@@ -766,7 +878,9 @@ class GenerationMixin: ...@@ -766,7 +878,9 @@ class GenerationMixin:
766 ) 878 )
767 else: 879 else:
768 # add next predicted token since it is not eos_token 880 # add next predicted token since it is not eos_token
769 - next_sent_beam.append((beam_token_score, token_id, effective_beam_id)) 881 + next_sent_beam.append(
882 + (beam_token_score, token_id, effective_beam_id)
883 + )
770 884
771 # once the beam for next step is full, don't add more tokens to it. 885 # once the beam for next step is full, don't add more tokens to it.
772 if len(next_sent_beam) == num_beams: 886 if len(next_sent_beam) == num_beams:
...@@ -780,7 +894,9 @@ class GenerationMixin: ...@@ -780,7 +894,9 @@ class GenerationMixin:
780 # update next beam content 894 # update next beam content
781 assert len(next_sent_beam) == num_beams, "Beam should always be full" 895 assert len(next_sent_beam) == num_beams, "Beam should always be full"
782 next_batch_beam.extend(next_sent_beam) 896 next_batch_beam.extend(next_sent_beam)
783 - assert len(next_batch_beam) == num_beams * (batch_idx + 1), "We should have added num_beams each step" 897 + assert len(next_batch_beam) == num_beams * (
898 + batch_idx + 1
899 + ), "We should have added num_beams each step"
784 900
785 # stop when we are done with each sentence 901 # stop when we are done with each sentence
786 if all(done): 902 if all(done):
...@@ -804,7 +920,11 @@ class GenerationMixin: ...@@ -804,7 +920,11 @@ class GenerationMixin:
804 # extend attention_mask for new generated input if only decoder 920 # extend attention_mask for new generated input if only decoder
805 if self.config.is_encoder_decoder is False: 921 if self.config.is_encoder_decoder is False:
806 attention_mask = torch.cat( 922 attention_mask = torch.cat(
807 - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 923 + [
924 + attention_mask,
925 + attention_mask.new_ones((attention_mask.shape[0], 1)),
926 + ],
927 + dim=-1,
808 ) 928 )
809 929
810 # finalize all open beam hypotheses and add to generated hypotheses 930 # finalize all open beam hypotheses and add to generated hypotheses
...@@ -814,10 +934,12 @@ class GenerationMixin: ...@@ -814,10 +934,12 @@ class GenerationMixin:
814 934
815 # test that beam scores match previously calculated scores if not eos and batch_idx not done 935 # test that beam scores match previously calculated scores if not eos and batch_idx not done
816 if eos_token_id is not None and all( 936 if eos_token_id is not None and all(
817 - (token_id % vocab_size).item() != eos_token_id for token_id in next_tokens[batch_idx] 937 + (token_id % vocab_size).item() != eos_token_id
938 + for token_id in next_tokens[batch_idx]
818 ): 939 ):
819 assert torch.all( 940 assert torch.all(
820 - next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx] 941 + next_scores[batch_idx, :num_beams]
942 + == beam_scores.view(batch_size, num_beams)[batch_idx]
821 ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format( 943 ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
822 next_scores[:, :num_beams][batch_idx], 944 next_scores[:, :num_beams][batch_idx],
823 beam_scores.view(batch_size, num_beams)[batch_idx], 945 beam_scores.view(batch_size, num_beams)[batch_idx],
...@@ -831,7 +953,9 @@ class GenerationMixin: ...@@ -831,7 +953,9 @@ class GenerationMixin:
831 generated_hyps[batch_idx].add(final_tokens, final_score) 953 generated_hyps[batch_idx].add(final_tokens, final_score)
832 954
833 # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch 955 # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
834 - output_batch_size = batch_size if do_sample else batch_size * num_return_sequences 956 + output_batch_size = (
957 + batch_size if do_sample else batch_size * num_return_sequences
958 + )
835 output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences 959 output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
836 960
837 # select the best hypotheses 961 # select the best hypotheses
...@@ -861,7 +985,9 @@ class GenerationMixin: ...@@ -861,7 +985,9 @@ class GenerationMixin:
861 else: 985 else:
862 # none of the hypotheses have an eos_token 986 # none of the hypotheses have an eos_token
863 assert (len(hypo) == max_length for hypo in best) 987 assert (len(hypo) == max_length for hypo in best)
864 - decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device) 988 + decoded = (
989 + torch.stack(best).type(torch.long).to(next(self.parameters()).device)
990 + )
865 991
866 return decoded 992 return decoded
867 993
...@@ -870,7 +996,9 @@ class GenerationMixin: ...@@ -870,7 +996,9 @@ class GenerationMixin:
870 return tuple(layer_past.index_select(1, beam_idx) for layer_past in past) 996 return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
871 997
872 998
873 -def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int) -> None: 999 +def calc_banned_ngram_tokens(
1000 + prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int
1001 +) -> None:
874 """Copied from fairseq for no_repeat_ngram in beam_search""" 1002 """Copied from fairseq for no_repeat_ngram in beam_search"""
875 if cur_len + 1 < no_repeat_ngram_size: 1003 if cur_len + 1 < no_repeat_ngram_size:
876 # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet 1004 # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
...@@ -881,7 +1009,9 @@ def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_n ...@@ -881,7 +1009,9 @@ def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_n
881 generated_ngram = generated_ngrams[idx] 1009 generated_ngram = generated_ngrams[idx]
882 for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]): 1010 for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
883 prev_ngram_tuple = tuple(ngram[:-1]) 1011 prev_ngram_tuple = tuple(ngram[:-1])
884 - generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]] 1012 + generated_ngram[prev_ngram_tuple] = generated_ngram.get(
1013 + prev_ngram_tuple, []
1014 + ) + [ngram[-1]]
885 1015
886 def _get_generated_ngrams(hypo_idx): 1016 def _get_generated_ngrams(hypo_idx):
887 # Before decoding the next token, prevent decoding of ngrams that have already appeared 1017 # Before decoding the next token, prevent decoding of ngrams that have already appeared
...@@ -893,7 +1023,9 @@ def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_n ...@@ -893,7 +1023,9 @@ def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_n
893 return banned_tokens 1023 return banned_tokens
894 1024
895 1025
896 -def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]) -> Iterable[int]: 1026 +def calc_banned_bad_words_ids(
1027 + prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]
1028 +) -> Iterable[int]:
897 banned_tokens = [] 1029 banned_tokens = []
898 1030
899 def _tokens_match(prev_tokens, tokens): 1031 def _tokens_match(prev_tokens, tokens):
...@@ -914,7 +1046,9 @@ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iter ...@@ -914,7 +1046,9 @@ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iter
914 banned_tokens_slice = [] 1046 banned_tokens_slice = []
915 1047
916 for banned_token_seq in bad_words_ids: 1048 for banned_token_seq in bad_words_ids:
917 - assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format( 1049 + assert (
1050 + len(banned_token_seq) > 0
1051 + ), "Banned words token sequences {} cannot have an empty list".format(
918 bad_words_ids 1052 bad_words_ids
919 ) 1053 )
920 1054
...@@ -929,7 +1063,9 @@ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iter ...@@ -929,7 +1063,9 @@ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iter
929 return banned_tokens 1063 return banned_tokens
930 1064
931 1065
932 -def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: List[List[int]]) -> None: 1066 +def set_scores_to_inf_for_banned_tokens(
1067 + scores: torch.Tensor, banned_tokens: List[List[int]]
1068 +) -> None:
933 """Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be 1069 """Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be
934 a list of list of banned tokens to ban in the format [[batch index, vocabulary position],...] 1070 a list of list of banned tokens to ban in the format [[batch index, vocabulary position],...]
935 Args: 1071 Args:
...@@ -949,7 +1085,12 @@ def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: Lis ...@@ -949,7 +1085,12 @@ def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: Lis
949 # [ 0 0 0 ] 1085 # [ 0 0 0 ]
950 # [ 1 0 0 ] 1086 # [ 1 0 0 ]
951 1087
952 - banned_mask = torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool() 1088 + banned_mask = (
1089 + torch.sparse.LongTensor(banned_mask.t(), indices, scores.size())
1090 + .to(scores.device)
1091 + .to_dense()
1092 + .bool()
1093 + )
953 scores.masked_fill_(banned_mask, -float("inf")) 1094 scores.masked_fill_(banned_mask, -float("inf"))
954 1095
955 1096
...@@ -989,7 +1130,9 @@ def top_k_top_p_filtering( ...@@ -989,7 +1130,9 @@ def top_k_top_p_filtering(
989 sorted_indices_to_remove[..., 0] = 0 1130 sorted_indices_to_remove[..., 0] = 0
990 1131
991 # scatter sorted tensors to original indexing 1132 # scatter sorted tensors to original indexing
992 - indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) 1133 + indices_to_remove = sorted_indices_to_remove.scatter(
1134 + 1, sorted_indices, sorted_indices_to_remove
1135 + )
993 logits[indices_to_remove] = filter_value 1136 logits[indices_to_remove] = filter_value
994 return logits 1137 return logits
995 1138
...@@ -1020,7 +1163,9 @@ class BeamHypotheses(object): ...@@ -1020,7 +1163,9 @@ class BeamHypotheses(object):
1020 if len(self) < self.num_beams or score > self.worst_score: 1163 if len(self) < self.num_beams or score > self.worst_score:
1021 self.beams.append((score, hyp)) 1164 self.beams.append((score, hyp))
1022 if len(self) > self.num_beams: 1165 if len(self) > self.num_beams:
1023 - sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)]) 1166 + sorted_scores = sorted(
1167 + [(s, idx) for idx, (s, _) in enumerate(self.beams)]
1168 + )
1024 del self.beams[sorted_scores[0][1]] 1169 del self.beams[sorted_scores[0][1]]
1025 self.worst_score = sorted_scores[1][0] 1170 self.worst_score = sorted_scores[1][0]
1026 else: 1171 else:
......
...@@ -69,7 +69,7 @@ class BaseTransformer(pl.LightningModule): ...@@ -69,7 +69,7 @@ class BaseTransformer(pl.LightningModule):
69 config=None, 69 config=None,
70 tokenizer=None, 70 tokenizer=None,
71 model=None, 71 model=None,
72 - **config_kwargs 72 + **config_kwargs,
73 ): 73 ):
74 """Initialize a model, tokenizer and config.""" 74 """Initialize a model, tokenizer and config."""
75 super().__init__() 75 super().__init__()
...@@ -83,7 +83,9 @@ class BaseTransformer(pl.LightningModule): ...@@ -83,7 +83,9 @@ class BaseTransformer(pl.LightningModule):
83 cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None 83 cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
84 if config is None: 84 if config is None:
85 self.config = AutoConfig.from_pretrained( 85 self.config = AutoConfig.from_pretrained(
86 - self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path, 86 + self.hparams.config_name
87 + if self.hparams.config_name
88 + else self.hparams.model_name_or_path,
87 **({"num_labels": num_labels} if num_labels is not None else {}), 89 **({"num_labels": num_labels} if num_labels is not None else {}),
88 cache_dir=cache_dir, 90 cache_dir=cache_dir,
89 **config_kwargs, 91 **config_kwargs,
...@@ -91,15 +93,24 @@ class BaseTransformer(pl.LightningModule): ...@@ -91,15 +93,24 @@ class BaseTransformer(pl.LightningModule):
91 else: 93 else:
92 self.config: PretrainedConfig = config 94 self.config: PretrainedConfig = config
93 95
94 - extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout") 96 + extra_model_params = (
97 + "encoder_layerdrop",
98 + "decoder_layerdrop",
99 + "dropout",
100 + "attention_dropout",
101 + )
95 for p in extra_model_params: 102 for p in extra_model_params:
96 if getattr(self.hparams, p, None): 103 if getattr(self.hparams, p, None):
97 - assert hasattr(self.config, p), f"model config doesn't have a `{p}` attribute" 104 + assert hasattr(
105 + self.config, p
106 + ), f"model config doesn't have a `{p}` attribute"
98 setattr(self.config, p, getattr(self.hparams, p)) 107 setattr(self.config, p, getattr(self.hparams, p))
99 108
100 if tokenizer is None: 109 if tokenizer is None:
101 self.tokenizer = AutoTokenizer.from_pretrained( 110 self.tokenizer = AutoTokenizer.from_pretrained(
102 - self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path, 111 + self.hparams.tokenizer_name
112 + if self.hparams.tokenizer_name
113 + else self.hparams.model_name_or_path,
103 cache_dir=cache_dir, 114 cache_dir=cache_dir,
104 ) 115 )
105 else: 116 else:
...@@ -121,7 +132,9 @@ class BaseTransformer(pl.LightningModule): ...@@ -121,7 +132,9 @@ class BaseTransformer(pl.LightningModule):
121 def get_lr_scheduler(self): 132 def get_lr_scheduler(self):
122 get_schedule_func = arg_to_scheduler[self.hparams.lr_scheduler] 133 get_schedule_func = arg_to_scheduler[self.hparams.lr_scheduler]
123 scheduler = get_schedule_func( 134 scheduler = get_schedule_func(
124 - self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps 135 + self.opt,
136 + num_warmup_steps=self.hparams.warmup_steps,
137 + num_training_steps=self.total_steps,
125 ) 138 )
126 scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} 139 scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
127 return scheduler 140 return scheduler
...@@ -132,22 +145,35 @@ class BaseTransformer(pl.LightningModule): ...@@ -132,22 +145,35 @@ class BaseTransformer(pl.LightningModule):
132 no_decay = ["bias", "LayerNorm.weight"] 145 no_decay = ["bias", "LayerNorm.weight"]
133 optimizer_grouped_parameters = [ 146 optimizer_grouped_parameters = [
134 { 147 {
135 - "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 148 + "params": [
149 + p
150 + for n, p in model.named_parameters()
151 + if not any(nd in n for nd in no_decay)
152 + ],
136 "weight_decay": self.hparams.weight_decay, 153 "weight_decay": self.hparams.weight_decay,
137 }, 154 },
138 { 155 {
139 - "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 156 + "params": [
157 + p
158 + for n, p in model.named_parameters()
159 + if any(nd in n for nd in no_decay)
160 + ],
140 "weight_decay": 0.0, 161 "weight_decay": 0.0,
141 }, 162 },
142 ] 163 ]
143 if self.hparams.adafactor: 164 if self.hparams.adafactor:
144 optimizer = Adafactor( 165 optimizer = Adafactor(
145 - optimizer_grouped_parameters, lr=self.hparams.learning_rate, scale_parameter=False, relative_step=False 166 + optimizer_grouped_parameters,
167 + lr=self.hparams.learning_rate,
168 + scale_parameter=False,
169 + relative_step=False,
146 ) 170 )
147 171
148 else: 172 else:
149 optimizer = AdamW( 173 optimizer = AdamW(
150 - optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon 174 + optimizer_grouped_parameters,
175 + lr=self.hparams.learning_rate,
176 + eps=self.hparams.adam_epsilon,
151 ) 177 )
152 self.opt = optimizer 178 self.opt = optimizer
153 179
...@@ -165,13 +191,19 @@ class BaseTransformer(pl.LightningModule): ...@@ -165,13 +191,19 @@ class BaseTransformer(pl.LightningModule):
165 def total_steps(self) -> int: 191 def total_steps(self) -> int:
166 """The number of total training steps that will be run. Used for lr scheduler purposes.""" 192 """The number of total training steps that will be run. Used for lr scheduler purposes."""
167 num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores 193 num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores
168 - effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices 194 + effective_batch_size = (
195 + self.hparams.train_batch_size
196 + * self.hparams.accumulate_grad_batches
197 + * num_devices
198 + )
169 dataset_size = len(self.train_loader.dataset) 199 dataset_size = len(self.train_loader.dataset)
170 return (dataset_size / effective_batch_size) * self.hparams.max_epochs 200 return (dataset_size / effective_batch_size) * self.hparams.max_epochs
171 201
172 def setup(self, mode): 202 def setup(self, mode):
173 if mode == "fit": 203 if mode == "fit":
174 - self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True) 204 + self.train_loader = self.get_dataloader(
205 + "train", self.hparams.train_batch_size, shuffle=True
206 + )
175 207
176 def get_dataloader(self, type_path, batch_size, shuffle=False): 208 def get_dataloader(self, type_path, batch_size, shuffle=False):
177 raise NotImplementedError("You must implement this for your task") 209 raise NotImplementedError("You must implement this for your task")
...@@ -212,7 +244,10 @@ class BaseTransformer(pl.LightningModule): ...@@ -212,7 +244,10 @@ class BaseTransformer(pl.LightningModule):
212 help="Path to pretrained model or model identifier from huggingface.co/models", 244 help="Path to pretrained model or model identifier from huggingface.co/models",
213 ) 245 )
214 parser.add_argument( 246 parser.add_argument(
215 - "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" 247 + "--config_name",
248 + default="",
249 + type=str,
250 + help="Pretrained config name or path if not the same as model_name",
216 ) 251 )
217 parser.add_argument( 252 parser.add_argument(
218 "--tokenizer_name", 253 "--tokenizer_name",
...@@ -246,7 +281,12 @@ class BaseTransformer(pl.LightningModule): ...@@ -246,7 +281,12 @@ class BaseTransformer(pl.LightningModule):
246 type=float, 281 type=float,
247 help="Attention dropout probability (Optional). Goes into model.config", 282 help="Attention dropout probability (Optional). Goes into model.config",
248 ) 283 )
249 - parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") 284 + parser.add_argument(
285 + "--learning_rate",
286 + default=5e-5,
287 + type=float,
288 + help="The initial learning rate for Adam.",
289 + )
250 parser.add_argument( 290 parser.add_argument(
251 "--lr_scheduler", 291 "--lr_scheduler",
252 default="linear", 292 default="linear",
...@@ -255,11 +295,30 @@ class BaseTransformer(pl.LightningModule): ...@@ -255,11 +295,30 @@ class BaseTransformer(pl.LightningModule):
255 type=str, 295 type=str,
256 help="Learning rate scheduler", 296 help="Learning rate scheduler",
257 ) 297 )
258 - parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") 298 + parser.add_argument(
259 - parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 299 + "--weight_decay",
260 - parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") 300 + default=0.0,
261 - parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader") 301 + type=float,
262 - parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int) 302 + help="Weight decay if we apply some.",
303 + )
304 + parser.add_argument(
305 + "--adam_epsilon",
306 + default=1e-8,
307 + type=float,
308 + help="Epsilon for Adam optimizer.",
309 + )
310 + parser.add_argument(
311 + "--warmup_steps",
312 + default=0,
313 + type=int,
314 + help="Linear warmup over warmup_steps.",
315 + )
316 + parser.add_argument(
317 + "--num_workers", default=4, type=int, help="kwarg passed to DataLoader"
318 + )
319 + parser.add_argument(
320 + "--num_train_epochs", dest="max_epochs", default=3, type=int
321 + )
263 parser.add_argument("--train_batch_size", default=32, type=int) 322 parser.add_argument("--train_batch_size", default=32, type=int)
264 parser.add_argument("--eval_batch_size", default=32, type=int) 323 parser.add_argument("--eval_batch_size", default=32, type=int)
265 parser.add_argument("--adafactor", action="store_true") 324 parser.add_argument("--adafactor", action="store_true")
...@@ -283,7 +342,9 @@ class LoggingCallback(pl.Callback): ...@@ -283,7 +342,9 @@ class LoggingCallback(pl.Callback):
283 rank_zero_info("***** Test results *****") 342 rank_zero_info("***** Test results *****")
284 metrics = trainer.callback_metrics 343 metrics = trainer.callback_metrics
285 # Log and save results to file 344 # Log and save results to file
286 - output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt") 345 + output_test_results_file = os.path.join(
346 + pl_module.hparams.output_dir, "test_results.txt"
347 + )
287 with open(output_test_results_file, "w") as writer: 348 with open(output_test_results_file, "w") as writer:
288 for key in sorted(metrics): 349 for key in sorted(metrics):
289 if key not in ["log", "progress_bar"]: 350 if key not in ["log", "progress_bar"]:
...@@ -314,9 +375,21 @@ def add_generic_args(parser, root_dir) -> None: ...@@ -314,9 +375,21 @@ def add_generic_args(parser, root_dir) -> None:
314 "See details at https://nvidia.github.io/apex/amp.html", 375 "See details at https://nvidia.github.io/apex/amp.html",
315 ) 376 )
316 parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int) 377 parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int)
317 - parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm") 378 + parser.add_argument(
318 - parser.add_argument("--do_train", action="store_true", help="Whether to run training.") 379 + "--max_grad_norm",
319 - parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.") 380 + dest="gradient_clip_val",
381 + default=1.0,
382 + type=float,
383 + help="Max gradient norm",
384 + )
385 + parser.add_argument(
386 + "--do_train", action="store_true", help="Whether to run training."
387 + )
388 + parser.add_argument(
389 + "--do_predict",
390 + action="store_true",
391 + help="Whether to run predictions on the test set.",
392 + )
320 parser.add_argument( 393 parser.add_argument(
321 "--gradient_accumulation_steps", 394 "--gradient_accumulation_steps",
322 dest="accumulate_grad_batches", 395 dest="accumulate_grad_batches",
...@@ -324,7 +397,9 @@ def add_generic_args(parser, root_dir) -> None: ...@@ -324,7 +397,9 @@ def add_generic_args(parser, root_dir) -> None:
324 default=1, 397 default=1,
325 help="Number of updates steps to accumulate before performing a backward/update pass.", 398 help="Number of updates steps to accumulate before performing a backward/update pass.",
326 ) 399 )
327 - parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 400 + parser.add_argument(
401 + "--seed", type=int, default=42, help="random seed for initialization"
402 + )
328 403
329 404
330 def generic_train( 405 def generic_train(
...@@ -335,7 +410,7 @@ def generic_train( ...@@ -335,7 +410,7 @@ def generic_train(
335 extra_callbacks=[], 410 extra_callbacks=[],
336 checkpoint_callback=None, 411 checkpoint_callback=None,
337 logging_callback=None, 412 logging_callback=None,
338 - **extra_train_kwargs 413 + **extra_train_kwargs,
339 ): 414 ):
340 pl.seed_everything(args.seed) 415 pl.seed_everything(args.seed)
341 416
...@@ -346,7 +421,11 @@ def generic_train( ...@@ -346,7 +421,11 @@ def generic_train(
346 # add custom checkpoints 421 # add custom checkpoints
347 if checkpoint_callback is None: 422 if checkpoint_callback is None:
348 checkpoint_callback = pl.callbacks.ModelCheckpoint( 423 checkpoint_callback = pl.callbacks.ModelCheckpoint(
349 - filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1 424 + filepath=args.output_dir,
425 + prefix="checkpoint",
426 + monitor="val_loss",
427 + mode="min",
428 + save_top_k=1,
350 ) 429 )
351 if logging_callback is None: 430 if logging_callback is None:
352 logging_callback = LoggingCallback() 431 logging_callback = LoggingCallback()
......
...@@ -141,7 +141,11 @@ def invert_mask(attention_mask): ...@@ -141,7 +141,11 @@ def invert_mask(attention_mask):
141 141
142 142
143 def _prepare_bart_decoder_inputs( 143 def _prepare_bart_decoder_inputs(
144 - config, input_ids, decoder_input_ids=None, decoder_padding_mask=None, causal_mask_dtype=torch.float32 144 + config,
145 + input_ids,
146 + decoder_input_ids=None,
147 + decoder_padding_mask=None,
148 + causal_mask_dtype=torch.float32,
145 ): 149 ):
146 """Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if 150 """Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if
147 none are provided. This mimics the default behavior in fairseq. To override it pass in masks. 151 none are provided. This mimics the default behavior in fairseq. To override it pass in masks.
...@@ -184,7 +188,9 @@ class PretrainedBartModel(PreTrainedModel): ...@@ -184,7 +188,9 @@ class PretrainedBartModel(PreTrainedModel):
184 @property 188 @property
185 def dummy_inputs(self): 189 def dummy_inputs(self):
186 pad_token = self.config.pad_token_id 190 pad_token = self.config.pad_token_id
187 - input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) 191 + input_ids = torch.tensor(
192 + [[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device
193 + )
188 dummy_inputs = { 194 dummy_inputs = {
189 "attention_mask": input_ids.ne(pad_token), 195 "attention_mask": input_ids.ne(pad_token),
190 "input_ids": input_ids, 196 "input_ids": input_ids,
...@@ -229,7 +235,11 @@ class EncoderLayer(nn.Module): ...@@ -229,7 +235,11 @@ class EncoderLayer(nn.Module):
229 def __init__(self, config: BartConfig): 235 def __init__(self, config: BartConfig):
230 super().__init__() 236 super().__init__()
231 self.embed_dim = config.d_model 237 self.embed_dim = config.d_model
232 - self.self_attn = Attention(self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout) 238 + self.self_attn = Attention(
239 + self.embed_dim,
240 + config.encoder_attention_heads,
241 + dropout=config.attention_dropout,
242 + )
233 self.normalize_before = config.normalize_before 243 self.normalize_before = config.normalize_before
234 self.self_attn_layer_norm = LayerNorm(self.embed_dim) 244 self.self_attn_layer_norm = LayerNorm(self.embed_dim)
235 self.dropout = config.dropout 245 self.dropout = config.dropout
...@@ -255,7 +265,10 @@ class EncoderLayer(nn.Module): ...@@ -255,7 +265,10 @@ class EncoderLayer(nn.Module):
255 if self.normalize_before: 265 if self.normalize_before:
256 x = self.self_attn_layer_norm(x) 266 x = self.self_attn_layer_norm(x)
257 x, attn_weights = self.self_attn( 267 x, attn_weights = self.self_attn(
258 - query=x, key=x, key_padding_mask=encoder_padding_mask, output_attentions=output_attentions 268 + query=x,
269 + key=x,
270 + key_padding_mask=encoder_padding_mask,
271 + output_attentions=output_attentions,
259 ) 272 )
260 x = F.dropout(x, p=self.dropout, training=self.training) 273 x = F.dropout(x, p=self.dropout, training=self.training)
261 x = residual + x 274 x = residual + x
...@@ -308,13 +321,23 @@ class BartEncoder(nn.Module): ...@@ -308,13 +321,23 @@ class BartEncoder(nn.Module):
308 config.extra_pos_embeddings, 321 config.extra_pos_embeddings,
309 ) 322 )
310 self.embed_patches = nn.Embedding(3, config.d_model, padding_idx=0) 323 self.embed_patches = nn.Embedding(3, config.d_model, padding_idx=0)
311 - self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)]) 324 + self.layers = nn.ModuleList(
312 - self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity() 325 + [EncoderLayer(config) for _ in range(config.encoder_layers)]
326 + )
327 + self.layernorm_embedding = (
328 + LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
329 + )
313 # mbart has one extra layer_norm 330 # mbart has one extra layer_norm
314 self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None 331 self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None
315 332
316 def forward( 333 def forward(
317 - self, input_ids, patch_ids, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=False 334 + self,
335 + input_ids,
336 + patch_ids,
337 + attention_mask=None,
338 + output_attentions=False,
339 + output_hidden_states=False,
340 + return_dict=False,
318 ): 341 ):
319 """ 342 """
320 Args: 343 Args:
...@@ -352,10 +375,14 @@ class BartEncoder(nn.Module): ...@@ -352,10 +375,14 @@ class BartEncoder(nn.Module):
352 encoder_states.append(x) 375 encoder_states.append(x)
353 # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) 376 # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
354 dropout_probability = random.uniform(0, 1) 377 dropout_probability = random.uniform(0, 1)
355 - if self.training and (dropout_probability < self.layerdrop): # skip the layer 378 + if self.training and (
379 + dropout_probability < self.layerdrop
380 + ): # skip the layer
356 attn = None 381 attn = None
357 else: 382 else:
358 - x, attn = encoder_layer(x, attention_mask, output_attentions=output_attentions) 383 + x, attn = encoder_layer(
384 + x, attention_mask, output_attentions=output_attentions
385 + )
359 386
360 if output_attentions: 387 if output_attentions:
361 all_attentions = all_attentions + (attn,) 388 all_attentions = all_attentions + (attn,)
...@@ -365,14 +392,20 @@ class BartEncoder(nn.Module): ...@@ -365,14 +392,20 @@ class BartEncoder(nn.Module):
365 if output_hidden_states: 392 if output_hidden_states:
366 encoder_states.append(x) 393 encoder_states.append(x)
367 # T x B x C -> B x T x C 394 # T x B x C -> B x T x C
368 - encoder_states = tuple(hidden_state.transpose(0, 1) for hidden_state in encoder_states) 395 + encoder_states = tuple(
396 + hidden_state.transpose(0, 1) for hidden_state in encoder_states
397 + )
369 398
370 # T x B x C -> B x T x C 399 # T x B x C -> B x T x C
371 x = x.transpose(0, 1) 400 x = x.transpose(0, 1)
372 401
373 if not return_dict: 402 if not return_dict:
374 - return tuple(v for v in [x, encoder_states, all_attentions] if v is not None) 403 + return tuple(
375 - return BaseModelOutput(last_hidden_state=x, hidden_states=encoder_states, attentions=all_attentions) 404 + v for v in [x, encoder_states, all_attentions] if v is not None
405 + )
406 + return BaseModelOutput(
407 + last_hidden_state=x, hidden_states=encoder_states, attentions=all_attentions
408 + )
376 409
377 410
378 class DecoderLayer(nn.Module): 411 class DecoderLayer(nn.Module):
...@@ -498,8 +531,12 @@ class BartDecoder(nn.Module): ...@@ -498,8 +531,12 @@ class BartDecoder(nn.Module):
498 self.layers = nn.ModuleList( 531 self.layers = nn.ModuleList(
499 [DecoderLayer(config) for _ in range(config.decoder_layers)] 532 [DecoderLayer(config) for _ in range(config.decoder_layers)]
500 ) # type: List[DecoderLayer] 533 ) # type: List[DecoderLayer]
501 - self.layernorm_embedding = LayerNorm(config.d_model) if config.normalize_embedding else nn.Identity() 534 + self.layernorm_embedding = (
502 - self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None 535 + LayerNorm(config.d_model) if config.normalize_embedding else nn.Identity()
536 + )
537 + self.layer_norm = (
538 + LayerNorm(config.d_model) if config.add_final_layer_norm else None
539 + )
503 540
504 def forward( 541 def forward(
505 self, 542 self,
...@@ -595,23 +632,34 @@ class BartDecoder(nn.Module): ...@@ -595,23 +632,34 @@ class BartDecoder(nn.Module):
595 if use_cache: 632 if use_cache:
596 next_decoder_cache.append(layer_past.copy()) 633 next_decoder_cache.append(layer_past.copy())
597 634
598 - if self.layer_norm and (idx == len(self.layers) - 1): # if config.add_final_layer_norm (mBART) 635 + if self.layer_norm and (
636 + idx == len(self.layers) - 1
637 + ): # if config.add_final_layer_norm (mBART)
599 x = self.layer_norm(x) 638 x = self.layer_norm(x)
600 if output_attentions: 639 if output_attentions:
601 all_self_attns += (layer_self_attn,) 640 all_self_attns += (layer_self_attn,)
602 641
603 # Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim) 642 # Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
604 if output_hidden_states: 643 if output_hidden_states:
605 - all_hidden_states = tuple(hidden_state.transpose(0, 1) for hidden_state in all_hidden_states) 644 + all_hidden_states = tuple(
645 + hidden_state.transpose(0, 1) for hidden_state in all_hidden_states
646 + )
606 x = x.transpose(0, 1) 647 x = x.transpose(0, 1)
607 encoder_hidden_states = encoder_hidden_states.transpose(0, 1) 648 encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
608 649
609 next_cache = next_decoder_cache if use_cache else None 650 next_cache = next_decoder_cache if use_cache else None
610 651
611 if not return_dict: 652 if not return_dict:
612 - return tuple(v for v in [x, next_cache, all_hidden_states, all_self_attns] if v is not None) 653 + return tuple(
654 + v
655 + for v in [x, next_cache, all_hidden_states, all_self_attns]
656 + if v is not None
657 + )
613 return BaseModelOutputWithPast( 658 return BaseModelOutputWithPast(
614 - last_hidden_state=x, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns 659 + last_hidden_state=x,
660 + past_key_values=next_cache,
661 + hidden_states=all_hidden_states,
662 + attentions=all_self_attns,
615 ) 663 )
616 664
617 665
...@@ -638,7 +686,9 @@ class Attention(nn.Module): ...@@ -638,7 +686,9 @@ class Attention(nn.Module):
638 self.num_heads = num_heads 686 self.num_heads = num_heads
639 self.dropout = dropout 687 self.dropout = dropout
640 self.head_dim = embed_dim // num_heads 688 self.head_dim = embed_dim // num_heads
641 - assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" 689 + assert (
690 + self.head_dim * num_heads == self.embed_dim
691 + ), "embed_dim must be divisible by num_heads"
642 self.scaling = self.head_dim ** -0.5 692 self.scaling = self.head_dim ** -0.5
643 693
644 self.encoder_decoder_attention = encoder_decoder_attention 694 self.encoder_decoder_attention = encoder_decoder_attention
...@@ -649,7 +699,11 @@ class Attention(nn.Module): ...@@ -649,7 +699,11 @@ class Attention(nn.Module):
649 self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self" 699 self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self"
650 700
651 def _shape(self, tensor, seq_len, bsz): 701 def _shape(self, tensor, seq_len, bsz):
652 - return tensor.contiguous().view(seq_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) 702 + return (
703 + tensor.contiguous()
704 + .view(seq_len, bsz * self.num_heads, self.head_dim)
705 + .transpose(0, 1)
706 + )
653 707
654 def forward( 708 def forward(
655 self, 709 self,
...@@ -693,7 +747,9 @@ class Attention(nn.Module): ...@@ -693,7 +747,9 @@ class Attention(nn.Module):
693 v = self._shape(v, -1, bsz) 747 v = self._shape(v, -1, bsz)
694 748
695 if saved_state is not None: 749 if saved_state is not None:
696 - k, v, key_padding_mask = self._use_saved_state(k, v, saved_state, key_padding_mask, static_kv, bsz) 750 + k, v, key_padding_mask = self._use_saved_state(
751 + k, v, saved_state, key_padding_mask, static_kv, bsz
752 + )
697 753
698 # Update cache 754 # Update cache
699 layer_state[self.cache_key] = { 755 layer_state[self.cache_key] = {
...@@ -708,7 +764,9 @@ class Attention(nn.Module): ...@@ -708,7 +764,9 @@ class Attention(nn.Module):
708 assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len) 764 assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len)
709 765
710 if attn_mask is not None: 766 if attn_mask is not None:
711 - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask 767 + attn_weights = (
768 + attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask
769 + )
712 attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 770 attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
713 771
714 # This is part of a workaround to get around fork/join parallelism not supporting Optional types. 772 # This is part of a workaround to get around fork/join parallelism not supporting Optional types.
...@@ -725,16 +783,14 @@ class Attention(nn.Module): ...@@ -725,16 +783,14 @@ class Attention(nn.Module):
725 attn_weights = attn_weights.masked_fill(reshaped, float("-inf")) 783 attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
726 attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 784 attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
727 attn_weights = F.softmax(attn_weights, dim=-1) 785 attn_weights = F.softmax(attn_weights, dim=-1)
728 - attn_probs = F.dropout( 786 + attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training,)
729 - attn_weights,
730 - p=self.dropout,
731 - training=self.training,
732 - )
733 787
734 assert v is not None 788 assert v is not None
735 attn_output = torch.bmm(attn_probs, v) 789 attn_output = torch.bmm(attn_probs, v)
736 assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim) 790 assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim)
737 - attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 791 + attn_output = (
792 + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
793 + )
738 attn_output = self.out_proj(attn_output) 794 attn_output = self.out_proj(attn_output)
739 if output_attentions: 795 if output_attentions:
740 attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 796 attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
...@@ -763,12 +819,16 @@ class Attention(nn.Module): ...@@ -763,12 +819,16 @@ class Attention(nn.Module):
763 assert v is not None 819 assert v is not None
764 v = torch.cat([prev_value, v], dim=1) 820 v = torch.cat([prev_value, v], dim=1)
765 assert k is not None and v is not None 821 assert k is not None and v is not None
766 - prev_key_padding_mask: Optional[Tensor] = saved_state.get("prev_key_padding_mask", None) 822 + prev_key_padding_mask: Optional[Tensor] = saved_state.get(
823 + "prev_key_padding_mask", None
824 + )
767 if prev_key_padding_mask is not None: 825 if prev_key_padding_mask is not None:
768 if static_kv: 826 if static_kv:
769 new_key_padding_mask = prev_key_padding_mask 827 new_key_padding_mask = prev_key_padding_mask
770 else: 828 else:
771 - new_key_padding_mask = torch.cat([prev_key_padding_mask, key_padding_mask], dim=1) 829 + new_key_padding_mask = torch.cat(
830 + [prev_key_padding_mask, key_padding_mask], dim=1
831 + )
772 else: 832 else:
773 new_key_padding_mask = key_padding_mask 833 new_key_padding_mask = key_padding_mask
774 return k, v, new_key_padding_mask 834 return k, v, new_key_padding_mask
...@@ -780,11 +840,7 @@ class BartClassificationHead(nn.Module): ...@@ -780,11 +840,7 @@ class BartClassificationHead(nn.Module):
780 # This can trivially be shared with RobertaClassificationHead 840 # This can trivially be shared with RobertaClassificationHead
781 841
782 def __init__( 842 def __init__(
783 - self, 843 + self, input_dim, inner_dim, num_classes, pooler_dropout,
784 - input_dim,
785 - inner_dim,
786 - num_classes,
787 - pooler_dropout,
788 ): 844 ):
789 super().__init__() 845 super().__init__()
790 self.dense = nn.Linear(input_dim, inner_dim) 846 self.dense = nn.Linear(input_dim, inner_dim)
...@@ -808,7 +864,9 @@ class LearnedPositionalEmbedding(nn.Embedding): ...@@ -808,7 +864,9 @@ class LearnedPositionalEmbedding(nn.Embedding):
808 position ids are passed to the forward function. 864 position ids are passed to the forward function.
809 """ 865 """
810 866
811 - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, offset): 867 + def __init__(
868 + self, num_embeddings: int, embedding_dim: int, padding_idx: int, offset
869 + ):
812 # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 870 # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
813 # and adjust num_embeddings appropriately. Other models dont have this hack 871 # and adjust num_embeddings appropriately. Other models dont have this hack
814 self.offset = offset 872 self.offset = offset
...@@ -820,10 +878,14 @@ class LearnedPositionalEmbedding(nn.Embedding): ...@@ -820,10 +878,14 @@ class LearnedPositionalEmbedding(nn.Embedding):
820 """Input is expected to be of size [bsz x seqlen].""" 878 """Input is expected to be of size [bsz x seqlen]."""
821 bsz, seq_len = input_ids.shape[:2] 879 bsz, seq_len = input_ids.shape[:2]
822 if use_cache: 880 if use_cache:
823 - positions = input_ids.data.new(1, 1).fill_(seq_len - 1) # called before slicing 881 + positions = input_ids.data.new(1, 1).fill_(
882 + seq_len - 1
883 + ) # called before slicing
824 else: 884 else:
825 # starts at 0, ends at 1-seq_len 885 # starts at 0, ends at 1-seq_len
826 - positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device) 886 + positions = torch.arange(
887 + seq_len, dtype=torch.long, device=self.weight.device
888 + )
827 return super().forward(positions + self.offset) 889 return super().forward(positions + self.offset)
828 890
829 891
...@@ -896,16 +958,28 @@ class BartModel(PretrainedBartModel): ...@@ -896,16 +958,28 @@ class BartModel(PretrainedBartModel):
896 if decoder_input_ids is None: 958 if decoder_input_ids is None:
897 use_cache = False 959 use_cache = False
898 960
899 - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 961 + output_attentions = (
962 + output_attentions
963 + if output_attentions is not None
964 + else self.config.output_attentions
965 + )
900 output_hidden_states = ( 966 output_hidden_states = (
901 - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 967 + output_hidden_states
968 + if output_hidden_states is not None
969 + else self.config.output_hidden_states
902 ) 970 )
903 use_cache = use_cache if use_cache is not None else self.config.use_cache 971 use_cache = use_cache if use_cache is not None else self.config.use_cache
904 - return_dict = return_dict if return_dict is not None else self.config.use_return_dict 972 + return_dict = (
973 + return_dict if return_dict is not None else self.config.use_return_dict
974 + )
905 975
906 # make masks if user doesn't supply 976 # make masks if user doesn't supply
907 if not use_cache: 977 if not use_cache:
908 - decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs( 978 + (
979 + decoder_input_ids,
980 + decoder_padding_mask,
981 + causal_mask,
982 + ) = _prepare_bart_decoder_inputs(
909 self.config, 983 self.config,
910 input_ids, 984 input_ids,
911 decoder_input_ids=decoder_input_ids, 985 decoder_input_ids=decoder_input_ids,
...@@ -974,17 +1048,24 @@ class BartModel(PretrainedBartModel): ...@@ -974,17 +1048,24 @@ class BartModel(PretrainedBartModel):
974 1048
975 1049
976 @add_start_docstrings( 1050 @add_start_docstrings(
977 - "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING 1051 + "The BART Model with a language modeling head. Can be used for summarization.",
1052 + BART_START_DOCSTRING,
978 ) 1053 )
979 class BartForConditionalGeneration(PretrainedBartModel): 1054 class BartForConditionalGeneration(PretrainedBartModel):
980 base_model_prefix = "model" 1055 base_model_prefix = "model"
981 - authorized_missing_keys = [r"final_logits_bias", r"encoder\.version", r"decoder\.version"] 1056 + authorized_missing_keys = [
1057 + r"final_logits_bias",
1058 + r"encoder\.version",
1059 + r"decoder\.version",
1060 + ]
982 1061
983 def __init__(self, config: BartConfig): 1062 def __init__(self, config: BartConfig):
984 super().__init__(config) 1063 super().__init__(config)
985 base_model = BartModel(config) 1064 base_model = BartModel(config)
986 self.model = base_model 1065 self.model = base_model
987 - self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) 1066 + self.register_buffer(
1067 + "final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))
1068 + )
988 1069
989 def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: 1070 def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
990 old_num_tokens = self.model.shared.num_embeddings 1071 old_num_tokens = self.model.shared.num_embeddings
...@@ -993,16 +1074,23 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -993,16 +1074,23 @@ class BartForConditionalGeneration(PretrainedBartModel):
993 self._resize_final_logits_bias(new_num_tokens, old_num_tokens) 1074 self._resize_final_logits_bias(new_num_tokens, old_num_tokens)
994 return new_embeddings 1075 return new_embeddings
995 1076
996 - def _resize_final_logits_bias(self, new_num_tokens: int, old_num_tokens: int) -> None: 1077 + def _resize_final_logits_bias(
1078 + self, new_num_tokens: int, old_num_tokens: int
1079 + ) -> None:
997 if new_num_tokens <= old_num_tokens: 1080 if new_num_tokens <= old_num_tokens:
998 new_bias = self.final_logits_bias[:, :new_num_tokens] 1081 new_bias = self.final_logits_bias[:, :new_num_tokens]
999 else: 1082 else:
1000 - extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) 1083 + extra_bias = torch.zeros(
1084 + (1, new_num_tokens - old_num_tokens),
1085 + device=self.final_logits_bias.device,
1086 + )
1001 new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) 1087 new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
1002 self.register_buffer("final_logits_bias", new_bias) 1088 self.register_buffer("final_logits_bias", new_bias)
1003 1089
1004 @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING) 1090 @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
1005 - @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) 1091 + @replace_return_docstrings(
1092 + output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
1093 + )
1006 @add_end_docstrings(BART_GENERATION_EXAMPLE) 1094 @add_end_docstrings(BART_GENERATION_EXAMPLE)
1007 def forward( 1095 def forward(
1008 self, 1096 self,
...@@ -1065,7 +1153,9 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -1065,7 +1153,9 @@ class BartForConditionalGeneration(PretrainedBartModel):
1065 FutureWarning, 1153 FutureWarning,
1066 ) 1154 )
1067 past_key_values = unused.pop("decoder_past_key_values") 1155 past_key_values = unused.pop("decoder_past_key_values")
1068 - return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1156 + return_dict = (
1157 + return_dict if return_dict is not None else self.config.use_return_dict
1158 + )
1069 1159
1070 if labels is not None: 1160 if labels is not None:
1071 use_cache = False 1161 use_cache = False
...@@ -1085,17 +1175,23 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -1085,17 +1175,23 @@ class BartForConditionalGeneration(PretrainedBartModel):
1085 output_hidden_states=output_hidden_states, 1175 output_hidden_states=output_hidden_states,
1086 return_dict=return_dict, 1176 return_dict=return_dict,
1087 ) 1177 )
1088 - lm_logits = F.linear(outputs[0], self.model.shared.weight, bias=self.final_logits_bias) 1178 + lm_logits = F.linear(
1179 + outputs[0], self.model.shared.weight, bias=self.final_logits_bias
1180 + )
1089 1181
1090 masked_lm_loss = None 1182 masked_lm_loss = None
1091 if labels is not None: 1183 if labels is not None:
1092 loss_fct = CrossEntropyLoss() 1184 loss_fct = CrossEntropyLoss()
1093 # TODO(SS): do we need to ignore pad tokens in labels? 1185 # TODO(SS): do we need to ignore pad tokens in labels?
1094 - masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) 1186 + masked_lm_loss = loss_fct(
1187 + lm_logits.view(-1, self.config.vocab_size), labels.view(-1)
1188 + )
1095 1189
1096 if not return_dict: 1190 if not return_dict:
1097 output = (lm_logits,) + outputs[1:] 1191 output = (lm_logits,) + outputs[1:]
1098 - return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 1192 + return (
1193 + ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1194 + )
1099 1195
1100 return Seq2SeqLMOutput( 1196 return Seq2SeqLMOutput(
1101 loss=masked_lm_loss, 1197 loss=masked_lm_loss,
...@@ -1109,7 +1205,13 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -1109,7 +1205,13 @@ class BartForConditionalGeneration(PretrainedBartModel):
1109 ) 1205 )
1110 1206
1111 def prepare_inputs_for_generation( 1207 def prepare_inputs_for_generation(
1112 - self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs 1208 + self,
1209 + decoder_input_ids,
1210 + past,
1211 + attention_mask,
1212 + use_cache,
1213 + encoder_outputs,
1214 + **kwargs,
1113 ): 1215 ):
1114 return { 1216 return {
1115 "input_ids": None, # encoder_outputs is defined. input_ids not needed 1217 "input_ids": None, # encoder_outputs is defined. input_ids not needed
...@@ -1130,7 +1232,9 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -1130,7 +1232,9 @@ class BartForConditionalGeneration(PretrainedBartModel):
1130 1232
1131 def _force_token_ids_generation(self, scores, token_id) -> None: 1233 def _force_token_ids_generation(self, scores, token_id) -> None:
1132 """force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))""" 1234 """force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
1133 - scores[:, [x for x in range(self.config.vocab_size) if x != token_id]] = -float("inf") 1235 + scores[:, [x for x in range(self.config.vocab_size) if x != token_id]] = -float(
1236 + "inf"
1237 + )
1134 1238
1135 @staticmethod 1239 @staticmethod
1136 def _reorder_cache(past, beam_idx): 1240 def _reorder_cache(past, beam_idx):
...@@ -1138,7 +1242,8 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -1138,7 +1242,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
1138 for layer_past in past: 1242 for layer_past in past:
1139 # get the correct batch idx from decoder layer's batch dim for cross and self-attn 1243 # get the correct batch idx from decoder layer's batch dim for cross and self-attn
1140 layer_past_new = { 1244 layer_past_new = {
1141 - attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items() 1245 + attn_key: _reorder_buffer(attn_cache, beam_idx)
1246 + for attn_key, attn_cache in layer_past.items()
1142 } 1247 }
1143 reordered_past.append(layer_past_new) 1248 reordered_past.append(layer_past_new)
1144 return reordered_past 1249 return reordered_past
...@@ -1159,10 +1264,7 @@ class BartForSequenceClassification(PretrainedBartModel): ...@@ -1159,10 +1264,7 @@ class BartForSequenceClassification(PretrainedBartModel):
1159 super().__init__(config, **kwargs) 1264 super().__init__(config, **kwargs)
1160 self.model = BartModel(config) 1265 self.model = BartModel(config)
1161 self.classification_head = BartClassificationHead( 1266 self.classification_head = BartClassificationHead(
1162 - config.d_model, 1267 + config.d_model, config.d_model, config.num_labels, config.classif_dropout,
1163 - config.d_model,
1164 - config.num_labels,
1165 - config.classif_dropout,
1166 ) 1268 )
1167 self.model._init_weights(self.classification_head.dense) 1269 self.model._init_weights(self.classification_head.dense)
1168 self.model._init_weights(self.classification_head.out_proj) 1270 self.model._init_weights(self.classification_head.out_proj)
...@@ -1193,7 +1295,9 @@ class BartForSequenceClassification(PretrainedBartModel): ...@@ -1193,7 +1295,9 @@ class BartForSequenceClassification(PretrainedBartModel):
1193 Indices should be in :obj:`[0, ..., config.num_labels - 1]`. 1295 Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
1194 If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1296 If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1195 """ 1297 """
1196 - return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1298 + return_dict = (
1299 + return_dict if return_dict is not None else self.config.use_return_dict
1300 + )
1197 if labels is not None: 1301 if labels is not None:
1198 use_cache = False 1302 use_cache = False
1199 1303
...@@ -1212,7 +1316,9 @@ class BartForSequenceClassification(PretrainedBartModel): ...@@ -1212,7 +1316,9 @@ class BartForSequenceClassification(PretrainedBartModel):
1212 eos_mask = input_ids.eq(self.config.eos_token_id) 1316 eos_mask = input_ids.eq(self.config.eos_token_id)
1213 if len(torch.unique(eos_mask.sum(1))) > 1: 1317 if len(torch.unique(eos_mask.sum(1))) > 1:
1214 raise ValueError("All examples must have the same number of <eos> tokens.") 1318 raise ValueError("All examples must have the same number of <eos> tokens.")
1215 - sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :] 1319 + sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[
1320 + :, -1, :
1321 + ]
1216 logits = self.classification_head(sentence_representation) 1322 logits = self.classification_head(sentence_representation)
1217 1323
1218 loss = None 1324 loss = None
...@@ -1284,7 +1390,9 @@ class BartForQuestionAnswering(PretrainedBartModel): ...@@ -1284,7 +1390,9 @@ class BartForQuestionAnswering(PretrainedBartModel):
1284 Positions are clamped to the length of the sequence (`sequence_length`). 1390 Positions are clamped to the length of the sequence (`sequence_length`).
1285 Position outside of the sequence are not taken into account for computing the loss. 1391 Position outside of the sequence are not taken into account for computing the loss.
1286 """ 1392 """
1287 - return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1393 + return_dict = (
1394 + return_dict if return_dict is not None else self.config.use_return_dict
1395 + )
1288 if start_positions is not None and end_positions is not None: 1396 if start_positions is not None and end_positions is not None:
1289 use_cache = False 1397 use_cache = False
1290 1398
...@@ -1325,10 +1433,7 @@ class BartForQuestionAnswering(PretrainedBartModel): ...@@ -1325,10 +1433,7 @@ class BartForQuestionAnswering(PretrainedBartModel):
1325 total_loss = (start_loss + end_loss) / 2 1433 total_loss = (start_loss + end_loss) / 2
1326 1434
1327 if not return_dict: 1435 if not return_dict:
1328 - output = ( 1436 + output = (start_logits, end_logits,) + outputs[1:]
1329 - start_logits,
1330 - end_logits,
1331 - ) + outputs[1:]
1332 return ((total_loss,) + output) if total_loss is not None else output 1437 return ((total_loss,) + output) if total_loss is not None else output
1333 1438
1334 return Seq2SeqQuestionAnsweringModelOutput( 1439 return Seq2SeqQuestionAnsweringModelOutput(
...@@ -1350,7 +1455,9 @@ class SinusoidalPositionalEmbedding(nn.Embedding): ...@@ -1350,7 +1455,9 @@ class SinusoidalPositionalEmbedding(nn.Embedding):
1350 def __init__(self, num_positions, embedding_dim, padding_idx=None): 1455 def __init__(self, num_positions, embedding_dim, padding_idx=None):
1351 super().__init__(num_positions, embedding_dim) 1456 super().__init__(num_positions, embedding_dim)
1352 if embedding_dim % 2 != 0: 1457 if embedding_dim % 2 != 0:
1353 - raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported") 1458 + raise NotImplementedError(
1459 + f"odd embedding_dim {embedding_dim} not supported"
1460 + )
1354 self.weight = self._init_weight(self.weight) 1461 self.weight = self._init_weight(self.weight)
1355 1462
1356 @staticmethod 1463 @staticmethod
...@@ -1360,9 +1467,14 @@ class SinusoidalPositionalEmbedding(nn.Embedding): ...@@ -1360,9 +1467,14 @@ class SinusoidalPositionalEmbedding(nn.Embedding):
1360 """ 1467 """
1361 n_pos, dim = out.shape 1468 n_pos, dim = out.shape
1362 position_enc = np.array( 1469 position_enc = np.array(
1363 - [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)] 1470 + [
1471 + [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)]
1472 + for pos in range(n_pos)
1473 + ]
1364 ) 1474 )
1365 - out[:, 0 : dim // 2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) # This line breaks for odd n_pos 1475 + out[:, 0 : dim // 2] = torch.FloatTensor(
1476 + np.sin(position_enc[:, 0::2])
1477 + ) # This line breaks for odd n_pos
1366 out[:, dim // 2 :] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) 1478 out[:, dim // 2 :] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
1367 out.detach_() 1479 out.detach_()
1368 out.requires_grad = False 1480 out.requires_grad = False
...@@ -1373,8 +1485,12 @@ class SinusoidalPositionalEmbedding(nn.Embedding): ...@@ -1373,8 +1485,12 @@ class SinusoidalPositionalEmbedding(nn.Embedding):
1373 """Input is expected to be of size [bsz x seqlen].""" 1485 """Input is expected to be of size [bsz x seqlen]."""
1374 bsz, seq_len = input_ids.shape[:2] 1486 bsz, seq_len = input_ids.shape[:2]
1375 if use_cache: 1487 if use_cache:
1376 - positions = input_ids.data.new(1, 1).fill_(seq_len - 1) # called before slicing 1488 + positions = input_ids.data.new(1, 1).fill_(
1489 + seq_len - 1
1490 + ) # called before slicing
1377 else: 1491 else:
1378 # starts at 0, ends at 1-seq_len 1492 # starts at 0, ends at 1-seq_len
1379 - positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device) 1493 + positions = torch.arange(
1494 + seq_len, dtype=torch.long, device=self.weight.device
1495 + )
1380 return super().forward(positions) 1496 return super().forward(positions)
......
...@@ -80,7 +80,9 @@ def find_pruneable_heads_and_indices( ...@@ -80,7 +80,9 @@ def find_pruneable_heads_and_indices(
80 :obj:`Tuple[Set[int], torch.LongTensor]`: A tuple with the remaining heads and their corresponding indices. 80 :obj:`Tuple[Set[int], torch.LongTensor]`: A tuple with the remaining heads and their corresponding indices.
81 """ 81 """
82 mask = torch.ones(n_heads, head_size) 82 mask = torch.ones(n_heads, head_size)
83 - heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads 83 + heads = (
84 + set(heads) - already_pruned_heads
85 + ) # Convert to set and remove already pruned heads
84 for head in heads: 86 for head in heads:
85 # Compute how many pruned heads are before the head and move the index accordingly 87 # Compute how many pruned heads are before the head and move the index accordingly
86 head = head - sum(1 if h < head else 0 for h in already_pruned_heads) 88 head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
...@@ -106,7 +108,11 @@ class ModuleUtilsMixin: ...@@ -106,7 +108,11 @@ class ModuleUtilsMixin:
106 Returns: 108 Returns:
107 :obj:`int`: The number of parameters. 109 :obj:`int`: The number of parameters.
108 """ 110 """
109 - params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters() 111 + params = (
112 + filter(lambda x: x.requires_grad, self.parameters())
113 + if only_trainable
114 + else self.parameters()
115 + )
110 return sum(p.numel() for p in params) 116 return sum(p.numel() for p in params)
111 117
112 @staticmethod 118 @staticmethod
...@@ -114,7 +120,9 @@ class ModuleUtilsMixin: ...@@ -114,7 +120,9 @@ class ModuleUtilsMixin:
114 try: 120 try:
115 import psutil 121 import psutil
116 except (ImportError): 122 except (ImportError):
117 - raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.") 123 + raise ImportError(
124 + "You need to install psutil (pip install psutil) to use memory tracing."
125 + )
118 126
119 process = psutil.Process(os.getpid()) 127 process = psutil.Process(os.getpid())
120 mem = process.memory_info() 128 mem = process.memory_info()
...@@ -126,13 +134,17 @@ class ModuleUtilsMixin: ...@@ -126,13 +134,17 @@ class ModuleUtilsMixin:
126 try: 134 try:
127 import psutil 135 import psutil
128 except (ImportError): 136 except (ImportError):
129 - raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.") 137 + raise ImportError(
138 + "You need to install psutil (pip install psutil) to use memory tracing."
139 + )
130 140
131 process = psutil.Process(os.getpid()) 141 process = psutil.Process(os.getpid())
132 mem = process.memory_info() 142 mem = process.memory_info()
133 module.mem_rss_post_forward = mem.rss 143 module.mem_rss_post_forward = mem.rss
134 mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward 144 mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward
135 - module.mem_rss_diff = mem_rss_diff + (module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0) 145 + module.mem_rss_diff = mem_rss_diff + (
146 + module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0
147 + )
136 return None 148 return None
137 149
138 def add_memory_hooks(self): 150 def add_memory_hooks(self):
...@@ -169,7 +181,9 @@ class ModuleUtilsMixin: ...@@ -169,7 +181,9 @@ class ModuleUtilsMixin:
169 # For nn.DataParallel compatibility in PyTorch 1.5 181 # For nn.DataParallel compatibility in PyTorch 1.5
170 182
171 def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: 183 def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
172 - tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] 184 + tuples = [
185 + (k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)
186 + ]
173 return tuples 187 return tuples
174 188
175 gen = self._named_members(get_members_fn=find_tensor_attributes) 189 gen = self._named_members(get_members_fn=find_tensor_attributes)
...@@ -187,7 +201,9 @@ class ModuleUtilsMixin: ...@@ -187,7 +201,9 @@ class ModuleUtilsMixin:
187 # For nn.DataParallel compatibility in PyTorch 1.5 201 # For nn.DataParallel compatibility in PyTorch 1.5
188 202
189 def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: 203 def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
190 - tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] 204 + tuples = [
205 + (k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)
206 + ]
191 return tuples 207 return tuples
192 208
193 gen = self._named_members(get_members_fn=find_tensor_attributes) 209 gen = self._named_members(get_members_fn=find_tensor_attributes)
...@@ -213,12 +229,18 @@ class ModuleUtilsMixin: ...@@ -213,12 +229,18 @@ class ModuleUtilsMixin:
213 # /transformer/transformer_layers.py#L270 229 # /transformer/transformer_layers.py#L270
214 # encoder_extended_attention_mask = (encoder_extended_attention_mask == 230 # encoder_extended_attention_mask = (encoder_extended_attention_mask ==
215 # encoder_extended_attention_mask.transpose(-1, -2)) 231 # encoder_extended_attention_mask.transpose(-1, -2))
216 - encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility 232 + encoder_extended_attention_mask = encoder_extended_attention_mask.to(
233 + dtype=self.dtype
234 + ) # fp16 compatibility
217 235
218 if self.dtype == torch.float16: 236 if self.dtype == torch.float16:
219 - encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e4 237 + encoder_extended_attention_mask = (
238 + 1.0 - encoder_extended_attention_mask
239 + ) * -1e4
220 elif self.dtype == torch.float32: 240 elif self.dtype == torch.float32:
221 - encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9 241 + encoder_extended_attention_mask = (
242 + 1.0 - encoder_extended_attention_mask
243 + ) * -1e9
222 else: 244 else:
223 raise ValueError( 245 raise ValueError(
224 "{} not recognized. `dtype` should be set to either `torch.float32` or `torch.float16`".format( 246 "{} not recognized. `dtype` should be set to either `torch.float32` or `torch.float16`".format(
...@@ -228,7 +250,9 @@ class ModuleUtilsMixin: ...@@ -228,7 +250,9 @@ class ModuleUtilsMixin:
228 250
229 return encoder_extended_attention_mask 251 return encoder_extended_attention_mask
230 252
231 - def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device) -> Tensor: 253 + def get_extended_attention_mask(
254 + self, attention_mask: Tensor, input_shape: Tuple[int], device: device
255 + ) -> Tensor:
232 """ 256 """
233 Makes broadcastable attention and causal masks so that future and masked tokens are ignored. 257 Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
234 258
...@@ -254,10 +278,15 @@ class ModuleUtilsMixin: ...@@ -254,10 +278,15 @@ class ModuleUtilsMixin:
254 if self.config.is_decoder: 278 if self.config.is_decoder:
255 batch_size, seq_length = input_shape 279 batch_size, seq_length = input_shape
256 seq_ids = torch.arange(seq_length, device=device) 280 seq_ids = torch.arange(seq_length, device=device)
257 - causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] 281 + causal_mask = (
282 + seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
283 + <= seq_ids[None, :, None]
284 + )
258 # causal and attention masks must have same type with pytorch version < 1.3 285 # causal and attention masks must have same type with pytorch version < 1.3
259 causal_mask = causal_mask.to(attention_mask.dtype) 286 causal_mask = causal_mask.to(attention_mask.dtype)
260 - extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] 287 + extended_attention_mask = (
288 + causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
289 + )
261 else: 290 else:
262 extended_attention_mask = attention_mask[:, None, None, :] 291 extended_attention_mask = attention_mask[:, None, None, :]
263 else: 292 else:
...@@ -272,12 +301,17 @@ class ModuleUtilsMixin: ...@@ -272,12 +301,17 @@ class ModuleUtilsMixin:
272 # positions we want to attend and -10000.0 for masked positions. 301 # positions we want to attend and -10000.0 for masked positions.
273 # Since we are adding it to the raw scores before the softmax, this is 302 # Since we are adding it to the raw scores before the softmax, this is
274 # effectively the same as removing these entirely. 303 # effectively the same as removing these entirely.
275 - extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility 304 + extended_attention_mask = extended_attention_mask.to(
305 + dtype=self.dtype
306 + ) # fp16 compatibility
276 extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 307 extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
277 return extended_attention_mask 308 return extended_attention_mask
278 309
279 def get_head_mask( 310 def get_head_mask(
280 - self, head_mask: Optional[Tensor], num_hidden_layers: int, is_attention_chunked: bool = False 311 + self,
312 + head_mask: Optional[Tensor],
313 + num_hidden_layers: int,
314 + is_attention_chunked: bool = False,
281 ) -> Tensor: 315 ) -> Tensor:
282 """ 316 """
283 Prepare the head mask if needed. 317 Prepare the head mask if needed.
...@@ -309,9 +343,13 @@ class ModuleUtilsMixin: ...@@ -309,9 +343,13 @@ class ModuleUtilsMixin:
309 head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 343 head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
310 head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1) 344 head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
311 elif head_mask.dim() == 2: 345 elif head_mask.dim() == 2:
312 - head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer 346 + head_mask = (
347 + head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
348 + ) # We can specify head_mask for each layer
313 assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}" 349 assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
314 - head_mask = head_mask.to(dtype=self.dtype) # switch to fload if need + fp16 compatibility 350 + head_mask = head_mask.to(
351 + dtype=self.dtype
352 + ) # switch to fload if need + fp16 compatibility
315 return head_mask 353 return head_mask
316 354
317 355
...@@ -420,12 +458,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -420,12 +458,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
420 self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) 458 self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
421 459
422 if self.config.is_encoder_decoder and self.config.tie_encoder_decoder: 460 if self.config.is_encoder_decoder and self.config.tie_encoder_decoder:
423 - self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix) 461 + self._tie_encoder_decoder_weights(
462 + self.encoder, self.decoder, self.base_model_prefix
463 + )
424 464
425 @staticmethod 465 @staticmethod
426 - def _tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str): 466 + def _tie_encoder_decoder_weights(
467 + encoder: nn.Module, decoder: nn.Module, base_model_prefix: str
468 + ):
427 uninitialized_encoder_weights: List[str] = [] 469 uninitialized_encoder_weights: List[str] = []
428 - assert decoder.__class__ == encoder.__class__, f"{decoder.__class__} and {encoder.__class__} have to be equal." 470 + assert (
471 + decoder.__class__ == encoder.__class__
472 + ), f"{decoder.__class__} and {encoder.__class__} have to be equal."
429 473
430 def tie_encoder_to_decoder_recursively( 474 def tie_encoder_to_decoder_recursively(
431 decoder_pointer: nn.Module, 475 decoder_pointer: nn.Module,
...@@ -452,13 +496,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -452,13 +496,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
452 len(encoder_modules) > 0 496 len(encoder_modules) > 0
453 ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" 497 ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
454 498
455 - all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()]) 499 + all_encoder_weights = set(
500 + [
501 + module_name + "/" + sub_name
502 + for sub_name in encoder_modules.keys()
503 + ]
504 + )
456 encoder_layer_pos = 0 505 encoder_layer_pos = 0
457 for name, module in decoder_modules.items(): 506 for name, module in decoder_modules.items():
458 if name.isdigit(): 507 if name.isdigit():
459 encoder_name = str(int(name) + encoder_layer_pos) 508 encoder_name = str(int(name) + encoder_layer_pos)
460 decoder_name = name 509 decoder_name = name
461 - if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])): 510 + if not isinstance(
511 + decoder_modules[decoder_name],
512 + type(encoder_modules[encoder_name]),
513 + ):
462 # this can happen if the name corresponds to the position in a list module list of layers 514 # this can happen if the name corresponds to the position in a list module list of layers
463 # in this case the decoder has added a cross-attention that the encoder does not have 515 # in this case the decoder has added a cross-attention that the encoder does not have
464 # thus skip this step and substract one layer pos from encoder 516 # thus skip this step and substract one layer pos from encoder
...@@ -484,7 +536,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -484,7 +536,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
484 uninitialized_encoder_weights += list(all_encoder_weights) 536 uninitialized_encoder_weights += list(all_encoder_weights)
485 537
486 # tie weights recursively 538 # tie weights recursively
487 - tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights) 539 + tie_encoder_to_decoder_recursively(
540 + decoder, encoder, base_model_prefix, uninitialized_encoder_weights
541 + )
488 if len(uninitialized_encoder_weights) > 0: 542 if len(uninitialized_encoder_weights) > 0:
489 logger.warning( 543 logger.warning(
490 f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}" 544 f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}"
...@@ -507,10 +561,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -507,10 +561,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
507 "constant", 561 "constant",
508 0, 562 0,
509 ) 563 )
510 - if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"): 564 + if hasattr(output_embeddings, "out_features") and hasattr(
565 + input_embeddings, "num_embeddings"
566 + ):
511 output_embeddings.out_features = input_embeddings.num_embeddings 567 output_embeddings.out_features = input_embeddings.num_embeddings
512 568
513 - def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> torch.nn.Embedding: 569 + def resize_token_embeddings(
570 + self, new_num_tokens: Optional[int] = None
571 + ) -> torch.nn.Embedding:
514 """ 572 """
515 Resizes input token embeddings matrix of the model if :obj:`new_num_tokens != config.vocab_size`. 573 Resizes input token embeddings matrix of the model if :obj:`new_num_tokens != config.vocab_size`.
516 574
...@@ -526,7 +584,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -526,7 +584,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
526 Return: 584 Return:
527 :obj:`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model. 585 :obj:`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
528 """ 586 """
529 - base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed 587 + base_model = getattr(
588 + self, self.base_model_prefix, self
589 + ) # get the base model if needed
530 model_embeds = base_model._resize_token_embeddings(new_num_tokens) 590 model_embeds = base_model._resize_token_embeddings(new_num_tokens)
531 if new_num_tokens is None: 591 if new_num_tokens is None:
532 return model_embeds 592 return model_embeds
...@@ -583,7 +643,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -583,7 +643,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
583 643
584 # Copy token embeddings from the previous weights 644 # Copy token embeddings from the previous weights
585 num_tokens_to_copy = min(old_num_tokens, new_num_tokens) 645 num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
586 - new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :] 646 + new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[
647 + :num_tokens_to_copy, :
648 + ]
587 649
588 return new_embeddings 650 return new_embeddings
589 651
...@@ -614,7 +676,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -614,7 +676,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
614 # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads 676 # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
615 for layer, heads in heads_to_prune.items(): 677 for layer, heads in heads_to_prune.items():
616 union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads) 678 union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
617 - self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON 679 + self.config.pruned_heads[layer] = list(
680 + union_heads
681 + ) # Unfortunately we have to store it as list for JSON
618 682
619 self.base_model._prune_heads(heads_to_prune) 683 self.base_model._prune_heads(heads_to_prune)
620 684
...@@ -628,7 +692,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -628,7 +692,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
628 Directory to which to save. Will be created if it doesn't exist. 692 Directory to which to save. Will be created if it doesn't exist.
629 """ 693 """
630 if os.path.isfile(save_directory): 694 if os.path.isfile(save_directory):
631 - logger.error("Provided path ({}) should be a directory, not a file".format(save_directory)) 695 + logger.error(
696 + "Provided path ({}) should be a directory, not a file".format(
697 + save_directory
698 + )
699 + )
632 return 700 return
633 os.makedirs(save_directory, exist_ok=True) 701 os.makedirs(save_directory, exist_ok=True)
634 702
...@@ -775,7 +843,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -775,7 +843,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
775 843
776 # Load config if we don't provide a configuration 844 # Load config if we don't provide a configuration
777 if not isinstance(config, PretrainedConfig): 845 if not isinstance(config, PretrainedConfig):
778 - config_path = config if config is not None else pretrained_model_name_or_path 846 + config_path = (
847 + config if config is not None else pretrained_model_name_or_path
848 + )
779 config, model_kwargs = cls.config_class.from_pretrained( 849 config, model_kwargs = cls.config_class.from_pretrained(
780 config_path, 850 config_path,
781 *model_args, 851 *model_args,
...@@ -793,23 +863,43 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -793,23 +863,43 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
793 # Load model 863 # Load model
794 if pretrained_model_name_or_path is not None: 864 if pretrained_model_name_or_path is not None:
795 if os.path.isdir(pretrained_model_name_or_path): 865 if os.path.isdir(pretrained_model_name_or_path):
796 - if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")): 866 + if from_tf and os.path.isfile(
867 + os.path.join(
868 + pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index"
869 + )
870 + ):
797 # Load from a TF 1.0 checkpoint 871 # Load from a TF 1.0 checkpoint
798 - archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index") 872 + archive_file = os.path.join(
799 - elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): 873 + pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index"
874 + )
875 + elif from_tf and os.path.isfile(
876 + os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
877 + ):
800 # Load from a TF 2.0 checkpoint 878 # Load from a TF 2.0 checkpoint
801 - archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) 879 + archive_file = os.path.join(
802 - elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): 880 + pretrained_model_name_or_path, TF2_WEIGHTS_NAME
881 + )
882 + elif os.path.isfile(
883 + os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
884 + ):
803 # Load from a PyTorch checkpoint 885 # Load from a PyTorch checkpoint
804 - archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) 886 + archive_file = os.path.join(
887 + pretrained_model_name_or_path, WEIGHTS_NAME
888 + )
805 else: 889 else:
806 raise EnvironmentError( 890 raise EnvironmentError(
807 "Error no file named {} found in directory {} or `from_tf` set to False".format( 891 "Error no file named {} found in directory {} or `from_tf` set to False".format(
808 - [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"], 892 + [
893 + WEIGHTS_NAME,
894 + TF2_WEIGHTS_NAME,
895 + TF_WEIGHTS_NAME + ".index",
896 + ],
809 pretrained_model_name_or_path, 897 pretrained_model_name_or_path,
810 ) 898 )
811 ) 899 )
812 - elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): 900 + elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(
901 + pretrained_model_name_or_path
902 + ):
813 archive_file = pretrained_model_name_or_path 903 archive_file = pretrained_model_name_or_path
814 elif os.path.isfile(pretrained_model_name_or_path + ".index"): 904 elif os.path.isfile(pretrained_model_name_or_path + ".index"):
815 assert ( 905 assert (
...@@ -848,7 +938,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -848,7 +938,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
848 if resolved_archive_file == archive_file: 938 if resolved_archive_file == archive_file:
849 logger.info("loading weights file {}".format(archive_file)) 939 logger.info("loading weights file {}".format(archive_file))
850 else: 940 else:
851 - logger.info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file)) 941 + logger.info(
942 + "loading weights file {} from cache at {}".format(
943 + archive_file, resolved_archive_file
944 + )
945 + )
852 else: 946 else:
853 resolved_archive_file = None 947 resolved_archive_file = None
854 948
...@@ -871,13 +965,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -871,13 +965,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
871 if from_tf: 965 if from_tf:
872 if resolved_archive_file.endswith(".index"): 966 if resolved_archive_file.endswith(".index"):
873 # Load from a TensorFlow 1.X checkpoint - provided by original authors 967 # Load from a TensorFlow 1.X checkpoint - provided by original authors
874 - model = cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index' 968 + model = cls.load_tf_weights(
969 + model, config, resolved_archive_file[:-6]
970 + ) # Remove the '.index'
875 else: 971 else:
876 # Load from our TensorFlow 2.0 checkpoints 972 # Load from our TensorFlow 2.0 checkpoints
877 try: 973 try:
878 from transformers import load_tf2_checkpoint_in_pytorch_model 974 from transformers import load_tf2_checkpoint_in_pytorch_model
879 975
880 - model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True) 976 + model = load_tf2_checkpoint_in_pytorch_model(
977 + model, resolved_archive_file, allow_missing_keys=True
978 + )
881 except ImportError: 979 except ImportError:
882 logger.error( 980 logger.error(
883 "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see " 981 "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
...@@ -909,7 +1007,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -909,7 +1007,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
909 # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants 1007 # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
910 # so we need to apply the function recursively. 1008 # so we need to apply the function recursively.
911 def load(module: nn.Module, prefix=""): 1009 def load(module: nn.Module, prefix=""):
912 - local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 1010 + local_metadata = (
1011 + {} if metadata is None else metadata.get(prefix[:-1], {})
1012 + )
913 module._load_from_state_dict( 1013 module._load_from_state_dict(
914 state_dict, 1014 state_dict,
915 prefix, 1015 prefix,
...@@ -926,7 +1026,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -926,7 +1026,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
926 # Make sure we are able to load base models as well as derived models (with heads) 1026 # Make sure we are able to load base models as well as derived models (with heads)
927 start_prefix = "" 1027 start_prefix = ""
928 model_to_load = model 1028 model_to_load = model
929 - has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()) 1029 + has_prefix_module = any(
1030 + s.startswith(cls.base_model_prefix) for s in state_dict.keys()
1031 + )
930 if not hasattr(model, cls.base_model_prefix) and has_prefix_module: 1032 if not hasattr(model, cls.base_model_prefix) and has_prefix_module:
931 start_prefix = cls.base_model_prefix + "." 1033 start_prefix = cls.base_model_prefix + "."
932 if hasattr(model, cls.base_model_prefix) and not has_prefix_module: 1034 if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
...@@ -937,15 +1039,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -937,15 +1039,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
937 if model.__class__.__name__ != model_to_load.__class__.__name__: 1039 if model.__class__.__name__ != model_to_load.__class__.__name__:
938 base_model_state_dict = model_to_load.state_dict().keys() 1040 base_model_state_dict = model_to_load.state_dict().keys()
939 head_model_state_dict_without_base_prefix = [ 1041 head_model_state_dict_without_base_prefix = [
940 - key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys() 1042 + key.split(cls.base_model_prefix + ".")[-1]
1043 + for key in model.state_dict().keys()
941 ] 1044 ]
942 - missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict) 1045 + missing_keys.extend(
1046 + head_model_state_dict_without_base_prefix - base_model_state_dict
1047 + )
943 1048
944 # Some models may have keys that are not in the state by design, removing them before needlessly warning 1049 # Some models may have keys that are not in the state by design, removing them before needlessly warning
945 # the user. 1050 # the user.
946 if cls.authorized_missing_keys is not None: 1051 if cls.authorized_missing_keys is not None:
947 for pat in cls.authorized_missing_keys: 1052 for pat in cls.authorized_missing_keys:
948 - missing_keys = [k for k in missing_keys if re.search(pat, k) is None] 1053 + missing_keys = [
1054 + k for k in missing_keys if re.search(pat, k) is None
1055 + ]
949 1056
950 if len(unexpected_keys) > 0: 1057 if len(unexpected_keys) > 0:
951 logger.warning( 1058 logger.warning(
...@@ -957,7 +1064,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -957,7 +1064,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
957 f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." 1064 f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
958 ) 1065 )
959 else: 1066 else:
960 - logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") 1067 + logger.info(
1068 + f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n"
1069 + )
961 if len(missing_keys) > 0: 1070 if len(missing_keys) > 0:
962 logger.warning( 1071 logger.warning(
963 f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " 1072 f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
...@@ -990,7 +1099,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -990,7 +1099,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
990 } 1099 }
991 return model, loading_info 1100 return model, loading_info
992 1101
993 - if hasattr(config, "xla_device") and config.xla_device and is_torch_tpu_available(): 1102 + if (
1103 + hasattr(config, "xla_device")
1104 + and config.xla_device
1105 + and is_torch_tpu_available()
1106 + ):
994 import torch_xla.core.xla_model as xm 1107 import torch_xla.core.xla_model as xm
995 1108
996 model = xm.send_cpu_data_to_device(model, xm.xla_device()) 1109 model = xm.send_cpu_data_to_device(model, xm.xla_device())
...@@ -1039,7 +1152,9 @@ class PoolerStartLogits(nn.Module): ...@@ -1039,7 +1152,9 @@ class PoolerStartLogits(nn.Module):
1039 self.dense = nn.Linear(config.hidden_size, 1) 1152 self.dense = nn.Linear(config.hidden_size, 1)
1040 1153
1041 def forward( 1154 def forward(
1042 - self, hidden_states: torch.FloatTensor, p_mask: Optional[torch.FloatTensor] = None 1155 + self,
1156 + hidden_states: torch.FloatTensor,
1157 + p_mask: Optional[torch.FloatTensor] = None,
1043 ) -> torch.FloatTensor: 1158 ) -> torch.FloatTensor:
1044 """ 1159 """
1045 Args: 1160 Args:
...@@ -1112,8 +1227,12 @@ class PoolerEndLogits(nn.Module): ...@@ -1112,8 +1227,12 @@ class PoolerEndLogits(nn.Module):
1112 ), "One of start_states, start_positions should be not None" 1227 ), "One of start_states, start_positions should be not None"
1113 if start_positions is not None: 1228 if start_positions is not None:
1114 slen, hsz = hidden_states.shape[-2:] 1229 slen, hsz = hidden_states.shape[-2:]
1115 - start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) 1230 + start_positions = start_positions[:, None, None].expand(
1116 - start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz) 1231 + -1, -1, hsz
1232 + ) # shape (bsz, 1, hsz)
1233 + start_states = hidden_states.gather(
1234 + -2, start_positions
1235 + ) # shape (bsz, 1, hsz)
1117 start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz) 1236 start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
1118 1237
1119 x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1)) 1238 x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
...@@ -1177,12 +1296,20 @@ class PoolerAnswerClass(nn.Module): ...@@ -1177,12 +1296,20 @@ class PoolerAnswerClass(nn.Module):
1177 start_states is not None or start_positions is not None 1296 start_states is not None or start_positions is not None
1178 ), "One of start_states, start_positions should be not None" 1297 ), "One of start_states, start_positions should be not None"
1179 if start_positions is not None: 1298 if start_positions is not None:
1180 - start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) 1299 + start_positions = start_positions[:, None, None].expand(
1181 - start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz) 1300 + -1, -1, hsz
1301 + ) # shape (bsz, 1, hsz)
1302 + start_states = hidden_states.gather(-2, start_positions).squeeze(
1303 + -2
1304 + ) # shape (bsz, hsz)
1182 1305
1183 if cls_index is not None: 1306 if cls_index is not None:
1184 - cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) 1307 + cls_index = cls_index[:, None, None].expand(
1185 - cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz) 1308 + -1, -1, hsz
1309 + ) # shape (bsz, 1, hsz)
1310 + cls_token_state = hidden_states.gather(-2, cls_index).squeeze(
1311 + -2
1312 + ) # shape (bsz, hsz)
1186 else: 1313 else:
1187 cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz) 1314 cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
1188 1315
...@@ -1241,7 +1368,9 @@ class SQuADHead(nn.Module): ...@@ -1241,7 +1368,9 @@ class SQuADHead(nn.Module):
1241 self.end_logits = PoolerEndLogits(config) 1368 self.end_logits = PoolerEndLogits(config)
1242 self.answer_class = PoolerAnswerClass(config) 1369 self.answer_class = PoolerAnswerClass(config)
1243 1370
1244 - @replace_return_docstrings(output_type=SquadHeadOutput, config_class=PretrainedConfig) 1371 + @replace_return_docstrings(
1372 + output_type=SquadHeadOutput, config_class=PretrainedConfig
1373 + )
1245 def forward( 1374 def forward(
1246 self, 1375 self,
1247 hidden_states: torch.FloatTensor, 1376 hidden_states: torch.FloatTensor,
...@@ -1281,7 +1410,9 @@ class SQuADHead(nn.Module): ...@@ -1281,7 +1410,9 @@ class SQuADHead(nn.Module):
1281 x.squeeze_(-1) 1410 x.squeeze_(-1)
1282 1411
1283 # during training, compute the end logits based on the ground truth of the start position 1412 # during training, compute the end logits based on the ground truth of the start position
1284 - end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask) 1413 + end_logits = self.end_logits(
1414 + hidden_states, start_positions=start_positions, p_mask=p_mask
1415 + )
1285 1416
1286 loss_fct = CrossEntropyLoss() 1417 loss_fct = CrossEntropyLoss()
1287 start_loss = loss_fct(start_logits, start_positions) 1418 start_loss = loss_fct(start_logits, start_positions)
...@@ -1290,7 +1421,9 @@ class SQuADHead(nn.Module): ...@@ -1290,7 +1421,9 @@ class SQuADHead(nn.Module):
1290 1421
1291 if cls_index is not None and is_impossible is not None: 1422 if cls_index is not None and is_impossible is not None:
1292 # Predict answerability from the representation of CLS and START 1423 # Predict answerability from the representation of CLS and START
1293 - cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index) 1424 + cls_logits = self.answer_class(
1425 + hidden_states, start_positions=start_positions, cls_index=cls_index
1426 + )
1294 loss_fct_cls = nn.BCEWithLogitsLoss() 1427 loss_fct_cls = nn.BCEWithLogitsLoss()
1295 cls_loss = loss_fct_cls(cls_logits, is_impossible) 1428 cls_loss = loss_fct_cls(cls_logits, is_impossible)
1296 1429
...@@ -1307,28 +1440,48 @@ class SQuADHead(nn.Module): ...@@ -1307,28 +1440,48 @@ class SQuADHead(nn.Module):
1307 start_top_log_probs, start_top_index = torch.topk( 1440 start_top_log_probs, start_top_index = torch.topk(
1308 start_log_probs, self.start_n_top, dim=-1 1441 start_log_probs, self.start_n_top, dim=-1
1309 ) # shape (bsz, start_n_top) 1442 ) # shape (bsz, start_n_top)
1310 - start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz) 1443 + start_top_index_exp = start_top_index.unsqueeze(-1).expand(
1311 - start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz) 1444 + -1, -1, hsz
1312 - start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz) 1445 + ) # shape (bsz, start_n_top, hsz)
1446 + start_states = torch.gather(
1447 + hidden_states, -2, start_top_index_exp
1448 + ) # shape (bsz, start_n_top, hsz)
1449 + start_states = start_states.unsqueeze(1).expand(
1450 + -1, slen, -1, -1
1451 + ) # shape (bsz, slen, start_n_top, hsz)
1313 1452
1314 hidden_states_expanded = hidden_states.unsqueeze(2).expand_as( 1453 hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
1315 start_states 1454 start_states
1316 ) # shape (bsz, slen, start_n_top, hsz) 1455 ) # shape (bsz, slen, start_n_top, hsz)
1317 p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None 1456 p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
1318 - end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask) 1457 + end_logits = self.end_logits(
1319 - end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top) 1458 + hidden_states_expanded, start_states=start_states, p_mask=p_mask
1459 + )
1460 + end_log_probs = F.softmax(
1461 + end_logits, dim=1
1462 + ) # shape (bsz, slen, start_n_top)
1320 1463
1321 end_top_log_probs, end_top_index = torch.topk( 1464 end_top_log_probs, end_top_index = torch.topk(
1322 end_log_probs, self.end_n_top, dim=1 1465 end_log_probs, self.end_n_top, dim=1
1323 ) # shape (bsz, end_n_top, start_n_top) 1466 ) # shape (bsz, end_n_top, start_n_top)
1324 - end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top) 1467 + end_top_log_probs = end_top_log_probs.view(
1468 + -1, self.start_n_top * self.end_n_top
1469 + )
1325 end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top) 1470 end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
1326 1471
1327 start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) 1472 start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
1328 - cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) 1473 + cls_logits = self.answer_class(
1474 + hidden_states, start_states=start_states, cls_index=cls_index
1475 + )
1329 1476
1330 if not return_dict: 1477 if not return_dict:
1331 - return (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) 1478 + return (
1479 + start_top_log_probs,
1480 + start_top_index,
1481 + end_top_log_probs,
1482 + end_top_index,
1483 + cls_logits,
1484 + )
1332 else: 1485 else:
1333 return SquadHeadOutput( 1486 return SquadHeadOutput(
1334 start_top_log_probs=start_top_log_probs, 1487 start_top_log_probs=start_top_log_probs,
...@@ -1379,17 +1532,26 @@ class SequenceSummary(nn.Module): ...@@ -1379,17 +1532,26 @@ class SequenceSummary(nn.Module):
1379 1532
1380 self.summary = Identity() 1533 self.summary = Identity()
1381 if hasattr(config, "summary_use_proj") and config.summary_use_proj: 1534 if hasattr(config, "summary_use_proj") and config.summary_use_proj:
1382 - if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: 1535 + if (
1536 + hasattr(config, "summary_proj_to_labels")
1537 + and config.summary_proj_to_labels
1538 + and config.num_labels > 0
1539 + ):
1383 num_classes = config.num_labels 1540 num_classes = config.num_labels
1384 else: 1541 else:
1385 num_classes = config.hidden_size 1542 num_classes = config.hidden_size
1386 self.summary = nn.Linear(config.hidden_size, num_classes) 1543 self.summary = nn.Linear(config.hidden_size, num_classes)
1387 1544
1388 activation_string = getattr(config, "summary_activation", None) 1545 activation_string = getattr(config, "summary_activation", None)
1389 - self.activation: Callable = get_activation(activation_string) if activation_string else Identity() 1546 + self.activation: Callable = get_activation(
1547 + activation_string
1548 + ) if activation_string else Identity()
1390 1549
1391 self.first_dropout = Identity() 1550 self.first_dropout = Identity()
1392 - if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: 1551 + if (
1552 + hasattr(config, "summary_first_dropout")
1553 + and config.summary_first_dropout > 0
1554 + ):
1393 self.first_dropout = nn.Dropout(config.summary_first_dropout) 1555 self.first_dropout = nn.Dropout(config.summary_first_dropout)
1394 1556
1395 self.last_dropout = Identity() 1557 self.last_dropout = Identity()
...@@ -1397,7 +1559,9 @@ class SequenceSummary(nn.Module): ...@@ -1397,7 +1559,9 @@ class SequenceSummary(nn.Module):
1397 self.last_dropout = nn.Dropout(config.summary_last_dropout) 1559 self.last_dropout = nn.Dropout(config.summary_last_dropout)
1398 1560
1399 def forward( 1561 def forward(
1400 - self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None 1562 + self,
1563 + hidden_states: torch.FloatTensor,
1564 + cls_index: Optional[torch.LongTensor] = None,
1401 ) -> torch.FloatTensor: 1565 ) -> torch.FloatTensor:
1402 """ 1566 """
1403 Compute a single vector summary of a sequence hidden states. 1567 Compute a single vector summary of a sequence hidden states.
...@@ -1427,9 +1591,13 @@ class SequenceSummary(nn.Module): ...@@ -1427,9 +1591,13 @@ class SequenceSummary(nn.Module):
1427 ) 1591 )
1428 else: 1592 else:
1429 cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) 1593 cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
1430 - cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) 1594 + cls_index = cls_index.expand(
1595 + (-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)
1596 + )
1431 # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states 1597 # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
1432 - output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) 1598 + output = hidden_states.gather(-2, cls_index).squeeze(
1599 + -2
1600 + ) # shape (bsz, XX, hidden_size)
1433 elif self.summary_type == "attn": 1601 elif self.summary_type == "attn":
1434 raise NotImplementedError 1602 raise NotImplementedError
1435 1603
...@@ -1441,7 +1609,9 @@ class SequenceSummary(nn.Module): ...@@ -1441,7 +1609,9 @@ class SequenceSummary(nn.Module):
1441 return output 1609 return output
1442 1610
1443 1611
1444 -def prune_linear_layer(layer: torch.nn.Linear, index: torch.LongTensor, dim: int = 0) -> torch.nn.Linear: 1612 +def prune_linear_layer(
1613 + layer: torch.nn.Linear, index: torch.LongTensor, dim: int = 0
1614 +) -> torch.nn.Linear:
1445 """ 1615 """
1446 Prune a linear layer to keep only entries in index. 1616 Prune a linear layer to keep only entries in index.
1447 1617
...@@ -1464,7 +1634,9 @@ def prune_linear_layer(layer: torch.nn.Linear, index: torch.LongTensor, dim: int ...@@ -1464,7 +1634,9 @@ def prune_linear_layer(layer: torch.nn.Linear, index: torch.LongTensor, dim: int
1464 b = layer.bias[index].clone().detach() 1634 b = layer.bias[index].clone().detach()
1465 new_size = list(layer.weight.size()) 1635 new_size = list(layer.weight.size())
1466 new_size[dim] = len(index) 1636 new_size[dim] = len(index)
1467 - new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device) 1637 + new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(
1638 + layer.weight.device
1639 + )
1468 new_layer.weight.requires_grad = False 1640 new_layer.weight.requires_grad = False
1469 new_layer.weight.copy_(W.contiguous()) 1641 new_layer.weight.copy_(W.contiguous())
1470 new_layer.weight.requires_grad = True 1642 new_layer.weight.requires_grad = True
...@@ -1509,7 +1681,9 @@ def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) -> ...@@ -1509,7 +1681,9 @@ def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) ->
1509 1681
1510 1682
1511 def prune_layer( 1683 def prune_layer(
1512 - layer: Union[torch.nn.Linear, Conv1D], index: torch.LongTensor, dim: Optional[int] = None 1684 + layer: Union[torch.nn.Linear, Conv1D],
1685 + index: torch.LongTensor,
1686 + dim: Optional[int] = None,
1513 ) -> Union[torch.nn.Linear, Conv1D]: 1687 ) -> Union[torch.nn.Linear, Conv1D]:
1514 """ 1688 """
1515 Prune a Conv1D or linear layer to keep only entries in index. 1689 Prune a Conv1D or linear layer to keep only entries in index.
...@@ -1534,7 +1708,10 @@ def prune_layer( ...@@ -1534,7 +1708,10 @@ def prune_layer(
1534 1708
1535 1709
1536 def apply_chunking_to_forward( 1710 def apply_chunking_to_forward(
1537 - forward_fn: Callable[..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors 1711 + forward_fn: Callable[..., torch.Tensor],
1712 + chunk_size: int,
1713 + chunk_dim: int,
1714 + *input_tensors,
1538 ) -> torch.Tensor: 1715 ) -> torch.Tensor:
1539 """ 1716 """
1540 This function chunks the :obj:`input_tensors` into smaller input tensor parts of size :obj:`chunk_size` over the 1717 This function chunks the :obj:`input_tensors` into smaller input tensor parts of size :obj:`chunk_size` over the
...@@ -1568,7 +1745,9 @@ def apply_chunking_to_forward( ...@@ -1568,7 +1745,9 @@ def apply_chunking_to_forward(
1568 return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states) 1745 return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
1569 """ 1746 """
1570 1747
1571 - assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors) 1748 + assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(
1749 + input_tensors
1750 + )
1572 tensor_shape = input_tensors[0].shape 1751 tensor_shape = input_tensors[0].shape
1573 assert all( 1752 assert all(
1574 input_tensor.shape == tensor_shape for input_tensor in input_tensors 1753 input_tensor.shape == tensor_shape for input_tensor in input_tensors
...@@ -1592,9 +1771,15 @@ def apply_chunking_to_forward( ...@@ -1592,9 +1771,15 @@ def apply_chunking_to_forward(
1592 num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size 1771 num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
1593 1772
1594 # chunk input tensor into tuples 1773 # chunk input tensor into tuples
1595 - input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors) 1774 + input_tensors_chunks = tuple(
1775 + input_tensor.chunk(num_chunks, dim=chunk_dim)
1776 + for input_tensor in input_tensors
1777 + )
1596 # apply forward fn to every tuple 1778 # apply forward fn to every tuple
1597 - output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks)) 1779 + output_chunks = tuple(
1780 + forward_fn(*input_tensors_chunk)
1781 + for input_tensors_chunk in zip(*input_tensors_chunks)
1782 + )
1598 # concatenate output at same dimension 1783 # concatenate output at same dimension
1599 return torch.cat(output_chunks, dim=chunk_dim) 1784 return torch.cat(output_chunks, dim=chunk_dim)
1600 1785
......
...@@ -39,9 +39,13 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): ...@@ -39,9 +39,13 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
39 return loss, nll_loss 39 return loss, nll_loss
40 40
41 41
42 -def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"): 42 +def encode_line(
43 + tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"
44 +):
43 """Only used by LegacyDataset""" 45 """Only used by LegacyDataset"""
44 - extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {} 46 + extra_kw = (
47 + {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
48 + )
45 return tokenizer( 49 return tokenizer(
46 [line], 50 [line],
47 max_length=max_length, 51 max_length=max_length,
...@@ -63,9 +67,7 @@ def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict: ...@@ -63,9 +67,7 @@ def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict:
63 67
64 68
65 def trim_batch( 69 def trim_batch(
66 - input_ids, 70 + input_ids, pad_token_id, attention_mask=None,
67 - pad_token_id,
68 - attention_mask=None,
69 ): 71 ):
70 """Remove columns that are populated exclusively by pad_token_id""" 72 """Remove columns that are populated exclusively by pad_token_id"""
71 keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) 73 keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
...@@ -125,7 +127,9 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset): ...@@ -125,7 +127,9 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
125 def __getitem__(self, index) -> Dict[str, torch.Tensor]: 127 def __getitem__(self, index) -> Dict[str, torch.Tensor]:
126 """Call tokenizer on src and tgt_lines""" 128 """Call tokenizer on src and tgt_lines"""
127 index = index + 1 # linecache starts at 1 129 index = index + 1 # linecache starts at 1
128 - source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") 130 + source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip(
131 + "\n"
132 + )
129 tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") 133 tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
130 assert source_line, f"empty source line for index {index}" 134 assert source_line, f"empty source line for index {index}"
131 assert tgt_line, f"empty tgt line for index {index}" 135 assert tgt_line, f"empty tgt line for index {index}"
...@@ -147,7 +151,9 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset): ...@@ -147,7 +151,9 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
147 target_ids = torch.stack([x["labels"] for x in batch]) 151 target_ids = torch.stack([x["labels"] for x in batch])
148 pad_token_id = self.pad_token_id 152 pad_token_id = self.pad_token_id
149 y = trim_batch(target_ids, pad_token_id) 153 y = trim_batch(target_ids, pad_token_id)
150 - source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks) 154 + source_ids, source_mask = trim_batch(
155 + input_ids, pad_token_id, attention_mask=masks
156 + )
151 batch = { 157 batch = {
152 "input_ids": source_ids, 158 "input_ids": source_ids,
153 "attention_mask": source_mask, 159 "attention_mask": source_mask,
...@@ -161,7 +167,9 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset): ...@@ -161,7 +167,9 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
161 167
162 def __getitem__(self, index) -> Dict[str, str]: 168 def __getitem__(self, index) -> Dict[str, str]:
163 index = index + 1 # linecache starts at 1 169 index = index + 1 # linecache starts at 1
164 - source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") 170 + source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip(
171 + "\n"
172 + )
165 tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") 173 tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
166 assert source_line, f"empty source line for index {index}" 174 assert source_line, f"empty source line for index {index}"
167 assert tgt_line, f"empty tgt line for index {index}" 175 assert tgt_line, f"empty tgt line for index {index}"
...@@ -201,12 +209,23 @@ class SortishSampler(Sampler): ...@@ -201,12 +209,23 @@ class SortishSampler(Sampler):
201 idxs = np.random.permutation(len(self.data)) 209 idxs = np.random.permutation(len(self.data))
202 sz = self.bs * 50 210 sz = self.bs * 50
203 ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)] 211 ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
204 - sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx]) 212 + sort_idx = np.concatenate(
213 + [sorted(s, key=self.key, reverse=True) for s in ck_idx]
214 + )
205 sz = self.bs 215 sz = self.bs
206 ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)] 216 ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
207 - max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx]) # find the chunk with the largest key, 217 + max_ck = np.argmax(
208 - ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first. 218 + [self.key(ck[0]) for ck in ck_idx]
209 - sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int) 219 + ) # find the chunk with the largest key,
220 + ck_idx[0], ck_idx[max_ck] = (
221 + ck_idx[max_ck],
222 + ck_idx[0],
223 + ) # then make sure it goes first.
224 + sort_idx = (
225 + np.concatenate(np.random.permutation(ck_idx[1:]))
226 + if len(ck_idx) > 1
227 + else np.array([], dtype=np.int)
228 + )
210 sort_idx = np.concatenate((ck_idx[0], sort_idx)) 229 sort_idx = np.concatenate((ck_idx[0], sort_idx))
211 return iter(sort_idx) 230 return iter(sort_idx)
212 231
...@@ -269,7 +288,9 @@ def get_git_info(): ...@@ -269,7 +288,9 @@ def get_git_info():
269 ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"] 288 ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"]
270 289
271 290
272 -def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer=True) -> Dict: 291 +def calculate_rouge(
292 + output_lns: List[str], reference_lns: List[str], use_stemmer=True
293 +) -> Dict:
273 scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer) 294 scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
274 aggregator = scoring.BootstrapAggregator() 295 aggregator = scoring.BootstrapAggregator()
275 296
...@@ -302,7 +323,9 @@ def assert_all_frozen(model): ...@@ -302,7 +323,9 @@ def assert_all_frozen(model):
302 model_grads: List[bool] = list(grad_status(model)) 323 model_grads: List[bool] = list(grad_status(model))
303 n_require_grad = sum(lmap(int, model_grads)) 324 n_require_grad = sum(lmap(int, model_grads))
304 npars = len(model_grads) 325 npars = len(model_grads)
305 - assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad" 326 + assert not any(
327 + model_grads
328 + ), f"{n_require_grad/npars:.1%} of {npars} weights require grad"
306 329
307 330
308 def assert_not_all_frozen(model): 331 def assert_not_all_frozen(model):
......