Showing
1 changed file
with
104 additions
and
10 deletions
... | @@ -17,12 +17,16 @@ import re | ... | @@ -17,12 +17,16 @@ import re |
17 | import enum | 17 | import enum |
18 | import logging | 18 | import logging |
19 | import argparse | 19 | import argparse |
20 | +import numpy as np | ||
21 | +from tqdm import * | ||
20 | import whatthepatch | 22 | import whatthepatch |
21 | from git import Repo | 23 | from git import Repo |
22 | from functools import partial | 24 | from functools import partial |
23 | from multiprocessing.pool import Pool | 25 | from multiprocessing.pool import Pool |
24 | from transformers import AutoTokenizer | 26 | from transformers import AutoTokenizer |
25 | 27 | ||
28 | +from matorage import * | ||
29 | + | ||
26 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name | 30 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name |
27 | logging.basicConfig( | 31 | logging.basicConfig( |
28 | format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s", | 32 | format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s", |
... | @@ -48,6 +52,7 @@ def truncate(tuple, max_length, value=0): | ... | @@ -48,6 +52,7 @@ def truncate(tuple, max_length, value=0): |
48 | return ls | 52 | return ls |
49 | 53 | ||
50 | def encode_line(tokenizer, line, patch): | 54 | def encode_line(tokenizer, line, patch): |
55 | + line = re.sub(r'[\u0100-\uFFFF\U00010000-\U0010FFFF]+', '', line).strip() | ||
51 | tokens = tokenizer.tokenize(line) | 56 | tokens = tokenizer.tokenize(line) |
52 | tokens = tokenizer.convert_tokens_to_ids(tokens) | 57 | tokens = tokenizer.convert_tokens_to_ids(tokens) |
53 | return ( | 58 | return ( |
... | @@ -69,39 +74,128 @@ def sha_parse(sha, tokenizer, max_length=1024): | ... | @@ -69,39 +74,128 @@ def sha_parse(sha, tokenizer, max_length=1024): |
69 | if change.old == None and change.new != None: | 74 | if change.old == None and change.new != None: |
70 | chunks.append(encode_line(tokenizer, change.line, PATCH.PLUS)) | 75 | chunks.append(encode_line(tokenizer, change.line, PATCH.PLUS)) |
71 | elif change.old != None and change.new == None: | 76 | elif change.old != None and change.new == None: |
72 | - chunks.append(encode_line(tokenizer, change.line, PATCH.PLUS)) | 77 | + chunks.append(encode_line(tokenizer, change.line, PATCH.MINUS)) |
78 | + | ||
79 | + if not chunks: | ||
80 | + return None | ||
73 | 81 | ||
74 | input_ids, attention_masks, patch_ids = zip(*chunks) | 82 | input_ids, attention_masks, patch_ids = zip(*chunks) |
75 | input_ids = truncate(input_ids, max_length, value=0) | 83 | input_ids = truncate(input_ids, max_length, value=0) |
76 | attention_masks = truncate(attention_masks, max_length, value=1) | 84 | attention_masks = truncate(attention_masks, max_length, value=1) |
77 | patch_ids = truncate(patch_ids, max_length, value=0) | 85 | patch_ids = truncate(patch_ids, max_length, value=0) |
78 | 86 | ||
87 | + return (input_ids, attention_masks, patch_ids) | ||
88 | + | ||
79 | def message_parse(msg, tokenizer, max_length=56): | 89 | def message_parse(msg, tokenizer, max_length=56): |
80 | msg = re.sub(r'#([0-9])+', '', msg) | 90 | msg = re.sub(r'#([0-9])+', '', msg) |
81 | msg = re.sub(r'(\(|)([A-z])+-([0-9])+(\)|)(:|)', '', msg) | 91 | msg = re.sub(r'(\(|)([A-z])+-([0-9])+(\)|)(:|)', '', msg) |
82 | - msg = msg.strip() | ||
83 | 92 | ||
93 | + msg = re.sub(r'[\u0100-\uFFFF\U00010000-\U0010FFFF]+', '', msg).strip() | ||
84 | msg = tokenizer.tokenize(msg) | 94 | msg = tokenizer.tokenize(msg) |
85 | msg = tokenizer.convert_tokens_to_ids(msg) | 95 | msg = tokenizer.convert_tokens_to_ids(msg) |
86 | msg = truncate(msg, max_length, value=0) | 96 | msg = truncate(msg, max_length, value=0) |
87 | 97 | ||
98 | + return msg | ||
88 | 99 | ||
89 | -def job(sha_msgs, tokenizer): | 100 | +def jobs(sha_msgs, args, data_config): |
90 | - sha, msg = sha_msgs | ||
91 | 101 | ||
92 | - sha_parse(sha, tokenizer=tokenizer) | 102 | + input_ids, attention_masks, patch_ids, targets = [], [], [], [] |
93 | - message_parse(msg, tokenizer=tokenizer) | 103 | + data_saver = DataSaver(config=data_config) |
104 | + | ||
105 | + for sha_msg in sha_msgs: | ||
106 | + sha, msg = sha_msg | ||
107 | + | ||
108 | + source = sha_parse(sha, tokenizer=args.tokenizer) | ||
109 | + if not source: | ||
110 | + continue | ||
111 | + input_id, attention_mask, patch_id = source | ||
112 | + target = message_parse(msg, tokenizer=args.tokenizer) | ||
113 | + | ||
114 | + input_ids.append(input_id) | ||
115 | + attention_masks.append(attention_mask) | ||
116 | + patch_ids.append(patch_id) | ||
117 | + targets.append(target) | ||
118 | + | ||
119 | + data_saver({ | ||
120 | + "input_ids": np.asarray(input_ids), | ||
121 | + "attention_masks": np.asarray(attention_masks), | ||
122 | + "patch_ids": np.asarray(patch_ids), | ||
123 | + "targets": np.asarray(targets), | ||
124 | + }) | ||
125 | + data_saver.disconnect() | ||
94 | 126 | ||
95 | def main(args): | 127 | def main(args): |
128 | + if 'access_key' not in os.environ or 'secret_key' not in os.environ: | ||
129 | + raise OSError("access_key or secret_key are not found.") | ||
130 | + | ||
131 | + data_config = DataConfig( | ||
132 | + endpoint=args.matorage_dir, | ||
133 | + access_key=os.environ['access_key'], | ||
134 | + secret_key=os.environ['secret_key'], | ||
135 | + dataset_name='commit-autosuggestions', | ||
136 | + additional={ | ||
137 | + "max_source_length": args.max_source_length, | ||
138 | + "max_target_length": args.max_target_length, | ||
139 | + }, | ||
140 | + attributes = [ | ||
141 | + ('input_ids', 'int32', (args.max_source_length,)), | ||
142 | + ('attention_masks', 'int32', (args.max_source_length,)), | ||
143 | + ('patch_ids', 'int32', (args.max_source_length,)), | ||
144 | + ('targets', 'int32', (args.max_target_length,)) | ||
145 | + ] | ||
146 | + ) | ||
147 | + | ||
96 | sha_msgs = [(c.hexsha, c.summary) for c in repo.iter_commits()] | 148 | sha_msgs = [(c.hexsha, c.summary) for c in repo.iter_commits()] |
97 | - func = partial(job, tokenizer=args.tokenizer) | 149 | + chunked_sha_msgs = [ |
150 | + sha_msgs[x:x + args.matorage_batch] | ||
151 | + for x in range(0, len(sha_msgs), args.matorage_batch) | ||
152 | + ] | ||
153 | + func = partial(jobs, args=args, data_config=data_config) | ||
98 | with Pool(processes=args.num_workers) as pool: | 154 | with Pool(processes=args.num_workers) as pool: |
99 | - pool.map(func, sha_msgs) | 155 | + with tqdm(total=len(chunked_sha_msgs)) as pbar: |
156 | + for i, _ in tqdm(enumerate(pool.imap_unordered(func, chunked_sha_msgs))): | ||
157 | + pbar.update() | ||
100 | 158 | ||
101 | if __name__ == "__main__": | 159 | if __name__ == "__main__": |
102 | parser = argparse.ArgumentParser(description="Code to collect commits on github") | 160 | parser = argparse.ArgumentParser(description="Code to collect commits on github") |
103 | - parser.add_argument("--url", type=str, required=True) | 161 | + parser.add_argument( |
104 | - parser.add_argument("--num_workers", type=int, default=1) | 162 | + "--url", |
163 | + type=str, | ||
164 | + required=True, | ||
165 | + help="github url" | ||
166 | + ) | ||
167 | + parser.add_argument( | ||
168 | + "--matorage_dir", | ||
169 | + type=str, | ||
170 | + required=True, | ||
171 | + help='matorage saved directory.' | ||
172 | + ) | ||
173 | + parser.add_argument( | ||
174 | + "--matorage_batch", | ||
175 | + default=1024, | ||
176 | + type=int, | ||
177 | + help='batch size to store data.' | ||
178 | + ) | ||
179 | + parser.add_argument( | ||
180 | + "--num_workers", | ||
181 | + default=4, | ||
182 | + type=int, | ||
183 | + help="number of process", | ||
184 | + ) | ||
185 | + parser.add_argument( | ||
186 | + "--max_source_length", | ||
187 | + default=1024, | ||
188 | + type=int, | ||
189 | + help="The maximum total input sequence length after tokenization. Sequences longer " | ||
190 | + "than this will be truncated, sequences shorter will be padded.", | ||
191 | + ) | ||
192 | + parser.add_argument( | ||
193 | + "--max_target_length", | ||
194 | + default=56, | ||
195 | + type=int, | ||
196 | + help="The maximum total input sequence length after tokenization. Sequences longer " | ||
197 | + "than this will be truncated, sequences shorter will be padded.", | ||
198 | + ) | ||
105 | args = parser.parse_args() | 199 | args = parser.parse_args() |
106 | 200 | ||
107 | args.local_path = args.url.split('/')[-1] | 201 | args.local_path = args.url.split('/')[-1] | ... | ... |
-
Please register or login to post a comment