graykode

(refactor) create diff_parse function, (add) tokenizer arguments

1 +# Copyright 2020-present Tae Hwan Jung
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +
15 +from .gitcommit import diff_parse
16 +
17 +__all__ = [
18 + 'diff_parse'
19 +]
...\ No newline at end of file ...\ No newline at end of file
...@@ -17,6 +17,7 @@ import re ...@@ -17,6 +17,7 @@ import re
17 import enum 17 import enum
18 import random 18 import random
19 import logging 19 import logging
20 +import tempfile
20 import argparse 21 import argparse
21 import numpy as np 22 import numpy as np
22 from tqdm import * 23 from tqdm import *
...@@ -62,10 +63,9 @@ def encode_line(tokenizer, line, patch): ...@@ -62,10 +63,9 @@ def encode_line(tokenizer, line, patch):
62 len(tokens) * [patch.value] 63 len(tokens) * [patch.value]
63 ) 64 )
64 65
65 -def sha_parse(sha, tokenizer, max_length=1024): 66 +def diff_parse(diff, tokenizer):
66 -
67 chunks = [] 67 chunks = []
68 - for diff in whatthepatch.parse_patch(repo.git.show(sha)): 68 + for diff in whatthepatch.parse_patch(diff):
69 if diff.header.old_path != diff.header.new_path: 69 if diff.header.old_path != diff.header.new_path:
70 chunks.append(encode_line(tokenizer, diff.header.old_path, PATCH.MINUS)) 70 chunks.append(encode_line(tokenizer, diff.header.old_path, PATCH.MINUS))
71 chunks.append(encode_line(tokenizer, diff.header.new_path, PATCH.PLUS)) 71 chunks.append(encode_line(tokenizer, diff.header.new_path, PATCH.PLUS))
...@@ -76,7 +76,11 @@ def sha_parse(sha, tokenizer, max_length=1024): ...@@ -76,7 +76,11 @@ def sha_parse(sha, tokenizer, max_length=1024):
76 chunks.append(encode_line(tokenizer, change.line, PATCH.PLUS)) 76 chunks.append(encode_line(tokenizer, change.line, PATCH.PLUS))
77 elif change.old != None and change.new == None: 77 elif change.old != None and change.new == None:
78 chunks.append(encode_line(tokenizer, change.line, PATCH.MINUS)) 78 chunks.append(encode_line(tokenizer, change.line, PATCH.MINUS))
79 + return chunks
80 +
81 +def sha_parse(sha, tokenizer, max_length=1024):
79 82
83 + chunks = diff_parse(diff=repo.git.show(sha), tokenizer=tokenizer)
80 if not chunks: 84 if not chunks:
81 return None 85 return None
82 86
...@@ -202,10 +206,16 @@ if __name__ == "__main__": ...@@ -202,10 +206,16 @@ if __name__ == "__main__":
202 help='matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html' 206 help='matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html'
203 ) 207 )
204 parser.add_argument( 208 parser.add_argument(
209 + "--tokenizer_name",
210 + default='sshleifer/distilbart-xsum-6-6',
211 + type=str,
212 + help="Pretrained tokenizer name or path if not the same as model_name",
213 + )
214 + parser.add_argument(
205 "--matorage_batch", 215 "--matorage_batch",
206 default=1024, 216 default=1024,
207 type=int, 217 type=int,
208 - help='batch size to store data.' 218 + help='The smallest batch size stored atomically in matorage.'
209 ) 219 )
210 parser.add_argument( 220 parser.add_argument(
211 "--num_workers", 221 "--num_workers",
...@@ -246,6 +256,6 @@ if __name__ == "__main__": ...@@ -246,6 +256,6 @@ if __name__ == "__main__":
246 if os.path.exists(args.local_path) 256 if os.path.exists(args.local_path)
247 else Repo.clone_from(args.url, to_path=args.local_path, branch="master") 257 else Repo.clone_from(args.url, to_path=args.local_path, branch="master")
248 ) 258 )
249 - args.tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-xsum-12-6") 259 + args.tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)
250 260
251 main(args) 261 main(args)
......