This commit is contained in:
wangobango 2021-06-16 16:04:38 +02:00
parent 6fde30cb6e
commit f407a2cf88
2 changed files with 8 additions and 3 deletions

View File

@ -28,8 +28,8 @@ def words_to_vecs(list_of_words):
return [nlp(x).vector for x in list_of_words]
def softXEnt(input, target):
m = torch.nn.LogSoftmax()
logprobs = m(input, dim=1)
m = torch.nn.LogSoftmax(dim=1)
logprobs = m(input)
return -(target * logprobs).sum() / input.shape[0]
def compute_class_vector(mark, classes):
@ -64,6 +64,11 @@ model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.02)
loss_function = softXEnt
"""
TODO
1) metoda ewaluacyjna
2) przenieść na cude !!!!
"""
if mode == "train":
for epoch in range(epochs):

View File

@ -18,7 +18,7 @@ class Model(torch.nn.Module):
"""
self.lstm = torch.nn.LSTM(150, 300, 2)
self.dense2 = torch.nn.Linear(300, 7)
self.softmax = torch.nn.Softmax()
self.softmax = torch.nn.Softmax(dim=1)
def forward(self, data, hidden_state, cell_state):
data = self.dense1(data.T)