Showing
1 changed file
with
100 additions
and
0 deletions
commit_suggester.py
0 → 100644
1 | +# Copyright 2020-present Tae Hwan Jung | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | + | ||
15 | +import torch | ||
16 | +import argparse | ||
17 | +import subprocess | ||
18 | +from transformers import AutoTokenizer | ||
19 | + | ||
20 | +from preprocess import diff_parse, truncate | ||
21 | +from train import BartForConditionalGeneration | ||
22 | + | ||
23 | + | ||
24 | +def suggester(chunks, max_source_length, model, tokenizer, device): | ||
25 | + input_ids, attention_masks, patch_ids = zip(*chunks) | ||
26 | + input_ids = torch.LongTensor([truncate(input_ids, max_source_length, value=0)]).to( | ||
27 | + device | ||
28 | + ) | ||
29 | + attention_masks = torch.LongTensor( | ||
30 | + [truncate(attention_masks, max_source_length, value=1)] | ||
31 | + ).to(device) | ||
32 | + patch_ids = torch.LongTensor([truncate(patch_ids, max_source_length, value=0)]).to( | ||
33 | + device | ||
34 | + ) | ||
35 | + | ||
36 | + summaries = model.generate( | ||
37 | + input_ids=input_ids, patch_ids=patch_ids, attention_mask=attention_masks | ||
38 | + ) | ||
39 | + return tokenizer.batch_decode( | ||
40 | + summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False | ||
41 | + ) | ||
42 | + | ||
43 | + | ||
44 | +def main(args): | ||
45 | + device = torch.device( | ||
46 | + "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" | ||
47 | + ) | ||
48 | + model = BartForConditionalGeneration.from_pretrained(args.output_dir).to(device) | ||
49 | + | ||
50 | + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name) | ||
51 | + | ||
52 | + if args.unittest: | ||
53 | + with open("test.source", "r") as f: | ||
54 | + chunks = diff_parse(f.read(), tokenizer) | ||
55 | + else: | ||
56 | + proc = subprocess.Popen(["git", "diff", "--cached"], stdout=subprocess.PIPE) | ||
57 | + staged_files = proc.stdout.readlines() | ||
58 | + staged_files = [f.decode("utf-8") for f in staged_files] | ||
59 | + staged_files = [f.strip() for f in staged_files] | ||
60 | + chunks = "\n".join(staged_files) | ||
61 | + | ||
62 | + commit_message = suggester( | ||
63 | + chunks, | ||
64 | + max_source_length=args.max_source_length, | ||
65 | + model=model, | ||
66 | + tokenizer=tokenizer, | ||
67 | + device=device, | ||
68 | + ) | ||
69 | + print(commit_message) | ||
70 | + | ||
71 | +if __name__ == "__main__": | ||
72 | + parser = argparse.ArgumentParser(description="Code to collect commits on github") | ||
73 | + parser.add_argument( | ||
74 | + "--no_cuda", action="store_true", help="Whether not to use CUDA when available" | ||
75 | + ) | ||
76 | + parser.add_argument( | ||
77 | + "--unittest", action="store_true", help="Unittest with an one batch git diff" | ||
78 | + ) | ||
79 | + parser.add_argument( | ||
80 | + "--output_dir", | ||
81 | + type=str, | ||
82 | + required=True, | ||
83 | + help="The output directory where the model predictions and checkpoints will be written.", | ||
84 | + ) | ||
85 | + parser.add_argument( | ||
86 | + "--tokenizer_name", | ||
87 | + default="sshleifer/distilbart-xsum-6-6", | ||
88 | + type=str, | ||
89 | + help="Pretrained tokenizer name or path if not the same as model_name", | ||
90 | + ) | ||
91 | + parser.add_argument( | ||
92 | + "--max_source_length", | ||
93 | + default=1024, | ||
94 | + type=int, | ||
95 | + help="The maximum total input sequence length after tokenization. Sequences longer " | ||
96 | + "than this will be truncated, sequences shorter will be padded.", | ||
97 | + ) | ||
98 | + args = parser.parse_args() | ||
99 | + | ||
100 | + main(args) |
-
Please register or login to post a comment