graykode

(add) git commit parser

1 +# Copyright 2020-present Tae Hwan Jung
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +
15 +import os
16 +import re
17 +import enum
18 +import logging
19 +import argparse
20 +import whatthepatch
21 +from git import Repo
22 +from functools import partial
23 +from multiprocessing.pool import Pool
24 +from transformers import AutoTokenizer
25 +
26 +logger = logging.getLogger(__name__) # pylint: disable=invalid-name
27 +logging.basicConfig(
28 + format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s",
29 + datefmt="%m/%d/%Y %H:%M:%S",
30 + level=logging.INFO,
31 +)
32 +
33 +class PATCH(enum.Enum):
34 + PLUS=1
35 + MINUS=2
36 +
37 +def truncate(tuple, max_length, value=0):
38 + ls = []
39 + for t in tuple:
40 + if isinstance(t, int):
41 + t = [t]
42 + ls.extend(t)
43 + ls = ls[:max_length - 1]
44 + ls.insert(0, value)
45 + if len(ls) < max_length:
46 + ls.extend([0] * (max_length - len(ls)))
47 + assert len(ls) == max_length
48 + return ls
49 +
50 +def encode_line(tokenizer, line, patch):
51 + tokens = tokenizer.tokenize(line)
52 + tokens = tokenizer.convert_tokens_to_ids(tokens)
53 + return (
54 + tokens,
55 + [1] * len(tokens),
56 + len(tokens) * [patch.value]
57 + )
58 +
59 +def sha_parse(sha, tokenizer, max_length=1024):
60 +
61 + chunks = []
62 + for diff in whatthepatch.parse_patch(repo.git.show(sha)):
63 + if diff.header.old_path != diff.header.new_path:
64 + chunks.append(encode_line(tokenizer, diff.header.old_path, PATCH.MINUS))
65 + chunks.append(encode_line(tokenizer, diff.header.new_path, PATCH.PLUS))
66 + if not diff.changes:
67 + continue
68 + for change in diff.changes:
69 + if change.old == None and change.new != None:
70 + chunks.append(encode_line(tokenizer, change.line, PATCH.PLUS))
71 + elif change.old != None and change.new == None:
72 + chunks.append(encode_line(tokenizer, change.line, PATCH.PLUS))
73 +
74 + input_ids, attention_masks, patch_ids = zip(*chunks)
75 + input_ids = truncate(input_ids, max_length, value=0)
76 + attention_masks = truncate(attention_masks, max_length, value=1)
77 + patch_ids = truncate(patch_ids, max_length, value=0)
78 +
79 +def message_parse(msg, tokenizer, max_length=56):
80 + msg = re.sub(r'#([0-9])+', '', msg)
81 + msg = re.sub(r'(\(|)([A-z])+-([0-9])+(\)|)(:|)', '', msg)
82 + msg = msg.strip()
83 +
84 + msg = tokenizer.tokenize(msg)
85 + msg = tokenizer.convert_tokens_to_ids(msg)
86 + msg = truncate(msg, max_length, value=0)
87 +
88 +
89 +def job(sha_msgs, tokenizer):
90 + sha, msg = sha_msgs
91 +
92 + sha_parse(sha, tokenizer=tokenizer)
93 + message_parse(msg, tokenizer=tokenizer)
94 +
95 +def main(args):
96 + sha_msgs = [(c.hexsha, c.summary) for c in repo.iter_commits()]
97 + func = partial(job, tokenizer=args.tokenizer)
98 + with Pool(processes=args.num_workers) as pool:
99 + pool.map(func, sha_msgs)
100 +
101 +if __name__ == "__main__":
102 + parser = argparse.ArgumentParser(description="Code to collect commits on github")
103 + parser.add_argument("--url", type=str, required=True)
104 + parser.add_argument("--num_workers", type=int, default=1)
105 + args = parser.parse_args()
106 +
107 + args.local_path = args.url.split('/')[-1]
108 + logger.info(f"master branch of {args.url} will be downloaded to {args.local_path}")
109 + repo = (
110 + Repo(args.local_path)
111 + if os.path.exists(args.local_path)
112 + else Repo.clone_from(args.url, to_path=args.local_path, branch="master")
113 + )
114 + args.tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-xsum-12-6")
115 +
116 + main(args)
1 +whatthepatch
2 +gitpython
...\ No newline at end of file ...\ No newline at end of file