forked from kubapok/en-ner-conll-2003
fix
This commit is contained in:
parent
7e3768c22d
commit
77d269f43b
6
main.py
6
main.py
@ -172,7 +172,7 @@ if mode == "eval" or mode == "generate":
|
|||||||
model.load_state_dict(torch.load("model.torch"))
|
model.load_state_dict(torch.load("model.torch"))
|
||||||
for i in tqdm(range(0, len(train_tokens_ids))):
|
for i in tqdm(range(0, len(train_tokens_ids))):
|
||||||
last_idx = 0
|
last_idx = 0
|
||||||
for k in range(0, len(train_tokens_ids[i]) - seq_length, seq_length):
|
for k in range(0, len(train_tokens_ids[i]) - seq_length + 1, seq_length):
|
||||||
batch_tokens = train_tokens_ids[i][k: k + seq_length].unsqueeze(0)
|
batch_tokens = train_tokens_ids[i][k: k + seq_length].unsqueeze(0)
|
||||||
tags = train_labels[i][k: k + seq_length].unsqueeze(1)
|
tags = train_labels[i][k: k + seq_length].unsqueeze(1)
|
||||||
predicted_tags = model.decode(batch_tokens.to(device))
|
predicted_tags = model.decode(batch_tokens.to(device))
|
||||||
@ -194,14 +194,14 @@ if mode == "eval" or mode == "generate":
|
|||||||
print(f1_score(correct, predicted, average="weighted"))
|
print(f1_score(correct, predicted, average="weighted"))
|
||||||
|
|
||||||
predicted = list(map(lambda x: inv_labels_vocab[x], predicted))
|
predicted = list(map(lambda x: inv_labels_vocab[x], predicted))
|
||||||
slices = [len(x.split(" ")) for x in in_data]
|
slices = [len(x.split(" ")) for x in target]
|
||||||
with open(save_path, "w") as save:
|
with open(save_path, "w") as save:
|
||||||
writer = csv.writer(save, delimiter='\t', lineterminator='\n')
|
writer = csv.writer(save, delimiter='\t', lineterminator='\n')
|
||||||
accumulator = 0
|
accumulator = 0
|
||||||
output = []
|
output = []
|
||||||
for slice in slices:
|
for slice in slices:
|
||||||
output.append(predicted[accumulator: accumulator + slice])
|
output.append(predicted[accumulator: accumulator + slice])
|
||||||
accumulator += slice - 1
|
accumulator += slice
|
||||||
for line in process_output(output):
|
for line in process_output(output):
|
||||||
writer.writerow([line])
|
writer.writerow([line])
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user