Showing
2 changed files
with
173 additions
and
0 deletions
src/api.py
0 → 100644
1 | +# Copyright 2020-present Tae Hwan Jung | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | + | ||
15 | +import os | ||
16 | +import torch | ||
17 | +import logging | ||
18 | +from tqdm import tqdm | ||
19 | +import torch.nn as nn | ||
20 | +from torch.utils.data import TensorDataset, DataLoader, SequentialSampler | ||
21 | +from transformers import (RobertaConfig, RobertaTokenizer) | ||
22 | + | ||
23 | +import argparse | ||
24 | +import whatthepatch | ||
25 | +from train.run import (Example, convert_examples_to_features) | ||
26 | +from train.model import Seq2Seq | ||
27 | +from train.customized_roberta import RobertaModel | ||
28 | + | ||
29 | +MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)} | ||
30 | + | ||
31 | +logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', | ||
32 | + datefmt = '%m/%d/%Y %H:%M:%S', | ||
33 | + level = logging.INFO) | ||
34 | +logger = logging.getLogger(__name__) | ||
35 | + | ||
36 | +def create_examples(diff, tokenizer): | ||
37 | + examples = [] | ||
38 | + for idx, example in enumerate(whatthepatch.parse_patch(diff)): | ||
39 | + added, deleted = [], [] | ||
40 | + for change in example.changes: | ||
41 | + if change.old == None and change.new != None: | ||
42 | + added.extend(tokenizer.tokenize(change.line)) | ||
43 | + elif change.old != None and change.new == None: | ||
44 | + deleted.extend(tokenizer.tokenize(change.line)) | ||
45 | + examples.append( | ||
46 | + Example( | ||
47 | + idx=idx, | ||
48 | + added=added, | ||
49 | + deleted=deleted, | ||
50 | + target=None | ||
51 | + ) | ||
52 | + ) | ||
53 | + | ||
54 | + return examples | ||
55 | + | ||
56 | +def main(args): | ||
57 | + | ||
58 | + config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] | ||
59 | + config = config_class.from_pretrained(args.config_name) | ||
60 | + tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, do_lower_case=args.do_lower_case) | ||
61 | + | ||
62 | + # budild model | ||
63 | + encoder = model_class(config=config) | ||
64 | + decoder_layer = nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads) | ||
65 | + decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) | ||
66 | + model = Seq2Seq(encoder=encoder, decoder=decoder, config=config, | ||
67 | + beam_size=args.beam_size, max_length=args.max_target_length, | ||
68 | + sos_id=tokenizer.cls_token_id, eos_id=tokenizer.sep_token_id) | ||
69 | + if args.load_model_path is not None: | ||
70 | + logger.info("reload model from {}".format(args.load_model_path)) | ||
71 | + model.load_state_dict(torch.load(args.load_model_path), strict=False) | ||
72 | + | ||
73 | + model.to(args.device) | ||
74 | + with open("test.source", "r") as f: | ||
75 | + eval_examples = create_examples(f.read(), tokenizer) | ||
76 | + | ||
77 | + test_features = convert_examples_to_features(eval_examples, tokenizer, args, stage='test') | ||
78 | + all_source_ids = torch.tensor([f.source_ids for f in test_features], dtype=torch.long) | ||
79 | + all_source_mask = torch.tensor([f.source_mask for f in test_features], dtype=torch.long) | ||
80 | + all_patch_ids = torch.tensor([f.patch_ids for f in test_features], dtype=torch.long) | ||
81 | + test_data = TensorDataset(all_source_ids, all_source_mask, all_patch_ids) | ||
82 | + | ||
83 | + # Calculate bleu | ||
84 | + eval_sampler = SequentialSampler(test_data) | ||
85 | + eval_dataloader = DataLoader(test_data, sampler=eval_sampler, batch_size=len(test_data)) | ||
86 | + | ||
87 | + model.eval() | ||
88 | + for batch in tqdm(eval_dataloader, total=len(eval_dataloader)): | ||
89 | + batch = tuple(t.to(args.device) for t in batch) | ||
90 | + source_ids, source_mask, patch_ids = batch | ||
91 | + with torch.no_grad(): | ||
92 | + preds = model(source_ids=source_ids, source_mask=source_mask, patch_ids=patch_ids) | ||
93 | + for pred in preds: | ||
94 | + t = pred[0].cpu().numpy() | ||
95 | + t = list(t) | ||
96 | + if 0 in t: | ||
97 | + t = t[:t.index(0)] | ||
98 | + text = tokenizer.decode(t, clean_up_tokenization_spaces=False) | ||
99 | + print(text) | ||
100 | + | ||
101 | + | ||
102 | +if __name__ == '__main__': | ||
103 | + parser = argparse.ArgumentParser(description="") | ||
104 | + parser.add_argument("--load_model_path", default=None, type=str, required=True, | ||
105 | + help="Path to trained model: Should contain the .bin files") | ||
106 | + | ||
107 | + parser.add_argument("--model_type", default='roberta', type=str, | ||
108 | + help="Model type: e.g. roberta") | ||
109 | + parser.add_argument("--config_name", default="microsoft/codebert-base", type=str, | ||
110 | + help="Pretrained config name or path if not the same as model_name") | ||
111 | + parser.add_argument("--tokenizer_name", type=str, | ||
112 | + default="microsoft/codebert-base", help="The name of tokenizer", ) | ||
113 | + parser.add_argument("--max_source_length", default=256, type=int, | ||
114 | + help="The maximum total source sequence length after tokenization. Sequences longer " | ||
115 | + "than this will be truncated, sequences shorter will be padded.") | ||
116 | + parser.add_argument("--max_target_length", default=128, type=int, | ||
117 | + help="The maximum total target sequence length after tokenization. Sequences longer " | ||
118 | + "than this will be truncated, sequences shorter will be padded.") | ||
119 | + parser.add_argument("--beam_size", default=10, type=int, | ||
120 | + help="beam size for beam search") | ||
121 | + parser.add_argument("--do_lower_case", action='store_true', | ||
122 | + help="Set this flag if you are using an uncased model.") | ||
123 | + parser.add_argument("--no_cuda", action='store_true', | ||
124 | + help="Avoid using CUDA when available") | ||
125 | + | ||
126 | + args = parser.parse_args() | ||
127 | + | ||
128 | + args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") | ||
129 | + | ||
130 | + main(args) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
src/test.source
0 → 100644
1 | +diff --git a/src/train/model.py b/src/train/model.py | ||
2 | +index 20e56b3..cab82e5 100644 | ||
3 | +--- a/src/train/model.py | ||
4 | ++++ b/src/train/model.py | ||
5 | +@@ -3,9 +3,7 @@ | ||
6 | + | ||
7 | + import torch | ||
8 | + import torch.nn as nn | ||
9 | +-import torch | ||
10 | +-from torch.autograd import Variable | ||
11 | +-import copy | ||
12 | ++ | ||
13 | + class Seq2Seq(nn.Module): | ||
14 | + """ | ||
15 | + Build Seqence-to-Sequence. | ||
16 | +diff --git a/src/train/run.py b/src/train/run.py | ||
17 | +index 5961ad1..be98fec 100644 | ||
18 | +--- a/src/train/run.py | ||
19 | ++++ b/src/train/run.py | ||
20 | +@@ -22,7 +22,6 @@ using a masked language modeling (MLM) loss. | ||
21 | + from __future__ import absolute_import | ||
22 | + import os | ||
23 | + import sys | ||
24 | +-import bleu | ||
25 | + import pickle | ||
26 | + import torch | ||
27 | + import json | ||
28 | +@@ -35,11 +34,14 @@ from itertools import cycle | ||
29 | + import torch.nn as nn | ||
30 | + from model import Seq2Seq | ||
31 | + from tqdm import tqdm, trange | ||
32 | +-from customized_roberta import RobertaModel | ||
33 | + from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler,TensorDataset | ||
34 | + from torch.utils.data.distributed import DistributedSampler | ||
35 | + from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup, | ||
36 | + RobertaConfig, RobertaTokenizer) | ||
37 | ++ | ||
38 | ++import train.bleu as bleu | ||
39 | ++from train.customized_roberta import RobertaModel | ||
40 | ++ | ||
41 | + MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)} | ||
42 | + | ||
43 | + logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', |
-
Please register or login to post a comment