graykode

(add) code2nl finetuning code

1 +#!/usr/bin/python
2 +
3 +'''
4 +This script was adapted from the original version by hieuhoang1972 which is part of MOSES.
5 +'''
6 +
7 +# $Id: bleu.py 1307 2007-03-14 22:22:36Z hieuhoang1972 $
8 +
9 +'''Provides:
10 +
11 +cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test().
12 +cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked().
13 +score_cooked(alltest, n=4): Score a list of cooked test sentences.
14 +
15 +score_set(s, testid, refids, n=4): Interface with dataset.py; calculate BLEU score of testid against refids.
16 +
17 +The reason for breaking the BLEU computation into three phases cook_refs(), cook_test(), and score_cooked() is to allow the caller to calculate BLEU scores for multiple test sets as efficiently as possible.
18 +'''
19 +
20 +import sys, math, re, xml.sax.saxutils
21 +import subprocess
22 +import os
23 +
24 +# Added to bypass NIST-style pre-processing of hyp and ref files -- wade
25 +nonorm = 0
26 +
27 +preserve_case = False
28 +eff_ref_len = "shortest"
29 +
30 +normalize1 = [
31 + ('<skipped>', ''), # strip "skipped" tags
32 + (r'-\n', ''), # strip end-of-line hyphenation and join lines
33 + (r'\n', ' '), # join lines
34 +# (r'(\d)\s+(?=\d)', r'\1'), # join digits
35 +]
36 +normalize1 = [(re.compile(pattern), replace) for (pattern, replace) in normalize1]
37 +
38 +normalize2 = [
39 + (r'([\{-\~\[-\` -\&\(-\+\:-\@\/])',r' \1 '), # tokenize punctuation. apostrophe is missing
40 + (r'([^0-9])([\.,])',r'\1 \2 '), # tokenize period and comma unless preceded by a digit
41 + (r'([\.,])([^0-9])',r' \1 \2'), # tokenize period and comma unless followed by a digit
42 + (r'([0-9])(-)',r'\1 \2 ') # tokenize dash when preceded by a digit
43 +]
44 +normalize2 = [(re.compile(pattern), replace) for (pattern, replace) in normalize2]
45 +
46 +def normalize(s):
47 + '''Normalize and tokenize text. This is lifted from NIST mteval-v11a.pl.'''
48 + # Added to bypass NIST-style pre-processing of hyp and ref files -- wade
49 + if (nonorm):
50 + return s.split()
51 + if type(s) is not str:
52 + s = " ".join(s)
53 + # language-independent part:
54 + for (pattern, replace) in normalize1:
55 + s = re.sub(pattern, replace, s)
56 + s = xml.sax.saxutils.unescape(s, {'&quot;':'"'})
57 + # language-dependent part (assuming Western languages):
58 + s = " %s " % s
59 + if not preserve_case:
60 + s = s.lower() # this might not be identical to the original
61 + for (pattern, replace) in normalize2:
62 + s = re.sub(pattern, replace, s)
63 + return s.split()
64 +
65 +def count_ngrams(words, n=4):
66 + counts = {}
67 + for k in range(1,n+1):
68 + for i in range(len(words)-k+1):
69 + ngram = tuple(words[i:i+k])
70 + counts[ngram] = counts.get(ngram, 0)+1
71 + return counts
72 +
73 +def cook_refs(refs, n=4):
74 + '''Takes a list of reference sentences for a single segment
75 + and returns an object that encapsulates everything that BLEU
76 + needs to know about them.'''
77 +
78 + refs = [normalize(ref) for ref in refs]
79 + maxcounts = {}
80 + for ref in refs:
81 + counts = count_ngrams(ref, n)
82 + for (ngram,count) in counts.items():
83 + maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
84 + return ([len(ref) for ref in refs], maxcounts)
85 +
86 +def cook_test(test, item, n=4):
87 + '''Takes a test sentence and returns an object that
88 + encapsulates everything that BLEU needs to know about it.'''
89 + (reflens, refmaxcounts)=item
90 + test = normalize(test)
91 + result = {}
92 + result["testlen"] = len(test)
93 +
94 + # Calculate effective reference sentence length.
95 +
96 + if eff_ref_len == "shortest":
97 + result["reflen"] = min(reflens)
98 + elif eff_ref_len == "average":
99 + result["reflen"] = float(sum(reflens))/len(reflens)
100 + elif eff_ref_len == "closest":
101 + min_diff = None
102 + for reflen in reflens:
103 + if min_diff is None or abs(reflen-len(test)) < min_diff:
104 + min_diff = abs(reflen-len(test))
105 + result['reflen'] = reflen
106 +
107 + result["guess"] = [max(len(test)-k+1,0) for k in range(1,n+1)]
108 +
109 + result['correct'] = [0]*n
110 + counts = count_ngrams(test, n)
111 + for (ngram, count) in counts.items():
112 + result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count)
113 +
114 + return result
115 +
116 +def score_cooked(allcomps, n=4, ground=0, smooth=1):
117 + totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n}
118 + for comps in allcomps:
119 + for key in ['testlen','reflen']:
120 + totalcomps[key] += comps[key]
121 + for key in ['guess','correct']:
122 + for k in range(n):
123 + totalcomps[key][k] += comps[key][k]
124 + logbleu = 0.0
125 + all_bleus = []
126 + for k in range(n):
127 + correct = totalcomps['correct'][k]
128 + guess = totalcomps['guess'][k]
129 + addsmooth = 0
130 + if smooth == 1 and k > 0:
131 + addsmooth = 1
132 + logbleu += math.log(correct + addsmooth + sys.float_info.min)-math.log(guess + addsmooth+ sys.float_info.min)
133 + if guess == 0:
134 + all_bleus.append(-10000000)
135 + else:
136 + all_bleus.append(math.log(correct + sys.float_info.min)-math.log( guess ))
137 +
138 + logbleu /= float(n)
139 + all_bleus.insert(0, logbleu)
140 +
141 + brevPenalty = min(0,1-float(totalcomps['reflen'] + 1)/(totalcomps['testlen'] + 1))
142 + for i in range(len(all_bleus)):
143 + if i ==0:
144 + all_bleus[i] += brevPenalty
145 + all_bleus[i] = math.exp(all_bleus[i])
146 + return all_bleus
147 +
148 +def bleu(refs, candidate, ground=0, smooth=1):
149 + refs = cook_refs(refs)
150 + test = cook_test(candidate, refs)
151 + return score_cooked([test], ground=ground, smooth=smooth)
152 +
153 +def splitPuncts(line):
154 + return ' '.join(re.findall(r"[\w]+|[^\s\w]", line))
155 +
156 +def computeMaps(predictions, goldfile):
157 + predictionMap = {}
158 + goldMap = {}
159 + gf = open(goldfile, 'r')
160 +
161 + for row in predictions:
162 + cols = row.strip().split('\t')
163 + if len(cols) == 1:
164 + (rid, pred) = (cols[0], '')
165 + else:
166 + (rid, pred) = (cols[0], cols[1])
167 + predictionMap[rid] = [splitPuncts(pred.strip().lower())]
168 +
169 + for row in gf:
170 + (rid, pred) = row.split('\t')
171 + if rid in predictionMap: # Only insert if the id exists for the method
172 + if rid not in goldMap:
173 + goldMap[rid] = []
174 + goldMap[rid].append(splitPuncts(pred.strip().lower()))
175 +
176 + sys.stderr.write('Total: ' + str(len(goldMap)) + '\n')
177 + return (goldMap, predictionMap)
178 +
179 +
180 +#m1 is the reference map
181 +#m2 is the prediction map
182 +def bleuFromMaps(m1, m2):
183 + score = [0] * 5
184 + num = 0.0
185 +
186 + for key in m1:
187 + if key in m2:
188 + bl = bleu(m1[key], m2[key][0])
189 + score = [ score[i] + bl[i] for i in range(0, len(bl))]
190 + num += 1
191 + return [s * 100.0 / num for s in score]
192 +
193 +if __name__ == '__main__':
194 + reference_file = sys.argv[1]
195 + predictions = []
196 + for row in sys.stdin:
197 + predictions.append(row)
198 + (goldMap, predictionMap) = computeMaps(predictions, reference_file)
199 + print (bleuFromMaps(goldMap, predictionMap)[0])
200 +
1 +# Copyright (c) Microsoft Corporation.
2 +# Licensed under the MIT license.
3 +
4 +import torch
5 +import torch.nn as nn
6 +import torch
7 +from torch.autograd import Variable
8 +import copy
9 +class Seq2Seq(nn.Module):
10 + """
11 + Build Seqence-to-Sequence.
12 +
13 + Parameters:
14 +
15 + * `encoder`- encoder of seq2seq model. e.g. roberta
16 + * `decoder`- decoder of seq2seq model. e.g. transformer
17 + * `config`- configuration of encoder model.
18 + * `beam_size`- beam size for beam search.
19 + * `max_length`- max length of target for beam search.
20 + * `sos_id`- start of symbol ids in target for beam search.
21 + * `eos_id`- end of symbol ids in target for beam search.
22 + """
23 + def __init__(self, encoder,decoder,config,beam_size=None,max_length=None,sos_id=None,eos_id=None):
24 + super(Seq2Seq, self).__init__()
25 + self.encoder = encoder
26 + self.decoder=decoder
27 + self.config=config
28 + self.register_buffer("bias", torch.tril(torch.ones(2048, 2048)))
29 + self.dense = nn.Linear(config.hidden_size, config.hidden_size)
30 + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
31 + self.lsm = nn.LogSoftmax(dim=-1)
32 + self.tie_weights()
33 +
34 + self.beam_size=beam_size
35 + self.max_length=max_length
36 + self.sos_id=sos_id
37 + self.eos_id=eos_id
38 +
39 + def _tie_or_clone_weights(self, first_module, second_module):
40 + """ Tie or clone module weights depending of weither we are using TorchScript or not
41 + """
42 + if self.config.torchscript:
43 + first_module.weight = nn.Parameter(second_module.weight.clone())
44 + else:
45 + first_module.weight = second_module.weight
46 +
47 + def tie_weights(self):
48 + """ Make sure we are sharing the input and output embeddings.
49 + Export to TorchScript can't handle parameter sharing so we are cloning them instead.
50 + """
51 + self._tie_or_clone_weights(self.lm_head,
52 + self.encoder.embeddings.word_embeddings)
53 +
54 + def forward(self, source_ids=None,source_mask=None,target_ids=None,target_mask=None,args=None):
55 + outputs = self.encoder(source_ids, attention_mask=source_mask)
56 + encoder_output = outputs[0].permute([1,0,2]).contiguous()
57 + if target_ids is not None:
58 + attn_mask=-1e4 *(1-self.bias[:target_ids.shape[1],:target_ids.shape[1]])
59 + tgt_embeddings = self.encoder.embeddings(target_ids).permute([1,0,2]).contiguous()
60 + out = self.decoder(tgt_embeddings,encoder_output,tgt_mask=attn_mask,memory_key_padding_mask=(1-source_mask).bool())
61 + hidden_states = torch.tanh(self.dense(out)).permute([1,0,2]).contiguous()
62 + lm_logits = self.lm_head(hidden_states)
63 + # Shift so that tokens < n predict n
64 + active_loss = target_mask[..., 1:].ne(0).view(-1) == 1
65 + shift_logits = lm_logits[..., :-1, :].contiguous()
66 + shift_labels = target_ids[..., 1:].contiguous()
67 + # Flatten the tokens
68 + loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
69 + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1))[active_loss],
70 + shift_labels.view(-1)[active_loss])
71 +
72 + outputs = loss,loss*active_loss.sum(),active_loss.sum()
73 + return outputs
74 + else:
75 + #Predict
76 + preds=[]
77 + zero=torch.cuda.LongTensor(1).fill_(0)
78 + for i in range(source_ids.shape[0]):
79 + context=encoder_output[:,i:i+1]
80 + context_mask=source_mask[i:i+1,:]
81 + beam = Beam(self.beam_size,self.sos_id,self.eos_id)
82 + input_ids=beam.getCurrentState()
83 + context=context.repeat(1, self.beam_size,1)
84 + context_mask=context_mask.repeat(self.beam_size,1)
85 + for _ in range(self.max_length):
86 + if beam.done():
87 + break
88 + attn_mask=-1e4 *(1-self.bias[:input_ids.shape[1],:input_ids.shape[1]])
89 + tgt_embeddings = self.encoder.embeddings(input_ids).permute([1,0,2]).contiguous()
90 + out = self.decoder(tgt_embeddings,context,tgt_mask=attn_mask,memory_key_padding_mask=(1-context_mask).bool())
91 + out = torch.tanh(self.dense(out))
92 + hidden_states=out.permute([1,0,2]).contiguous()[:,-1,:]
93 + out = self.lsm(self.lm_head(hidden_states)).data
94 + beam.advance(out)
95 + input_ids.data.copy_(input_ids.data.index_select(0, beam.getCurrentOrigin()))
96 + input_ids=torch.cat((input_ids,beam.getCurrentState()),-1)
97 + hyp= beam.getHyp(beam.getFinal())
98 + pred=beam.buildTargetTokens(hyp)[:self.beam_size]
99 + pred=[torch.cat([x.view(-1) for x in p]+[zero]*(self.max_length-len(p))).view(1,-1) for p in pred]
100 + preds.append(torch.cat(pred,0).unsqueeze(0))
101 +
102 + preds=torch.cat(preds,0)
103 + return preds
104 +
105 +
106 +
107 +class Beam(object):
108 + def __init__(self, size,sos,eos):
109 + self.size = size
110 + self.tt = torch.cuda
111 + # The score for each translation on the beam.
112 + self.scores = self.tt.FloatTensor(size).zero_()
113 + # The backpointers at each time-step.
114 + self.prevKs = []
115 + # The outputs at each time-step.
116 + self.nextYs = [self.tt.LongTensor(size)
117 + .fill_(0)]
118 + self.nextYs[0][0] = sos
119 + # Has EOS topped the beam yet.
120 + self._eos = eos
121 + self.eosTop = False
122 + # Time and k pair for finished.
123 + self.finished = []
124 +
125 + def getCurrentState(self):
126 + "Get the outputs for the current timestep."
127 + batch = self.tt.LongTensor(self.nextYs[-1]).view(-1, 1)
128 + return batch
129 +
130 + def getCurrentOrigin(self):
131 + "Get the backpointers for the current timestep."
132 + return self.prevKs[-1]
133 +
134 + def advance(self, wordLk):
135 + """
136 + Given prob over words for every last beam `wordLk` and attention
137 + `attnOut`: Compute and update the beam search.
138 +
139 + Parameters:
140 +
141 + * `wordLk`- probs of advancing from the last step (K x words)
142 + * `attnOut`- attention at the last step
143 +
144 + Returns: True if beam search is complete.
145 + """
146 + numWords = wordLk.size(1)
147 +
148 + # Sum the previous scores.
149 + if len(self.prevKs) > 0:
150 + beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk)
151 +
152 + # Don't let EOS have children.
153 + for i in range(self.nextYs[-1].size(0)):
154 + if self.nextYs[-1][i] == self._eos:
155 + beamLk[i] = -1e20
156 + else:
157 + beamLk = wordLk[0]
158 + flatBeamLk = beamLk.view(-1)
159 + bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True)
160 +
161 + self.scores = bestScores
162 +
163 + # bestScoresId is flattened beam x word array, so calculate which
164 + # word and beam each score came from
165 + prevK = bestScoresId / numWords
166 + self.prevKs.append(prevK)
167 + self.nextYs.append((bestScoresId - prevK * numWords))
168 +
169 +
170 + for i in range(self.nextYs[-1].size(0)):
171 + if self.nextYs[-1][i] == self._eos:
172 + s = self.scores[i]
173 + self.finished.append((s, len(self.nextYs) - 1, i))
174 +
175 + # End condition is when top-of-beam is EOS and no global score.
176 + if self.nextYs[-1][0] == self._eos:
177 + self.eosTop = True
178 +
179 + def done(self):
180 + return self.eosTop and len(self.finished) >=self.size
181 +
182 + def getFinal(self):
183 + if len(self.finished) == 0:
184 + self.finished.append((self.scores[0], len(self.nextYs) - 1, 0))
185 + self.finished.sort(key=lambda a: -a[0])
186 + if len(self.finished) != self.size:
187 + unfinished=[]
188 + for i in range(self.nextYs[-1].size(0)):
189 + if self.nextYs[-1][i] != self._eos:
190 + s = self.scores[i]
191 + unfinished.append((s, len(self.nextYs) - 1, i))
192 + unfinished.sort(key=lambda a: -a[0])
193 + self.finished+=unfinished[:self.size-len(self.finished)]
194 + return self.finished[:self.size]
195 +
196 + def getHyp(self, beam_res):
197 + """
198 + Walk back to construct the full hypothesis.
199 + """
200 + hyps=[]
201 + for _,timestep, k in beam_res:
202 + hyp = []
203 + for j in range(len(self.prevKs[:timestep]) - 1, -1, -1):
204 + hyp.append(self.nextYs[j+1][k])
205 + k = self.prevKs[j][k]
206 + hyps.append(hyp[::-1])
207 + return hyps
208 +
209 + def buildTargetTokens(self, preds):
210 + sentence=[]
211 + for pred in preds:
212 + tokens = []
213 + for tok in pred:
214 + if tok==self._eos:
215 + break
216 + tokens.append(tok)
217 + sentence.append(tokens)
218 + return sentence
219 +
This diff is collapsed. Click to expand it.