forked from kubapok/en-ner-conll-2003
fix
This commit is contained in:
parent
7e3768c22d
commit
77d269f43b
6
main.py
6
main.py
@ -172,7 +172,7 @@ if mode == "eval" or mode == "generate":
|
||||
model.load_state_dict(torch.load("model.torch"))
|
||||
for i in tqdm(range(0, len(train_tokens_ids))):
|
||||
last_idx = 0
|
||||
for k in range(0, len(train_tokens_ids[i]) - seq_length, seq_length):
|
||||
for k in range(0, len(train_tokens_ids[i]) - seq_length + 1, seq_length):
|
||||
batch_tokens = train_tokens_ids[i][k: k + seq_length].unsqueeze(0)
|
||||
tags = train_labels[i][k: k + seq_length].unsqueeze(1)
|
||||
predicted_tags = model.decode(batch_tokens.to(device))
|
||||
@ -194,14 +194,14 @@ if mode == "eval" or mode == "generate":
|
||||
print(f1_score(correct, predicted, average="weighted"))
|
||||
|
||||
predicted = list(map(lambda x: inv_labels_vocab[x], predicted))
|
||||
slices = [len(x.split(" ")) for x in in_data]
|
||||
slices = [len(x.split(" ")) for x in target]
|
||||
with open(save_path, "w") as save:
|
||||
writer = csv.writer(save, delimiter='\t', lineterminator='\n')
|
||||
accumulator = 0
|
||||
output = []
|
||||
for slice in slices:
|
||||
output.append(predicted[accumulator: accumulator + slice])
|
||||
accumulator += slice - 1
|
||||
accumulator += slice
|
||||
for line in process_output(output):
|
||||
writer.writerow([line])
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user