2021-06-16 12:51:01 +02:00
|
|
|
import torch
|
|
|
|
|
|
|
|
class Model(torch.nn.Module):
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
super(Model, self).__init__()
|
|
|
|
# in: 4 słowa kontekstu przed i 1 słowo kontekstu po
|
|
|
|
"""
|
|
|
|
5 in features
|
|
|
|
150 out features
|
|
|
|
"""
|
|
|
|
self.dense1 = torch.nn.Linear(6, 150, bias=False)
|
|
|
|
self.tanh1 = torch.nn.Tanh()
|
|
|
|
"""
|
|
|
|
150 in features
|
|
|
|
300 hidden values
|
|
|
|
2 num layers
|
|
|
|
"""
|
|
|
|
self.lstm = torch.nn.LSTM(150, 300, 2)
|
|
|
|
self.dense2 = torch.nn.Linear(300, 7)
|
2021-06-16 16:04:38 +02:00
|
|
|
self.softmax = torch.nn.Softmax(dim=1)
|
2021-06-16 12:51:01 +02:00
|
|
|
|
|
|
|
def forward(self, data, hidden_state, cell_state):
|
|
|
|
data = self.dense1(data.T)
|
|
|
|
data = self.tanh1(data)
|
|
|
|
data, (hidden_state, cell_state) = self.lstm(data.unsqueeze(1), (hidden_state, cell_state))
|
|
|
|
data = self.dense2(data)
|
|
|
|
data = self.softmax(data)
|
|
|
|
return data, (hidden_state, cell_state)
|