fix2
This commit is contained in:
parent
6fde30cb6e
commit
f407a2cf88
9
main.py
9
main.py
@ -28,8 +28,8 @@ def words_to_vecs(list_of_words):
|
||||
return [nlp(x).vector for x in list_of_words]
|
||||
|
||||
def softXEnt(input, target):
|
||||
m = torch.nn.LogSoftmax()
|
||||
logprobs = m(input, dim=1)
|
||||
m = torch.nn.LogSoftmax(dim=1)
|
||||
logprobs = m(input)
|
||||
return -(target * logprobs).sum() / input.shape[0]
|
||||
|
||||
def compute_class_vector(mark, classes):
|
||||
@ -64,6 +64,11 @@ model.train()
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=0.02)
|
||||
loss_function = softXEnt
|
||||
|
||||
"""
|
||||
TODO
|
||||
1) metoda ewaluacyjna
|
||||
2) przenieść na cude !!!!
|
||||
"""
|
||||
|
||||
if mode == "train":
|
||||
for epoch in range(epochs):
|
||||
|
Loading…
Reference in New Issue
Block a user