196 lines
8.5 KiB
Python
196 lines
8.5 KiB
Python
|
#!/usr/bin/env python3
|
||
|
# coding=utf-8
|
||
|
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
|
||
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
""" Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/Transformer-XL/XLNet)
|
||
|
"""
|
||
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||
|
|
||
|
import argparse
|
||
|
import logging
|
||
|
from tqdm import trange
|
||
|
|
||
|
import torch
|
||
|
import torch.nn.functional as F
|
||
|
import numpy as np
|
||
|
|
||
|
from pytorch_transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig
|
||
|
|
||
|
from pytorch_transformers import GPT2LMHeadModel, GPT2Tokenizer
|
||
|
from pytorch_transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
|
||
|
from pytorch_transformers import XLNetLMHeadModel, XLNetTokenizer
|
||
|
from pytorch_transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
|
||
|
|
||
|
|
||
|
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||
|
datefmt='%m/%d/%Y %H:%M:%S',
|
||
|
level=logging.INFO)
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
|
||
|
|
||
|
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (
|
||
|
GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig)), ())
|
||
|
|
||
|
MODEL_CLASSES = {
|
||
|
'gpt2': (GPT2LMHeadModel, GPT2Tokenizer),
|
||
|
'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
||
|
'xlnet': (XLNetLMHeadModel, XLNetTokenizer),
|
||
|
'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer),
|
||
|
}
|
||
|
|
||
|
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
|
||
|
# in https://github.com/rusiaaman/XLNet-gen#methodology
|
||
|
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
|
||
|
PADDING_TEXT = """ In 1991, the remains of Russian Tsar Nicholas II and his family
|
||
|
(except for Alexei and Maria) are discovered.
|
||
|
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
|
||
|
remainder of the story. 1883 Western Siberia,
|
||
|
a young Grigori Rasputin is asked by his father and a group of men to perform magic.
|
||
|
Rasputin has a vision and denounces one of the men as a horse thief. Although his
|
||
|
father initially slaps him for making such an accusation, Rasputin watches as the
|
||
|
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
|
||
|
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
|
||
|
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
|
||
|
|
||
|
|
||
|
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
||
|
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||
|
Args:
|
||
|
logits: logits distribution shape (vocabulary size)
|
||
|
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
||
|
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
||
|
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
||
|
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
||
|
"""
|
||
|
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
|
||
|
top_k = min(top_k, logits.size(-1)) # Safety check
|
||
|
if top_k > 0:
|
||
|
# Remove all tokens with a probability less than the last token of the top-k
|
||
|
indices_to_remove = logits < torch.topk(logits, top_k)[
|
||
|
0][..., -1, None]
|
||
|
logits[indices_to_remove] = filter_value
|
||
|
|
||
|
if top_p > 0.0:
|
||
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||
|
cumulative_probs = torch.cumsum(
|
||
|
F.softmax(sorted_logits, dim=-1), dim=-1)
|
||
|
|
||
|
# Remove tokens with cumulative probability above the threshold
|
||
|
sorted_indices_to_remove = cumulative_probs > top_p
|
||
|
# Shift the indices to the right to keep also the first token above the threshold
|
||
|
sorted_indices_to_remove[...,
|
||
|
1:] = sorted_indices_to_remove[..., :-1].clone()
|
||
|
sorted_indices_to_remove[..., 0] = 0
|
||
|
|
||
|
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
||
|
logits[indices_to_remove] = filter_value
|
||
|
return logits
|
||
|
|
||
|
|
||
|
def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, is_xlnet=False, device='cpu'):
|
||
|
context = torch.tensor(context, dtype=torch.long, device=device)
|
||
|
context = context.unsqueeze(0).repeat(num_samples, 1)
|
||
|
generated = context
|
||
|
with torch.no_grad():
|
||
|
for _ in trange(length):
|
||
|
|
||
|
inputs = {'input_ids': generated}
|
||
|
if is_xlnet:
|
||
|
# XLNet is a direct (predict same token, not next token) and bi-directional model by default
|
||
|
# => need one additional dummy token in the input (will be masked), attention mask and target mapping (see model docstring)
|
||
|
input_ids = torch.cat((generated, torch.zeros(
|
||
|
(1, 1), dtype=torch.long, device=device)), dim=1)
|
||
|
perm_mask = torch.zeros(
|
||
|
(1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float, device=device)
|
||
|
# Previous tokens don't see last token
|
||
|
perm_mask[:, :, -1] = 1.0
|
||
|
target_mapping = torch.zeros(
|
||
|
(1, 1, input_ids.shape[1]), dtype=torch.float, device=device)
|
||
|
target_mapping[0, 0, -1] = 1.0 # predict last token
|
||
|
inputs = {'input_ids': input_ids, 'perm_mask': perm_mask,
|
||
|
'target_mapping': target_mapping}
|
||
|
|
||
|
# Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
|
||
|
outputs = model(**inputs)
|
||
|
next_token_logits = outputs[0][0, -1, :] / temperature
|
||
|
filtered_logits = top_k_top_p_filtering(
|
||
|
next_token_logits, top_k=top_k, top_p=top_p)
|
||
|
next_token = torch.multinomial(
|
||
|
F.softmax(filtered_logits, dim=-1), num_samples=1)
|
||
|
generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
|
||
|
return generated
|
||
|
|
||
|
|
||
|
def generate_text(
|
||
|
padding_text=None,
|
||
|
model_type='gpt2',
|
||
|
model_name_or_path='gpt2',
|
||
|
prompt='',
|
||
|
length=20,
|
||
|
temperature=1.0,
|
||
|
top_k=0,
|
||
|
top_p=0.9,
|
||
|
no_cuda=True,
|
||
|
seed=42,
|
||
|
):
|
||
|
device = torch.device(
|
||
|
"cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
|
||
|
n_gpu = torch.cuda.device_count()
|
||
|
|
||
|
np.random.seed(seed)
|
||
|
torch.manual_seed(seed)
|
||
|
if n_gpu > 0:
|
||
|
torch.cuda.manual_seed_all(seed)
|
||
|
|
||
|
model_type = model_type.lower()
|
||
|
model_class, tokenizer_class = MODEL_CLASSES[model_type]
|
||
|
tokenizer = tokenizer_class.from_pretrained(model_name_or_path)
|
||
|
model = model_class.from_pretrained(model_name_or_path)
|
||
|
model.to(device)
|
||
|
model.eval()
|
||
|
|
||
|
if length < 0 and model.config.max_position_embeddings > 0:
|
||
|
length = model.config.max_position_embeddings
|
||
|
elif 0 < model.config.max_position_embeddings < length:
|
||
|
# No generation bigger than model size
|
||
|
length = model.config.max_position_embeddings
|
||
|
elif length < 0:
|
||
|
length = MAX_LENGTH # avoid infinite loop
|
||
|
|
||
|
while True:
|
||
|
raw_text = prompt if prompt else input("Model prompt >>> ")
|
||
|
if model_type in ["transfo-xl", "xlnet"]:
|
||
|
# Models with memory likes to have a long prompt for short inputs.
|
||
|
raw_text = (
|
||
|
padding_text if padding_text else PADDING_TEXT) + raw_text
|
||
|
context_tokens = tokenizer.encode(raw_text)
|
||
|
out = sample_sequence(
|
||
|
model=model,
|
||
|
context=context_tokens,
|
||
|
length=length,
|
||
|
temperature=temperature,
|
||
|
top_k=top_k,
|
||
|
top_p=top_p,
|
||
|
device=device,
|
||
|
is_xlnet=bool(model_type == "xlnet"),
|
||
|
)
|
||
|
out = out[0, len(context_tokens):].tolist()
|
||
|
text = tokenizer.decode(out, clean_up_tokenization_spaces=True)
|
||
|
print(text)
|
||
|
if prompt:
|
||
|
break
|
||
|
return text
|