105 lines
3.1 KiB
Python
105 lines
3.1 KiB
Python
|
import torch
|
||
|
from torch import nn, optim
|
||
|
from torch.utils.data import DataLoader
|
||
|
import numpy as np
|
||
|
from collections import Counter
|
||
|
import string
|
||
|
import lzma
|
||
|
import pdb
|
||
|
import copy
|
||
|
from torch.utils.data import IterableDataset
|
||
|
import itertools
|
||
|
import lzma
|
||
|
import regex as re
|
||
|
import pickle
|
||
|
import string
|
||
|
import pdb
|
||
|
import utils
|
||
|
import os
|
||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||
|
device = 'cuda'
|
||
|
|
||
|
vocab_size = utils.vocab_size
|
||
|
|
||
|
with open("vocab.pickle", 'rb') as handle:
|
||
|
vocab = pickle.load( handle)
|
||
|
vocab.set_default_index(vocab['<unk>'])
|
||
|
class Model(nn.Module):
|
||
|
def __init__(self, vocab_size):
|
||
|
super(Model, self).__init__()
|
||
|
self.lstm_size = 150
|
||
|
self.embedding_dim = 200
|
||
|
self.num_layers = 1
|
||
|
|
||
|
self.embedding = nn.Embedding(
|
||
|
num_embeddings=vocab_size,
|
||
|
embedding_dim=self.embedding_dim,
|
||
|
)
|
||
|
self.lstm = nn.LSTM(
|
||
|
input_size=self.embedding_dim,
|
||
|
hidden_size=self.lstm_size,
|
||
|
num_layers=self.num_layers,
|
||
|
batch_first=True,
|
||
|
bidirectional=True,
|
||
|
# dropout=0.2,
|
||
|
)
|
||
|
self.fc = nn.Linear(self.lstm_size*2, vocab_size)
|
||
|
|
||
|
def forward(self, x, prev_state = None):
|
||
|
embed = self.embedding(x)
|
||
|
output, state = self.lstm(embed, prev_state)
|
||
|
logits = self.fc(output)
|
||
|
return logits, state
|
||
|
|
||
|
def init_state(self, sequence_length):
|
||
|
return (torch.zeros(self.num_layers*2, sequence_length, self.lstm_size).to(device),
|
||
|
torch.zeros(self.num_layers*2, sequence_length, self.lstm_size).to(device))
|
||
|
|
||
|
|
||
|
model = Model(vocab_size = vocab_size).to(device)
|
||
|
|
||
|
model.load_state_dict(torch.load('lstm_step_10000.bin'))
|
||
|
model.eval()
|
||
|
|
||
|
def predict(model, text_splitted):
|
||
|
model.eval()
|
||
|
words = text_splitted
|
||
|
|
||
|
x = torch.tensor([[vocab[w] for w in words]]).to(device)
|
||
|
|
||
|
state_h, state_c = model.init_state(x.size()[0])
|
||
|
|
||
|
y_pred, (state_h, state_c) = model(x, (state_h, state_c))
|
||
|
|
||
|
|
||
|
last_word_logits = y_pred[0][-1]
|
||
|
p = torch.nn.functional.softmax(last_word_logits, dim=0)
|
||
|
|
||
|
top = torch.topk(p, 10)
|
||
|
top_indices = top.indices.tolist()
|
||
|
top_words = vocab.lookup_tokens(top_indices)
|
||
|
if '<unk>' in top_words:
|
||
|
top_words.remove('<unk>')
|
||
|
|
||
|
return np.random.choice(top_words)
|
||
|
|
||
|
prompts = [
|
||
|
'These, and a thousand other means, by which the wealth of a nation may be greatly increase',
|
||
|
'Pants, coat and vest of the latest styles, are provided. Whenever the fires need coaling,',
|
||
|
'Mr. Deddrick intends to clothe it and\ngive it as nearly as possible a likeness'
|
||
|
]
|
||
|
for p in prompts:
|
||
|
answer = ''
|
||
|
for i in range(10):
|
||
|
answer += predict(model, p.split()) + ' '
|
||
|
print('Prompt: ', p)
|
||
|
print('Answer: ', answer)
|
||
|
|
||
|
# Prompt: These, and a thousand other means, by which the wealth of a nation may be greatly increase
|
||
|
# Answer: as the of as and to in to for in
|
||
|
# Prompt: Pants, coat and vest of the latest styles, are provided. Whenever the fires need coaling,
|
||
|
# Answer: in that The a the of the to the for
|
||
|
# Prompt: Mr. Deddrick intends to clothe it and
|
||
|
# give it as nearly as possible a likeness
|
||
|
# Answer: and of\nthe for man in of\nthe and of man of
|