graykode

(add) code2nl finetuning code

1 +#!/usr/bin/python
2 +
3 +'''
4 +This script was adapted from the original version by hieuhoang1972 which is part of MOSES.
5 +'''
6 +
7 +# $Id: bleu.py 1307 2007-03-14 22:22:36Z hieuhoang1972 $
8 +
9 +'''Provides:
10 +
11 +cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test().
12 +cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked().
13 +score_cooked(alltest, n=4): Score a list of cooked test sentences.
14 +
15 +score_set(s, testid, refids, n=4): Interface with dataset.py; calculate BLEU score of testid against refids.
16 +
17 +The reason for breaking the BLEU computation into three phases cook_refs(), cook_test(), and score_cooked() is to allow the caller to calculate BLEU scores for multiple test sets as efficiently as possible.
18 +'''
19 +
20 +import sys, math, re, xml.sax.saxutils
21 +import subprocess
22 +import os
23 +
24 +# Added to bypass NIST-style pre-processing of hyp and ref files -- wade
25 +nonorm = 0
26 +
27 +preserve_case = False
28 +eff_ref_len = "shortest"
29 +
30 +normalize1 = [
31 + ('<skipped>', ''), # strip "skipped" tags
32 + (r'-\n', ''), # strip end-of-line hyphenation and join lines
33 + (r'\n', ' '), # join lines
34 +# (r'(\d)\s+(?=\d)', r'\1'), # join digits
35 +]
36 +normalize1 = [(re.compile(pattern), replace) for (pattern, replace) in normalize1]
37 +
38 +normalize2 = [
39 + (r'([\{-\~\[-\` -\&\(-\+\:-\@\/])',r' \1 '), # tokenize punctuation. apostrophe is missing
40 + (r'([^0-9])([\.,])',r'\1 \2 '), # tokenize period and comma unless preceded by a digit
41 + (r'([\.,])([^0-9])',r' \1 \2'), # tokenize period and comma unless followed by a digit
42 + (r'([0-9])(-)',r'\1 \2 ') # tokenize dash when preceded by a digit
43 +]
44 +normalize2 = [(re.compile(pattern), replace) for (pattern, replace) in normalize2]
45 +
46 +def normalize(s):
47 + '''Normalize and tokenize text. This is lifted from NIST mteval-v11a.pl.'''
48 + # Added to bypass NIST-style pre-processing of hyp and ref files -- wade
49 + if (nonorm):
50 + return s.split()
51 + if type(s) is not str:
52 + s = " ".join(s)
53 + # language-independent part:
54 + for (pattern, replace) in normalize1:
55 + s = re.sub(pattern, replace, s)
56 + s = xml.sax.saxutils.unescape(s, {'&quot;':'"'})
57 + # language-dependent part (assuming Western languages):
58 + s = " %s " % s
59 + if not preserve_case:
60 + s = s.lower() # this might not be identical to the original
61 + for (pattern, replace) in normalize2:
62 + s = re.sub(pattern, replace, s)
63 + return s.split()
64 +
65 +def count_ngrams(words, n=4):
66 + counts = {}
67 + for k in range(1,n+1):
68 + for i in range(len(words)-k+1):
69 + ngram = tuple(words[i:i+k])
70 + counts[ngram] = counts.get(ngram, 0)+1
71 + return counts
72 +
73 +def cook_refs(refs, n=4):
74 + '''Takes a list of reference sentences for a single segment
75 + and returns an object that encapsulates everything that BLEU
76 + needs to know about them.'''
77 +
78 + refs = [normalize(ref) for ref in refs]
79 + maxcounts = {}
80 + for ref in refs:
81 + counts = count_ngrams(ref, n)
82 + for (ngram,count) in counts.items():
83 + maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
84 + return ([len(ref) for ref in refs], maxcounts)
85 +
86 +def cook_test(test, item, n=4):
87 + '''Takes a test sentence and returns an object that
88 + encapsulates everything that BLEU needs to know about it.'''
89 + (reflens, refmaxcounts)=item
90 + test = normalize(test)
91 + result = {}
92 + result["testlen"] = len(test)
93 +
94 + # Calculate effective reference sentence length.
95 +
96 + if eff_ref_len == "shortest":
97 + result["reflen"] = min(reflens)
98 + elif eff_ref_len == "average":
99 + result["reflen"] = float(sum(reflens))/len(reflens)
100 + elif eff_ref_len == "closest":
101 + min_diff = None
102 + for reflen in reflens:
103 + if min_diff is None or abs(reflen-len(test)) < min_diff:
104 + min_diff = abs(reflen-len(test))
105 + result['reflen'] = reflen
106 +
107 + result["guess"] = [max(len(test)-k+1,0) for k in range(1,n+1)]
108 +
109 + result['correct'] = [0]*n
110 + counts = count_ngrams(test, n)
111 + for (ngram, count) in counts.items():
112 + result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count)
113 +
114 + return result
115 +
116 +def score_cooked(allcomps, n=4, ground=0, smooth=1):
117 + totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n}
118 + for comps in allcomps:
119 + for key in ['testlen','reflen']:
120 + totalcomps[key] += comps[key]
121 + for key in ['guess','correct']:
122 + for k in range(n):
123 + totalcomps[key][k] += comps[key][k]
124 + logbleu = 0.0
125 + all_bleus = []
126 + for k in range(n):
127 + correct = totalcomps['correct'][k]
128 + guess = totalcomps['guess'][k]
129 + addsmooth = 0
130 + if smooth == 1 and k > 0:
131 + addsmooth = 1
132 + logbleu += math.log(correct + addsmooth + sys.float_info.min)-math.log(guess + addsmooth+ sys.float_info.min)
133 + if guess == 0:
134 + all_bleus.append(-10000000)
135 + else:
136 + all_bleus.append(math.log(correct + sys.float_info.min)-math.log( guess ))
137 +
138 + logbleu /= float(n)
139 + all_bleus.insert(0, logbleu)
140 +
141 + brevPenalty = min(0,1-float(totalcomps['reflen'] + 1)/(totalcomps['testlen'] + 1))
142 + for i in range(len(all_bleus)):
143 + if i ==0:
144 + all_bleus[i] += brevPenalty
145 + all_bleus[i] = math.exp(all_bleus[i])
146 + return all_bleus
147 +
148 +def bleu(refs, candidate, ground=0, smooth=1):
149 + refs = cook_refs(refs)
150 + test = cook_test(candidate, refs)
151 + return score_cooked([test], ground=ground, smooth=smooth)
152 +
153 +def splitPuncts(line):
154 + return ' '.join(re.findall(r"[\w]+|[^\s\w]", line))
155 +
156 +def computeMaps(predictions, goldfile):
157 + predictionMap = {}
158 + goldMap = {}
159 + gf = open(goldfile, 'r')
160 +
161 + for row in predictions:
162 + cols = row.strip().split('\t')
163 + if len(cols) == 1:
164 + (rid, pred) = (cols[0], '')
165 + else:
166 + (rid, pred) = (cols[0], cols[1])
167 + predictionMap[rid] = [splitPuncts(pred.strip().lower())]
168 +
169 + for row in gf:
170 + (rid, pred) = row.split('\t')
171 + if rid in predictionMap: # Only insert if the id exists for the method
172 + if rid not in goldMap:
173 + goldMap[rid] = []
174 + goldMap[rid].append(splitPuncts(pred.strip().lower()))
175 +
176 + sys.stderr.write('Total: ' + str(len(goldMap)) + '\n')
177 + return (goldMap, predictionMap)
178 +
179 +
180 +#m1 is the reference map
181 +#m2 is the prediction map
182 +def bleuFromMaps(m1, m2):
183 + score = [0] * 5
184 + num = 0.0
185 +
186 + for key in m1:
187 + if key in m2:
188 + bl = bleu(m1[key], m2[key][0])
189 + score = [ score[i] + bl[i] for i in range(0, len(bl))]
190 + num += 1
191 + return [s * 100.0 / num for s in score]
192 +
193 +if __name__ == '__main__':
194 + reference_file = sys.argv[1]
195 + predictions = []
196 + for row in sys.stdin:
197 + predictions.append(row)
198 + (goldMap, predictionMap) = computeMaps(predictions, reference_file)
199 + print (bleuFromMaps(goldMap, predictionMap)[0])
200 +
1 +# Copyright (c) Microsoft Corporation.
2 +# Licensed under the MIT license.
3 +
4 +import torch
5 +import torch.nn as nn
6 +import torch
7 +from torch.autograd import Variable
8 +import copy
9 +class Seq2Seq(nn.Module):
10 + """
11 + Build Seqence-to-Sequence.
12 +
13 + Parameters:
14 +
15 + * `encoder`- encoder of seq2seq model. e.g. roberta
16 + * `decoder`- decoder of seq2seq model. e.g. transformer
17 + * `config`- configuration of encoder model.
18 + * `beam_size`- beam size for beam search.
19 + * `max_length`- max length of target for beam search.
20 + * `sos_id`- start of symbol ids in target for beam search.
21 + * `eos_id`- end of symbol ids in target for beam search.
22 + """
23 + def __init__(self, encoder,decoder,config,beam_size=None,max_length=None,sos_id=None,eos_id=None):
24 + super(Seq2Seq, self).__init__()
25 + self.encoder = encoder
26 + self.decoder=decoder
27 + self.config=config
28 + self.register_buffer("bias", torch.tril(torch.ones(2048, 2048)))
29 + self.dense = nn.Linear(config.hidden_size, config.hidden_size)
30 + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
31 + self.lsm = nn.LogSoftmax(dim=-1)
32 + self.tie_weights()
33 +
34 + self.beam_size=beam_size
35 + self.max_length=max_length
36 + self.sos_id=sos_id
37 + self.eos_id=eos_id
38 +
39 + def _tie_or_clone_weights(self, first_module, second_module):
40 + """ Tie or clone module weights depending of weither we are using TorchScript or not
41 + """
42 + if self.config.torchscript:
43 + first_module.weight = nn.Parameter(second_module.weight.clone())
44 + else:
45 + first_module.weight = second_module.weight
46 +
47 + def tie_weights(self):
48 + """ Make sure we are sharing the input and output embeddings.
49 + Export to TorchScript can't handle parameter sharing so we are cloning them instead.
50 + """
51 + self._tie_or_clone_weights(self.lm_head,
52 + self.encoder.embeddings.word_embeddings)
53 +
54 + def forward(self, source_ids=None,source_mask=None,target_ids=None,target_mask=None,args=None):
55 + outputs = self.encoder(source_ids, attention_mask=source_mask)
56 + encoder_output = outputs[0].permute([1,0,2]).contiguous()
57 + if target_ids is not None:
58 + attn_mask=-1e4 *(1-self.bias[:target_ids.shape[1],:target_ids.shape[1]])
59 + tgt_embeddings = self.encoder.embeddings(target_ids).permute([1,0,2]).contiguous()
60 + out = self.decoder(tgt_embeddings,encoder_output,tgt_mask=attn_mask,memory_key_padding_mask=(1-source_mask).bool())
61 + hidden_states = torch.tanh(self.dense(out)).permute([1,0,2]).contiguous()
62 + lm_logits = self.lm_head(hidden_states)
63 + # Shift so that tokens < n predict n
64 + active_loss = target_mask[..., 1:].ne(0).view(-1) == 1
65 + shift_logits = lm_logits[..., :-1, :].contiguous()
66 + shift_labels = target_ids[..., 1:].contiguous()
67 + # Flatten the tokens
68 + loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
69 + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1))[active_loss],
70 + shift_labels.view(-1)[active_loss])
71 +
72 + outputs = loss,loss*active_loss.sum(),active_loss.sum()
73 + return outputs
74 + else:
75 + #Predict
76 + preds=[]
77 + zero=torch.cuda.LongTensor(1).fill_(0)
78 + for i in range(source_ids.shape[0]):
79 + context=encoder_output[:,i:i+1]
80 + context_mask=source_mask[i:i+1,:]
81 + beam = Beam(self.beam_size,self.sos_id,self.eos_id)
82 + input_ids=beam.getCurrentState()
83 + context=context.repeat(1, self.beam_size,1)
84 + context_mask=context_mask.repeat(self.beam_size,1)
85 + for _ in range(self.max_length):
86 + if beam.done():
87 + break
88 + attn_mask=-1e4 *(1-self.bias[:input_ids.shape[1],:input_ids.shape[1]])
89 + tgt_embeddings = self.encoder.embeddings(input_ids).permute([1,0,2]).contiguous()
90 + out = self.decoder(tgt_embeddings,context,tgt_mask=attn_mask,memory_key_padding_mask=(1-context_mask).bool())
91 + out = torch.tanh(self.dense(out))
92 + hidden_states=out.permute([1,0,2]).contiguous()[:,-1,:]
93 + out = self.lsm(self.lm_head(hidden_states)).data
94 + beam.advance(out)
95 + input_ids.data.copy_(input_ids.data.index_select(0, beam.getCurrentOrigin()))
96 + input_ids=torch.cat((input_ids,beam.getCurrentState()),-1)
97 + hyp= beam.getHyp(beam.getFinal())
98 + pred=beam.buildTargetTokens(hyp)[:self.beam_size]
99 + pred=[torch.cat([x.view(-1) for x in p]+[zero]*(self.max_length-len(p))).view(1,-1) for p in pred]
100 + preds.append(torch.cat(pred,0).unsqueeze(0))
101 +
102 + preds=torch.cat(preds,0)
103 + return preds
104 +
105 +
106 +
107 +class Beam(object):
108 + def __init__(self, size,sos,eos):
109 + self.size = size
110 + self.tt = torch.cuda
111 + # The score for each translation on the beam.
112 + self.scores = self.tt.FloatTensor(size).zero_()
113 + # The backpointers at each time-step.
114 + self.prevKs = []
115 + # The outputs at each time-step.
116 + self.nextYs = [self.tt.LongTensor(size)
117 + .fill_(0)]
118 + self.nextYs[0][0] = sos
119 + # Has EOS topped the beam yet.
120 + self._eos = eos
121 + self.eosTop = False
122 + # Time and k pair for finished.
123 + self.finished = []
124 +
125 + def getCurrentState(self):
126 + "Get the outputs for the current timestep."
127 + batch = self.tt.LongTensor(self.nextYs[-1]).view(-1, 1)
128 + return batch
129 +
130 + def getCurrentOrigin(self):
131 + "Get the backpointers for the current timestep."
132 + return self.prevKs[-1]
133 +
134 + def advance(self, wordLk):
135 + """
136 + Given prob over words for every last beam `wordLk` and attention
137 + `attnOut`: Compute and update the beam search.
138 +
139 + Parameters:
140 +
141 + * `wordLk`- probs of advancing from the last step (K x words)
142 + * `attnOut`- attention at the last step
143 +
144 + Returns: True if beam search is complete.
145 + """
146 + numWords = wordLk.size(1)
147 +
148 + # Sum the previous scores.
149 + if len(self.prevKs) > 0:
150 + beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk)
151 +
152 + # Don't let EOS have children.
153 + for i in range(self.nextYs[-1].size(0)):
154 + if self.nextYs[-1][i] == self._eos:
155 + beamLk[i] = -1e20
156 + else:
157 + beamLk = wordLk[0]
158 + flatBeamLk = beamLk.view(-1)
159 + bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True)
160 +
161 + self.scores = bestScores
162 +
163 + # bestScoresId is flattened beam x word array, so calculate which
164 + # word and beam each score came from
165 + prevK = bestScoresId / numWords
166 + self.prevKs.append(prevK)
167 + self.nextYs.append((bestScoresId - prevK * numWords))
168 +
169 +
170 + for i in range(self.nextYs[-1].size(0)):
171 + if self.nextYs[-1][i] == self._eos:
172 + s = self.scores[i]
173 + self.finished.append((s, len(self.nextYs) - 1, i))
174 +
175 + # End condition is when top-of-beam is EOS and no global score.
176 + if self.nextYs[-1][0] == self._eos:
177 + self.eosTop = True
178 +
179 + def done(self):
180 + return self.eosTop and len(self.finished) >=self.size
181 +
182 + def getFinal(self):
183 + if len(self.finished) == 0:
184 + self.finished.append((self.scores[0], len(self.nextYs) - 1, 0))
185 + self.finished.sort(key=lambda a: -a[0])
186 + if len(self.finished) != self.size:
187 + unfinished=[]
188 + for i in range(self.nextYs[-1].size(0)):
189 + if self.nextYs[-1][i] != self._eos:
190 + s = self.scores[i]
191 + unfinished.append((s, len(self.nextYs) - 1, i))
192 + unfinished.sort(key=lambda a: -a[0])
193 + self.finished+=unfinished[:self.size-len(self.finished)]
194 + return self.finished[:self.size]
195 +
196 + def getHyp(self, beam_res):
197 + """
198 + Walk back to construct the full hypothesis.
199 + """
200 + hyps=[]
201 + for _,timestep, k in beam_res:
202 + hyp = []
203 + for j in range(len(self.prevKs[:timestep]) - 1, -1, -1):
204 + hyp.append(self.nextYs[j+1][k])
205 + k = self.prevKs[j][k]
206 + hyps.append(hyp[::-1])
207 + return hyps
208 +
209 + def buildTargetTokens(self, preds):
210 + sentence=[]
211 + for pred in preds:
212 + tokens = []
213 + for tok in pred:
214 + if tok==self._eos:
215 + break
216 + tokens.append(tok)
217 + sentence.append(tokens)
218 + return sentence
219 +
1 +# coding=utf-8
2 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 +#
5 +# Licensed under the Apache License, Version 2.0 (the "License");
6 +# you may not use this file except in compliance with the License.
7 +# You may obtain a copy of the License at
8 +#
9 +# http://www.apache.org/licenses/LICENSE-2.0
10 +#
11 +# Unless required by applicable law or agreed to in writing, software
12 +# distributed under the License is distributed on an "AS IS" BASIS,
13 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 +# See the License for the specific language governing permissions and
15 +# limitations under the License.
16 +"""
17 +Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).
18 +GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned
19 +using a masked language modeling (MLM) loss.
20 +"""
21 +
22 +from __future__ import absolute_import
23 +import os
24 +import sys
25 +import bleu
26 +import pickle
27 +import torch
28 +import json
29 +import random
30 +import logging
31 +import argparse
32 +import numpy as np
33 +from io import open
34 +from itertools import cycle
35 +import torch.nn as nn
36 +from model import Seq2Seq
37 +from tqdm import tqdm, trange
38 +from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler,TensorDataset
39 +from torch.utils.data.distributed import DistributedSampler
40 +from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup,
41 + RobertaConfig, RobertaModel, RobertaTokenizer)
42 +MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)}
43 +
44 +logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
45 + datefmt = '%m/%d/%Y %H:%M:%S',
46 + level = logging.INFO)
47 +logger = logging.getLogger(__name__)
48 +
49 +class Example(object):
50 + """A single training/test example."""
51 + def __init__(self,
52 + idx,
53 + source,
54 + target,
55 + ):
56 + self.idx = idx
57 + self.source = source
58 + self.target = target
59 +
60 +def read_examples(filename):
61 + """Read examples from filename."""
62 + examples=[]
63 + with open(filename,encoding="utf-8") as f:
64 + for idx, line in enumerate(f):
65 + line=line.strip()
66 + js=json.loads(line)
67 + if 'idx' not in js:
68 + js['idx']=idx
69 + code=' '.join(js['code_tokens']).replace('\n',' ')
70 + code=' '.join(code.strip().split())
71 + nl=' '.join(js['docstring_tokens']).replace('\n','')
72 + nl=' '.join(nl.strip().split())
73 + examples.append(
74 + Example(
75 + idx = idx,
76 + source=code,
77 + target = nl,
78 + )
79 + )
80 + return examples
81 +
82 +
83 +class InputFeatures(object):
84 + """A single training/test features for a example."""
85 + def __init__(self,
86 + example_id,
87 + source_ids,
88 + target_ids,
89 + source_mask,
90 + target_mask,
91 +
92 + ):
93 + self.example_id = example_id
94 + self.source_ids = source_ids
95 + self.target_ids = target_ids
96 + self.source_mask = source_mask
97 + self.target_mask = target_mask
98 +
99 +
100 +
101 +def convert_examples_to_features(examples, tokenizer, args,stage=None):
102 + features = []
103 + for example_index, example in enumerate(examples):
104 + #source
105 + source_tokens = tokenizer.tokenize(example.source)[:args.max_source_length-2]
106 + source_tokens =[tokenizer.cls_token]+source_tokens+[tokenizer.sep_token]
107 + source_ids = tokenizer.convert_tokens_to_ids(source_tokens)
108 + source_mask = [1] * (len(source_tokens))
109 + padding_length = args.max_source_length - len(source_ids)
110 + source_ids+=[tokenizer.pad_token_id]*padding_length
111 + source_mask+=[0]*padding_length
112 +
113 + #target
114 + if stage=="test":
115 + target_tokens = tokenizer.tokenize("None")
116 + else:
117 + target_tokens = tokenizer.tokenize(example.target)[:args.max_target_length-2]
118 + target_tokens = [tokenizer.cls_token]+target_tokens+[tokenizer.sep_token]
119 + target_ids = tokenizer.convert_tokens_to_ids(target_tokens)
120 + target_mask = [1] *len(target_ids)
121 + padding_length = args.max_target_length - len(target_ids)
122 + target_ids+=[tokenizer.pad_token_id]*padding_length
123 + target_mask+=[0]*padding_length
124 +
125 + if example_index < 5:
126 + if stage=='train':
127 + logger.info("*** Example ***")
128 + logger.info("idx: {}".format(example.idx))
129 +
130 + logger.info("source_tokens: {}".format([x.replace('\u0120','_') for x in source_tokens]))
131 + logger.info("source_ids: {}".format(' '.join(map(str, source_ids))))
132 + logger.info("source_mask: {}".format(' '.join(map(str, source_mask))))
133 +
134 + logger.info("target_tokens: {}".format([x.replace('\u0120','_') for x in target_tokens]))
135 + logger.info("target_ids: {}".format(' '.join(map(str, target_ids))))
136 + logger.info("target_mask: {}".format(' '.join(map(str, target_mask))))
137 +
138 + features.append(
139 + InputFeatures(
140 + example_index,
141 + source_ids,
142 + target_ids,
143 + source_mask,
144 + target_mask,
145 + )
146 + )
147 + return features
148 +
149 +
150 +
151 +def set_seed(args):
152 + """set random seed."""
153 + random.seed(args.seed)
154 + np.random.seed(args.seed)
155 + torch.manual_seed(args.seed)
156 + if args.n_gpu > 0:
157 + torch.cuda.manual_seed_all(args.seed)
158 +
159 +def main():
160 + parser = argparse.ArgumentParser()
161 +
162 + ## Required parameters
163 + parser.add_argument("--model_type", default=None, type=str, required=True,
164 + help="Model type: e.g. roberta")
165 + parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
166 + help="Path to pre-trained model: e.g. roberta-base" )
167 + parser.add_argument("--output_dir", default=None, type=str, required=True,
168 + help="The output directory where the model predictions and checkpoints will be written.")
169 + parser.add_argument("--load_model_path", default=None, type=str,
170 + help="Path to trained model: Should contain the .bin files" )
171 + ## Other parameters
172 + parser.add_argument("--train_filename", default=None, type=str,
173 + help="The train filename. Should contain the .jsonl files for this task.")
174 + parser.add_argument("--dev_filename", default=None, type=str,
175 + help="The dev filename. Should contain the .jsonl files for this task.")
176 + parser.add_argument("--test_filename", default=None, type=str,
177 + help="The test filename. Should contain the .jsonl files for this task.")
178 +
179 + parser.add_argument("--config_name", default="", type=str,
180 + help="Pretrained config name or path if not the same as model_name")
181 + parser.add_argument("--tokenizer_name", default="", type=str,
182 + help="Pretrained tokenizer name or path if not the same as model_name")
183 + parser.add_argument("--max_source_length", default=64, type=int,
184 + help="The maximum total source sequence length after tokenization. Sequences longer "
185 + "than this will be truncated, sequences shorter will be padded.")
186 + parser.add_argument("--max_target_length", default=32, type=int,
187 + help="The maximum total target sequence length after tokenization. Sequences longer "
188 + "than this will be truncated, sequences shorter will be padded.")
189 +
190 + parser.add_argument("--do_train", action='store_true',
191 + help="Whether to run training.")
192 + parser.add_argument("--do_eval", action='store_true',
193 + help="Whether to run eval on the dev set.")
194 + parser.add_argument("--do_test", action='store_true',
195 + help="Whether to run eval on the dev set.")
196 + parser.add_argument("--do_lower_case", action='store_true',
197 + help="Set this flag if you are using an uncased model.")
198 + parser.add_argument("--no_cuda", action='store_true',
199 + help="Avoid using CUDA when available")
200 +
201 + parser.add_argument("--train_batch_size", default=8, type=int,
202 + help="Batch size per GPU/CPU for training.")
203 + parser.add_argument("--eval_batch_size", default=8, type=int,
204 + help="Batch size per GPU/CPU for evaluation.")
205 + parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
206 + help="Number of updates steps to accumulate before performing a backward/update pass.")
207 + parser.add_argument("--learning_rate", default=5e-5, type=float,
208 + help="The initial learning rate for Adam.")
209 + parser.add_argument("--beam_size", default=10, type=int,
210 + help="beam size for beam search")
211 + parser.add_argument("--weight_decay", default=0.0, type=float,
212 + help="Weight deay if we apply some.")
213 + parser.add_argument("--adam_epsilon", default=1e-8, type=float,
214 + help="Epsilon for Adam optimizer.")
215 + parser.add_argument("--max_grad_norm", default=1.0, type=float,
216 + help="Max gradient norm.")
217 + parser.add_argument("--num_train_epochs", default=3.0, type=float,
218 + help="Total number of training epochs to perform.")
219 + parser.add_argument("--max_steps", default=-1, type=int,
220 + help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
221 + parser.add_argument("--eval_steps", default=-1, type=int,
222 + help="")
223 + parser.add_argument("--train_steps", default=-1, type=int,
224 + help="")
225 + parser.add_argument("--warmup_steps", default=0, type=int,
226 + help="Linear warmup over warmup_steps.")
227 + parser.add_argument("--local_rank", type=int, default=-1,
228 + help="For distributed training: local_rank")
229 + parser.add_argument('--seed', type=int, default=42,
230 + help="random seed for initialization")
231 + # print arguments
232 + args = parser.parse_args()
233 + logger.info(args)
234 +
235 + # Setup CUDA, GPU & distributed training
236 + if args.local_rank == -1 or args.no_cuda:
237 + device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
238 + args.n_gpu = torch.cuda.device_count()
239 + else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
240 + torch.cuda.set_device(args.local_rank)
241 + device = torch.device("cuda", args.local_rank)
242 + torch.distributed.init_process_group(backend='nccl')
243 + args.n_gpu = 1
244 + logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s",
245 + args.local_rank, device, args.n_gpu, bool(args.local_rank != -1))
246 + args.device = device
247 + # Set seed
248 + set_seed(args)
249 + # make dir if output_dir not exist
250 + if os.path.exists(args.output_dir) is False:
251 + os.makedirs(args.output_dir)
252 +
253 + config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
254 + config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
255 + tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,do_lower_case=args.do_lower_case)
256 +
257 + #budild model
258 + encoder = model_class.from_pretrained(args.model_name_or_path,config=config)
259 + decoder_layer = nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads)
260 + decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
261 + model=Seq2Seq(encoder=encoder,decoder=decoder,config=config,
262 + beam_size=args.beam_size,max_length=args.max_target_length,
263 + sos_id=tokenizer.cls_token_id,eos_id=tokenizer.sep_token_id)
264 + if args.load_model_path is not None:
265 + logger.info("reload model from {}".format(args.load_model_path))
266 + model.load_state_dict(torch.load(args.load_model_path))
267 +
268 + model.to(device)
269 + if args.local_rank != -1:
270 + # Distributed training
271 + try:
272 + from apex.parallel import DistributedDataParallel as DDP
273 + except ImportError:
274 + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
275 +
276 + model = DDP(model)
277 + elif args.n_gpu > 1:
278 + # multi-gpu training
279 + model = torch.nn.DataParallel(model)
280 +
281 +
282 +
283 +
284 + if args.do_train:
285 + # Prepare training data loader
286 + train_examples = read_examples(args.train_filename)
287 + train_features = convert_examples_to_features(train_examples, tokenizer,args,stage='train')
288 + all_source_ids = torch.tensor([f.source_ids for f in train_features], dtype=torch.long)
289 + all_source_mask = torch.tensor([f.source_mask for f in train_features], dtype=torch.long)
290 + all_target_ids = torch.tensor([f.target_ids for f in train_features], dtype=torch.long)
291 + all_target_mask = torch.tensor([f.target_mask for f in train_features], dtype=torch.long)
292 + train_data = TensorDataset(all_source_ids,all_source_mask,all_target_ids,all_target_mask)
293 +
294 + if args.local_rank == -1:
295 + train_sampler = RandomSampler(train_data)
296 + else:
297 + train_sampler = DistributedSampler(train_data)
298 + train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size//args.gradient_accumulation_steps)
299 +
300 + num_train_optimization_steps = args.train_steps
301 +
302 + # Prepare optimizer and schedule (linear warmup and decay)
303 + no_decay = ['bias', 'LayerNorm.weight']
304 + optimizer_grouped_parameters = [
305 + {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
306 + 'weight_decay': args.weight_decay},
307 + {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
308 + ]
309 + optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
310 + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps,
311 + num_training_steps=num_train_optimization_steps)
312 +
313 +
314 + #Start training
315 + logger.info("***** Running training *****")
316 + logger.info(" Num examples = %d", len(train_examples))
317 + logger.info(" Batch size = %d", args.train_batch_size)
318 + logger.info(" Num epoch = %d", num_train_optimization_steps*args.train_batch_size//len(train_examples))
319 +
320 +
321 + model.train()
322 + dev_dataset={}
323 + nb_tr_examples, nb_tr_steps,tr_loss,global_step,best_bleu,best_loss = 0, 0,0,0,0,1e6
324 + bar = tqdm(range(num_train_optimization_steps),total=num_train_optimization_steps)
325 + train_dataloader=cycle(train_dataloader)
326 + eval_flag = True
327 + for step in bar:
328 + batch = next(train_dataloader)
329 + batch = tuple(t.to(device) for t in batch)
330 + source_ids,source_mask,target_ids,target_mask = batch
331 + loss,_,_ = model(source_ids=source_ids,source_mask=source_mask,target_ids=target_ids,target_mask=target_mask)
332 +
333 + if args.n_gpu > 1:
334 + loss = loss.mean() # mean() to average on multi-gpu.
335 + if args.gradient_accumulation_steps > 1:
336 + loss = loss / args.gradient_accumulation_steps
337 + tr_loss += loss.item()
338 + train_loss=round(tr_loss*args.gradient_accumulation_steps/(nb_tr_steps+1),4)
339 + bar.set_description("loss {}".format(train_loss))
340 + nb_tr_examples += source_ids.size(0)
341 + nb_tr_steps += 1
342 + loss.backward()
343 +
344 + if (nb_tr_steps + 1) % args.gradient_accumulation_steps == 0:
345 + #Update parameters
346 + optimizer.step()
347 + optimizer.zero_grad()
348 + scheduler.step()
349 + global_step += 1
350 + eval_flag = True
351 +
352 + if args.do_eval and ((global_step + 1) %args.eval_steps == 0) and eval_flag:
353 + #Eval model with dev dataset
354 + tr_loss = 0
355 + nb_tr_examples, nb_tr_steps = 0, 0
356 + eval_flag=False
357 + if 'dev_loss' in dev_dataset:
358 + eval_examples,eval_data=dev_dataset['dev_loss']
359 + else:
360 + eval_examples = read_examples(args.dev_filename)
361 + eval_features = convert_examples_to_features(eval_examples, tokenizer, args,stage='dev')
362 + all_source_ids = torch.tensor([f.source_ids for f in eval_features], dtype=torch.long)
363 + all_source_mask = torch.tensor([f.source_mask for f in eval_features], dtype=torch.long)
364 + all_target_ids = torch.tensor([f.target_ids for f in eval_features], dtype=torch.long)
365 + all_target_mask = torch.tensor([f.target_mask for f in eval_features], dtype=torch.long)
366 + eval_data = TensorDataset(all_source_ids,all_source_mask,all_target_ids,all_target_mask)
367 + dev_dataset['dev_loss']=eval_examples,eval_data
368 + eval_sampler = SequentialSampler(eval_data)
369 + eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
370 +
371 + logger.info("\n***** Running evaluation *****")
372 + logger.info(" Num examples = %d", len(eval_examples))
373 + logger.info(" Batch size = %d", args.eval_batch_size)
374 +
375 + #Start Evaling model
376 + model.eval()
377 + eval_loss,tokens_num = 0,0
378 + for batch in eval_dataloader:
379 + batch = tuple(t.to(device) for t in batch)
380 + source_ids,source_mask,target_ids,target_mask = batch
381 +
382 + with torch.no_grad():
383 + _,loss,num = model(source_ids=source_ids,source_mask=source_mask,
384 + target_ids=target_ids,target_mask=target_mask)
385 + eval_loss += loss.sum().item()
386 + tokens_num += num.sum().item()
387 + #Pring loss of dev dataset
388 + model.train()
389 + eval_loss = eval_loss / tokens_num
390 + result = {'eval_ppl': round(np.exp(eval_loss),5),
391 + 'global_step': global_step+1,
392 + 'train_loss': round(train_loss,5)}
393 + for key in sorted(result.keys()):
394 + logger.info(" %s = %s", key, str(result[key]))
395 + logger.info(" "+"*"*20)
396 +
397 + #save last checkpoint
398 + last_output_dir = os.path.join(args.output_dir, 'checkpoint-last')
399 + if not os.path.exists(last_output_dir):
400 + os.makedirs(last_output_dir)
401 + model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
402 + output_model_file = os.path.join(last_output_dir, "pytorch_model.bin")
403 + torch.save(model_to_save.state_dict(), output_model_file)
404 + if eval_loss<best_loss:
405 + logger.info(" Best ppl:%s",round(np.exp(eval_loss),5))
406 + logger.info(" "+"*"*20)
407 + best_loss=eval_loss
408 + # Save best checkpoint for best ppl
409 + output_dir = os.path.join(args.output_dir, 'checkpoint-best-ppl')
410 + if not os.path.exists(output_dir):
411 + os.makedirs(output_dir)
412 + model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
413 + output_model_file = os.path.join(output_dir, "pytorch_model.bin")
414 + torch.save(model_to_save.state_dict(), output_model_file)
415 +
416 +
417 + #Calculate bleu
418 + if 'dev_bleu' in dev_dataset:
419 + eval_examples,eval_data=dev_dataset['dev_bleu']
420 + else:
421 + eval_examples = read_examples(args.dev_filename)
422 + eval_examples = random.sample(eval_examples,min(1000,len(eval_examples)))
423 + eval_features = convert_examples_to_features(eval_examples, tokenizer, args,stage='test')
424 + all_source_ids = torch.tensor([f.source_ids for f in eval_features], dtype=torch.long)
425 + all_source_mask = torch.tensor([f.source_mask for f in eval_features], dtype=torch.long)
426 + eval_data = TensorDataset(all_source_ids,all_source_mask)
427 + dev_dataset['dev_bleu']=eval_examples,eval_data
428 +
429 +
430 +
431 + eval_sampler = SequentialSampler(eval_data)
432 + eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
433 +
434 + model.eval()
435 + p=[]
436 + for batch in eval_dataloader:
437 + batch = tuple(t.to(device) for t in batch)
438 + source_ids,source_mask= batch
439 + with torch.no_grad():
440 + preds = model(source_ids=source_ids,source_mask=source_mask)
441 + for pred in preds:
442 + t=pred[0].cpu().numpy()
443 + t=list(t)
444 + if 0 in t:
445 + t=t[:t.index(0)]
446 + text = tokenizer.decode(t,clean_up_tokenization_spaces=False)
447 + p.append(text)
448 + model.train()
449 + predictions=[]
450 + with open(os.path.join(args.output_dir,"dev.output"),'w') as f, open(os.path.join(args.output_dir,"dev.gold"),'w') as f1:
451 + for ref,gold in zip(p,eval_examples):
452 + predictions.append(str(gold.idx)+'\t'+ref)
453 + f.write(str(gold.idx)+'\t'+ref+'\n')
454 + f1.write(str(gold.idx)+'\t'+gold.target+'\n')
455 +
456 + (goldMap, predictionMap) = bleu.computeMaps(predictions, os.path.join(args.output_dir, "dev.gold"))
457 + dev_bleu=round(bleu.bleuFromMaps(goldMap, predictionMap)[0],2)
458 + logger.info(" %s = %s "%("bleu-4",str(dev_bleu)))
459 + logger.info(" "+"*"*20)
460 + if dev_bleu>best_bleu:
461 + logger.info(" Best bleu:%s",dev_bleu)
462 + logger.info(" "+"*"*20)
463 + best_bleu=dev_bleu
464 + # Save best checkpoint for best bleu
465 + output_dir = os.path.join(args.output_dir, 'checkpoint-best-bleu')
466 + if not os.path.exists(output_dir):
467 + os.makedirs(output_dir)
468 + model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
469 + output_model_file = os.path.join(output_dir, "pytorch_model.bin")
470 + torch.save(model_to_save.state_dict(), output_model_file)
471 +
472 + if args.do_test:
473 + files=[]
474 + if args.dev_filename is not None:
475 + files.append(args.dev_filename)
476 + if args.test_filename is not None:
477 + files.append(args.test_filename)
478 + for idx,file in enumerate(files):
479 + logger.info("Test file: {}".format(file))
480 + eval_examples = read_examples(file)
481 + eval_features = convert_examples_to_features(eval_examples, tokenizer, args,stage='test')
482 + all_source_ids = torch.tensor([f.source_ids for f in eval_features], dtype=torch.long)
483 + all_source_mask = torch.tensor([f.source_mask for f in eval_features], dtype=torch.long)
484 + eval_data = TensorDataset(all_source_ids,all_source_mask)
485 +
486 + # Calculate bleu
487 + eval_sampler = SequentialSampler(eval_data)
488 + eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
489 +
490 + model.eval()
491 + p=[]
492 + for batch in tqdm(eval_dataloader,total=len(eval_dataloader)):
493 + batch = tuple(t.to(device) for t in batch)
494 + source_ids,source_mask= batch
495 + with torch.no_grad():
496 + preds = model(source_ids=source_ids,source_mask=source_mask)
497 + for pred in preds:
498 + t=pred[0].cpu().numpy()
499 + t=list(t)
500 + if 0 in t:
501 + t=t[:t.index(0)]
502 + text = tokenizer.decode(t,clean_up_tokenization_spaces=False)
503 + p.append(text)
504 + model.train()
505 + predictions=[]
506 + with open(os.path.join(args.output_dir,"test_{}.output".format(str(idx))),'w') as f, open(os.path.join(args.output_dir,"test_{}.gold".format(str(idx))),'w') as f1:
507 + for ref,gold in zip(p,eval_examples):
508 + predictions.append(str(gold.idx)+'\t'+ref)
509 + f.write(str(gold.idx)+'\t'+ref+'\n')
510 + f1.write(str(gold.idx)+'\t'+gold.target+'\n')
511 +
512 + (goldMap, predictionMap) = bleu.computeMaps(predictions, os.path.join(args.output_dir, "test_{}.gold".format(idx)))
513 + dev_bleu=round(bleu.bleuFromMaps(goldMap, predictionMap)[0],2)
514 + logger.info(" %s = %s "%("bleu-4",str(dev_bleu)))
515 + logger.info(" "+"*"*20)
516 +
517 +
518 +
519 +
520 +
521 +
522 +
523 +if __name__ == "__main__":
524 + main()
525 +
526 +