This commit is contained in:
wangobango 2021-06-16 13:03:28 +02:00
parent f3404fc347
commit 6fde30cb6e

47
main.py
View File

@ -28,8 +28,8 @@ def words_to_vecs(list_of_words):
return [nlp(x).vector for x in list_of_words] return [nlp(x).vector for x in list_of_words]
def softXEnt(input, target): def softXEnt(input, target):
m = torch.nn.LogSoftmax(dim = 1) m = torch.nn.LogSoftmax()
logprobs = m(input) logprobs = m(input, dim=1)
return -(target * logprobs).sum() / input.shape[0] return -(target * logprobs).sum() / input.shape[0]
def compute_class_vector(mark, classes): def compute_class_vector(mark, classes):
@ -43,6 +43,9 @@ def compute_class_vector(mark, classes):
data_dir = "./fa/poleval_final_dataset/train" data_dir = "./fa/poleval_final_dataset/train"
data_nopunc_dir = "./fa/poleval_final_dataset1/train" data_nopunc_dir = "./fa/poleval_final_dataset1/train"
mode = "train"
# mode = "evaluate"
# mode = "generate"
data_paths = os.listdir(data_dir) data_paths = os.listdir(data_dir)
data_paths = [data_dir + "/" + x for x in data_paths] data_paths = [data_dir + "/" + x for x in data_paths]
@ -61,7 +64,9 @@ model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.02) optimizer = torch.optim.AdamW(model.parameters(), lr=0.02)
loss_function = softXEnt loss_function = softXEnt
for epoch in range(epochs):
if mode == "train":
for epoch in range(epochs):
for path in tqdm.tqdm(data_paths): for path in tqdm.tqdm(data_paths):
with open(path, "r") as file: with open(path, "r") as file:
list = file.readlines()[:-1] list = file.readlines()[:-1]
@ -70,13 +75,14 @@ for epoch in range(epochs):
x = list[i: i + context_size] x = list[i: i + context_size]
x = [line2word(y) for y in x] x = [line2word(y) for y in x]
x_1 = [line2word(list[i + context_size + 1])] x_1 = [line2word(list[i + context_size + 1])]
mark = find_interpunction(x[-1], classes)
# mark = words_to_vecs(mark)
x = x + x_1 x = x + x_1
x = words_to_vecs(x) x = words_to_vecs(x)
mark = find_interpunction(x, classes)
mark = words_to_vecs(mark)
x = torch.tensor(x, dtype=torch.float) x = torch.tensor(x, dtype=torch.float)
mark = torch.tensor(mark, dtype=torch.float) # mark = torch.tensor(mark, dtype=torch.float)
output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state) output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state)
output = output.squeeze(1) output = output.squeeze(1)
@ -86,11 +92,34 @@ for epoch in range(epochs):
hidden_state = hidden_state.detach() hidden_state = hidden_state.detach()
cell_state = cell_state.detach() cell_state = cell_state.detach()
"""
vector -> (96,), np nadarray
"""
print("Epoch: {}".format(epoch)) print("Epoch: {}".format(epoch))
torch.save( torch.save(
model.state_dict(), model.state_dict(),
os.path.join("./", f"{output_prefix}-{epoch}.pt"), os.path.join("./", f"{output_prefix}-{epoch}.pt"),
) )
elif mode == "evaluate":
for pathA, pathB in zip(data_nopunc_dir, data_dir):
listA = []
listB = []
with open(pathA, "r") as file:
listA = file.readlines()[:-1]
with open(pathA, "r") as file:
listb = file.readlines()[:-1]
for i in range(0, len(list) - context_size - 1):
model.zero_grad()
x = listA[i: i + context_size]
x = [line2word(y) for y in x]
x_1 = [line2word(listA[i + context_size + 1])]
x = x + x_1
x = words_to_vecs(x)
y = listB[i + context_size]
y = [line2word(x) for x in y]
mark_y = find_interpunction(y)
x = torch.tensor(x, dtype=torch.float)
output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state)
if classes[np.argmax(output)] == mark_y:
print('dupa')