graykode

(refactor) print message in api

...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
15 import os 15 import os
16 import torch 16 import torch
17 import argparse 17 import argparse
18 -import whatthepatch
19 from tqdm import tqdm 18 from tqdm import tqdm
20 import torch.nn as nn 19 import torch.nn as nn
21 from torch.utils.data import TensorDataset, DataLoader, SequentialSampler 20 from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
...@@ -47,7 +46,7 @@ def get_model(model_class, config, tokenizer, mode): ...@@ -47,7 +46,7 @@ def get_model(model_class, config, tokenizer, mode):
47 model.load_state_dict( 46 model.load_state_dict(
48 torch.load( 47 torch.load(
49 os.path.join(args.load_model_path, mode, 'pytorch_model.bin'), 48 os.path.join(args.load_model_path, mode, 'pytorch_model.bin'),
50 - map_location=torch.device(args.device) 49 + map_location=torch.device('cpu')
51 ), 50 ),
52 strict=False 51 strict=False
53 ) 52 )
...@@ -55,9 +54,15 @@ def get_model(model_class, config, tokenizer, mode): ...@@ -55,9 +54,15 @@ def get_model(model_class, config, tokenizer, mode):
55 54
56 def get_features(examples): 55 def get_features(examples):
57 features = convert_examples_to_features(examples, args.tokenizer, args, stage='test') 56 features = convert_examples_to_features(examples, args.tokenizer, args, stage='test')
58 - all_source_ids = torch.tensor([f.source_ids for f in features], dtype=torch.long) 57 + all_source_ids = torch.tensor(
59 - all_source_mask = torch.tensor([f.source_mask for f in features], dtype=torch.long) 58 + [f.source_ids[:args.max_source_length] for f in features], dtype=torch.long
60 - all_patch_ids = torch.tensor([f.patch_ids for f in features], dtype=torch.long) 59 + )
60 + all_source_mask = torch.tensor(
61 + [f.source_mask[:args.max_source_length] for f in features], dtype=torch.long
62 + )
63 + all_patch_ids = torch.tensor(
64 + [f.patch_ids[:args.max_source_length] for f in features], dtype=torch.long
65 + )
61 return TensorDataset(all_source_ids, all_source_mask, all_patch_ids) 66 return TensorDataset(all_source_ids, all_source_mask, all_patch_ids)
62 67
63 def create_app(): 68 def create_app():
...@@ -150,7 +155,7 @@ if __name__ == '__main__': ...@@ -150,7 +155,7 @@ if __name__ == '__main__':
150 help="Pretrained config name or path if not the same as model_name") 155 help="Pretrained config name or path if not the same as model_name")
151 parser.add_argument("--tokenizer_name", type=str, 156 parser.add_argument("--tokenizer_name", type=str,
152 default="microsoft/codebert-base", help="The name of tokenizer", ) 157 default="microsoft/codebert-base", help="The name of tokenizer", )
153 - parser.add_argument("--max_source_length", default=256, type=int, 158 + parser.add_argument("--max_source_length", default=512, type=int,
154 help="The maximum total source sequence length after tokenization. Sequences longer " 159 help="The maximum total source sequence length after tokenization. Sequences longer "
155 "than this will be truncated, sequences shorter will be padded.") 160 "than this will be truncated, sequences shorter will be padded.")
156 parser.add_argument("--max_target_length", default=128, type=int, 161 parser.add_argument("--max_target_length", default=128, type=int,
......
...@@ -27,8 +27,12 @@ def tokenizing(code): ...@@ -27,8 +27,12 @@ def tokenizing(code):
27 ) 27 )
28 return json.loads(res.text)["tokens"] 28 return json.loads(res.text)["tokens"]
29 29
30 -def preprocessing(diffs): 30 +def autocommit(diffs):
31 + commit_message = []
31 for idx, example in enumerate(whatthepatch.parse_patch(diffs)): 32 for idx, example in enumerate(whatthepatch.parse_patch(diffs)):
33 + if not example.changes:
34 + continue
35 +
32 isadded, isdeleted = False, False 36 isadded, isdeleted = False, False
33 added, deleted = [], [] 37 added, deleted = [], []
34 for change in example.changes: 38 for change in example.changes:
...@@ -46,7 +50,7 @@ def preprocessing(diffs): ...@@ -46,7 +50,7 @@ def preprocessing(diffs):
46 data=json.dumps(data), 50 data=json.dumps(data),
47 headers=args.headers 51 headers=args.headers
48 ) 52 )
49 - print(json.loads(res.text)) 53 + commit_message.append(json.loads(res.text))
50 else: 54 else:
51 data = {"idx": idx, "added": added, "deleted": deleted} 55 data = {"idx": idx, "added": added, "deleted": deleted}
52 res = requests.post( 56 res = requests.post(
...@@ -54,7 +58,8 @@ def preprocessing(diffs): ...@@ -54,7 +58,8 @@ def preprocessing(diffs):
54 data=json.dumps(data), 58 data=json.dumps(data),
55 headers=args.headers 59 headers=args.headers
56 ) 60 )
57 - print(json.loads(res.text)) 61 + commit_message.append(json.loads(res.text))
62 + return commit_message
58 63
59 def main(): 64 def main():
60 65
...@@ -64,6 +69,8 @@ def main(): ...@@ -64,6 +69,8 @@ def main():
64 staged_files = [f.strip() for f in staged_files] 69 staged_files = [f.strip() for f in staged_files]
65 diffs = "\n".join(staged_files) 70 diffs = "\n".join(staged_files)
66 71
72 + message = autocommit(diffs=diffs)
73 + print(message)
67 74
68 if __name__ == '__main__': 75 if __name__ == '__main__':
69 parser = argparse.ArgumentParser(description="") 76 parser = argparse.ArgumentParser(description="")
......