Showing
3 changed files
with
419 additions
and
0 deletions
code2nl/bleu.py
0 → 100644
1 | +#!/usr/bin/python | ||
2 | + | ||
3 | +''' | ||
4 | +This script was adapted from the original version by hieuhoang1972 which is part of MOSES. | ||
5 | +''' | ||
6 | + | ||
7 | +# $Id: bleu.py 1307 2007-03-14 22:22:36Z hieuhoang1972 $ | ||
8 | + | ||
9 | +'''Provides: | ||
10 | + | ||
11 | +cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). | ||
12 | +cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). | ||
13 | +score_cooked(alltest, n=4): Score a list of cooked test sentences. | ||
14 | + | ||
15 | +score_set(s, testid, refids, n=4): Interface with dataset.py; calculate BLEU score of testid against refids. | ||
16 | + | ||
17 | +The reason for breaking the BLEU computation into three phases cook_refs(), cook_test(), and score_cooked() is to allow the caller to calculate BLEU scores for multiple test sets as efficiently as possible. | ||
18 | +''' | ||
19 | + | ||
20 | +import sys, math, re, xml.sax.saxutils | ||
21 | +import subprocess | ||
22 | +import os | ||
23 | + | ||
24 | +# Added to bypass NIST-style pre-processing of hyp and ref files -- wade | ||
25 | +nonorm = 0 | ||
26 | + | ||
27 | +preserve_case = False | ||
28 | +eff_ref_len = "shortest" | ||
29 | + | ||
30 | +normalize1 = [ | ||
31 | + ('<skipped>', ''), # strip "skipped" tags | ||
32 | + (r'-\n', ''), # strip end-of-line hyphenation and join lines | ||
33 | + (r'\n', ' '), # join lines | ||
34 | +# (r'(\d)\s+(?=\d)', r'\1'), # join digits | ||
35 | +] | ||
36 | +normalize1 = [(re.compile(pattern), replace) for (pattern, replace) in normalize1] | ||
37 | + | ||
38 | +normalize2 = [ | ||
39 | + (r'([\{-\~\[-\` -\&\(-\+\:-\@\/])',r' \1 '), # tokenize punctuation. apostrophe is missing | ||
40 | + (r'([^0-9])([\.,])',r'\1 \2 '), # tokenize period and comma unless preceded by a digit | ||
41 | + (r'([\.,])([^0-9])',r' \1 \2'), # tokenize period and comma unless followed by a digit | ||
42 | + (r'([0-9])(-)',r'\1 \2 ') # tokenize dash when preceded by a digit | ||
43 | +] | ||
44 | +normalize2 = [(re.compile(pattern), replace) for (pattern, replace) in normalize2] | ||
45 | + | ||
46 | +def normalize(s): | ||
47 | + '''Normalize and tokenize text. This is lifted from NIST mteval-v11a.pl.''' | ||
48 | + # Added to bypass NIST-style pre-processing of hyp and ref files -- wade | ||
49 | + if (nonorm): | ||
50 | + return s.split() | ||
51 | + if type(s) is not str: | ||
52 | + s = " ".join(s) | ||
53 | + # language-independent part: | ||
54 | + for (pattern, replace) in normalize1: | ||
55 | + s = re.sub(pattern, replace, s) | ||
56 | + s = xml.sax.saxutils.unescape(s, {'"':'"'}) | ||
57 | + # language-dependent part (assuming Western languages): | ||
58 | + s = " %s " % s | ||
59 | + if not preserve_case: | ||
60 | + s = s.lower() # this might not be identical to the original | ||
61 | + for (pattern, replace) in normalize2: | ||
62 | + s = re.sub(pattern, replace, s) | ||
63 | + return s.split() | ||
64 | + | ||
65 | +def count_ngrams(words, n=4): | ||
66 | + counts = {} | ||
67 | + for k in range(1,n+1): | ||
68 | + for i in range(len(words)-k+1): | ||
69 | + ngram = tuple(words[i:i+k]) | ||
70 | + counts[ngram] = counts.get(ngram, 0)+1 | ||
71 | + return counts | ||
72 | + | ||
73 | +def cook_refs(refs, n=4): | ||
74 | + '''Takes a list of reference sentences for a single segment | ||
75 | + and returns an object that encapsulates everything that BLEU | ||
76 | + needs to know about them.''' | ||
77 | + | ||
78 | + refs = [normalize(ref) for ref in refs] | ||
79 | + maxcounts = {} | ||
80 | + for ref in refs: | ||
81 | + counts = count_ngrams(ref, n) | ||
82 | + for (ngram,count) in counts.items(): | ||
83 | + maxcounts[ngram] = max(maxcounts.get(ngram,0), count) | ||
84 | + return ([len(ref) for ref in refs], maxcounts) | ||
85 | + | ||
86 | +def cook_test(test, item, n=4): | ||
87 | + '''Takes a test sentence and returns an object that | ||
88 | + encapsulates everything that BLEU needs to know about it.''' | ||
89 | + (reflens, refmaxcounts)=item | ||
90 | + test = normalize(test) | ||
91 | + result = {} | ||
92 | + result["testlen"] = len(test) | ||
93 | + | ||
94 | + # Calculate effective reference sentence length. | ||
95 | + | ||
96 | + if eff_ref_len == "shortest": | ||
97 | + result["reflen"] = min(reflens) | ||
98 | + elif eff_ref_len == "average": | ||
99 | + result["reflen"] = float(sum(reflens))/len(reflens) | ||
100 | + elif eff_ref_len == "closest": | ||
101 | + min_diff = None | ||
102 | + for reflen in reflens: | ||
103 | + if min_diff is None or abs(reflen-len(test)) < min_diff: | ||
104 | + min_diff = abs(reflen-len(test)) | ||
105 | + result['reflen'] = reflen | ||
106 | + | ||
107 | + result["guess"] = [max(len(test)-k+1,0) for k in range(1,n+1)] | ||
108 | + | ||
109 | + result['correct'] = [0]*n | ||
110 | + counts = count_ngrams(test, n) | ||
111 | + for (ngram, count) in counts.items(): | ||
112 | + result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count) | ||
113 | + | ||
114 | + return result | ||
115 | + | ||
116 | +def score_cooked(allcomps, n=4, ground=0, smooth=1): | ||
117 | + totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n} | ||
118 | + for comps in allcomps: | ||
119 | + for key in ['testlen','reflen']: | ||
120 | + totalcomps[key] += comps[key] | ||
121 | + for key in ['guess','correct']: | ||
122 | + for k in range(n): | ||
123 | + totalcomps[key][k] += comps[key][k] | ||
124 | + logbleu = 0.0 | ||
125 | + all_bleus = [] | ||
126 | + for k in range(n): | ||
127 | + correct = totalcomps['correct'][k] | ||
128 | + guess = totalcomps['guess'][k] | ||
129 | + addsmooth = 0 | ||
130 | + if smooth == 1 and k > 0: | ||
131 | + addsmooth = 1 | ||
132 | + logbleu += math.log(correct + addsmooth + sys.float_info.min)-math.log(guess + addsmooth+ sys.float_info.min) | ||
133 | + if guess == 0: | ||
134 | + all_bleus.append(-10000000) | ||
135 | + else: | ||
136 | + all_bleus.append(math.log(correct + sys.float_info.min)-math.log( guess )) | ||
137 | + | ||
138 | + logbleu /= float(n) | ||
139 | + all_bleus.insert(0, logbleu) | ||
140 | + | ||
141 | + brevPenalty = min(0,1-float(totalcomps['reflen'] + 1)/(totalcomps['testlen'] + 1)) | ||
142 | + for i in range(len(all_bleus)): | ||
143 | + if i ==0: | ||
144 | + all_bleus[i] += brevPenalty | ||
145 | + all_bleus[i] = math.exp(all_bleus[i]) | ||
146 | + return all_bleus | ||
147 | + | ||
148 | +def bleu(refs, candidate, ground=0, smooth=1): | ||
149 | + refs = cook_refs(refs) | ||
150 | + test = cook_test(candidate, refs) | ||
151 | + return score_cooked([test], ground=ground, smooth=smooth) | ||
152 | + | ||
153 | +def splitPuncts(line): | ||
154 | + return ' '.join(re.findall(r"[\w]+|[^\s\w]", line)) | ||
155 | + | ||
156 | +def computeMaps(predictions, goldfile): | ||
157 | + predictionMap = {} | ||
158 | + goldMap = {} | ||
159 | + gf = open(goldfile, 'r') | ||
160 | + | ||
161 | + for row in predictions: | ||
162 | + cols = row.strip().split('\t') | ||
163 | + if len(cols) == 1: | ||
164 | + (rid, pred) = (cols[0], '') | ||
165 | + else: | ||
166 | + (rid, pred) = (cols[0], cols[1]) | ||
167 | + predictionMap[rid] = [splitPuncts(pred.strip().lower())] | ||
168 | + | ||
169 | + for row in gf: | ||
170 | + (rid, pred) = row.split('\t') | ||
171 | + if rid in predictionMap: # Only insert if the id exists for the method | ||
172 | + if rid not in goldMap: | ||
173 | + goldMap[rid] = [] | ||
174 | + goldMap[rid].append(splitPuncts(pred.strip().lower())) | ||
175 | + | ||
176 | + sys.stderr.write('Total: ' + str(len(goldMap)) + '\n') | ||
177 | + return (goldMap, predictionMap) | ||
178 | + | ||
179 | + | ||
180 | +#m1 is the reference map | ||
181 | +#m2 is the prediction map | ||
182 | +def bleuFromMaps(m1, m2): | ||
183 | + score = [0] * 5 | ||
184 | + num = 0.0 | ||
185 | + | ||
186 | + for key in m1: | ||
187 | + if key in m2: | ||
188 | + bl = bleu(m1[key], m2[key][0]) | ||
189 | + score = [ score[i] + bl[i] for i in range(0, len(bl))] | ||
190 | + num += 1 | ||
191 | + return [s * 100.0 / num for s in score] | ||
192 | + | ||
193 | +if __name__ == '__main__': | ||
194 | + reference_file = sys.argv[1] | ||
195 | + predictions = [] | ||
196 | + for row in sys.stdin: | ||
197 | + predictions.append(row) | ||
198 | + (goldMap, predictionMap) = computeMaps(predictions, reference_file) | ||
199 | + print (bleuFromMaps(goldMap, predictionMap)[0]) | ||
200 | + |
code2nl/model.py
0 → 100644
1 | +# Copyright (c) Microsoft Corporation. | ||
2 | +# Licensed under the MIT license. | ||
3 | + | ||
4 | +import torch | ||
5 | +import torch.nn as nn | ||
6 | +import torch | ||
7 | +from torch.autograd import Variable | ||
8 | +import copy | ||
9 | +class Seq2Seq(nn.Module): | ||
10 | + """ | ||
11 | + Build Seqence-to-Sequence. | ||
12 | + | ||
13 | + Parameters: | ||
14 | + | ||
15 | + * `encoder`- encoder of seq2seq model. e.g. roberta | ||
16 | + * `decoder`- decoder of seq2seq model. e.g. transformer | ||
17 | + * `config`- configuration of encoder model. | ||
18 | + * `beam_size`- beam size for beam search. | ||
19 | + * `max_length`- max length of target for beam search. | ||
20 | + * `sos_id`- start of symbol ids in target for beam search. | ||
21 | + * `eos_id`- end of symbol ids in target for beam search. | ||
22 | + """ | ||
23 | + def __init__(self, encoder,decoder,config,beam_size=None,max_length=None,sos_id=None,eos_id=None): | ||
24 | + super(Seq2Seq, self).__init__() | ||
25 | + self.encoder = encoder | ||
26 | + self.decoder=decoder | ||
27 | + self.config=config | ||
28 | + self.register_buffer("bias", torch.tril(torch.ones(2048, 2048))) | ||
29 | + self.dense = nn.Linear(config.hidden_size, config.hidden_size) | ||
30 | + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | ||
31 | + self.lsm = nn.LogSoftmax(dim=-1) | ||
32 | + self.tie_weights() | ||
33 | + | ||
34 | + self.beam_size=beam_size | ||
35 | + self.max_length=max_length | ||
36 | + self.sos_id=sos_id | ||
37 | + self.eos_id=eos_id | ||
38 | + | ||
39 | + def _tie_or_clone_weights(self, first_module, second_module): | ||
40 | + """ Tie or clone module weights depending of weither we are using TorchScript or not | ||
41 | + """ | ||
42 | + if self.config.torchscript: | ||
43 | + first_module.weight = nn.Parameter(second_module.weight.clone()) | ||
44 | + else: | ||
45 | + first_module.weight = second_module.weight | ||
46 | + | ||
47 | + def tie_weights(self): | ||
48 | + """ Make sure we are sharing the input and output embeddings. | ||
49 | + Export to TorchScript can't handle parameter sharing so we are cloning them instead. | ||
50 | + """ | ||
51 | + self._tie_or_clone_weights(self.lm_head, | ||
52 | + self.encoder.embeddings.word_embeddings) | ||
53 | + | ||
54 | + def forward(self, source_ids=None,source_mask=None,target_ids=None,target_mask=None,args=None): | ||
55 | + outputs = self.encoder(source_ids, attention_mask=source_mask) | ||
56 | + encoder_output = outputs[0].permute([1,0,2]).contiguous() | ||
57 | + if target_ids is not None: | ||
58 | + attn_mask=-1e4 *(1-self.bias[:target_ids.shape[1],:target_ids.shape[1]]) | ||
59 | + tgt_embeddings = self.encoder.embeddings(target_ids).permute([1,0,2]).contiguous() | ||
60 | + out = self.decoder(tgt_embeddings,encoder_output,tgt_mask=attn_mask,memory_key_padding_mask=(1-source_mask).bool()) | ||
61 | + hidden_states = torch.tanh(self.dense(out)).permute([1,0,2]).contiguous() | ||
62 | + lm_logits = self.lm_head(hidden_states) | ||
63 | + # Shift so that tokens < n predict n | ||
64 | + active_loss = target_mask[..., 1:].ne(0).view(-1) == 1 | ||
65 | + shift_logits = lm_logits[..., :-1, :].contiguous() | ||
66 | + shift_labels = target_ids[..., 1:].contiguous() | ||
67 | + # Flatten the tokens | ||
68 | + loss_fct = nn.CrossEntropyLoss(ignore_index=-1) | ||
69 | + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1))[active_loss], | ||
70 | + shift_labels.view(-1)[active_loss]) | ||
71 | + | ||
72 | + outputs = loss,loss*active_loss.sum(),active_loss.sum() | ||
73 | + return outputs | ||
74 | + else: | ||
75 | + #Predict | ||
76 | + preds=[] | ||
77 | + zero=torch.cuda.LongTensor(1).fill_(0) | ||
78 | + for i in range(source_ids.shape[0]): | ||
79 | + context=encoder_output[:,i:i+1] | ||
80 | + context_mask=source_mask[i:i+1,:] | ||
81 | + beam = Beam(self.beam_size,self.sos_id,self.eos_id) | ||
82 | + input_ids=beam.getCurrentState() | ||
83 | + context=context.repeat(1, self.beam_size,1) | ||
84 | + context_mask=context_mask.repeat(self.beam_size,1) | ||
85 | + for _ in range(self.max_length): | ||
86 | + if beam.done(): | ||
87 | + break | ||
88 | + attn_mask=-1e4 *(1-self.bias[:input_ids.shape[1],:input_ids.shape[1]]) | ||
89 | + tgt_embeddings = self.encoder.embeddings(input_ids).permute([1,0,2]).contiguous() | ||
90 | + out = self.decoder(tgt_embeddings,context,tgt_mask=attn_mask,memory_key_padding_mask=(1-context_mask).bool()) | ||
91 | + out = torch.tanh(self.dense(out)) | ||
92 | + hidden_states=out.permute([1,0,2]).contiguous()[:,-1,:] | ||
93 | + out = self.lsm(self.lm_head(hidden_states)).data | ||
94 | + beam.advance(out) | ||
95 | + input_ids.data.copy_(input_ids.data.index_select(0, beam.getCurrentOrigin())) | ||
96 | + input_ids=torch.cat((input_ids,beam.getCurrentState()),-1) | ||
97 | + hyp= beam.getHyp(beam.getFinal()) | ||
98 | + pred=beam.buildTargetTokens(hyp)[:self.beam_size] | ||
99 | + pred=[torch.cat([x.view(-1) for x in p]+[zero]*(self.max_length-len(p))).view(1,-1) for p in pred] | ||
100 | + preds.append(torch.cat(pred,0).unsqueeze(0)) | ||
101 | + | ||
102 | + preds=torch.cat(preds,0) | ||
103 | + return preds | ||
104 | + | ||
105 | + | ||
106 | + | ||
107 | +class Beam(object): | ||
108 | + def __init__(self, size,sos,eos): | ||
109 | + self.size = size | ||
110 | + self.tt = torch.cuda | ||
111 | + # The score for each translation on the beam. | ||
112 | + self.scores = self.tt.FloatTensor(size).zero_() | ||
113 | + # The backpointers at each time-step. | ||
114 | + self.prevKs = [] | ||
115 | + # The outputs at each time-step. | ||
116 | + self.nextYs = [self.tt.LongTensor(size) | ||
117 | + .fill_(0)] | ||
118 | + self.nextYs[0][0] = sos | ||
119 | + # Has EOS topped the beam yet. | ||
120 | + self._eos = eos | ||
121 | + self.eosTop = False | ||
122 | + # Time and k pair for finished. | ||
123 | + self.finished = [] | ||
124 | + | ||
125 | + def getCurrentState(self): | ||
126 | + "Get the outputs for the current timestep." | ||
127 | + batch = self.tt.LongTensor(self.nextYs[-1]).view(-1, 1) | ||
128 | + return batch | ||
129 | + | ||
130 | + def getCurrentOrigin(self): | ||
131 | + "Get the backpointers for the current timestep." | ||
132 | + return self.prevKs[-1] | ||
133 | + | ||
134 | + def advance(self, wordLk): | ||
135 | + """ | ||
136 | + Given prob over words for every last beam `wordLk` and attention | ||
137 | + `attnOut`: Compute and update the beam search. | ||
138 | + | ||
139 | + Parameters: | ||
140 | + | ||
141 | + * `wordLk`- probs of advancing from the last step (K x words) | ||
142 | + * `attnOut`- attention at the last step | ||
143 | + | ||
144 | + Returns: True if beam search is complete. | ||
145 | + """ | ||
146 | + numWords = wordLk.size(1) | ||
147 | + | ||
148 | + # Sum the previous scores. | ||
149 | + if len(self.prevKs) > 0: | ||
150 | + beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk) | ||
151 | + | ||
152 | + # Don't let EOS have children. | ||
153 | + for i in range(self.nextYs[-1].size(0)): | ||
154 | + if self.nextYs[-1][i] == self._eos: | ||
155 | + beamLk[i] = -1e20 | ||
156 | + else: | ||
157 | + beamLk = wordLk[0] | ||
158 | + flatBeamLk = beamLk.view(-1) | ||
159 | + bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True) | ||
160 | + | ||
161 | + self.scores = bestScores | ||
162 | + | ||
163 | + # bestScoresId is flattened beam x word array, so calculate which | ||
164 | + # word and beam each score came from | ||
165 | + prevK = bestScoresId / numWords | ||
166 | + self.prevKs.append(prevK) | ||
167 | + self.nextYs.append((bestScoresId - prevK * numWords)) | ||
168 | + | ||
169 | + | ||
170 | + for i in range(self.nextYs[-1].size(0)): | ||
171 | + if self.nextYs[-1][i] == self._eos: | ||
172 | + s = self.scores[i] | ||
173 | + self.finished.append((s, len(self.nextYs) - 1, i)) | ||
174 | + | ||
175 | + # End condition is when top-of-beam is EOS and no global score. | ||
176 | + if self.nextYs[-1][0] == self._eos: | ||
177 | + self.eosTop = True | ||
178 | + | ||
179 | + def done(self): | ||
180 | + return self.eosTop and len(self.finished) >=self.size | ||
181 | + | ||
182 | + def getFinal(self): | ||
183 | + if len(self.finished) == 0: | ||
184 | + self.finished.append((self.scores[0], len(self.nextYs) - 1, 0)) | ||
185 | + self.finished.sort(key=lambda a: -a[0]) | ||
186 | + if len(self.finished) != self.size: | ||
187 | + unfinished=[] | ||
188 | + for i in range(self.nextYs[-1].size(0)): | ||
189 | + if self.nextYs[-1][i] != self._eos: | ||
190 | + s = self.scores[i] | ||
191 | + unfinished.append((s, len(self.nextYs) - 1, i)) | ||
192 | + unfinished.sort(key=lambda a: -a[0]) | ||
193 | + self.finished+=unfinished[:self.size-len(self.finished)] | ||
194 | + return self.finished[:self.size] | ||
195 | + | ||
196 | + def getHyp(self, beam_res): | ||
197 | + """ | ||
198 | + Walk back to construct the full hypothesis. | ||
199 | + """ | ||
200 | + hyps=[] | ||
201 | + for _,timestep, k in beam_res: | ||
202 | + hyp = [] | ||
203 | + for j in range(len(self.prevKs[:timestep]) - 1, -1, -1): | ||
204 | + hyp.append(self.nextYs[j+1][k]) | ||
205 | + k = self.prevKs[j][k] | ||
206 | + hyps.append(hyp[::-1]) | ||
207 | + return hyps | ||
208 | + | ||
209 | + def buildTargetTokens(self, preds): | ||
210 | + sentence=[] | ||
211 | + for pred in preds: | ||
212 | + tokens = [] | ||
213 | + for tok in pred: | ||
214 | + if tok==self._eos: | ||
215 | + break | ||
216 | + tokens.append(tok) | ||
217 | + sentence.append(tokens) | ||
218 | + return sentence | ||
219 | + |
code2nl/run.py
0 → 100644
This diff is collapsed. Click to expand it.
-
Please register or login to post a comment