graykode

(add) matorage feature

Showing 1 changed file with 104 additions and 10 deletions
...@@ -17,12 +17,16 @@ import re ...@@ -17,12 +17,16 @@ import re
17 import enum 17 import enum
18 import logging 18 import logging
19 import argparse 19 import argparse
20 +import numpy as np
21 +from tqdm import *
20 import whatthepatch 22 import whatthepatch
21 from git import Repo 23 from git import Repo
22 from functools import partial 24 from functools import partial
23 from multiprocessing.pool import Pool 25 from multiprocessing.pool import Pool
24 from transformers import AutoTokenizer 26 from transformers import AutoTokenizer
25 27
28 +from matorage import *
29 +
26 logger = logging.getLogger(__name__) # pylint: disable=invalid-name 30 logger = logging.getLogger(__name__) # pylint: disable=invalid-name
27 logging.basicConfig( 31 logging.basicConfig(
28 format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s", 32 format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s",
...@@ -48,6 +52,7 @@ def truncate(tuple, max_length, value=0): ...@@ -48,6 +52,7 @@ def truncate(tuple, max_length, value=0):
48 return ls 52 return ls
49 53
50 def encode_line(tokenizer, line, patch): 54 def encode_line(tokenizer, line, patch):
55 + line = re.sub(r'[\u0100-\uFFFF\U00010000-\U0010FFFF]+', '', line).strip()
51 tokens = tokenizer.tokenize(line) 56 tokens = tokenizer.tokenize(line)
52 tokens = tokenizer.convert_tokens_to_ids(tokens) 57 tokens = tokenizer.convert_tokens_to_ids(tokens)
53 return ( 58 return (
...@@ -69,39 +74,128 @@ def sha_parse(sha, tokenizer, max_length=1024): ...@@ -69,39 +74,128 @@ def sha_parse(sha, tokenizer, max_length=1024):
69 if change.old == None and change.new != None: 74 if change.old == None and change.new != None:
70 chunks.append(encode_line(tokenizer, change.line, PATCH.PLUS)) 75 chunks.append(encode_line(tokenizer, change.line, PATCH.PLUS))
71 elif change.old != None and change.new == None: 76 elif change.old != None and change.new == None:
72 - chunks.append(encode_line(tokenizer, change.line, PATCH.PLUS)) 77 + chunks.append(encode_line(tokenizer, change.line, PATCH.MINUS))
78 +
79 + if not chunks:
80 + return None
73 81
74 input_ids, attention_masks, patch_ids = zip(*chunks) 82 input_ids, attention_masks, patch_ids = zip(*chunks)
75 input_ids = truncate(input_ids, max_length, value=0) 83 input_ids = truncate(input_ids, max_length, value=0)
76 attention_masks = truncate(attention_masks, max_length, value=1) 84 attention_masks = truncate(attention_masks, max_length, value=1)
77 patch_ids = truncate(patch_ids, max_length, value=0) 85 patch_ids = truncate(patch_ids, max_length, value=0)
78 86
87 + return (input_ids, attention_masks, patch_ids)
88 +
79 def message_parse(msg, tokenizer, max_length=56): 89 def message_parse(msg, tokenizer, max_length=56):
80 msg = re.sub(r'#([0-9])+', '', msg) 90 msg = re.sub(r'#([0-9])+', '', msg)
81 msg = re.sub(r'(\(|)([A-z])+-([0-9])+(\)|)(:|)', '', msg) 91 msg = re.sub(r'(\(|)([A-z])+-([0-9])+(\)|)(:|)', '', msg)
82 - msg = msg.strip()
83 92
93 + msg = re.sub(r'[\u0100-\uFFFF\U00010000-\U0010FFFF]+', '', msg).strip()
84 msg = tokenizer.tokenize(msg) 94 msg = tokenizer.tokenize(msg)
85 msg = tokenizer.convert_tokens_to_ids(msg) 95 msg = tokenizer.convert_tokens_to_ids(msg)
86 msg = truncate(msg, max_length, value=0) 96 msg = truncate(msg, max_length, value=0)
87 97
98 + return msg
88 99
89 -def job(sha_msgs, tokenizer): 100 +def jobs(sha_msgs, args, data_config):
90 - sha, msg = sha_msgs
91 101
92 - sha_parse(sha, tokenizer=tokenizer) 102 + input_ids, attention_masks, patch_ids, targets = [], [], [], []
93 - message_parse(msg, tokenizer=tokenizer) 103 + data_saver = DataSaver(config=data_config)
104 +
105 + for sha_msg in sha_msgs:
106 + sha, msg = sha_msg
107 +
108 + source = sha_parse(sha, tokenizer=args.tokenizer)
109 + if not source:
110 + continue
111 + input_id, attention_mask, patch_id = source
112 + target = message_parse(msg, tokenizer=args.tokenizer)
113 +
114 + input_ids.append(input_id)
115 + attention_masks.append(attention_mask)
116 + patch_ids.append(patch_id)
117 + targets.append(target)
118 +
119 + data_saver({
120 + "input_ids": np.asarray(input_ids),
121 + "attention_masks": np.asarray(attention_masks),
122 + "patch_ids": np.asarray(patch_ids),
123 + "targets": np.asarray(targets),
124 + })
125 + data_saver.disconnect()
94 126
95 def main(args): 127 def main(args):
128 + if 'access_key' not in os.environ or 'secret_key' not in os.environ:
129 + raise OSError("access_key or secret_key are not found.")
130 +
131 + data_config = DataConfig(
132 + endpoint=args.matorage_dir,
133 + access_key=os.environ['access_key'],
134 + secret_key=os.environ['secret_key'],
135 + dataset_name='commit-autosuggestions',
136 + additional={
137 + "max_source_length": args.max_source_length,
138 + "max_target_length": args.max_target_length,
139 + },
140 + attributes = [
141 + ('input_ids', 'int32', (args.max_source_length,)),
142 + ('attention_masks', 'int32', (args.max_source_length,)),
143 + ('patch_ids', 'int32', (args.max_source_length,)),
144 + ('targets', 'int32', (args.max_target_length,))
145 + ]
146 + )
147 +
96 sha_msgs = [(c.hexsha, c.summary) for c in repo.iter_commits()] 148 sha_msgs = [(c.hexsha, c.summary) for c in repo.iter_commits()]
97 - func = partial(job, tokenizer=args.tokenizer) 149 + chunked_sha_msgs = [
150 + sha_msgs[x:x + args.matorage_batch]
151 + for x in range(0, len(sha_msgs), args.matorage_batch)
152 + ]
153 + func = partial(jobs, args=args, data_config=data_config)
98 with Pool(processes=args.num_workers) as pool: 154 with Pool(processes=args.num_workers) as pool:
99 - pool.map(func, sha_msgs) 155 + with tqdm(total=len(chunked_sha_msgs)) as pbar:
156 + for i, _ in tqdm(enumerate(pool.imap_unordered(func, chunked_sha_msgs))):
157 + pbar.update()
100 158
101 if __name__ == "__main__": 159 if __name__ == "__main__":
102 parser = argparse.ArgumentParser(description="Code to collect commits on github") 160 parser = argparse.ArgumentParser(description="Code to collect commits on github")
103 - parser.add_argument("--url", type=str, required=True) 161 + parser.add_argument(
104 - parser.add_argument("--num_workers", type=int, default=1) 162 + "--url",
163 + type=str,
164 + required=True,
165 + help="github url"
166 + )
167 + parser.add_argument(
168 + "--matorage_dir",
169 + type=str,
170 + required=True,
171 + help='matorage saved directory.'
172 + )
173 + parser.add_argument(
174 + "--matorage_batch",
175 + default=1024,
176 + type=int,
177 + help='batch size to store data.'
178 + )
179 + parser.add_argument(
180 + "--num_workers",
181 + default=4,
182 + type=int,
183 + help="number of process",
184 + )
185 + parser.add_argument(
186 + "--max_source_length",
187 + default=1024,
188 + type=int,
189 + help="The maximum total input sequence length after tokenization. Sequences longer "
190 + "than this will be truncated, sequences shorter will be padded.",
191 + )
192 + parser.add_argument(
193 + "--max_target_length",
194 + default=56,
195 + type=int,
196 + help="The maximum total input sequence length after tokenization. Sequences longer "
197 + "than this will be truncated, sequences shorter will be padded.",
198 + )
105 args = parser.parse_args() 199 args = parser.parse_args()
106 200
107 args.local_path = args.url.split('/')[-1] 201 args.local_path = args.url.split('/')[-1]
......