punctuation_restoration/util.py
wangobango 808efa0aad 123
2021-06-16 22:23:55 +02:00

30 lines
987 B
Python

import torch
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
# in: 4 słowa kontekstu przed i 1 słowo kontekstu po
"""
5 in features
150 out features
"""
self.dense1 = torch.nn.Linear(6, 150, bias=False)
self.tanh1 = torch.nn.Tanh()
"""
150 in features
300 hidden values
2 num layers
"""
self.lstm = torch.nn.LSTM(150, 300, 2)
self.dense2 = torch.nn.Linear(300, 8)
self.softmax = torch.nn.Softmax(dim=0)
def forward(self, data, hidden_state, cell_state):
data = self.dense1(data.T)
data = self.tanh1(data)
data, (hidden_state, cell_state) = self.lstm(data.unsqueeze(1), (hidden_state, cell_state))
# data, (hidden_state, cell_state) = self.lstm(data.unsqueeze(1))
data = self.dense2(data)
data = self.softmax(data)
return data, (hidden_state, cell_state)