(refactor) create diff_parse function, (add) tokenizer arguments
Showing
2 changed files
with
34 additions
and
5 deletions
preprocess/__init__.py
0 → 100644
1 | +# Copyright 2020-present Tae Hwan Jung | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | + | ||
15 | +from .gitcommit import diff_parse | ||
16 | + | ||
17 | +__all__ = [ | ||
18 | + 'diff_parse' | ||
19 | +] | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
... | @@ -17,6 +17,7 @@ import re | ... | @@ -17,6 +17,7 @@ import re |
17 | import enum | 17 | import enum |
18 | import random | 18 | import random |
19 | import logging | 19 | import logging |
20 | +import tempfile | ||
20 | import argparse | 21 | import argparse |
21 | import numpy as np | 22 | import numpy as np |
22 | from tqdm import * | 23 | from tqdm import * |
... | @@ -62,10 +63,9 @@ def encode_line(tokenizer, line, patch): | ... | @@ -62,10 +63,9 @@ def encode_line(tokenizer, line, patch): |
62 | len(tokens) * [patch.value] | 63 | len(tokens) * [patch.value] |
63 | ) | 64 | ) |
64 | 65 | ||
65 | -def sha_parse(sha, tokenizer, max_length=1024): | 66 | +def diff_parse(diff, tokenizer): |
66 | - | ||
67 | chunks = [] | 67 | chunks = [] |
68 | - for diff in whatthepatch.parse_patch(repo.git.show(sha)): | 68 | + for diff in whatthepatch.parse_patch(diff): |
69 | if diff.header.old_path != diff.header.new_path: | 69 | if diff.header.old_path != diff.header.new_path: |
70 | chunks.append(encode_line(tokenizer, diff.header.old_path, PATCH.MINUS)) | 70 | chunks.append(encode_line(tokenizer, diff.header.old_path, PATCH.MINUS)) |
71 | chunks.append(encode_line(tokenizer, diff.header.new_path, PATCH.PLUS)) | 71 | chunks.append(encode_line(tokenizer, diff.header.new_path, PATCH.PLUS)) |
... | @@ -76,7 +76,11 @@ def sha_parse(sha, tokenizer, max_length=1024): | ... | @@ -76,7 +76,11 @@ def sha_parse(sha, tokenizer, max_length=1024): |
76 | chunks.append(encode_line(tokenizer, change.line, PATCH.PLUS)) | 76 | chunks.append(encode_line(tokenizer, change.line, PATCH.PLUS)) |
77 | elif change.old != None and change.new == None: | 77 | elif change.old != None and change.new == None: |
78 | chunks.append(encode_line(tokenizer, change.line, PATCH.MINUS)) | 78 | chunks.append(encode_line(tokenizer, change.line, PATCH.MINUS)) |
79 | + return chunks | ||
80 | + | ||
81 | +def sha_parse(sha, tokenizer, max_length=1024): | ||
79 | 82 | ||
83 | + chunks = diff_parse(diff=repo.git.show(sha), tokenizer=tokenizer) | ||
80 | if not chunks: | 84 | if not chunks: |
81 | return None | 85 | return None |
82 | 86 | ||
... | @@ -202,10 +206,16 @@ if __name__ == "__main__": | ... | @@ -202,10 +206,16 @@ if __name__ == "__main__": |
202 | help='matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html' | 206 | help='matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html' |
203 | ) | 207 | ) |
204 | parser.add_argument( | 208 | parser.add_argument( |
209 | + "--tokenizer_name", | ||
210 | + default='sshleifer/distilbart-xsum-6-6', | ||
211 | + type=str, | ||
212 | + help="Pretrained tokenizer name or path if not the same as model_name", | ||
213 | + ) | ||
214 | + parser.add_argument( | ||
205 | "--matorage_batch", | 215 | "--matorage_batch", |
206 | default=1024, | 216 | default=1024, |
207 | type=int, | 217 | type=int, |
208 | - help='batch size to store data.' | 218 | + help='The smallest batch size stored atomically in matorage.' |
209 | ) | 219 | ) |
210 | parser.add_argument( | 220 | parser.add_argument( |
211 | "--num_workers", | 221 | "--num_workers", |
... | @@ -246,6 +256,6 @@ if __name__ == "__main__": | ... | @@ -246,6 +256,6 @@ if __name__ == "__main__": |
246 | if os.path.exists(args.local_path) | 256 | if os.path.exists(args.local_path) |
247 | else Repo.clone_from(args.url, to_path=args.local_path, branch="master") | 257 | else Repo.clone_from(args.url, to_path=args.local_path, branch="master") |
248 | ) | 258 | ) |
249 | - args.tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-xsum-12-6") | 259 | + args.tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name) |
250 | 260 | ||
251 | main(args) | 261 | main(args) | ... | ... |
-
Please register or login to post a comment