challenging-america-word-ga.../generator.py

105 lines
3.1 KiB
Python
Raw Normal View History

2023-06-04 17:07:15 +02:00
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import numpy as np
from collections import Counter
import string
import lzma
import pdb
import copy
from torch.utils.data import IterableDataset
import itertools
import lzma
import regex as re
import pickle
import string
import pdb
import utils
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = 'cuda'
vocab_size = utils.vocab_size
with open("vocab.pickle", 'rb') as handle:
vocab = pickle.load( handle)
vocab.set_default_index(vocab['<unk>'])
class Model(nn.Module):
def __init__(self, vocab_size):
super(Model, self).__init__()
self.lstm_size = 150
self.embedding_dim = 200
self.num_layers = 1
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=self.embedding_dim,
)
self.lstm = nn.LSTM(
input_size=self.embedding_dim,
hidden_size=self.lstm_size,
num_layers=self.num_layers,
batch_first=True,
bidirectional=True,
# dropout=0.2,
)
self.fc = nn.Linear(self.lstm_size*2, vocab_size)
def forward(self, x, prev_state = None):
embed = self.embedding(x)
output, state = self.lstm(embed, prev_state)
logits = self.fc(output)
return logits, state
def init_state(self, sequence_length):
return (torch.zeros(self.num_layers*2, sequence_length, self.lstm_size).to(device),
torch.zeros(self.num_layers*2, sequence_length, self.lstm_size).to(device))
model = Model(vocab_size = vocab_size).to(device)
model.load_state_dict(torch.load('lstm_step_10000.bin'))
model.eval()
def predict(model, text_splitted):
model.eval()
words = text_splitted
x = torch.tensor([[vocab[w] for w in words]]).to(device)
state_h, state_c = model.init_state(x.size()[0])
y_pred, (state_h, state_c) = model(x, (state_h, state_c))
last_word_logits = y_pred[0][-1]
p = torch.nn.functional.softmax(last_word_logits, dim=0)
top = torch.topk(p, 10)
top_indices = top.indices.tolist()
top_words = vocab.lookup_tokens(top_indices)
if '<unk>' in top_words:
top_words.remove('<unk>')
return np.random.choice(top_words)
prompts = [
'These, and a thousand other means, by which the wealth of a nation may be greatly increase',
'Pants, coat and vest of the latest styles, are provided. Whenever the fires need coaling,',
'Mr. Deddrick intends to clothe it and\ngive it as nearly as possible a likeness'
]
for p in prompts:
answer = ''
for i in range(10):
answer += predict(model, p.split()) + ' '
print('Prompt: ', p)
print('Answer: ', answer)
# Prompt: These, and a thousand other means, by which the wealth of a nation may be greatly increase
# Answer: as the of as and to in to for in
# Prompt: Pants, coat and vest of the latest styles, are provided. Whenever the fires need coaling,
# Answer: in that The a the of the to the for
# Prompt: Mr. Deddrick intends to clothe it and
# give it as nearly as possible a likeness
# Answer: and of\nthe for man in of\nthe and of man of