From 77d269f43b4deb200bb9ca10a53677eb8fcbe55e Mon Sep 17 00:00:00 2001 From: wangobango Date: Mon, 21 Jun 2021 21:12:35 +0200 Subject: [PATCH] fix --- main.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 76da06d..56db375 100644 --- a/main.py +++ b/main.py @@ -172,7 +172,7 @@ if mode == "eval" or mode == "generate": model.load_state_dict(torch.load("model.torch")) for i in tqdm(range(0, len(train_tokens_ids))): last_idx = 0 - for k in range(0, len(train_tokens_ids[i]) - seq_length, seq_length): + for k in range(0, len(train_tokens_ids[i]) - seq_length + 1, seq_length): batch_tokens = train_tokens_ids[i][k: k + seq_length].unsqueeze(0) tags = train_labels[i][k: k + seq_length].unsqueeze(1) predicted_tags = model.decode(batch_tokens.to(device)) @@ -194,14 +194,14 @@ if mode == "eval" or mode == "generate": print(f1_score(correct, predicted, average="weighted")) predicted = list(map(lambda x: inv_labels_vocab[x], predicted)) - slices = [len(x.split(" ")) for x in in_data] + slices = [len(x.split(" ")) for x in target] with open(save_path, "w") as save: writer = csv.writer(save, delimiter='\t', lineterminator='\n') accumulator = 0 output = [] for slice in slices: output.append(predicted[accumulator: accumulator + slice]) - accumulator += slice - 1 + accumulator += slice for line in process_output(output): writer.writerow([line])