Showing
2 changed files
with
21 additions
and
9 deletions
| ... | @@ -15,7 +15,6 @@ | ... | @@ -15,7 +15,6 @@ |
| 15 | import os | 15 | import os |
| 16 | import torch | 16 | import torch |
| 17 | import argparse | 17 | import argparse |
| 18 | -import whatthepatch | ||
| 19 | from tqdm import tqdm | 18 | from tqdm import tqdm |
| 20 | import torch.nn as nn | 19 | import torch.nn as nn |
| 21 | from torch.utils.data import TensorDataset, DataLoader, SequentialSampler | 20 | from torch.utils.data import TensorDataset, DataLoader, SequentialSampler |
| ... | @@ -47,7 +46,7 @@ def get_model(model_class, config, tokenizer, mode): | ... | @@ -47,7 +46,7 @@ def get_model(model_class, config, tokenizer, mode): |
| 47 | model.load_state_dict( | 46 | model.load_state_dict( |
| 48 | torch.load( | 47 | torch.load( |
| 49 | os.path.join(args.load_model_path, mode, 'pytorch_model.bin'), | 48 | os.path.join(args.load_model_path, mode, 'pytorch_model.bin'), |
| 50 | - map_location=torch.device(args.device) | 49 | + map_location=torch.device('cpu') |
| 51 | ), | 50 | ), |
| 52 | strict=False | 51 | strict=False |
| 53 | ) | 52 | ) |
| ... | @@ -55,9 +54,15 @@ def get_model(model_class, config, tokenizer, mode): | ... | @@ -55,9 +54,15 @@ def get_model(model_class, config, tokenizer, mode): |
| 55 | 54 | ||
| 56 | def get_features(examples): | 55 | def get_features(examples): |
| 57 | features = convert_examples_to_features(examples, args.tokenizer, args, stage='test') | 56 | features = convert_examples_to_features(examples, args.tokenizer, args, stage='test') |
| 58 | - all_source_ids = torch.tensor([f.source_ids for f in features], dtype=torch.long) | 57 | + all_source_ids = torch.tensor( |
| 59 | - all_source_mask = torch.tensor([f.source_mask for f in features], dtype=torch.long) | 58 | + [f.source_ids[:args.max_source_length] for f in features], dtype=torch.long |
| 60 | - all_patch_ids = torch.tensor([f.patch_ids for f in features], dtype=torch.long) | 59 | + ) |
| 60 | + all_source_mask = torch.tensor( | ||
| 61 | + [f.source_mask[:args.max_source_length] for f in features], dtype=torch.long | ||
| 62 | + ) | ||
| 63 | + all_patch_ids = torch.tensor( | ||
| 64 | + [f.patch_ids[:args.max_source_length] for f in features], dtype=torch.long | ||
| 65 | + ) | ||
| 61 | return TensorDataset(all_source_ids, all_source_mask, all_patch_ids) | 66 | return TensorDataset(all_source_ids, all_source_mask, all_patch_ids) |
| 62 | 67 | ||
| 63 | def create_app(): | 68 | def create_app(): |
| ... | @@ -150,7 +155,7 @@ if __name__ == '__main__': | ... | @@ -150,7 +155,7 @@ if __name__ == '__main__': |
| 150 | help="Pretrained config name or path if not the same as model_name") | 155 | help="Pretrained config name or path if not the same as model_name") |
| 151 | parser.add_argument("--tokenizer_name", type=str, | 156 | parser.add_argument("--tokenizer_name", type=str, |
| 152 | default="microsoft/codebert-base", help="The name of tokenizer", ) | 157 | default="microsoft/codebert-base", help="The name of tokenizer", ) |
| 153 | - parser.add_argument("--max_source_length", default=256, type=int, | 158 | + parser.add_argument("--max_source_length", default=512, type=int, |
| 154 | help="The maximum total source sequence length after tokenization. Sequences longer " | 159 | help="The maximum total source sequence length after tokenization. Sequences longer " |
| 155 | "than this will be truncated, sequences shorter will be padded.") | 160 | "than this will be truncated, sequences shorter will be padded.") |
| 156 | parser.add_argument("--max_target_length", default=128, type=int, | 161 | parser.add_argument("--max_target_length", default=128, type=int, | ... | ... |
| ... | @@ -27,8 +27,12 @@ def tokenizing(code): | ... | @@ -27,8 +27,12 @@ def tokenizing(code): |
| 27 | ) | 27 | ) |
| 28 | return json.loads(res.text)["tokens"] | 28 | return json.loads(res.text)["tokens"] |
| 29 | 29 | ||
| 30 | -def preprocessing(diffs): | 30 | +def autocommit(diffs): |
| 31 | + commit_message = [] | ||
| 31 | for idx, example in enumerate(whatthepatch.parse_patch(diffs)): | 32 | for idx, example in enumerate(whatthepatch.parse_patch(diffs)): |
| 33 | + if not example.changes: | ||
| 34 | + continue | ||
| 35 | + | ||
| 32 | isadded, isdeleted = False, False | 36 | isadded, isdeleted = False, False |
| 33 | added, deleted = [], [] | 37 | added, deleted = [], [] |
| 34 | for change in example.changes: | 38 | for change in example.changes: |
| ... | @@ -46,7 +50,7 @@ def preprocessing(diffs): | ... | @@ -46,7 +50,7 @@ def preprocessing(diffs): |
| 46 | data=json.dumps(data), | 50 | data=json.dumps(data), |
| 47 | headers=args.headers | 51 | headers=args.headers |
| 48 | ) | 52 | ) |
| 49 | - print(json.loads(res.text)) | 53 | + commit_message.append(json.loads(res.text)) |
| 50 | else: | 54 | else: |
| 51 | data = {"idx": idx, "added": added, "deleted": deleted} | 55 | data = {"idx": idx, "added": added, "deleted": deleted} |
| 52 | res = requests.post( | 56 | res = requests.post( |
| ... | @@ -54,7 +58,8 @@ def preprocessing(diffs): | ... | @@ -54,7 +58,8 @@ def preprocessing(diffs): |
| 54 | data=json.dumps(data), | 58 | data=json.dumps(data), |
| 55 | headers=args.headers | 59 | headers=args.headers |
| 56 | ) | 60 | ) |
| 57 | - print(json.loads(res.text)) | 61 | + commit_message.append(json.loads(res.text)) |
| 62 | + return commit_message | ||
| 58 | 63 | ||
| 59 | def main(): | 64 | def main(): |
| 60 | 65 | ||
| ... | @@ -64,6 +69,8 @@ def main(): | ... | @@ -64,6 +69,8 @@ def main(): |
| 64 | staged_files = [f.strip() for f in staged_files] | 69 | staged_files = [f.strip() for f in staged_files] |
| 65 | diffs = "\n".join(staged_files) | 70 | diffs = "\n".join(staged_files) |
| 66 | 71 | ||
| 72 | + message = autocommit(diffs=diffs) | ||
| 73 | + print(message) | ||
| 67 | 74 | ||
| 68 | if __name__ == '__main__': | 75 | if __name__ == '__main__': |
| 69 | parser = argparse.ArgumentParser(description="") | 76 | parser = argparse.ArgumentParser(description="") | ... | ... |
-
Please register or login to post a comment