This commit is contained in:
wangobango 2021-06-16 16:04:38 +02:00
parent 6fde30cb6e
commit f407a2cf88
2 changed files with 8 additions and 3 deletions

View File

@ -28,8 +28,8 @@ def words_to_vecs(list_of_words):
return [nlp(x).vector for x in list_of_words] return [nlp(x).vector for x in list_of_words]
def softXEnt(input, target): def softXEnt(input, target):
m = torch.nn.LogSoftmax() m = torch.nn.LogSoftmax(dim=1)
logprobs = m(input, dim=1) logprobs = m(input)
return -(target * logprobs).sum() / input.shape[0] return -(target * logprobs).sum() / input.shape[0]
def compute_class_vector(mark, classes): def compute_class_vector(mark, classes):
@ -64,6 +64,11 @@ model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.02) optimizer = torch.optim.AdamW(model.parameters(), lr=0.02)
loss_function = softXEnt loss_function = softXEnt
"""
TODO
1) metoda ewaluacyjna
2) przenieść na cude !!!!
"""
if mode == "train": if mode == "train":
for epoch in range(epochs): for epoch in range(epochs):

View File

@ -18,7 +18,7 @@ class Model(torch.nn.Module):
""" """
self.lstm = torch.nn.LSTM(150, 300, 2) self.lstm = torch.nn.LSTM(150, 300, 2)
self.dense2 = torch.nn.Linear(300, 7) self.dense2 = torch.nn.Linear(300, 7)
self.softmax = torch.nn.Softmax() self.softmax = torch.nn.Softmax(dim=1)
def forward(self, data, hidden_state, cell_state): def forward(self, data, hidden_state, cell_state):
data = self.dense1(data.T) data = self.dense1(data.T)