Showing
2 changed files
with
21 additions
and
9 deletions
... | @@ -15,7 +15,6 @@ | ... | @@ -15,7 +15,6 @@ |
15 | import os | 15 | import os |
16 | import torch | 16 | import torch |
17 | import argparse | 17 | import argparse |
18 | -import whatthepatch | ||
19 | from tqdm import tqdm | 18 | from tqdm import tqdm |
20 | import torch.nn as nn | 19 | import torch.nn as nn |
21 | from torch.utils.data import TensorDataset, DataLoader, SequentialSampler | 20 | from torch.utils.data import TensorDataset, DataLoader, SequentialSampler |
... | @@ -47,7 +46,7 @@ def get_model(model_class, config, tokenizer, mode): | ... | @@ -47,7 +46,7 @@ def get_model(model_class, config, tokenizer, mode): |
47 | model.load_state_dict( | 46 | model.load_state_dict( |
48 | torch.load( | 47 | torch.load( |
49 | os.path.join(args.load_model_path, mode, 'pytorch_model.bin'), | 48 | os.path.join(args.load_model_path, mode, 'pytorch_model.bin'), |
50 | - map_location=torch.device(args.device) | 49 | + map_location=torch.device('cpu') |
51 | ), | 50 | ), |
52 | strict=False | 51 | strict=False |
53 | ) | 52 | ) |
... | @@ -55,9 +54,15 @@ def get_model(model_class, config, tokenizer, mode): | ... | @@ -55,9 +54,15 @@ def get_model(model_class, config, tokenizer, mode): |
55 | 54 | ||
56 | def get_features(examples): | 55 | def get_features(examples): |
57 | features = convert_examples_to_features(examples, args.tokenizer, args, stage='test') | 56 | features = convert_examples_to_features(examples, args.tokenizer, args, stage='test') |
58 | - all_source_ids = torch.tensor([f.source_ids for f in features], dtype=torch.long) | 57 | + all_source_ids = torch.tensor( |
59 | - all_source_mask = torch.tensor([f.source_mask for f in features], dtype=torch.long) | 58 | + [f.source_ids[:args.max_source_length] for f in features], dtype=torch.long |
60 | - all_patch_ids = torch.tensor([f.patch_ids for f in features], dtype=torch.long) | 59 | + ) |
60 | + all_source_mask = torch.tensor( | ||
61 | + [f.source_mask[:args.max_source_length] for f in features], dtype=torch.long | ||
62 | + ) | ||
63 | + all_patch_ids = torch.tensor( | ||
64 | + [f.patch_ids[:args.max_source_length] for f in features], dtype=torch.long | ||
65 | + ) | ||
61 | return TensorDataset(all_source_ids, all_source_mask, all_patch_ids) | 66 | return TensorDataset(all_source_ids, all_source_mask, all_patch_ids) |
62 | 67 | ||
63 | def create_app(): | 68 | def create_app(): |
... | @@ -150,7 +155,7 @@ if __name__ == '__main__': | ... | @@ -150,7 +155,7 @@ if __name__ == '__main__': |
150 | help="Pretrained config name or path if not the same as model_name") | 155 | help="Pretrained config name or path if not the same as model_name") |
151 | parser.add_argument("--tokenizer_name", type=str, | 156 | parser.add_argument("--tokenizer_name", type=str, |
152 | default="microsoft/codebert-base", help="The name of tokenizer", ) | 157 | default="microsoft/codebert-base", help="The name of tokenizer", ) |
153 | - parser.add_argument("--max_source_length", default=256, type=int, | 158 | + parser.add_argument("--max_source_length", default=512, type=int, |
154 | help="The maximum total source sequence length after tokenization. Sequences longer " | 159 | help="The maximum total source sequence length after tokenization. Sequences longer " |
155 | "than this will be truncated, sequences shorter will be padded.") | 160 | "than this will be truncated, sequences shorter will be padded.") |
156 | parser.add_argument("--max_target_length", default=128, type=int, | 161 | parser.add_argument("--max_target_length", default=128, type=int, | ... | ... |
... | @@ -27,8 +27,12 @@ def tokenizing(code): | ... | @@ -27,8 +27,12 @@ def tokenizing(code): |
27 | ) | 27 | ) |
28 | return json.loads(res.text)["tokens"] | 28 | return json.loads(res.text)["tokens"] |
29 | 29 | ||
30 | -def preprocessing(diffs): | 30 | +def autocommit(diffs): |
31 | + commit_message = [] | ||
31 | for idx, example in enumerate(whatthepatch.parse_patch(diffs)): | 32 | for idx, example in enumerate(whatthepatch.parse_patch(diffs)): |
33 | + if not example.changes: | ||
34 | + continue | ||
35 | + | ||
32 | isadded, isdeleted = False, False | 36 | isadded, isdeleted = False, False |
33 | added, deleted = [], [] | 37 | added, deleted = [], [] |
34 | for change in example.changes: | 38 | for change in example.changes: |
... | @@ -46,7 +50,7 @@ def preprocessing(diffs): | ... | @@ -46,7 +50,7 @@ def preprocessing(diffs): |
46 | data=json.dumps(data), | 50 | data=json.dumps(data), |
47 | headers=args.headers | 51 | headers=args.headers |
48 | ) | 52 | ) |
49 | - print(json.loads(res.text)) | 53 | + commit_message.append(json.loads(res.text)) |
50 | else: | 54 | else: |
51 | data = {"idx": idx, "added": added, "deleted": deleted} | 55 | data = {"idx": idx, "added": added, "deleted": deleted} |
52 | res = requests.post( | 56 | res = requests.post( |
... | @@ -54,7 +58,8 @@ def preprocessing(diffs): | ... | @@ -54,7 +58,8 @@ def preprocessing(diffs): |
54 | data=json.dumps(data), | 58 | data=json.dumps(data), |
55 | headers=args.headers | 59 | headers=args.headers |
56 | ) | 60 | ) |
57 | - print(json.loads(res.text)) | 61 | + commit_message.append(json.loads(res.text)) |
62 | + return commit_message | ||
58 | 63 | ||
59 | def main(): | 64 | def main(): |
60 | 65 | ||
... | @@ -64,6 +69,8 @@ def main(): | ... | @@ -64,6 +69,8 @@ def main(): |
64 | staged_files = [f.strip() for f in staged_files] | 69 | staged_files = [f.strip() for f in staged_files] |
65 | diffs = "\n".join(staged_files) | 70 | diffs = "\n".join(staged_files) |
66 | 71 | ||
72 | + message = autocommit(diffs=diffs) | ||
73 | + print(message) | ||
67 | 74 | ||
68 | if __name__ == '__main__': | 75 | if __name__ == '__main__': |
69 | parser = argparse.ArgumentParser(description="") | 76 | parser = argparse.ArgumentParser(description="") | ... | ... |
-
Please register or login to post a comment