Showing
13 changed files
with
0 additions
and
553 deletions
commit_suggester.py
deleted
100644 → 0
1 | -# Copyright 2020-present Tae Hwan Jung | ||
2 | -# | ||
3 | -# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | -# you may not use this file except in compliance with the License. | ||
5 | -# You may obtain a copy of the License at | ||
6 | -# | ||
7 | -# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | -# | ||
9 | -# Unless required by applicable law or agreed to in writing, software | ||
10 | -# distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | -# See the License for the specific language governing permissions and | ||
13 | -# limitations under the License. | ||
14 | - | ||
15 | -import torch | ||
16 | -import argparse | ||
17 | -import subprocess | ||
18 | -from transformers import AutoTokenizer | ||
19 | - | ||
20 | -from preprocess import diff_parse, truncate | ||
21 | -from train import BartForConditionalGeneration | ||
22 | - | ||
23 | -def get_length(chunks): | ||
24 | - cnt = 0 | ||
25 | - for chunk in chunks: | ||
26 | - cnt += len(chunk) | ||
27 | - return cnt | ||
28 | - | ||
29 | -def suggester(chunks, model, tokenizer, device): | ||
30 | - max_source_length = get_length(chunks) | ||
31 | - | ||
32 | - input_ids, attention_masks, patch_ids = zip(*chunks) | ||
33 | - input_ids = torch.LongTensor( | ||
34 | - [truncate(input_ids, max_source_length, value=0)] | ||
35 | - ).to(device) | ||
36 | - attention_masks = torch.LongTensor( | ||
37 | - [truncate(attention_masks, max_source_length, value=1)] | ||
38 | - ).to(device) | ||
39 | - patch_ids = torch.LongTensor( | ||
40 | - [truncate(patch_ids, max_source_length, value=0)] | ||
41 | - ).to(device) | ||
42 | - | ||
43 | - summaries = model.generate( | ||
44 | - input_ids=input_ids, patch_ids=patch_ids, attention_mask=attention_masks | ||
45 | - ) | ||
46 | - return tokenizer.batch_decode( | ||
47 | - summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False | ||
48 | - ) | ||
49 | - | ||
50 | - | ||
51 | -def main(args): | ||
52 | - device = torch.device( | ||
53 | - "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" | ||
54 | - ) | ||
55 | - model = BartForConditionalGeneration.from_pretrained(args.output_dir).to(device) | ||
56 | - | ||
57 | - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name) | ||
58 | - | ||
59 | - if args.unittest: | ||
60 | - with open("test.source", "r") as f: | ||
61 | - chunks = diff_parse(f.read(), tokenizer) | ||
62 | - else: | ||
63 | - proc = subprocess.Popen(["git", "diff", "--cached"], stdout=subprocess.PIPE) | ||
64 | - staged_files = proc.stdout.readlines() | ||
65 | - staged_files = [f.decode("utf-8") for f in staged_files] | ||
66 | - staged_files = [f.strip() for f in staged_files] | ||
67 | - chunks = "\n".join(staged_files) | ||
68 | - | ||
69 | - chunks = diff_parse(chunks, tokenizer) | ||
70 | - if not chunks: | ||
71 | - print('There is no file in staged state.') | ||
72 | - return | ||
73 | - | ||
74 | - commit_message = suggester( | ||
75 | - chunks, | ||
76 | - model=model, | ||
77 | - tokenizer=tokenizer, | ||
78 | - device=device, | ||
79 | - ) | ||
80 | - print(commit_message) | ||
81 | - | ||
82 | - | ||
83 | -if __name__ == "__main__": | ||
84 | - parser = argparse.ArgumentParser(description="Code to collect commits on github") | ||
85 | - parser.add_argument( | ||
86 | - "--no_cuda", action="store_true", help="Whether not to use CUDA when available" | ||
87 | - ) | ||
88 | - parser.add_argument( | ||
89 | - "--unittest", action="store_true", help="Unittest with an one batch git diff" | ||
90 | - ) | ||
91 | - parser.add_argument( | ||
92 | - "--output_dir", | ||
93 | - type=str, | ||
94 | - required=True, | ||
95 | - help="The output directory where the model predictions and checkpoints will be written.", | ||
96 | - ) | ||
97 | - parser.add_argument( | ||
98 | - "--tokenizer_name", | ||
99 | - default="sshleifer/distilbart-xsum-6-6", | ||
100 | - type=str, | ||
101 | - help="Pretrained tokenizer name or path if not the same as model_name", | ||
102 | - ) | ||
103 | - args = parser.parse_args() | ||
104 | - | ||
105 | - main(args) |
preprocess/__init__.py
deleted
100644 → 0
1 | -# Copyright 2020-present Tae Hwan Jung | ||
2 | -# | ||
3 | -# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | -# you may not use this file except in compliance with the License. | ||
5 | -# You may obtain a copy of the License at | ||
6 | -# | ||
7 | -# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | -# | ||
9 | -# Unless required by applicable law or agreed to in writing, software | ||
10 | -# distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | -# See the License for the specific language governing permissions and | ||
13 | -# limitations under the License. | ||
14 | - | ||
15 | -from .gitcommit import diff_parse, truncate | ||
16 | - | ||
17 | -__all__ = [ | ||
18 | - "diff_parse", | ||
19 | - "truncate", | ||
20 | -] |
preprocess/gitcommit.py
deleted
100644 → 0
1 | -# Copyright 2020-present Tae Hwan Jung | ||
2 | -# | ||
3 | -# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | -# you may not use this file except in compliance with the License. | ||
5 | -# You may obtain a copy of the License at | ||
6 | -# | ||
7 | -# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | -# | ||
9 | -# Unless required by applicable law or agreed to in writing, software | ||
10 | -# distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | -# See the License for the specific language governing permissions and | ||
13 | -# limitations under the License. | ||
14 | - | ||
15 | -import os | ||
16 | -import re | ||
17 | -import enum | ||
18 | -import random | ||
19 | -import logging | ||
20 | -import tempfile | ||
21 | -import argparse | ||
22 | -import numpy as np | ||
23 | -from tqdm import * | ||
24 | -import whatthepatch | ||
25 | -from git import Repo | ||
26 | -from functools import partial | ||
27 | -from multiprocessing.pool import Pool | ||
28 | -from transformers import AutoTokenizer | ||
29 | - | ||
30 | -from matorage import * | ||
31 | - | ||
32 | -logger = logging.getLogger(__name__) # pylint: disable=invalid-name | ||
33 | -logging.basicConfig( | ||
34 | - format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s", | ||
35 | - datefmt="%m/%d/%Y %H:%M:%S", | ||
36 | - level=logging.INFO, | ||
37 | -) | ||
38 | - | ||
39 | - | ||
40 | -class PATCH(enum.Enum): | ||
41 | - PLUS = 1 | ||
42 | - MINUS = 2 | ||
43 | - | ||
44 | - | ||
45 | -def truncate(tuple, max_length, value=0): | ||
46 | - ls = [] | ||
47 | - for t in tuple: | ||
48 | - if isinstance(t, int): | ||
49 | - t = [t] | ||
50 | - ls.extend(t) | ||
51 | - ls = ls[: max_length - 1] | ||
52 | - ls.insert(0, value) | ||
53 | - if len(ls) < max_length: | ||
54 | - ls.extend([0] * (max_length - len(ls))) | ||
55 | - assert len(ls) == max_length | ||
56 | - return ls | ||
57 | - | ||
58 | - | ||
59 | -def encode_line(tokenizer, line, patch): | ||
60 | - line = re.sub(r"[\u0100-\uFFFF\U00010000-\U0010FFFF]+", "", line).strip() | ||
61 | - tokens = tokenizer.tokenize(line) | ||
62 | - tokens = tokenizer.convert_tokens_to_ids(tokens) | ||
63 | - return (tokens, [1] * len(tokens), len(tokens) * [patch.value]) | ||
64 | - | ||
65 | - | ||
66 | -def diff_parse(diff, tokenizer): | ||
67 | - chunks = [] | ||
68 | - for diff in whatthepatch.parse_patch(diff): | ||
69 | - if diff.header.old_path != diff.header.new_path: | ||
70 | - chunks.append(encode_line(tokenizer, diff.header.old_path, PATCH.MINUS)) | ||
71 | - chunks.append(encode_line(tokenizer, diff.header.new_path, PATCH.PLUS)) | ||
72 | - if not diff.changes: | ||
73 | - continue | ||
74 | - for change in diff.changes: | ||
75 | - if change.old == None and change.new != None: | ||
76 | - chunks.append(encode_line(tokenizer, change.line, PATCH.PLUS)) | ||
77 | - elif change.old != None and change.new == None: | ||
78 | - chunks.append(encode_line(tokenizer, change.line, PATCH.MINUS)) | ||
79 | - return chunks | ||
80 | - | ||
81 | - | ||
82 | -def sha_parse(sha, tokenizer, max_length=1024): | ||
83 | - | ||
84 | - chunks = diff_parse(diff=repo.git.show(sha), tokenizer=tokenizer) | ||
85 | - if not chunks: | ||
86 | - return None | ||
87 | - | ||
88 | - input_ids, attention_masks, patch_ids = zip(*chunks) | ||
89 | - input_ids = truncate(input_ids, max_length, value=0) | ||
90 | - attention_masks = truncate(attention_masks, max_length, value=1) | ||
91 | - patch_ids = truncate(patch_ids, max_length, value=0) | ||
92 | - | ||
93 | - return (input_ids, attention_masks, patch_ids) | ||
94 | - | ||
95 | - | ||
96 | -def message_parse(msg, tokenizer, max_length=56): | ||
97 | - msg = re.sub(r"(\(|)#([0-9])+(\)|)", "", msg) | ||
98 | - | ||
99 | - msg = re.sub(r"[\u0100-\uFFFF\U00010000-\U0010FFFF]+", "", msg).strip() | ||
100 | - msg = tokenizer.tokenize(msg) | ||
101 | - msg = tokenizer.convert_tokens_to_ids(msg) | ||
102 | - msg = truncate(msg, max_length, value=0) | ||
103 | - | ||
104 | - return msg | ||
105 | - | ||
106 | - | ||
107 | -def jobs(sha_msgs, args, data_config, train=True): | ||
108 | - | ||
109 | - input_ids, attention_masks, patch_ids, targets = [], [], [], [] | ||
110 | - data_saver = DataSaver(config=data_config) | ||
111 | - | ||
112 | - for sha_msg in sha_msgs: | ||
113 | - sha, msg = sha_msg | ||
114 | - | ||
115 | - source = sha_parse( | ||
116 | - sha, tokenizer=args.tokenizer, max_length=args.max_source_length | ||
117 | - ) | ||
118 | - if not source: | ||
119 | - continue | ||
120 | - input_id, attention_mask, patch_id = source | ||
121 | - target = message_parse( | ||
122 | - msg, | ||
123 | - tokenizer=args.tokenizer, | ||
124 | - max_length=( | ||
125 | - args.max_target_length if train else args.val_max_target_length | ||
126 | - ), | ||
127 | - ) | ||
128 | - | ||
129 | - input_ids.append(input_id) | ||
130 | - attention_masks.append(attention_mask) | ||
131 | - patch_ids.append(patch_id) | ||
132 | - targets.append(target) | ||
133 | - | ||
134 | - data_saver( | ||
135 | - { | ||
136 | - "input_ids": np.asarray(input_ids), | ||
137 | - "attention_masks": np.asarray(attention_masks), | ||
138 | - "patch_ids": np.asarray(patch_ids), | ||
139 | - "targets": np.asarray(targets), | ||
140 | - } | ||
141 | - ) | ||
142 | - data_saver.disconnect() | ||
143 | - | ||
144 | - | ||
145 | -def start(chunked_sha_msgs, train=True): | ||
146 | - | ||
147 | - logger.info(f"Start %s pre-processing" % ("training" if train else "evaluation")) | ||
148 | - | ||
149 | - max_target_length = args.max_target_length if train else args.val_max_target_length | ||
150 | - | ||
151 | - data_config = DataConfig( | ||
152 | - endpoint=args.endpoint, | ||
153 | - access_key=os.environ["access_key"], | ||
154 | - secret_key=os.environ["secret_key"], | ||
155 | - region=args.region, | ||
156 | - dataset_name="commit-autosuggestions", | ||
157 | - additional={ | ||
158 | - "mode": ("training" if train else "evaluation"), | ||
159 | - "max_source_length": args.max_source_length, | ||
160 | - "max_target_length": max_target_length, | ||
161 | - "url": args.url, | ||
162 | - }, | ||
163 | - attributes=[ | ||
164 | - ("input_ids", "int32", (args.max_source_length,)), | ||
165 | - ("attention_masks", "int32", (args.max_source_length,)), | ||
166 | - ("patch_ids", "int32", (args.max_source_length,)), | ||
167 | - ("targets", "int32", (max_target_length,)), | ||
168 | - ], | ||
169 | - ) | ||
170 | - | ||
171 | - func = partial(jobs, args=args, data_config=data_config, train=train) | ||
172 | - with Pool(processes=args.num_workers) as pool: | ||
173 | - with tqdm(total=len(chunked_sha_msgs)) as pbar: | ||
174 | - for i, _ in tqdm(enumerate(pool.imap_unordered(func, chunked_sha_msgs))): | ||
175 | - pbar.update() | ||
176 | - | ||
177 | - | ||
178 | -def main(args): | ||
179 | - if "access_key" not in os.environ or "secret_key" not in os.environ: | ||
180 | - raise OSError("access_key or secret_key are not found.") | ||
181 | - | ||
182 | - sha_msgs = [(c.hexsha, c.summary) for c in repo.iter_commits()] | ||
183 | - random.shuffle(sha_msgs) | ||
184 | - chunked_sha_msgs = [ | ||
185 | - sha_msgs[x : x + args.matorage_batch] | ||
186 | - for x in range(0, len(sha_msgs), args.matorage_batch) | ||
187 | - ] | ||
188 | - | ||
189 | - barrier = int(len(chunked_sha_msgs) * (1 - args.p_val)) | ||
190 | - if args.do_train: | ||
191 | - start(chunked_sha_msgs[:barrier], train=True) | ||
192 | - if args.do_predict: | ||
193 | - start(chunked_sha_msgs[barrier:], train=False) | ||
194 | - | ||
195 | - | ||
196 | -if __name__ == "__main__": | ||
197 | - parser = argparse.ArgumentParser(description="Code to collect commits on github") | ||
198 | - parser.add_argument("--url", type=str, required=True, help="github url") | ||
199 | - parser.add_argument( | ||
200 | - "--endpoint", | ||
201 | - type=str, | ||
202 | - required=True, | ||
203 | - help="matorage endpoint, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html", | ||
204 | - ) | ||
205 | - parser.add_argument( | ||
206 | - "--region", | ||
207 | - type=str, | ||
208 | - default=None, | ||
209 | - help="matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html", | ||
210 | - ) | ||
211 | - parser.add_argument( | ||
212 | - "--tokenizer_name", | ||
213 | - default="sshleifer/distilbart-xsum-6-6", | ||
214 | - type=str, | ||
215 | - help="Pretrained tokenizer name or path if not the same as model_name", | ||
216 | - ) | ||
217 | - parser.add_argument( | ||
218 | - "--matorage_batch", | ||
219 | - default=1024, | ||
220 | - type=int, | ||
221 | - help="The smallest batch size stored atomically in matorage.", | ||
222 | - ) | ||
223 | - parser.add_argument( | ||
224 | - "--num_workers", default=4, type=int, help="number of process", | ||
225 | - ) | ||
226 | - parser.add_argument( | ||
227 | - "--max_source_length", | ||
228 | - default=1024, | ||
229 | - type=int, | ||
230 | - help="The maximum total input sequence length after tokenization. Sequences longer " | ||
231 | - "than this will be truncated, sequences shorter will be padded.", | ||
232 | - ) | ||
233 | - parser.add_argument( | ||
234 | - "--max_target_length", | ||
235 | - default=56, | ||
236 | - type=int, | ||
237 | - help="The maximum total input sequence length after tokenization. Sequences longer " | ||
238 | - "than this will be truncated, sequences shorter will be padded.", | ||
239 | - ) | ||
240 | - parser.add_argument( | ||
241 | - "--val_max_target_length", | ||
242 | - default=142, # these defaults are optimized for CNNDM. For xsum, see README.md. | ||
243 | - type=int, | ||
244 | - help="The maximum total input sequence length after tokenization. Sequences longer " | ||
245 | - "than this will be truncated, sequences shorter will be padded.", | ||
246 | - ) | ||
247 | - parser.add_argument( | ||
248 | - "--p_val", type=float, default=0.25, help="percent of validation dataset" | ||
249 | - ) | ||
250 | - parser.add_argument("--do_train", action="store_true", default=False) | ||
251 | - parser.add_argument("--do_predict", action="store_true", default=False) | ||
252 | - args = parser.parse_args() | ||
253 | - | ||
254 | - args.local_path = args.url.split("/")[-1] | ||
255 | - logger.info(f"master branch of {args.url} will be downloaded to {args.local_path}") | ||
256 | - repo = ( | ||
257 | - Repo(args.local_path) | ||
258 | - if os.path.exists(args.local_path) | ||
259 | - else Repo.clone_from(args.url, to_path=args.local_path, branch="master") | ||
260 | - ) | ||
261 | - args.tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name) | ||
262 | - | ||
263 | - main(args) |
test.source
deleted
100644 → 0
This diff is collapsed. Click to expand it.
train.py
deleted
100644 → 0
1 | -# Copyright 2020-present Tae Hwan Jung | ||
2 | -# | ||
3 | -# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | -# you may not use this file except in compliance with the License. | ||
5 | -# You may obtain a copy of the License at | ||
6 | -# | ||
7 | -# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | -# | ||
9 | -# Unless required by applicable law or agreed to in writing, software | ||
10 | -# distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | -# See the License for the specific language governing permissions and | ||
13 | -# limitations under the License. | ||
14 | - | ||
15 | -import os | ||
16 | -import argparse | ||
17 | -import pytorch_lightning as pl | ||
18 | -from train.finetune import main, SummarizationModule | ||
19 | - | ||
20 | -if __name__ == "__main__": | ||
21 | - parser = argparse.ArgumentParser() | ||
22 | - parser = pl.Trainer.add_argparse_args(parser) | ||
23 | - parser = SummarizationModule.add_model_specific_args(parser, os.getcwd()) | ||
24 | - | ||
25 | - args = parser.parse_args() | ||
26 | - | ||
27 | - main(args) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
train/__init__.py
deleted
100644 → 0
1 | -# Copyright 2020-present Tae Hwan Jung | ||
2 | -# | ||
3 | -# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | -# you may not use this file except in compliance with the License. | ||
5 | -# You may obtain a copy of the License at | ||
6 | -# | ||
7 | -# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | -# | ||
9 | -# Unless required by applicable law or agreed to in writing, software | ||
10 | -# distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | -# See the License for the specific language governing permissions and | ||
13 | -# limitations under the License. | ||
14 | - | ||
15 | -from train.modeling_bart import BartForConditionalGeneration | ||
16 | - | ||
17 | -__all__ = ["BartForConditionalGeneration"] |
train/callbacks.py
deleted
100644 → 0
1 | -import logging | ||
2 | -import os | ||
3 | -from pathlib import Path | ||
4 | - | ||
5 | -import numpy as np | ||
6 | -import pytorch_lightning as pl | ||
7 | -import torch | ||
8 | -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint | ||
9 | -from pytorch_lightning.utilities import rank_zero_only | ||
10 | - | ||
11 | - | ||
12 | -def count_trainable_parameters(model): | ||
13 | - model_parameters = filter(lambda p: p.requires_grad, model.parameters()) | ||
14 | - params = sum([np.prod(p.size()) for p in model_parameters]) | ||
15 | - return params | ||
16 | - | ||
17 | - | ||
18 | -logger = logging.getLogger(__name__) | ||
19 | - | ||
20 | - | ||
21 | -class Seq2SeqLoggingCallback(pl.Callback): | ||
22 | - def on_batch_end(self, trainer, pl_module): | ||
23 | - lrs = { | ||
24 | - f"lr_group_{i}": param["lr"] | ||
25 | - for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups) | ||
26 | - } | ||
27 | - pl_module.logger.log_metrics(lrs) | ||
28 | - | ||
29 | - @rank_zero_only | ||
30 | - def _write_logs( | ||
31 | - self, | ||
32 | - trainer: pl.Trainer, | ||
33 | - pl_module: pl.LightningModule, | ||
34 | - type_path: str, | ||
35 | - save_generations=True, | ||
36 | - ) -> None: | ||
37 | - logger.info( | ||
38 | - f"***** {type_path} results at step {trainer.global_step:05d} *****" | ||
39 | - ) | ||
40 | - metrics = trainer.callback_metrics | ||
41 | - trainer.logger.log_metrics( | ||
42 | - { | ||
43 | - k: v | ||
44 | - for k, v in metrics.items() | ||
45 | - if k not in ["log", "progress_bar", "preds"] | ||
46 | - } | ||
47 | - ) | ||
48 | - # Log results | ||
49 | - od = Path(pl_module.hparams.output_dir) | ||
50 | - if type_path == "test": | ||
51 | - results_file = od / "test_results.txt" | ||
52 | - generations_file = od / "test_generations.txt" | ||
53 | - else: | ||
54 | - # this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json | ||
55 | - # If people want this it will be easy enough to add back. | ||
56 | - results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt" | ||
57 | - generations_file = ( | ||
58 | - od / f"{type_path}_generations/{trainer.global_step:05d}.txt" | ||
59 | - ) | ||
60 | - results_file.parent.mkdir(exist_ok=True) | ||
61 | - generations_file.parent.mkdir(exist_ok=True) | ||
62 | - with open(results_file, "a+") as writer: | ||
63 | - for key in sorted(metrics): | ||
64 | - if key in ["log", "progress_bar", "preds"]: | ||
65 | - continue | ||
66 | - val = metrics[key] | ||
67 | - if isinstance(val, torch.Tensor): | ||
68 | - val = val.item() | ||
69 | - msg = f"{key}: {val:.6f}\n" | ||
70 | - writer.write(msg) | ||
71 | - | ||
72 | - if not save_generations: | ||
73 | - return | ||
74 | - | ||
75 | - if "preds" in metrics: | ||
76 | - content = "\n".join(metrics["preds"]) | ||
77 | - generations_file.open("w+").write(content) | ||
78 | - | ||
79 | - @rank_zero_only | ||
80 | - def on_train_start(self, trainer, pl_module): | ||
81 | - try: | ||
82 | - npars = pl_module.model.model.num_parameters() | ||
83 | - except AttributeError: | ||
84 | - npars = pl_module.model.num_parameters() | ||
85 | - | ||
86 | - n_trainable_pars = count_trainable_parameters(pl_module) | ||
87 | - # mp stands for million parameters | ||
88 | - trainer.logger.log_metrics( | ||
89 | - {"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6} | ||
90 | - ) | ||
91 | - | ||
92 | - @rank_zero_only | ||
93 | - def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): | ||
94 | - return self._write_logs(trainer, pl_module, "test") | ||
95 | - | ||
96 | - | ||
97 | -def get_checkpoint_callback(output_dir, metric): | ||
98 | - """Saves the best model by validation ROUGE2 score.""" | ||
99 | - if metric == "rouge2": | ||
100 | - exp = "{val_avg_rouge2:.4f}-{step_count}" | ||
101 | - elif metric == "bleu": | ||
102 | - exp = "{val_avg_bleu:.4f}-{step_count}" | ||
103 | - else: | ||
104 | - raise NotImplementedError( | ||
105 | - f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function." | ||
106 | - ) | ||
107 | - | ||
108 | - checkpoint_callback = ModelCheckpoint( | ||
109 | - filepath=os.path.join(output_dir, exp), | ||
110 | - monitor=f"val_{metric}", | ||
111 | - mode="max", | ||
112 | - save_top_k=1, | ||
113 | - period=0, # maybe save a checkpoint every time val is run, not just end of epoch. | ||
114 | - ) | ||
115 | - return checkpoint_callback | ||
116 | - | ||
117 | - | ||
118 | -def get_early_stopping_callback(metric, patience): | ||
119 | - return EarlyStopping( | ||
120 | - monitor=f"val_{metric}", mode="max", patience=patience, verbose=True, | ||
121 | - ) |
train/finetune.py
deleted
100644 → 0
This diff is collapsed. Click to expand it.
train/generation_utils.py
deleted
100644 → 0
This diff is collapsed. Click to expand it.
train/lightning_base.py
deleted
100644 → 0
This diff is collapsed. Click to expand it.
train/modeling_bart.py
deleted
100644 → 0
This diff is collapsed. Click to expand it.
train/modeling_utils.py
deleted
100644 → 0
This diff is collapsed. Click to expand it.
train/utils.py
deleted
100644 → 0
This diff is collapsed. Click to expand it.
-
Please register or login to post a comment