paragraph_gen.py 5.83 KB
from random import choice, choices, randint
import argparse
import re
import time
import torch
from kogpt2.pytorch_kogpt2 import get_pytorch_kogpt2_model
from gluonnlp.data import SentencepieceTokenizer
from kogpt2.utils import get_tokenizer

def greedy(predict):
    return (torch.argmax(predict, axis=-1).tolist())

def top_k(predict, k):
    # topk 중 랜덤으로 선택된 값을 반환.
    probs, indices = torch.topk(predict, k=k,dim=-1)
    return choice(indices.tolist())

def top_p(logits, threshold = 0.9):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    indices = sorted_indices.tolist()
    sorted_softmax_logits = torch.nn.functional.softmax(sorted_logits, dim=-1)
    cum_prob = 0
    top_p_index = 0
    # Top-p에 해당하는 index를 획득
    for i, prob in enumerate(sorted_softmax_logits):
        if cum_prob>threshold:
            top_p_index = 0 if i==0 else i-1
            break
        cum_prob+=prob
    rand_num = randint(0, top_p_index) # top-p 분포에서 랜덤 샘플링
    return indices[rand_num]

def weighted_random(logits):
    indices=torch.where(logits>=0)[0] #음수 고려 안 함
    selected_logits=torch.index_select(logits,-1,indices)
    softmax_logits = torch.nn.functional.softmax(selected_logits, dim=-1)
    return choices(indices.tolist(),weights=softmax_logits)[0]

def weighted_top_k(predict, k):
    probs, indices = torch.topk(predict, k=k,dim=-1)
    softmax_probs = torch.nn.functional.softmax(probs, dim=-1)
    return choices(indices.tolist(),weights=softmax_probs)[0]

def weighted_top_p(logits, threshold = 0.9):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    sorted_softmax_logits = torch.nn.functional.softmax(sorted_logits, dim=-1)
    cum_prob = 0
    last_cum_prob=0
    top_p_bound = 0
    # Top-p에 해당하는 index를 획득
    for i, prob in enumerate(sorted_softmax_logits):
        if cum_prob>threshold:
            top_p_bound = i
            break
        last_cum_prob=cum_prob
        cum_prob+=prob
    return choices(sorted_indices[:top_p_bound].tolist(),weights=sorted_softmax_logits[:top_p_bound]/last_cum_prob)[0]


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='KoGPT2 generation    example')
    group=parser.add_mutually_exclusive_group()
    group.add_argument('-g','--greedy',action='store_const',const='greedy',help='Greedy sampling')
    group.add_argument('-k','--topk',type=int, choices=range(1,51), help='Top k sampling. 1<=K<=50', metavar='K')
    group.add_argument('-p','--topp',type=float, help='Top p sampling. 0<P<=1.0', metavar='P')
    parser.add_argument('-w','--weighted',action='store_true', help='Use weighted version of sampling.')
    parser.add_argument('-d','--docker', action='store_true', help="Train on docker. Sets model cache path:/code/model, dataset path:/dataset, save path:/code/save.")
    parser.add_argument('-c','--checkpoint',type=str , help='Model chekpoint path',metavar='PATH')
    parser.add_argument('-f','--full_sentence', action='store_true' , help='Treat last S as a full_sentence. (Do not append it.)')
    parser.add_argument('-l','--length', type=int, choices=range(1,21) , help='Set length of paragraph.', metavar='LENGTH', default=15)
    parser.add_argument('sentence', metavar='S', type=str, nargs='*',
                                                                    help='korean sentence to use as input.')
    args = parser.parse_args()
    print(args)
    model_cache_path='/code/model' if args.docker else 'model'
    save_path='/code/save' if args.docker else 'save'

    if args.greedy:
        sampling_name = "Weighted" if args.weighted else "Greedy"
        sampling=weighted_random if args.weighted else  greedy
    elif args.topk is not None:
        sampling_name=f"Weighted Top k={args.topk}" if args.weighted else f"Top k={args.topk}"
        sampling= (lambda pred: weighted_top_k(pred,args.topk)) if args.weighted else (lambda pred: top_k(pred,args.topk))
    elif args.topp is not None:
        sampling_name=f"Weighted Top p={args.topp}" if args.weighted else f"Top p={args.topp}"
        sampling= (lambda pred: weighted_top_p(pred,args.topp)) if args.weighted else (lambda pred: top_p(pred,args.topp))
    else: #if args.weighted: 
        sampling_name="Weighted"
        sampling=weighted_random

    ctx='cuda:0' if torch.cuda.is_available() else 'cpu'
    device=torch.device(ctx)
    tok_path = get_tokenizer(cachedir=model_cache_path)
    model, vocab = get_pytorch_kogpt2_model(ctx=ctx,cachedir=model_cache_path)
    tok = SentencepieceTokenizer(tok_path, num_best=0, alpha=0)
    if args.checkpoint:
        checkpoint = torch.load(args.checkpoint, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        epoch = checkpoint['epoch']
    model.eval()

    toked=[]
    for sent in args.sentence:
        toked += (tok(sent)+[vocab.eos_token,vocab.bos_token]) 
    else:
        if not args.full_sentence:
            toked=toked[:-2]
    token_count=0
    sent_count=0
    started=time.time()
    while token_count<1000:
        try:
            input_ids = torch.tensor([vocab[vocab.bos_token],] + vocab[toked]).unsqueeze(0).to(device=device)
            pred = model(input_ids)[0]
            gen_id = sampling(pred.squeeze()[-1])
            gen_token=vocab.to_tokens(gen_id)
            if gen_token == vocab.eos_token:
                sent_count+=1
                print(sent_count, token_count)
                if sent_count>=args.length:
                    break
                else:
                    toked+=[vocab.eos_token,vocab.bos_token]
                    token_count+=2
            else:
                toked.append(gen_token)
                token_count+=1
        except KeyboardInterrupt:
            break
    print(f'{sampling_name}:',re.sub('</s>', '\r\n',re.sub('(▁|<s>)',' ',''.join(toked))))
    print("Time elapsed:", time.time()-started)