import torch from torch import nn, optim from torch.utils.data import DataLoader import numpy as np from collections import Counter import string import lzma import pdb import copy from torch.utils.data import IterableDataset import itertools import lzma import regex as re import pickle import string import pdb import utils import os os.environ["CUDA_VISIBLE_DEVICES"] = "0" device = 'cuda' vocab_size = utils.vocab_size with open("vocab.pickle", 'rb') as handle: vocab = pickle.load( handle) vocab.set_default_index(vocab['']) class Model(nn.Module): def __init__(self, vocab_size): super(Model, self).__init__() self.lstm_size = 150 self.embedding_dim = 200 self.num_layers = 1 self.embedding = nn.Embedding( num_embeddings=vocab_size, embedding_dim=self.embedding_dim, ) self.lstm = nn.LSTM( input_size=self.embedding_dim, hidden_size=self.lstm_size, num_layers=self.num_layers, batch_first=True, bidirectional=True, # dropout=0.2, ) self.fc = nn.Linear(self.lstm_size*2, vocab_size) def forward(self, x, prev_state = None): embed = self.embedding(x) output, state = self.lstm(embed, prev_state) logits = self.fc(output) return logits, state def init_state(self, sequence_length): return (torch.zeros(self.num_layers*2, sequence_length, self.lstm_size).to(device), torch.zeros(self.num_layers*2, sequence_length, self.lstm_size).to(device)) model = Model(vocab_size = vocab_size).to(device) model.load_state_dict(torch.load('lstm_step_10000.bin')) model.eval() def predict(model, text_splitted): model.eval() words = text_splitted x = torch.tensor([[vocab[w] for w in words]]).to(device) state_h, state_c = model.init_state(x.size()[0]) y_pred, (state_h, state_c) = model(x, (state_h, state_c)) last_word_logits = y_pred[0][-1] p = torch.nn.functional.softmax(last_word_logits, dim=0) top = torch.topk(p, 10) top_indices = top.indices.tolist() top_words = vocab.lookup_tokens(top_indices) if '' in top_words: top_words.remove('') return np.random.choice(top_words) prompts = [ 'These, and a thousand other means, by which the wealth of a nation may be greatly increase', 'Pants, coat and vest of the latest styles, are provided. Whenever the fires need coaling,', 'Mr. Deddrick intends to clothe it and\ngive it as nearly as possible a likeness' ] for p in prompts: answer = '' for i in range(10): answer += predict(model, p.split()) + ' ' print('Prompt: ', p) print('Answer: ', answer) # Prompt: These, and a thousand other means, by which the wealth of a nation may be greatly increase # Answer: as the of as and to in to for in # Prompt: Pants, coat and vest of the latest styles, are provided. Whenever the fires need coaling, # Answer: in that The a the of the to the for # Prompt: Mr. Deddrick intends to clothe it and # give it as nearly as possible a likeness # Answer: and of\nthe for man in of\nthe and of man of