Showing
2 changed files
with
118 additions
and
0 deletions
gitcommit.py
0 → 100644
1 | +# Copyright 2020-present Tae Hwan Jung | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | + | ||
15 | +import os | ||
16 | +import re | ||
17 | +import enum | ||
18 | +import logging | ||
19 | +import argparse | ||
20 | +import whatthepatch | ||
21 | +from git import Repo | ||
22 | +from functools import partial | ||
23 | +from multiprocessing.pool import Pool | ||
24 | +from transformers import AutoTokenizer | ||
25 | + | ||
26 | +logger = logging.getLogger(__name__) # pylint: disable=invalid-name | ||
27 | +logging.basicConfig( | ||
28 | + format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s", | ||
29 | + datefmt="%m/%d/%Y %H:%M:%S", | ||
30 | + level=logging.INFO, | ||
31 | +) | ||
32 | + | ||
33 | +class PATCH(enum.Enum): | ||
34 | + PLUS=1 | ||
35 | + MINUS=2 | ||
36 | + | ||
37 | +def truncate(tuple, max_length, value=0): | ||
38 | + ls = [] | ||
39 | + for t in tuple: | ||
40 | + if isinstance(t, int): | ||
41 | + t = [t] | ||
42 | + ls.extend(t) | ||
43 | + ls = ls[:max_length - 1] | ||
44 | + ls.insert(0, value) | ||
45 | + if len(ls) < max_length: | ||
46 | + ls.extend([0] * (max_length - len(ls))) | ||
47 | + assert len(ls) == max_length | ||
48 | + return ls | ||
49 | + | ||
50 | +def encode_line(tokenizer, line, patch): | ||
51 | + tokens = tokenizer.tokenize(line) | ||
52 | + tokens = tokenizer.convert_tokens_to_ids(tokens) | ||
53 | + return ( | ||
54 | + tokens, | ||
55 | + [1] * len(tokens), | ||
56 | + len(tokens) * [patch.value] | ||
57 | + ) | ||
58 | + | ||
59 | +def sha_parse(sha, tokenizer, max_length=1024): | ||
60 | + | ||
61 | + chunks = [] | ||
62 | + for diff in whatthepatch.parse_patch(repo.git.show(sha)): | ||
63 | + if diff.header.old_path != diff.header.new_path: | ||
64 | + chunks.append(encode_line(tokenizer, diff.header.old_path, PATCH.MINUS)) | ||
65 | + chunks.append(encode_line(tokenizer, diff.header.new_path, PATCH.PLUS)) | ||
66 | + if not diff.changes: | ||
67 | + continue | ||
68 | + for change in diff.changes: | ||
69 | + if change.old == None and change.new != None: | ||
70 | + chunks.append(encode_line(tokenizer, change.line, PATCH.PLUS)) | ||
71 | + elif change.old != None and change.new == None: | ||
72 | + chunks.append(encode_line(tokenizer, change.line, PATCH.PLUS)) | ||
73 | + | ||
74 | + input_ids, attention_masks, patch_ids = zip(*chunks) | ||
75 | + input_ids = truncate(input_ids, max_length, value=0) | ||
76 | + attention_masks = truncate(attention_masks, max_length, value=1) | ||
77 | + patch_ids = truncate(patch_ids, max_length, value=0) | ||
78 | + | ||
79 | +def message_parse(msg, tokenizer, max_length=56): | ||
80 | + msg = re.sub(r'#([0-9])+', '', msg) | ||
81 | + msg = re.sub(r'(\(|)([A-z])+-([0-9])+(\)|)(:|)', '', msg) | ||
82 | + msg = msg.strip() | ||
83 | + | ||
84 | + msg = tokenizer.tokenize(msg) | ||
85 | + msg = tokenizer.convert_tokens_to_ids(msg) | ||
86 | + msg = truncate(msg, max_length, value=0) | ||
87 | + | ||
88 | + | ||
89 | +def job(sha_msgs, tokenizer): | ||
90 | + sha, msg = sha_msgs | ||
91 | + | ||
92 | + sha_parse(sha, tokenizer=tokenizer) | ||
93 | + message_parse(msg, tokenizer=tokenizer) | ||
94 | + | ||
95 | +def main(args): | ||
96 | + sha_msgs = [(c.hexsha, c.summary) for c in repo.iter_commits()] | ||
97 | + func = partial(job, tokenizer=args.tokenizer) | ||
98 | + with Pool(processes=args.num_workers) as pool: | ||
99 | + pool.map(func, sha_msgs) | ||
100 | + | ||
101 | +if __name__ == "__main__": | ||
102 | + parser = argparse.ArgumentParser(description="Code to collect commits on github") | ||
103 | + parser.add_argument("--url", type=str, required=True) | ||
104 | + parser.add_argument("--num_workers", type=int, default=1) | ||
105 | + args = parser.parse_args() | ||
106 | + | ||
107 | + args.local_path = args.url.split('/')[-1] | ||
108 | + logger.info(f"master branch of {args.url} will be downloaded to {args.local_path}") | ||
109 | + repo = ( | ||
110 | + Repo(args.local_path) | ||
111 | + if os.path.exists(args.local_path) | ||
112 | + else Repo.clone_from(args.url, to_path=args.local_path, branch="master") | ||
113 | + ) | ||
114 | + args.tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-xsum-12-6") | ||
115 | + | ||
116 | + main(args) |
requirements.txt
0 → 100644
-
Please register or login to post a comment