graykode

(add) unittest for api

1 +# Copyright 2020-present Tae Hwan Jung
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +
15 +import os
16 +import torch
17 +import logging
18 +from tqdm import tqdm
19 +import torch.nn as nn
20 +from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
21 +from transformers import (RobertaConfig, RobertaTokenizer)
22 +
23 +import argparse
24 +import whatthepatch
25 +from train.run import (Example, convert_examples_to_features)
26 +from train.model import Seq2Seq
27 +from train.customized_roberta import RobertaModel
28 +
29 +MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)}
30 +
31 +logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
32 + datefmt = '%m/%d/%Y %H:%M:%S',
33 + level = logging.INFO)
34 +logger = logging.getLogger(__name__)
35 +
36 +def create_examples(diff, tokenizer):
37 + examples = []
38 + for idx, example in enumerate(whatthepatch.parse_patch(diff)):
39 + added, deleted = [], []
40 + for change in example.changes:
41 + if change.old == None and change.new != None:
42 + added.extend(tokenizer.tokenize(change.line))
43 + elif change.old != None and change.new == None:
44 + deleted.extend(tokenizer.tokenize(change.line))
45 + examples.append(
46 + Example(
47 + idx=idx,
48 + added=added,
49 + deleted=deleted,
50 + target=None
51 + )
52 + )
53 +
54 + return examples
55 +
56 +def main(args):
57 +
58 + config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
59 + config = config_class.from_pretrained(args.config_name)
60 + tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, do_lower_case=args.do_lower_case)
61 +
62 + # budild model
63 + encoder = model_class(config=config)
64 + decoder_layer = nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads)
65 + decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
66 + model = Seq2Seq(encoder=encoder, decoder=decoder, config=config,
67 + beam_size=args.beam_size, max_length=args.max_target_length,
68 + sos_id=tokenizer.cls_token_id, eos_id=tokenizer.sep_token_id)
69 + if args.load_model_path is not None:
70 + logger.info("reload model from {}".format(args.load_model_path))
71 + model.load_state_dict(torch.load(args.load_model_path), strict=False)
72 +
73 + model.to(args.device)
74 + with open("test.source", "r") as f:
75 + eval_examples = create_examples(f.read(), tokenizer)
76 +
77 + test_features = convert_examples_to_features(eval_examples, tokenizer, args, stage='test')
78 + all_source_ids = torch.tensor([f.source_ids for f in test_features], dtype=torch.long)
79 + all_source_mask = torch.tensor([f.source_mask for f in test_features], dtype=torch.long)
80 + all_patch_ids = torch.tensor([f.patch_ids for f in test_features], dtype=torch.long)
81 + test_data = TensorDataset(all_source_ids, all_source_mask, all_patch_ids)
82 +
83 + # Calculate bleu
84 + eval_sampler = SequentialSampler(test_data)
85 + eval_dataloader = DataLoader(test_data, sampler=eval_sampler, batch_size=len(test_data))
86 +
87 + model.eval()
88 + for batch in tqdm(eval_dataloader, total=len(eval_dataloader)):
89 + batch = tuple(t.to(args.device) for t in batch)
90 + source_ids, source_mask, patch_ids = batch
91 + with torch.no_grad():
92 + preds = model(source_ids=source_ids, source_mask=source_mask, patch_ids=patch_ids)
93 + for pred in preds:
94 + t = pred[0].cpu().numpy()
95 + t = list(t)
96 + if 0 in t:
97 + t = t[:t.index(0)]
98 + text = tokenizer.decode(t, clean_up_tokenization_spaces=False)
99 + print(text)
100 +
101 +
102 +if __name__ == '__main__':
103 + parser = argparse.ArgumentParser(description="")
104 + parser.add_argument("--load_model_path", default=None, type=str, required=True,
105 + help="Path to trained model: Should contain the .bin files")
106 +
107 + parser.add_argument("--model_type", default='roberta', type=str,
108 + help="Model type: e.g. roberta")
109 + parser.add_argument("--config_name", default="microsoft/codebert-base", type=str,
110 + help="Pretrained config name or path if not the same as model_name")
111 + parser.add_argument("--tokenizer_name", type=str,
112 + default="microsoft/codebert-base", help="The name of tokenizer", )
113 + parser.add_argument("--max_source_length", default=256, type=int,
114 + help="The maximum total source sequence length after tokenization. Sequences longer "
115 + "than this will be truncated, sequences shorter will be padded.")
116 + parser.add_argument("--max_target_length", default=128, type=int,
117 + help="The maximum total target sequence length after tokenization. Sequences longer "
118 + "than this will be truncated, sequences shorter will be padded.")
119 + parser.add_argument("--beam_size", default=10, type=int,
120 + help="beam size for beam search")
121 + parser.add_argument("--do_lower_case", action='store_true',
122 + help="Set this flag if you are using an uncased model.")
123 + parser.add_argument("--no_cuda", action='store_true',
124 + help="Avoid using CUDA when available")
125 +
126 + args = parser.parse_args()
127 +
128 + args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
129 +
130 + main(args)
...\ No newline at end of file ...\ No newline at end of file
1 +diff --git a/src/train/model.py b/src/train/model.py
2 +index 20e56b3..cab82e5 100644
3 +--- a/src/train/model.py
4 ++++ b/src/train/model.py
5 +@@ -3,9 +3,7 @@
6 +
7 + import torch
8 + import torch.nn as nn
9 +-import torch
10 +-from torch.autograd import Variable
11 +-import copy
12 ++
13 + class Seq2Seq(nn.Module):
14 + """
15 + Build Seqence-to-Sequence.
16 +diff --git a/src/train/run.py b/src/train/run.py
17 +index 5961ad1..be98fec 100644
18 +--- a/src/train/run.py
19 ++++ b/src/train/run.py
20 +@@ -22,7 +22,6 @@ using a masked language modeling (MLM) loss.
21 + from __future__ import absolute_import
22 + import os
23 + import sys
24 +-import bleu
25 + import pickle
26 + import torch
27 + import json
28 +@@ -35,11 +34,14 @@ from itertools import cycle
29 + import torch.nn as nn
30 + from model import Seq2Seq
31 + from tqdm import tqdm, trange
32 +-from customized_roberta import RobertaModel
33 + from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler,TensorDataset
34 + from torch.utils.data.distributed import DistributedSampler
35 + from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup,
36 + RobertaConfig, RobertaTokenizer)
37 ++
38 ++import train.bleu as bleu
39 ++from train.customized_roberta import RobertaModel
40 ++
41 + MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)}
42 +
43 + logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',