punctuation_restoration/util.py

29 lines
913 B
Python
Raw Normal View History

2021-06-16 12:51:01 +02:00
import torch
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
# in: 4 słowa kontekstu przed i 1 słowo kontekstu po
"""
5 in features
150 out features
"""
self.dense1 = torch.nn.Linear(6, 150, bias=False)
self.tanh1 = torch.nn.Tanh()
"""
150 in features
300 hidden values
2 num layers
"""
self.lstm = torch.nn.LSTM(150, 300, 2)
self.dense2 = torch.nn.Linear(300, 7)
2021-06-16 16:04:38 +02:00
self.softmax = torch.nn.Softmax(dim=1)
2021-06-16 12:51:01 +02:00
def forward(self, data, hidden_state, cell_state):
data = self.dense1(data.T)
data = self.tanh1(data)
data, (hidden_state, cell_state) = self.lstm(data.unsqueeze(1), (hidden_state, cell_state))
data = self.dense2(data)
data = self.softmax(data)
return data, (hidden_state, cell_state)