fix2
This commit is contained in:
parent
6fde30cb6e
commit
f407a2cf88
9
main.py
9
main.py
@ -28,8 +28,8 @@ def words_to_vecs(list_of_words):
|
|||||||
return [nlp(x).vector for x in list_of_words]
|
return [nlp(x).vector for x in list_of_words]
|
||||||
|
|
||||||
def softXEnt(input, target):
|
def softXEnt(input, target):
|
||||||
m = torch.nn.LogSoftmax()
|
m = torch.nn.LogSoftmax(dim=1)
|
||||||
logprobs = m(input, dim=1)
|
logprobs = m(input)
|
||||||
return -(target * logprobs).sum() / input.shape[0]
|
return -(target * logprobs).sum() / input.shape[0]
|
||||||
|
|
||||||
def compute_class_vector(mark, classes):
|
def compute_class_vector(mark, classes):
|
||||||
@ -64,6 +64,11 @@ model.train()
|
|||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=0.02)
|
optimizer = torch.optim.AdamW(model.parameters(), lr=0.02)
|
||||||
loss_function = softXEnt
|
loss_function = softXEnt
|
||||||
|
|
||||||
|
"""
|
||||||
|
TODO
|
||||||
|
1) metoda ewaluacyjna
|
||||||
|
2) przenieść na cude !!!!
|
||||||
|
"""
|
||||||
|
|
||||||
if mode == "train":
|
if mode == "train":
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
|
2
util.py
2
util.py
@ -18,7 +18,7 @@ class Model(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
self.lstm = torch.nn.LSTM(150, 300, 2)
|
self.lstm = torch.nn.LSTM(150, 300, 2)
|
||||||
self.dense2 = torch.nn.Linear(300, 7)
|
self.dense2 = torch.nn.Linear(300, 7)
|
||||||
self.softmax = torch.nn.Softmax()
|
self.softmax = torch.nn.Softmax(dim=1)
|
||||||
|
|
||||||
def forward(self, data, hidden_state, cell_state):
|
def forward(self, data, hidden_state, cell_state):
|
||||||
data = self.dense1(data.T)
|
data = self.dense1(data.T)
|
||||||
|
Loading…
Reference in New Issue
Block a user