fix
This commit is contained in:
parent
f3404fc347
commit
6fde30cb6e
95
main.py
95
main.py
@ -28,8 +28,8 @@ def words_to_vecs(list_of_words):
|
|||||||
return [nlp(x).vector for x in list_of_words]
|
return [nlp(x).vector for x in list_of_words]
|
||||||
|
|
||||||
def softXEnt(input, target):
|
def softXEnt(input, target):
|
||||||
m = torch.nn.LogSoftmax(dim = 1)
|
m = torch.nn.LogSoftmax()
|
||||||
logprobs = m(input)
|
logprobs = m(input, dim=1)
|
||||||
return -(target * logprobs).sum() / input.shape[0]
|
return -(target * logprobs).sum() / input.shape[0]
|
||||||
|
|
||||||
def compute_class_vector(mark, classes):
|
def compute_class_vector(mark, classes):
|
||||||
@ -43,6 +43,9 @@ def compute_class_vector(mark, classes):
|
|||||||
|
|
||||||
data_dir = "./fa/poleval_final_dataset/train"
|
data_dir = "./fa/poleval_final_dataset/train"
|
||||||
data_nopunc_dir = "./fa/poleval_final_dataset1/train"
|
data_nopunc_dir = "./fa/poleval_final_dataset1/train"
|
||||||
|
mode = "train"
|
||||||
|
# mode = "evaluate"
|
||||||
|
# mode = "generate"
|
||||||
|
|
||||||
data_paths = os.listdir(data_dir)
|
data_paths = os.listdir(data_dir)
|
||||||
data_paths = [data_dir + "/" + x for x in data_paths]
|
data_paths = [data_dir + "/" + x for x in data_paths]
|
||||||
@ -61,36 +64,62 @@ model.train()
|
|||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=0.02)
|
optimizer = torch.optim.AdamW(model.parameters(), lr=0.02)
|
||||||
loss_function = softXEnt
|
loss_function = softXEnt
|
||||||
|
|
||||||
for epoch in range(epochs):
|
|
||||||
for path in tqdm.tqdm(data_paths):
|
|
||||||
with open(path, "r") as file:
|
|
||||||
list = file.readlines()[:-1]
|
|
||||||
for i in range(0, len(list) - context_size - 1):
|
|
||||||
model.zero_grad()
|
|
||||||
x = list[i: i + context_size]
|
|
||||||
x = [line2word(y) for y in x]
|
|
||||||
x_1 = [line2word(list[i + context_size + 1])]
|
|
||||||
x = x + x_1
|
|
||||||
x = words_to_vecs(x)
|
|
||||||
mark = find_interpunction(x, classes)
|
|
||||||
mark = words_to_vecs(mark)
|
|
||||||
|
|
||||||
x = torch.tensor(x, dtype=torch.float)
|
|
||||||
mark = torch.tensor(mark, dtype=torch.float)
|
|
||||||
|
|
||||||
output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state)
|
if mode == "train":
|
||||||
output = output.squeeze(1)
|
for epoch in range(epochs):
|
||||||
loss = loss_function(output, compute_class_vector(mark, classes))
|
for path in tqdm.tqdm(data_paths):
|
||||||
loss.backward()
|
with open(path, "r") as file:
|
||||||
optimizer.step()
|
list = file.readlines()[:-1]
|
||||||
hidden_state = hidden_state.detach()
|
for i in range(0, len(list) - context_size - 1):
|
||||||
cell_state = cell_state.detach()
|
model.zero_grad()
|
||||||
|
x = list[i: i + context_size]
|
||||||
|
x = [line2word(y) for y in x]
|
||||||
|
x_1 = [line2word(list[i + context_size + 1])]
|
||||||
|
mark = find_interpunction(x[-1], classes)
|
||||||
|
# mark = words_to_vecs(mark)
|
||||||
|
|
||||||
"""
|
x = x + x_1
|
||||||
vector -> (96,), np nadarray
|
x = words_to_vecs(x)
|
||||||
"""
|
|
||||||
print("Epoch: {}".format(epoch))
|
x = torch.tensor(x, dtype=torch.float)
|
||||||
torch.save(
|
# mark = torch.tensor(mark, dtype=torch.float)
|
||||||
model.state_dict(),
|
|
||||||
os.path.join("./", f"{output_prefix}-{epoch}.pt"),
|
output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state)
|
||||||
)
|
output = output.squeeze(1)
|
||||||
|
loss = loss_function(output, compute_class_vector(mark, classes))
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
hidden_state = hidden_state.detach()
|
||||||
|
cell_state = cell_state.detach()
|
||||||
|
|
||||||
|
print("Epoch: {}".format(epoch))
|
||||||
|
torch.save(
|
||||||
|
model.state_dict(),
|
||||||
|
os.path.join("./", f"{output_prefix}-{epoch}.pt"),
|
||||||
|
)
|
||||||
|
|
||||||
|
elif mode == "evaluate":
|
||||||
|
for pathA, pathB in zip(data_nopunc_dir, data_dir):
|
||||||
|
listA = []
|
||||||
|
listB = []
|
||||||
|
with open(pathA, "r") as file:
|
||||||
|
listA = file.readlines()[:-1]
|
||||||
|
with open(pathA, "r") as file:
|
||||||
|
listb = file.readlines()[:-1]
|
||||||
|
for i in range(0, len(list) - context_size - 1):
|
||||||
|
model.zero_grad()
|
||||||
|
x = listA[i: i + context_size]
|
||||||
|
x = [line2word(y) for y in x]
|
||||||
|
x_1 = [line2word(listA[i + context_size + 1])]
|
||||||
|
|
||||||
|
x = x + x_1
|
||||||
|
x = words_to_vecs(x)
|
||||||
|
|
||||||
|
y = listB[i + context_size]
|
||||||
|
y = [line2word(x) for x in y]
|
||||||
|
mark_y = find_interpunction(y)
|
||||||
|
x = torch.tensor(x, dtype=torch.float)
|
||||||
|
|
||||||
|
output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state)
|
||||||
|
if classes[np.argmax(output)] == mark_y:
|
||||||
|
print('dupa')
|
Loading…
Reference in New Issue
Block a user