graykode

(add) commit suggester

1 +# Copyright 2020-present Tae Hwan Jung
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +
15 +import torch
16 +import argparse
17 +import subprocess
18 +from transformers import AutoTokenizer
19 +
20 +from preprocess import diff_parse, truncate
21 +from train import BartForConditionalGeneration
22 +
23 +
24 +def suggester(chunks, max_source_length, model, tokenizer, device):
25 + input_ids, attention_masks, patch_ids = zip(*chunks)
26 + input_ids = torch.LongTensor([truncate(input_ids, max_source_length, value=0)]).to(
27 + device
28 + )
29 + attention_masks = torch.LongTensor(
30 + [truncate(attention_masks, max_source_length, value=1)]
31 + ).to(device)
32 + patch_ids = torch.LongTensor([truncate(patch_ids, max_source_length, value=0)]).to(
33 + device
34 + )
35 +
36 + summaries = model.generate(
37 + input_ids=input_ids, patch_ids=patch_ids, attention_mask=attention_masks
38 + )
39 + return tokenizer.batch_decode(
40 + summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False
41 + )
42 +
43 +
44 +def main(args):
45 + device = torch.device(
46 + "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
47 + )
48 + model = BartForConditionalGeneration.from_pretrained(args.output_dir).to(device)
49 +
50 + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)
51 +
52 + if args.unittest:
53 + with open("test.source", "r") as f:
54 + chunks = diff_parse(f.read(), tokenizer)
55 + else:
56 + proc = subprocess.Popen(["git", "diff", "--cached"], stdout=subprocess.PIPE)
57 + staged_files = proc.stdout.readlines()
58 + staged_files = [f.decode("utf-8") for f in staged_files]
59 + staged_files = [f.strip() for f in staged_files]
60 + chunks = "\n".join(staged_files)
61 +
62 + commit_message = suggester(
63 + chunks,
64 + max_source_length=args.max_source_length,
65 + model=model,
66 + tokenizer=tokenizer,
67 + device=device,
68 + )
69 + print(commit_message)
70 +
71 +if __name__ == "__main__":
72 + parser = argparse.ArgumentParser(description="Code to collect commits on github")
73 + parser.add_argument(
74 + "--no_cuda", action="store_true", help="Whether not to use CUDA when available"
75 + )
76 + parser.add_argument(
77 + "--unittest", action="store_true", help="Unittest with an one batch git diff"
78 + )
79 + parser.add_argument(
80 + "--output_dir",
81 + type=str,
82 + required=True,
83 + help="The output directory where the model predictions and checkpoints will be written.",
84 + )
85 + parser.add_argument(
86 + "--tokenizer_name",
87 + default="sshleifer/distilbart-xsum-6-6",
88 + type=str,
89 + help="Pretrained tokenizer name or path if not the same as model_name",
90 + )
91 + parser.add_argument(
92 + "--max_source_length",
93 + default=1024,
94 + type=int,
95 + help="The maximum total input sequence length after tokenization. Sequences longer "
96 + "than this will be truncated, sequences shorter will be padded.",
97 + )
98 + args = parser.parse_args()
99 +
100 + main(args)