30 lines
987 B
Python
30 lines
987 B
Python
import torch
|
|
|
|
class Model(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
super(Model, self).__init__()
|
|
# in: 4 słowa kontekstu przed i 1 słowo kontekstu po
|
|
"""
|
|
5 in features
|
|
150 out features
|
|
"""
|
|
self.dense1 = torch.nn.Linear(6, 150, bias=False)
|
|
self.tanh1 = torch.nn.Tanh()
|
|
"""
|
|
150 in features
|
|
300 hidden values
|
|
2 num layers
|
|
"""
|
|
self.lstm = torch.nn.LSTM(150, 300, 2)
|
|
self.dense2 = torch.nn.Linear(300, 8)
|
|
self.softmax = torch.nn.Softmax(dim=0)
|
|
|
|
def forward(self, data, hidden_state, cell_state):
|
|
data = self.dense1(data.T)
|
|
data = self.tanh1(data)
|
|
data, (hidden_state, cell_state) = self.lstm(data.unsqueeze(1), (hidden_state, cell_state))
|
|
# data, (hidden_state, cell_state) = self.lstm(data.unsqueeze(1))
|
|
data = self.dense2(data)
|
|
data = self.softmax(data)
|
|
return data, (hidden_state, cell_state) |