This commit is contained in:
wangobango 2021-06-21 21:12:35 +02:00
parent 7e3768c22d
commit 77d269f43b

View File

@ -172,7 +172,7 @@ if mode == "eval" or mode == "generate":
model.load_state_dict(torch.load("model.torch")) model.load_state_dict(torch.load("model.torch"))
for i in tqdm(range(0, len(train_tokens_ids))): for i in tqdm(range(0, len(train_tokens_ids))):
last_idx = 0 last_idx = 0
for k in range(0, len(train_tokens_ids[i]) - seq_length, seq_length): for k in range(0, len(train_tokens_ids[i]) - seq_length + 1, seq_length):
batch_tokens = train_tokens_ids[i][k: k + seq_length].unsqueeze(0) batch_tokens = train_tokens_ids[i][k: k + seq_length].unsqueeze(0)
tags = train_labels[i][k: k + seq_length].unsqueeze(1) tags = train_labels[i][k: k + seq_length].unsqueeze(1)
predicted_tags = model.decode(batch_tokens.to(device)) predicted_tags = model.decode(batch_tokens.to(device))
@ -194,14 +194,14 @@ if mode == "eval" or mode == "generate":
print(f1_score(correct, predicted, average="weighted")) print(f1_score(correct, predicted, average="weighted"))
predicted = list(map(lambda x: inv_labels_vocab[x], predicted)) predicted = list(map(lambda x: inv_labels_vocab[x], predicted))
slices = [len(x.split(" ")) for x in in_data] slices = [len(x.split(" ")) for x in target]
with open(save_path, "w") as save: with open(save_path, "w") as save:
writer = csv.writer(save, delimiter='\t', lineterminator='\n') writer = csv.writer(save, delimiter='\t', lineterminator='\n')
accumulator = 0 accumulator = 0
output = [] output = []
for slice in slices: for slice in slices:
output.append(predicted[accumulator: accumulator + slice]) output.append(predicted[accumulator: accumulator + slice])
accumulator += slice - 1 accumulator += slice
for line in process_output(output): for line in process_output(output):
writer.writerow([line]) writer.writerow([line])