Showing
3 changed files
with
945 additions
and
0 deletions
code2nl/bleu.py
0 → 100644
1 | +#!/usr/bin/python | ||
2 | + | ||
3 | +''' | ||
4 | +This script was adapted from the original version by hieuhoang1972 which is part of MOSES. | ||
5 | +''' | ||
6 | + | ||
7 | +# $Id: bleu.py 1307 2007-03-14 22:22:36Z hieuhoang1972 $ | ||
8 | + | ||
9 | +'''Provides: | ||
10 | + | ||
11 | +cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). | ||
12 | +cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). | ||
13 | +score_cooked(alltest, n=4): Score a list of cooked test sentences. | ||
14 | + | ||
15 | +score_set(s, testid, refids, n=4): Interface with dataset.py; calculate BLEU score of testid against refids. | ||
16 | + | ||
17 | +The reason for breaking the BLEU computation into three phases cook_refs(), cook_test(), and score_cooked() is to allow the caller to calculate BLEU scores for multiple test sets as efficiently as possible. | ||
18 | +''' | ||
19 | + | ||
20 | +import sys, math, re, xml.sax.saxutils | ||
21 | +import subprocess | ||
22 | +import os | ||
23 | + | ||
24 | +# Added to bypass NIST-style pre-processing of hyp and ref files -- wade | ||
25 | +nonorm = 0 | ||
26 | + | ||
27 | +preserve_case = False | ||
28 | +eff_ref_len = "shortest" | ||
29 | + | ||
30 | +normalize1 = [ | ||
31 | + ('<skipped>', ''), # strip "skipped" tags | ||
32 | + (r'-\n', ''), # strip end-of-line hyphenation and join lines | ||
33 | + (r'\n', ' '), # join lines | ||
34 | +# (r'(\d)\s+(?=\d)', r'\1'), # join digits | ||
35 | +] | ||
36 | +normalize1 = [(re.compile(pattern), replace) for (pattern, replace) in normalize1] | ||
37 | + | ||
38 | +normalize2 = [ | ||
39 | + (r'([\{-\~\[-\` -\&\(-\+\:-\@\/])',r' \1 '), # tokenize punctuation. apostrophe is missing | ||
40 | + (r'([^0-9])([\.,])',r'\1 \2 '), # tokenize period and comma unless preceded by a digit | ||
41 | + (r'([\.,])([^0-9])',r' \1 \2'), # tokenize period and comma unless followed by a digit | ||
42 | + (r'([0-9])(-)',r'\1 \2 ') # tokenize dash when preceded by a digit | ||
43 | +] | ||
44 | +normalize2 = [(re.compile(pattern), replace) for (pattern, replace) in normalize2] | ||
45 | + | ||
46 | +def normalize(s): | ||
47 | + '''Normalize and tokenize text. This is lifted from NIST mteval-v11a.pl.''' | ||
48 | + # Added to bypass NIST-style pre-processing of hyp and ref files -- wade | ||
49 | + if (nonorm): | ||
50 | + return s.split() | ||
51 | + if type(s) is not str: | ||
52 | + s = " ".join(s) | ||
53 | + # language-independent part: | ||
54 | + for (pattern, replace) in normalize1: | ||
55 | + s = re.sub(pattern, replace, s) | ||
56 | + s = xml.sax.saxutils.unescape(s, {'"':'"'}) | ||
57 | + # language-dependent part (assuming Western languages): | ||
58 | + s = " %s " % s | ||
59 | + if not preserve_case: | ||
60 | + s = s.lower() # this might not be identical to the original | ||
61 | + for (pattern, replace) in normalize2: | ||
62 | + s = re.sub(pattern, replace, s) | ||
63 | + return s.split() | ||
64 | + | ||
65 | +def count_ngrams(words, n=4): | ||
66 | + counts = {} | ||
67 | + for k in range(1,n+1): | ||
68 | + for i in range(len(words)-k+1): | ||
69 | + ngram = tuple(words[i:i+k]) | ||
70 | + counts[ngram] = counts.get(ngram, 0)+1 | ||
71 | + return counts | ||
72 | + | ||
73 | +def cook_refs(refs, n=4): | ||
74 | + '''Takes a list of reference sentences for a single segment | ||
75 | + and returns an object that encapsulates everything that BLEU | ||
76 | + needs to know about them.''' | ||
77 | + | ||
78 | + refs = [normalize(ref) for ref in refs] | ||
79 | + maxcounts = {} | ||
80 | + for ref in refs: | ||
81 | + counts = count_ngrams(ref, n) | ||
82 | + for (ngram,count) in counts.items(): | ||
83 | + maxcounts[ngram] = max(maxcounts.get(ngram,0), count) | ||
84 | + return ([len(ref) for ref in refs], maxcounts) | ||
85 | + | ||
86 | +def cook_test(test, item, n=4): | ||
87 | + '''Takes a test sentence and returns an object that | ||
88 | + encapsulates everything that BLEU needs to know about it.''' | ||
89 | + (reflens, refmaxcounts)=item | ||
90 | + test = normalize(test) | ||
91 | + result = {} | ||
92 | + result["testlen"] = len(test) | ||
93 | + | ||
94 | + # Calculate effective reference sentence length. | ||
95 | + | ||
96 | + if eff_ref_len == "shortest": | ||
97 | + result["reflen"] = min(reflens) | ||
98 | + elif eff_ref_len == "average": | ||
99 | + result["reflen"] = float(sum(reflens))/len(reflens) | ||
100 | + elif eff_ref_len == "closest": | ||
101 | + min_diff = None | ||
102 | + for reflen in reflens: | ||
103 | + if min_diff is None or abs(reflen-len(test)) < min_diff: | ||
104 | + min_diff = abs(reflen-len(test)) | ||
105 | + result['reflen'] = reflen | ||
106 | + | ||
107 | + result["guess"] = [max(len(test)-k+1,0) for k in range(1,n+1)] | ||
108 | + | ||
109 | + result['correct'] = [0]*n | ||
110 | + counts = count_ngrams(test, n) | ||
111 | + for (ngram, count) in counts.items(): | ||
112 | + result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count) | ||
113 | + | ||
114 | + return result | ||
115 | + | ||
116 | +def score_cooked(allcomps, n=4, ground=0, smooth=1): | ||
117 | + totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n} | ||
118 | + for comps in allcomps: | ||
119 | + for key in ['testlen','reflen']: | ||
120 | + totalcomps[key] += comps[key] | ||
121 | + for key in ['guess','correct']: | ||
122 | + for k in range(n): | ||
123 | + totalcomps[key][k] += comps[key][k] | ||
124 | + logbleu = 0.0 | ||
125 | + all_bleus = [] | ||
126 | + for k in range(n): | ||
127 | + correct = totalcomps['correct'][k] | ||
128 | + guess = totalcomps['guess'][k] | ||
129 | + addsmooth = 0 | ||
130 | + if smooth == 1 and k > 0: | ||
131 | + addsmooth = 1 | ||
132 | + logbleu += math.log(correct + addsmooth + sys.float_info.min)-math.log(guess + addsmooth+ sys.float_info.min) | ||
133 | + if guess == 0: | ||
134 | + all_bleus.append(-10000000) | ||
135 | + else: | ||
136 | + all_bleus.append(math.log(correct + sys.float_info.min)-math.log( guess )) | ||
137 | + | ||
138 | + logbleu /= float(n) | ||
139 | + all_bleus.insert(0, logbleu) | ||
140 | + | ||
141 | + brevPenalty = min(0,1-float(totalcomps['reflen'] + 1)/(totalcomps['testlen'] + 1)) | ||
142 | + for i in range(len(all_bleus)): | ||
143 | + if i ==0: | ||
144 | + all_bleus[i] += brevPenalty | ||
145 | + all_bleus[i] = math.exp(all_bleus[i]) | ||
146 | + return all_bleus | ||
147 | + | ||
148 | +def bleu(refs, candidate, ground=0, smooth=1): | ||
149 | + refs = cook_refs(refs) | ||
150 | + test = cook_test(candidate, refs) | ||
151 | + return score_cooked([test], ground=ground, smooth=smooth) | ||
152 | + | ||
153 | +def splitPuncts(line): | ||
154 | + return ' '.join(re.findall(r"[\w]+|[^\s\w]", line)) | ||
155 | + | ||
156 | +def computeMaps(predictions, goldfile): | ||
157 | + predictionMap = {} | ||
158 | + goldMap = {} | ||
159 | + gf = open(goldfile, 'r') | ||
160 | + | ||
161 | + for row in predictions: | ||
162 | + cols = row.strip().split('\t') | ||
163 | + if len(cols) == 1: | ||
164 | + (rid, pred) = (cols[0], '') | ||
165 | + else: | ||
166 | + (rid, pred) = (cols[0], cols[1]) | ||
167 | + predictionMap[rid] = [splitPuncts(pred.strip().lower())] | ||
168 | + | ||
169 | + for row in gf: | ||
170 | + (rid, pred) = row.split('\t') | ||
171 | + if rid in predictionMap: # Only insert if the id exists for the method | ||
172 | + if rid not in goldMap: | ||
173 | + goldMap[rid] = [] | ||
174 | + goldMap[rid].append(splitPuncts(pred.strip().lower())) | ||
175 | + | ||
176 | + sys.stderr.write('Total: ' + str(len(goldMap)) + '\n') | ||
177 | + return (goldMap, predictionMap) | ||
178 | + | ||
179 | + | ||
180 | +#m1 is the reference map | ||
181 | +#m2 is the prediction map | ||
182 | +def bleuFromMaps(m1, m2): | ||
183 | + score = [0] * 5 | ||
184 | + num = 0.0 | ||
185 | + | ||
186 | + for key in m1: | ||
187 | + if key in m2: | ||
188 | + bl = bleu(m1[key], m2[key][0]) | ||
189 | + score = [ score[i] + bl[i] for i in range(0, len(bl))] | ||
190 | + num += 1 | ||
191 | + return [s * 100.0 / num for s in score] | ||
192 | + | ||
193 | +if __name__ == '__main__': | ||
194 | + reference_file = sys.argv[1] | ||
195 | + predictions = [] | ||
196 | + for row in sys.stdin: | ||
197 | + predictions.append(row) | ||
198 | + (goldMap, predictionMap) = computeMaps(predictions, reference_file) | ||
199 | + print (bleuFromMaps(goldMap, predictionMap)[0]) | ||
200 | + |
code2nl/model.py
0 → 100644
1 | +# Copyright (c) Microsoft Corporation. | ||
2 | +# Licensed under the MIT license. | ||
3 | + | ||
4 | +import torch | ||
5 | +import torch.nn as nn | ||
6 | +import torch | ||
7 | +from torch.autograd import Variable | ||
8 | +import copy | ||
9 | +class Seq2Seq(nn.Module): | ||
10 | + """ | ||
11 | + Build Seqence-to-Sequence. | ||
12 | + | ||
13 | + Parameters: | ||
14 | + | ||
15 | + * `encoder`- encoder of seq2seq model. e.g. roberta | ||
16 | + * `decoder`- decoder of seq2seq model. e.g. transformer | ||
17 | + * `config`- configuration of encoder model. | ||
18 | + * `beam_size`- beam size for beam search. | ||
19 | + * `max_length`- max length of target for beam search. | ||
20 | + * `sos_id`- start of symbol ids in target for beam search. | ||
21 | + * `eos_id`- end of symbol ids in target for beam search. | ||
22 | + """ | ||
23 | + def __init__(self, encoder,decoder,config,beam_size=None,max_length=None,sos_id=None,eos_id=None): | ||
24 | + super(Seq2Seq, self).__init__() | ||
25 | + self.encoder = encoder | ||
26 | + self.decoder=decoder | ||
27 | + self.config=config | ||
28 | + self.register_buffer("bias", torch.tril(torch.ones(2048, 2048))) | ||
29 | + self.dense = nn.Linear(config.hidden_size, config.hidden_size) | ||
30 | + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | ||
31 | + self.lsm = nn.LogSoftmax(dim=-1) | ||
32 | + self.tie_weights() | ||
33 | + | ||
34 | + self.beam_size=beam_size | ||
35 | + self.max_length=max_length | ||
36 | + self.sos_id=sos_id | ||
37 | + self.eos_id=eos_id | ||
38 | + | ||
39 | + def _tie_or_clone_weights(self, first_module, second_module): | ||
40 | + """ Tie or clone module weights depending of weither we are using TorchScript or not | ||
41 | + """ | ||
42 | + if self.config.torchscript: | ||
43 | + first_module.weight = nn.Parameter(second_module.weight.clone()) | ||
44 | + else: | ||
45 | + first_module.weight = second_module.weight | ||
46 | + | ||
47 | + def tie_weights(self): | ||
48 | + """ Make sure we are sharing the input and output embeddings. | ||
49 | + Export to TorchScript can't handle parameter sharing so we are cloning them instead. | ||
50 | + """ | ||
51 | + self._tie_or_clone_weights(self.lm_head, | ||
52 | + self.encoder.embeddings.word_embeddings) | ||
53 | + | ||
54 | + def forward(self, source_ids=None,source_mask=None,target_ids=None,target_mask=None,args=None): | ||
55 | + outputs = self.encoder(source_ids, attention_mask=source_mask) | ||
56 | + encoder_output = outputs[0].permute([1,0,2]).contiguous() | ||
57 | + if target_ids is not None: | ||
58 | + attn_mask=-1e4 *(1-self.bias[:target_ids.shape[1],:target_ids.shape[1]]) | ||
59 | + tgt_embeddings = self.encoder.embeddings(target_ids).permute([1,0,2]).contiguous() | ||
60 | + out = self.decoder(tgt_embeddings,encoder_output,tgt_mask=attn_mask,memory_key_padding_mask=(1-source_mask).bool()) | ||
61 | + hidden_states = torch.tanh(self.dense(out)).permute([1,0,2]).contiguous() | ||
62 | + lm_logits = self.lm_head(hidden_states) | ||
63 | + # Shift so that tokens < n predict n | ||
64 | + active_loss = target_mask[..., 1:].ne(0).view(-1) == 1 | ||
65 | + shift_logits = lm_logits[..., :-1, :].contiguous() | ||
66 | + shift_labels = target_ids[..., 1:].contiguous() | ||
67 | + # Flatten the tokens | ||
68 | + loss_fct = nn.CrossEntropyLoss(ignore_index=-1) | ||
69 | + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1))[active_loss], | ||
70 | + shift_labels.view(-1)[active_loss]) | ||
71 | + | ||
72 | + outputs = loss,loss*active_loss.sum(),active_loss.sum() | ||
73 | + return outputs | ||
74 | + else: | ||
75 | + #Predict | ||
76 | + preds=[] | ||
77 | + zero=torch.cuda.LongTensor(1).fill_(0) | ||
78 | + for i in range(source_ids.shape[0]): | ||
79 | + context=encoder_output[:,i:i+1] | ||
80 | + context_mask=source_mask[i:i+1,:] | ||
81 | + beam = Beam(self.beam_size,self.sos_id,self.eos_id) | ||
82 | + input_ids=beam.getCurrentState() | ||
83 | + context=context.repeat(1, self.beam_size,1) | ||
84 | + context_mask=context_mask.repeat(self.beam_size,1) | ||
85 | + for _ in range(self.max_length): | ||
86 | + if beam.done(): | ||
87 | + break | ||
88 | + attn_mask=-1e4 *(1-self.bias[:input_ids.shape[1],:input_ids.shape[1]]) | ||
89 | + tgt_embeddings = self.encoder.embeddings(input_ids).permute([1,0,2]).contiguous() | ||
90 | + out = self.decoder(tgt_embeddings,context,tgt_mask=attn_mask,memory_key_padding_mask=(1-context_mask).bool()) | ||
91 | + out = torch.tanh(self.dense(out)) | ||
92 | + hidden_states=out.permute([1,0,2]).contiguous()[:,-1,:] | ||
93 | + out = self.lsm(self.lm_head(hidden_states)).data | ||
94 | + beam.advance(out) | ||
95 | + input_ids.data.copy_(input_ids.data.index_select(0, beam.getCurrentOrigin())) | ||
96 | + input_ids=torch.cat((input_ids,beam.getCurrentState()),-1) | ||
97 | + hyp= beam.getHyp(beam.getFinal()) | ||
98 | + pred=beam.buildTargetTokens(hyp)[:self.beam_size] | ||
99 | + pred=[torch.cat([x.view(-1) for x in p]+[zero]*(self.max_length-len(p))).view(1,-1) for p in pred] | ||
100 | + preds.append(torch.cat(pred,0).unsqueeze(0)) | ||
101 | + | ||
102 | + preds=torch.cat(preds,0) | ||
103 | + return preds | ||
104 | + | ||
105 | + | ||
106 | + | ||
107 | +class Beam(object): | ||
108 | + def __init__(self, size,sos,eos): | ||
109 | + self.size = size | ||
110 | + self.tt = torch.cuda | ||
111 | + # The score for each translation on the beam. | ||
112 | + self.scores = self.tt.FloatTensor(size).zero_() | ||
113 | + # The backpointers at each time-step. | ||
114 | + self.prevKs = [] | ||
115 | + # The outputs at each time-step. | ||
116 | + self.nextYs = [self.tt.LongTensor(size) | ||
117 | + .fill_(0)] | ||
118 | + self.nextYs[0][0] = sos | ||
119 | + # Has EOS topped the beam yet. | ||
120 | + self._eos = eos | ||
121 | + self.eosTop = False | ||
122 | + # Time and k pair for finished. | ||
123 | + self.finished = [] | ||
124 | + | ||
125 | + def getCurrentState(self): | ||
126 | + "Get the outputs for the current timestep." | ||
127 | + batch = self.tt.LongTensor(self.nextYs[-1]).view(-1, 1) | ||
128 | + return batch | ||
129 | + | ||
130 | + def getCurrentOrigin(self): | ||
131 | + "Get the backpointers for the current timestep." | ||
132 | + return self.prevKs[-1] | ||
133 | + | ||
134 | + def advance(self, wordLk): | ||
135 | + """ | ||
136 | + Given prob over words for every last beam `wordLk` and attention | ||
137 | + `attnOut`: Compute and update the beam search. | ||
138 | + | ||
139 | + Parameters: | ||
140 | + | ||
141 | + * `wordLk`- probs of advancing from the last step (K x words) | ||
142 | + * `attnOut`- attention at the last step | ||
143 | + | ||
144 | + Returns: True if beam search is complete. | ||
145 | + """ | ||
146 | + numWords = wordLk.size(1) | ||
147 | + | ||
148 | + # Sum the previous scores. | ||
149 | + if len(self.prevKs) > 0: | ||
150 | + beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk) | ||
151 | + | ||
152 | + # Don't let EOS have children. | ||
153 | + for i in range(self.nextYs[-1].size(0)): | ||
154 | + if self.nextYs[-1][i] == self._eos: | ||
155 | + beamLk[i] = -1e20 | ||
156 | + else: | ||
157 | + beamLk = wordLk[0] | ||
158 | + flatBeamLk = beamLk.view(-1) | ||
159 | + bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True) | ||
160 | + | ||
161 | + self.scores = bestScores | ||
162 | + | ||
163 | + # bestScoresId is flattened beam x word array, so calculate which | ||
164 | + # word and beam each score came from | ||
165 | + prevK = bestScoresId / numWords | ||
166 | + self.prevKs.append(prevK) | ||
167 | + self.nextYs.append((bestScoresId - prevK * numWords)) | ||
168 | + | ||
169 | + | ||
170 | + for i in range(self.nextYs[-1].size(0)): | ||
171 | + if self.nextYs[-1][i] == self._eos: | ||
172 | + s = self.scores[i] | ||
173 | + self.finished.append((s, len(self.nextYs) - 1, i)) | ||
174 | + | ||
175 | + # End condition is when top-of-beam is EOS and no global score. | ||
176 | + if self.nextYs[-1][0] == self._eos: | ||
177 | + self.eosTop = True | ||
178 | + | ||
179 | + def done(self): | ||
180 | + return self.eosTop and len(self.finished) >=self.size | ||
181 | + | ||
182 | + def getFinal(self): | ||
183 | + if len(self.finished) == 0: | ||
184 | + self.finished.append((self.scores[0], len(self.nextYs) - 1, 0)) | ||
185 | + self.finished.sort(key=lambda a: -a[0]) | ||
186 | + if len(self.finished) != self.size: | ||
187 | + unfinished=[] | ||
188 | + for i in range(self.nextYs[-1].size(0)): | ||
189 | + if self.nextYs[-1][i] != self._eos: | ||
190 | + s = self.scores[i] | ||
191 | + unfinished.append((s, len(self.nextYs) - 1, i)) | ||
192 | + unfinished.sort(key=lambda a: -a[0]) | ||
193 | + self.finished+=unfinished[:self.size-len(self.finished)] | ||
194 | + return self.finished[:self.size] | ||
195 | + | ||
196 | + def getHyp(self, beam_res): | ||
197 | + """ | ||
198 | + Walk back to construct the full hypothesis. | ||
199 | + """ | ||
200 | + hyps=[] | ||
201 | + for _,timestep, k in beam_res: | ||
202 | + hyp = [] | ||
203 | + for j in range(len(self.prevKs[:timestep]) - 1, -1, -1): | ||
204 | + hyp.append(self.nextYs[j+1][k]) | ||
205 | + k = self.prevKs[j][k] | ||
206 | + hyps.append(hyp[::-1]) | ||
207 | + return hyps | ||
208 | + | ||
209 | + def buildTargetTokens(self, preds): | ||
210 | + sentence=[] | ||
211 | + for pred in preds: | ||
212 | + tokens = [] | ||
213 | + for tok in pred: | ||
214 | + if tok==self._eos: | ||
215 | + break | ||
216 | + tokens.append(tok) | ||
217 | + sentence.append(tokens) | ||
218 | + return sentence | ||
219 | + |
code2nl/run.py
0 → 100644
1 | +# coding=utf-8 | ||
2 | +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. | ||
3 | +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | ||
4 | +# | ||
5 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
6 | +# you may not use this file except in compliance with the License. | ||
7 | +# You may obtain a copy of the License at | ||
8 | +# | ||
9 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
10 | +# | ||
11 | +# Unless required by applicable law or agreed to in writing, software | ||
12 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
14 | +# See the License for the specific language governing permissions and | ||
15 | +# limitations under the License. | ||
16 | +""" | ||
17 | +Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa). | ||
18 | +GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned | ||
19 | +using a masked language modeling (MLM) loss. | ||
20 | +""" | ||
21 | + | ||
22 | +from __future__ import absolute_import | ||
23 | +import os | ||
24 | +import sys | ||
25 | +import bleu | ||
26 | +import pickle | ||
27 | +import torch | ||
28 | +import json | ||
29 | +import random | ||
30 | +import logging | ||
31 | +import argparse | ||
32 | +import numpy as np | ||
33 | +from io import open | ||
34 | +from itertools import cycle | ||
35 | +import torch.nn as nn | ||
36 | +from model import Seq2Seq | ||
37 | +from tqdm import tqdm, trange | ||
38 | +from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler,TensorDataset | ||
39 | +from torch.utils.data.distributed import DistributedSampler | ||
40 | +from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup, | ||
41 | + RobertaConfig, RobertaModel, RobertaTokenizer) | ||
42 | +MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)} | ||
43 | + | ||
44 | +logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', | ||
45 | + datefmt = '%m/%d/%Y %H:%M:%S', | ||
46 | + level = logging.INFO) | ||
47 | +logger = logging.getLogger(__name__) | ||
48 | + | ||
49 | +class Example(object): | ||
50 | + """A single training/test example.""" | ||
51 | + def __init__(self, | ||
52 | + idx, | ||
53 | + source, | ||
54 | + target, | ||
55 | + ): | ||
56 | + self.idx = idx | ||
57 | + self.source = source | ||
58 | + self.target = target | ||
59 | + | ||
60 | +def read_examples(filename): | ||
61 | + """Read examples from filename.""" | ||
62 | + examples=[] | ||
63 | + with open(filename,encoding="utf-8") as f: | ||
64 | + for idx, line in enumerate(f): | ||
65 | + line=line.strip() | ||
66 | + js=json.loads(line) | ||
67 | + if 'idx' not in js: | ||
68 | + js['idx']=idx | ||
69 | + code=' '.join(js['code_tokens']).replace('\n',' ') | ||
70 | + code=' '.join(code.strip().split()) | ||
71 | + nl=' '.join(js['docstring_tokens']).replace('\n','') | ||
72 | + nl=' '.join(nl.strip().split()) | ||
73 | + examples.append( | ||
74 | + Example( | ||
75 | + idx = idx, | ||
76 | + source=code, | ||
77 | + target = nl, | ||
78 | + ) | ||
79 | + ) | ||
80 | + return examples | ||
81 | + | ||
82 | + | ||
83 | +class InputFeatures(object): | ||
84 | + """A single training/test features for a example.""" | ||
85 | + def __init__(self, | ||
86 | + example_id, | ||
87 | + source_ids, | ||
88 | + target_ids, | ||
89 | + source_mask, | ||
90 | + target_mask, | ||
91 | + | ||
92 | + ): | ||
93 | + self.example_id = example_id | ||
94 | + self.source_ids = source_ids | ||
95 | + self.target_ids = target_ids | ||
96 | + self.source_mask = source_mask | ||
97 | + self.target_mask = target_mask | ||
98 | + | ||
99 | + | ||
100 | + | ||
101 | +def convert_examples_to_features(examples, tokenizer, args,stage=None): | ||
102 | + features = [] | ||
103 | + for example_index, example in enumerate(examples): | ||
104 | + #source | ||
105 | + source_tokens = tokenizer.tokenize(example.source)[:args.max_source_length-2] | ||
106 | + source_tokens =[tokenizer.cls_token]+source_tokens+[tokenizer.sep_token] | ||
107 | + source_ids = tokenizer.convert_tokens_to_ids(source_tokens) | ||
108 | + source_mask = [1] * (len(source_tokens)) | ||
109 | + padding_length = args.max_source_length - len(source_ids) | ||
110 | + source_ids+=[tokenizer.pad_token_id]*padding_length | ||
111 | + source_mask+=[0]*padding_length | ||
112 | + | ||
113 | + #target | ||
114 | + if stage=="test": | ||
115 | + target_tokens = tokenizer.tokenize("None") | ||
116 | + else: | ||
117 | + target_tokens = tokenizer.tokenize(example.target)[:args.max_target_length-2] | ||
118 | + target_tokens = [tokenizer.cls_token]+target_tokens+[tokenizer.sep_token] | ||
119 | + target_ids = tokenizer.convert_tokens_to_ids(target_tokens) | ||
120 | + target_mask = [1] *len(target_ids) | ||
121 | + padding_length = args.max_target_length - len(target_ids) | ||
122 | + target_ids+=[tokenizer.pad_token_id]*padding_length | ||
123 | + target_mask+=[0]*padding_length | ||
124 | + | ||
125 | + if example_index < 5: | ||
126 | + if stage=='train': | ||
127 | + logger.info("*** Example ***") | ||
128 | + logger.info("idx: {}".format(example.idx)) | ||
129 | + | ||
130 | + logger.info("source_tokens: {}".format([x.replace('\u0120','_') for x in source_tokens])) | ||
131 | + logger.info("source_ids: {}".format(' '.join(map(str, source_ids)))) | ||
132 | + logger.info("source_mask: {}".format(' '.join(map(str, source_mask)))) | ||
133 | + | ||
134 | + logger.info("target_tokens: {}".format([x.replace('\u0120','_') for x in target_tokens])) | ||
135 | + logger.info("target_ids: {}".format(' '.join(map(str, target_ids)))) | ||
136 | + logger.info("target_mask: {}".format(' '.join(map(str, target_mask)))) | ||
137 | + | ||
138 | + features.append( | ||
139 | + InputFeatures( | ||
140 | + example_index, | ||
141 | + source_ids, | ||
142 | + target_ids, | ||
143 | + source_mask, | ||
144 | + target_mask, | ||
145 | + ) | ||
146 | + ) | ||
147 | + return features | ||
148 | + | ||
149 | + | ||
150 | + | ||
151 | +def set_seed(args): | ||
152 | + """set random seed.""" | ||
153 | + random.seed(args.seed) | ||
154 | + np.random.seed(args.seed) | ||
155 | + torch.manual_seed(args.seed) | ||
156 | + if args.n_gpu > 0: | ||
157 | + torch.cuda.manual_seed_all(args.seed) | ||
158 | + | ||
159 | +def main(): | ||
160 | + parser = argparse.ArgumentParser() | ||
161 | + | ||
162 | + ## Required parameters | ||
163 | + parser.add_argument("--model_type", default=None, type=str, required=True, | ||
164 | + help="Model type: e.g. roberta") | ||
165 | + parser.add_argument("--model_name_or_path", default=None, type=str, required=True, | ||
166 | + help="Path to pre-trained model: e.g. roberta-base" ) | ||
167 | + parser.add_argument("--output_dir", default=None, type=str, required=True, | ||
168 | + help="The output directory where the model predictions and checkpoints will be written.") | ||
169 | + parser.add_argument("--load_model_path", default=None, type=str, | ||
170 | + help="Path to trained model: Should contain the .bin files" ) | ||
171 | + ## Other parameters | ||
172 | + parser.add_argument("--train_filename", default=None, type=str, | ||
173 | + help="The train filename. Should contain the .jsonl files for this task.") | ||
174 | + parser.add_argument("--dev_filename", default=None, type=str, | ||
175 | + help="The dev filename. Should contain the .jsonl files for this task.") | ||
176 | + parser.add_argument("--test_filename", default=None, type=str, | ||
177 | + help="The test filename. Should contain the .jsonl files for this task.") | ||
178 | + | ||
179 | + parser.add_argument("--config_name", default="", type=str, | ||
180 | + help="Pretrained config name or path if not the same as model_name") | ||
181 | + parser.add_argument("--tokenizer_name", default="", type=str, | ||
182 | + help="Pretrained tokenizer name or path if not the same as model_name") | ||
183 | + parser.add_argument("--max_source_length", default=64, type=int, | ||
184 | + help="The maximum total source sequence length after tokenization. Sequences longer " | ||
185 | + "than this will be truncated, sequences shorter will be padded.") | ||
186 | + parser.add_argument("--max_target_length", default=32, type=int, | ||
187 | + help="The maximum total target sequence length after tokenization. Sequences longer " | ||
188 | + "than this will be truncated, sequences shorter will be padded.") | ||
189 | + | ||
190 | + parser.add_argument("--do_train", action='store_true', | ||
191 | + help="Whether to run training.") | ||
192 | + parser.add_argument("--do_eval", action='store_true', | ||
193 | + help="Whether to run eval on the dev set.") | ||
194 | + parser.add_argument("--do_test", action='store_true', | ||
195 | + help="Whether to run eval on the dev set.") | ||
196 | + parser.add_argument("--do_lower_case", action='store_true', | ||
197 | + help="Set this flag if you are using an uncased model.") | ||
198 | + parser.add_argument("--no_cuda", action='store_true', | ||
199 | + help="Avoid using CUDA when available") | ||
200 | + | ||
201 | + parser.add_argument("--train_batch_size", default=8, type=int, | ||
202 | + help="Batch size per GPU/CPU for training.") | ||
203 | + parser.add_argument("--eval_batch_size", default=8, type=int, | ||
204 | + help="Batch size per GPU/CPU for evaluation.") | ||
205 | + parser.add_argument('--gradient_accumulation_steps', type=int, default=1, | ||
206 | + help="Number of updates steps to accumulate before performing a backward/update pass.") | ||
207 | + parser.add_argument("--learning_rate", default=5e-5, type=float, | ||
208 | + help="The initial learning rate for Adam.") | ||
209 | + parser.add_argument("--beam_size", default=10, type=int, | ||
210 | + help="beam size for beam search") | ||
211 | + parser.add_argument("--weight_decay", default=0.0, type=float, | ||
212 | + help="Weight deay if we apply some.") | ||
213 | + parser.add_argument("--adam_epsilon", default=1e-8, type=float, | ||
214 | + help="Epsilon for Adam optimizer.") | ||
215 | + parser.add_argument("--max_grad_norm", default=1.0, type=float, | ||
216 | + help="Max gradient norm.") | ||
217 | + parser.add_argument("--num_train_epochs", default=3.0, type=float, | ||
218 | + help="Total number of training epochs to perform.") | ||
219 | + parser.add_argument("--max_steps", default=-1, type=int, | ||
220 | + help="If > 0: set total number of training steps to perform. Override num_train_epochs.") | ||
221 | + parser.add_argument("--eval_steps", default=-1, type=int, | ||
222 | + help="") | ||
223 | + parser.add_argument("--train_steps", default=-1, type=int, | ||
224 | + help="") | ||
225 | + parser.add_argument("--warmup_steps", default=0, type=int, | ||
226 | + help="Linear warmup over warmup_steps.") | ||
227 | + parser.add_argument("--local_rank", type=int, default=-1, | ||
228 | + help="For distributed training: local_rank") | ||
229 | + parser.add_argument('--seed', type=int, default=42, | ||
230 | + help="random seed for initialization") | ||
231 | + # print arguments | ||
232 | + args = parser.parse_args() | ||
233 | + logger.info(args) | ||
234 | + | ||
235 | + # Setup CUDA, GPU & distributed training | ||
236 | + if args.local_rank == -1 or args.no_cuda: | ||
237 | + device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") | ||
238 | + args.n_gpu = torch.cuda.device_count() | ||
239 | + else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs | ||
240 | + torch.cuda.set_device(args.local_rank) | ||
241 | + device = torch.device("cuda", args.local_rank) | ||
242 | + torch.distributed.init_process_group(backend='nccl') | ||
243 | + args.n_gpu = 1 | ||
244 | + logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s", | ||
245 | + args.local_rank, device, args.n_gpu, bool(args.local_rank != -1)) | ||
246 | + args.device = device | ||
247 | + # Set seed | ||
248 | + set_seed(args) | ||
249 | + # make dir if output_dir not exist | ||
250 | + if os.path.exists(args.output_dir) is False: | ||
251 | + os.makedirs(args.output_dir) | ||
252 | + | ||
253 | + config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] | ||
254 | + config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path) | ||
255 | + tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,do_lower_case=args.do_lower_case) | ||
256 | + | ||
257 | + #budild model | ||
258 | + encoder = model_class.from_pretrained(args.model_name_or_path,config=config) | ||
259 | + decoder_layer = nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads) | ||
260 | + decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) | ||
261 | + model=Seq2Seq(encoder=encoder,decoder=decoder,config=config, | ||
262 | + beam_size=args.beam_size,max_length=args.max_target_length, | ||
263 | + sos_id=tokenizer.cls_token_id,eos_id=tokenizer.sep_token_id) | ||
264 | + if args.load_model_path is not None: | ||
265 | + logger.info("reload model from {}".format(args.load_model_path)) | ||
266 | + model.load_state_dict(torch.load(args.load_model_path)) | ||
267 | + | ||
268 | + model.to(device) | ||
269 | + if args.local_rank != -1: | ||
270 | + # Distributed training | ||
271 | + try: | ||
272 | + from apex.parallel import DistributedDataParallel as DDP | ||
273 | + except ImportError: | ||
274 | + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") | ||
275 | + | ||
276 | + model = DDP(model) | ||
277 | + elif args.n_gpu > 1: | ||
278 | + # multi-gpu training | ||
279 | + model = torch.nn.DataParallel(model) | ||
280 | + | ||
281 | + | ||
282 | + | ||
283 | + | ||
284 | + if args.do_train: | ||
285 | + # Prepare training data loader | ||
286 | + train_examples = read_examples(args.train_filename) | ||
287 | + train_features = convert_examples_to_features(train_examples, tokenizer,args,stage='train') | ||
288 | + all_source_ids = torch.tensor([f.source_ids for f in train_features], dtype=torch.long) | ||
289 | + all_source_mask = torch.tensor([f.source_mask for f in train_features], dtype=torch.long) | ||
290 | + all_target_ids = torch.tensor([f.target_ids for f in train_features], dtype=torch.long) | ||
291 | + all_target_mask = torch.tensor([f.target_mask for f in train_features], dtype=torch.long) | ||
292 | + train_data = TensorDataset(all_source_ids,all_source_mask,all_target_ids,all_target_mask) | ||
293 | + | ||
294 | + if args.local_rank == -1: | ||
295 | + train_sampler = RandomSampler(train_data) | ||
296 | + else: | ||
297 | + train_sampler = DistributedSampler(train_data) | ||
298 | + train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size//args.gradient_accumulation_steps) | ||
299 | + | ||
300 | + num_train_optimization_steps = args.train_steps | ||
301 | + | ||
302 | + # Prepare optimizer and schedule (linear warmup and decay) | ||
303 | + no_decay = ['bias', 'LayerNorm.weight'] | ||
304 | + optimizer_grouped_parameters = [ | ||
305 | + {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], | ||
306 | + 'weight_decay': args.weight_decay}, | ||
307 | + {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} | ||
308 | + ] | ||
309 | + optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) | ||
310 | + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, | ||
311 | + num_training_steps=num_train_optimization_steps) | ||
312 | + | ||
313 | + | ||
314 | + #Start training | ||
315 | + logger.info("***** Running training *****") | ||
316 | + logger.info(" Num examples = %d", len(train_examples)) | ||
317 | + logger.info(" Batch size = %d", args.train_batch_size) | ||
318 | + logger.info(" Num epoch = %d", num_train_optimization_steps*args.train_batch_size//len(train_examples)) | ||
319 | + | ||
320 | + | ||
321 | + model.train() | ||
322 | + dev_dataset={} | ||
323 | + nb_tr_examples, nb_tr_steps,tr_loss,global_step,best_bleu,best_loss = 0, 0,0,0,0,1e6 | ||
324 | + bar = tqdm(range(num_train_optimization_steps),total=num_train_optimization_steps) | ||
325 | + train_dataloader=cycle(train_dataloader) | ||
326 | + eval_flag = True | ||
327 | + for step in bar: | ||
328 | + batch = next(train_dataloader) | ||
329 | + batch = tuple(t.to(device) for t in batch) | ||
330 | + source_ids,source_mask,target_ids,target_mask = batch | ||
331 | + loss,_,_ = model(source_ids=source_ids,source_mask=source_mask,target_ids=target_ids,target_mask=target_mask) | ||
332 | + | ||
333 | + if args.n_gpu > 1: | ||
334 | + loss = loss.mean() # mean() to average on multi-gpu. | ||
335 | + if args.gradient_accumulation_steps > 1: | ||
336 | + loss = loss / args.gradient_accumulation_steps | ||
337 | + tr_loss += loss.item() | ||
338 | + train_loss=round(tr_loss*args.gradient_accumulation_steps/(nb_tr_steps+1),4) | ||
339 | + bar.set_description("loss {}".format(train_loss)) | ||
340 | + nb_tr_examples += source_ids.size(0) | ||
341 | + nb_tr_steps += 1 | ||
342 | + loss.backward() | ||
343 | + | ||
344 | + if (nb_tr_steps + 1) % args.gradient_accumulation_steps == 0: | ||
345 | + #Update parameters | ||
346 | + optimizer.step() | ||
347 | + optimizer.zero_grad() | ||
348 | + scheduler.step() | ||
349 | + global_step += 1 | ||
350 | + eval_flag = True | ||
351 | + | ||
352 | + if args.do_eval and ((global_step + 1) %args.eval_steps == 0) and eval_flag: | ||
353 | + #Eval model with dev dataset | ||
354 | + tr_loss = 0 | ||
355 | + nb_tr_examples, nb_tr_steps = 0, 0 | ||
356 | + eval_flag=False | ||
357 | + if 'dev_loss' in dev_dataset: | ||
358 | + eval_examples,eval_data=dev_dataset['dev_loss'] | ||
359 | + else: | ||
360 | + eval_examples = read_examples(args.dev_filename) | ||
361 | + eval_features = convert_examples_to_features(eval_examples, tokenizer, args,stage='dev') | ||
362 | + all_source_ids = torch.tensor([f.source_ids for f in eval_features], dtype=torch.long) | ||
363 | + all_source_mask = torch.tensor([f.source_mask for f in eval_features], dtype=torch.long) | ||
364 | + all_target_ids = torch.tensor([f.target_ids for f in eval_features], dtype=torch.long) | ||
365 | + all_target_mask = torch.tensor([f.target_mask for f in eval_features], dtype=torch.long) | ||
366 | + eval_data = TensorDataset(all_source_ids,all_source_mask,all_target_ids,all_target_mask) | ||
367 | + dev_dataset['dev_loss']=eval_examples,eval_data | ||
368 | + eval_sampler = SequentialSampler(eval_data) | ||
369 | + eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) | ||
370 | + | ||
371 | + logger.info("\n***** Running evaluation *****") | ||
372 | + logger.info(" Num examples = %d", len(eval_examples)) | ||
373 | + logger.info(" Batch size = %d", args.eval_batch_size) | ||
374 | + | ||
375 | + #Start Evaling model | ||
376 | + model.eval() | ||
377 | + eval_loss,tokens_num = 0,0 | ||
378 | + for batch in eval_dataloader: | ||
379 | + batch = tuple(t.to(device) for t in batch) | ||
380 | + source_ids,source_mask,target_ids,target_mask = batch | ||
381 | + | ||
382 | + with torch.no_grad(): | ||
383 | + _,loss,num = model(source_ids=source_ids,source_mask=source_mask, | ||
384 | + target_ids=target_ids,target_mask=target_mask) | ||
385 | + eval_loss += loss.sum().item() | ||
386 | + tokens_num += num.sum().item() | ||
387 | + #Pring loss of dev dataset | ||
388 | + model.train() | ||
389 | + eval_loss = eval_loss / tokens_num | ||
390 | + result = {'eval_ppl': round(np.exp(eval_loss),5), | ||
391 | + 'global_step': global_step+1, | ||
392 | + 'train_loss': round(train_loss,5)} | ||
393 | + for key in sorted(result.keys()): | ||
394 | + logger.info(" %s = %s", key, str(result[key])) | ||
395 | + logger.info(" "+"*"*20) | ||
396 | + | ||
397 | + #save last checkpoint | ||
398 | + last_output_dir = os.path.join(args.output_dir, 'checkpoint-last') | ||
399 | + if not os.path.exists(last_output_dir): | ||
400 | + os.makedirs(last_output_dir) | ||
401 | + model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self | ||
402 | + output_model_file = os.path.join(last_output_dir, "pytorch_model.bin") | ||
403 | + torch.save(model_to_save.state_dict(), output_model_file) | ||
404 | + if eval_loss<best_loss: | ||
405 | + logger.info(" Best ppl:%s",round(np.exp(eval_loss),5)) | ||
406 | + logger.info(" "+"*"*20) | ||
407 | + best_loss=eval_loss | ||
408 | + # Save best checkpoint for best ppl | ||
409 | + output_dir = os.path.join(args.output_dir, 'checkpoint-best-ppl') | ||
410 | + if not os.path.exists(output_dir): | ||
411 | + os.makedirs(output_dir) | ||
412 | + model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self | ||
413 | + output_model_file = os.path.join(output_dir, "pytorch_model.bin") | ||
414 | + torch.save(model_to_save.state_dict(), output_model_file) | ||
415 | + | ||
416 | + | ||
417 | + #Calculate bleu | ||
418 | + if 'dev_bleu' in dev_dataset: | ||
419 | + eval_examples,eval_data=dev_dataset['dev_bleu'] | ||
420 | + else: | ||
421 | + eval_examples = read_examples(args.dev_filename) | ||
422 | + eval_examples = random.sample(eval_examples,min(1000,len(eval_examples))) | ||
423 | + eval_features = convert_examples_to_features(eval_examples, tokenizer, args,stage='test') | ||
424 | + all_source_ids = torch.tensor([f.source_ids for f in eval_features], dtype=torch.long) | ||
425 | + all_source_mask = torch.tensor([f.source_mask for f in eval_features], dtype=torch.long) | ||
426 | + eval_data = TensorDataset(all_source_ids,all_source_mask) | ||
427 | + dev_dataset['dev_bleu']=eval_examples,eval_data | ||
428 | + | ||
429 | + | ||
430 | + | ||
431 | + eval_sampler = SequentialSampler(eval_data) | ||
432 | + eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) | ||
433 | + | ||
434 | + model.eval() | ||
435 | + p=[] | ||
436 | + for batch in eval_dataloader: | ||
437 | + batch = tuple(t.to(device) for t in batch) | ||
438 | + source_ids,source_mask= batch | ||
439 | + with torch.no_grad(): | ||
440 | + preds = model(source_ids=source_ids,source_mask=source_mask) | ||
441 | + for pred in preds: | ||
442 | + t=pred[0].cpu().numpy() | ||
443 | + t=list(t) | ||
444 | + if 0 in t: | ||
445 | + t=t[:t.index(0)] | ||
446 | + text = tokenizer.decode(t,clean_up_tokenization_spaces=False) | ||
447 | + p.append(text) | ||
448 | + model.train() | ||
449 | + predictions=[] | ||
450 | + with open(os.path.join(args.output_dir,"dev.output"),'w') as f, open(os.path.join(args.output_dir,"dev.gold"),'w') as f1: | ||
451 | + for ref,gold in zip(p,eval_examples): | ||
452 | + predictions.append(str(gold.idx)+'\t'+ref) | ||
453 | + f.write(str(gold.idx)+'\t'+ref+'\n') | ||
454 | + f1.write(str(gold.idx)+'\t'+gold.target+'\n') | ||
455 | + | ||
456 | + (goldMap, predictionMap) = bleu.computeMaps(predictions, os.path.join(args.output_dir, "dev.gold")) | ||
457 | + dev_bleu=round(bleu.bleuFromMaps(goldMap, predictionMap)[0],2) | ||
458 | + logger.info(" %s = %s "%("bleu-4",str(dev_bleu))) | ||
459 | + logger.info(" "+"*"*20) | ||
460 | + if dev_bleu>best_bleu: | ||
461 | + logger.info(" Best bleu:%s",dev_bleu) | ||
462 | + logger.info(" "+"*"*20) | ||
463 | + best_bleu=dev_bleu | ||
464 | + # Save best checkpoint for best bleu | ||
465 | + output_dir = os.path.join(args.output_dir, 'checkpoint-best-bleu') | ||
466 | + if not os.path.exists(output_dir): | ||
467 | + os.makedirs(output_dir) | ||
468 | + model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self | ||
469 | + output_model_file = os.path.join(output_dir, "pytorch_model.bin") | ||
470 | + torch.save(model_to_save.state_dict(), output_model_file) | ||
471 | + | ||
472 | + if args.do_test: | ||
473 | + files=[] | ||
474 | + if args.dev_filename is not None: | ||
475 | + files.append(args.dev_filename) | ||
476 | + if args.test_filename is not None: | ||
477 | + files.append(args.test_filename) | ||
478 | + for idx,file in enumerate(files): | ||
479 | + logger.info("Test file: {}".format(file)) | ||
480 | + eval_examples = read_examples(file) | ||
481 | + eval_features = convert_examples_to_features(eval_examples, tokenizer, args,stage='test') | ||
482 | + all_source_ids = torch.tensor([f.source_ids for f in eval_features], dtype=torch.long) | ||
483 | + all_source_mask = torch.tensor([f.source_mask for f in eval_features], dtype=torch.long) | ||
484 | + eval_data = TensorDataset(all_source_ids,all_source_mask) | ||
485 | + | ||
486 | + # Calculate bleu | ||
487 | + eval_sampler = SequentialSampler(eval_data) | ||
488 | + eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) | ||
489 | + | ||
490 | + model.eval() | ||
491 | + p=[] | ||
492 | + for batch in tqdm(eval_dataloader,total=len(eval_dataloader)): | ||
493 | + batch = tuple(t.to(device) for t in batch) | ||
494 | + source_ids,source_mask= batch | ||
495 | + with torch.no_grad(): | ||
496 | + preds = model(source_ids=source_ids,source_mask=source_mask) | ||
497 | + for pred in preds: | ||
498 | + t=pred[0].cpu().numpy() | ||
499 | + t=list(t) | ||
500 | + if 0 in t: | ||
501 | + t=t[:t.index(0)] | ||
502 | + text = tokenizer.decode(t,clean_up_tokenization_spaces=False) | ||
503 | + p.append(text) | ||
504 | + model.train() | ||
505 | + predictions=[] | ||
506 | + with open(os.path.join(args.output_dir,"test_{}.output".format(str(idx))),'w') as f, open(os.path.join(args.output_dir,"test_{}.gold".format(str(idx))),'w') as f1: | ||
507 | + for ref,gold in zip(p,eval_examples): | ||
508 | + predictions.append(str(gold.idx)+'\t'+ref) | ||
509 | + f.write(str(gold.idx)+'\t'+ref+'\n') | ||
510 | + f1.write(str(gold.idx)+'\t'+gold.target+'\n') | ||
511 | + | ||
512 | + (goldMap, predictionMap) = bleu.computeMaps(predictions, os.path.join(args.output_dir, "test_{}.gold".format(idx))) | ||
513 | + dev_bleu=round(bleu.bleuFromMaps(goldMap, predictionMap)[0],2) | ||
514 | + logger.info(" %s = %s "%("bleu-4",str(dev_bleu))) | ||
515 | + logger.info(" "+"*"*20) | ||
516 | + | ||
517 | + | ||
518 | + | ||
519 | + | ||
520 | + | ||
521 | + | ||
522 | + | ||
523 | +if __name__ == "__main__": | ||
524 | + main() | ||
525 | + | ||
526 | + |
-
Please register or login to post a comment