example.py 5.02 KB
import torch
from random import choice, choices, randint
import argparse
from kogpt2.pytorch_kogpt2 import get_pytorch_kogpt2_model
from gluonnlp.data import SentencepieceTokenizer
from kogpt2.utils import get_tokenizer

def top_k(predict, vocab, k):
  # topk 중 랜덤으로 선택된 값을 반환.
  probs, indices = torch.topk(predict, k=k,dim=-1)
  return vocab.to_tokens(choice(indices.tolist()))

def top_p(logits, vocab, threshold = 0.9):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    indexs = sorted_indices.tolist()
    sorted_softmax_logits = torch.nn.functional.softmax(sorted_logits, dim=-1)
    cum_prob = 0
    top_p_index = 0
    # Top-p에 해당하는 index를 획득
    for i, prob in enumerate(sorted_softmax_logits):
      if cum_prob>threshold:
        top_p_index = 0 if i==0 else i-1
        break
      cum_prob+=prob
    rand_num = randint(0, top_p_index) # top-p 분포에서 랜덤 샘플링
    return vocab.to_tokens(indexs[rand_num])

def weighted_random(logits, vocab):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    indexs = sorted_indices.tolist()
    sorted_softmax_logits = torch.nn.functional.softmax(sorted_logits, dim=-1)
    return vocab.to_tokens(choices(indexs,weights=sorted_softmax_logits)[0])

if __name__ == "__main__":
  parser = argparse.ArgumentParser(description='KoGPT2 generation  example')
  parser.add_argument('sentence', metavar='S', type=str,  nargs='?',default= '2019년 한해를 보내며,',
                    help='korean sentence to use as input.')

  ctx='cuda' if torch.cuda.is_available() else 'cpu'
  tok_path = get_tokenizer(cachedir='/code/model')
  model, vocab = get_pytorch_kogpt2_model(ctx=ctx,cachedir='/code/model')
  tok = SentencepieceTokenizer(tok_path,  num_best=0, alpha=0)
  sent = parser.parse_args().sentence
  toked = tok(sent)
  token_count=0
  while token_count<100:
    try:
      input_ids = torch.tensor([vocab[vocab.bos_token],]  + vocab[toked]).unsqueeze(0)
      pred = model(input_ids)[0]
      gen = vocab.to_tokens(torch.argmax(pred, axis=-1).squeeze().tolist())[-1]
      if gen == '</s>':
          break
      sent += gen.replace('▁', ' ')
      toked = tok(sent)
      token_count+=1
    except KeyboardInterrupt:
      break
  print('Greedy:',sent)

  sent = parser.parse_args().sentence
  toked = tok(sent)
  token_count=0
  while token_count<100:
    try:
      input_ids = torch.tensor([vocab[vocab.bos_token],]  + vocab[toked]).unsqueeze(0)
      pred = model(input_ids)[0]
      gen = top_k(pred.squeeze()[-1], vocab, 3)
      if gen == '</s>':
          break
      sent += gen.replace('▁', ' ')
      toked = tok(sent)
      token_count+=1
    except KeyboardInterrupt:
      break
  print('Top 3:', sent)

  sent = parser.parse_args().sentence
  toked = tok(sent)
  token_count=0
  while token_count<100:
    try:
      input_ids = torch.tensor([vocab[vocab.bos_token],]  + vocab[toked]).unsqueeze(0)
      pred = model(input_ids)[0]
      gen = top_k(pred.squeeze()[-1], vocab, 5)
      if gen == '</s>':
          break
      sent += gen.replace('▁', ' ')
      toked = tok(sent)
      token_count+=1
    except KeyboardInterrupt:
      break
  print('Top 5:', sent)

  sent = parser.parse_args().sentence
  toked = tok(sent)
  token_count=0
  while token_count<100:
    try:
      input_ids = torch.tensor([vocab[vocab.bos_token],]  + vocab[toked]).unsqueeze(0)
      pred = model(input_ids)[0]
      gen = top_p(pred.squeeze()[-1], vocab,0.5)
      if gen == '</s>':
          break
      sent += gen.replace('▁', ' ')
      toked = tok(sent)
      token_count+=1
    except KeyboardInterrupt:
      break
  print('Top p=0.5:', sent)

  sent = parser.parse_args().sentence
  toked = tok(sent)
  token_count=0
  while token_count<100:
    try:
      input_ids = torch.tensor([vocab[vocab.bos_token],]  + vocab[toked]).unsqueeze(0)
      pred = model(input_ids)[0]
      gen = top_p(pred.squeeze()[-1], vocab,0.7)
      if gen == '</s>':
          break
      sent += gen.replace('▁', ' ')
      toked = tok(sent)
      token_count+=1
    except KeyboardInterrupt:
      break
  print('Top p=0.7:', sent)

  sent = parser.parse_args().sentence
  toked = tok(sent)
  token_count=0
  while token_count<100:
    try:
      input_ids = torch.tensor([vocab[vocab.bos_token],]  + vocab[toked]).unsqueeze(0)
      pred = model(input_ids)[0]
      gen = top_p(pred.squeeze()[-1], vocab)
      if gen == '</s>':
          break
      sent += gen.replace('▁', ' ')
      toked = tok(sent)
      token_count+=1
    except KeyboardInterrupt:
      break
  print('Top p=0.9:', sent)

  sent = parser.parse_args().sentence
  toked = tok(sent)
  token_count=0
  while token_count<100:
    try:
      input_ids = torch.tensor([vocab[vocab.bos_token],]  + vocab[toked]).unsqueeze(0)
      pred = model(input_ids)[0]
      gen = weighted_random(pred.squeeze()[-1], vocab)
      if gen == '</s>':
          break
      sent += gen.replace('▁', ' ')
      toked = tok(sent)
      token_count+=1
    except KeyboardInterrupt:
      break
  print('Weighted random:', sent)