TAU_22_sane_words_torch_nn/new_lab.py
2019-12-04 22:52:59 +01:00

94 lines
2.5 KiB
Python

import sys
import torch
from torch import nn
from torch import optim
history_length = 32
nb_of_char_codes = 128
history_encoded = [ord("\n")] * history_length
embedding_size = 10
hidden_size = 100
device = torch.device('cpu')
print(history_encoded)
def char_source():
for line in sys.stdin:
for char in line:
if ord(char) < nb_of_char_codes:
yield ord(char)
class NGramLanguageModel(nn.Module):
def __init__(self, nb_of_char_codes, history_length, embedding_size, hidden_size):
super(NGramLanguageModel, self).__init__()
self.embeddings = nn.Embedding(nb_of_char_codes, embedding_size).to(device)
self.model = nn.Sequential(
nn.Linear(history_length * embedding_size, hidden_size),
nn.Linear(hidden_size, nb_of_char_codes),
nn.LogSoftmax()
).to(device)
def forward(self, inputs):
embedded_inputs = self.embeddings(inputs)
return self.model(embedded_inputs.view(-1)) #view -1 rozpłaszcza
def generate(self, to_be_continued, n):
t = (" " * history_length + to_be_continued)[-history_length:]
history = [ord(c) for c in t]
with torch.no_grad():
for _ in range(n):
x = torch.tensor(history, dtype = torch.long)
y = ((torch.exp(model(x))))
best = sorted(range(nb_of_char_codes), key= lambda i: -y[i])[0:4]
yb = torch.tensor([
y[ix] if ix in best else 0.0
for ix in range(nb_of_char_codes)
])
c = torch.multinomial(y, 1)[0].item()
t+= chr(c)
history.pop(0)
history.append(c)
return t
model = NGramLanguageModel(nb_of_char_codes, history_length, embedding_size, hidden_size)
counter = 0
step = 1000
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters())
losses = []
for c in char_source():
x = torch.tensor(history_encoded, dtype=torch.long, device = device)
model.zero_grad()
y=model(x)
loss = criterion(y.view(1, -1), torch.tensor([c], dtype=torch.long, device=device))
losses.append(loss.item())
if len(losses) > step:
losses.pop(0)
if counter % step == 0:
awg_losses = sum(losses) / len(losses)
print(awg_losses)
print(loss)
print(model.generate("Machine translation is", 200))
loss.backward()
optimizer.step()
#print(y)
history_encoded.pop(0)
history_encoded.append(c)
"""
zd, zrobić generator
Nucleus - do generowania
"""