diff --git a/RNN_6.ipynb b/RNN_6.ipynb index a5d55f4..0a53b8d 100644 --- a/RNN_6.ipynb +++ b/RNN_6.ipynb @@ -985,6 +985,20 @@ } ], "source": [ + "def validate_bio_sequence(labels):\n", + " corrected_labels = []\n", + " previous_label = 'O'\n", + " for label in labels:\n", + " if label.startswith('I-'):\n", + " if previous_label == 'O' or previous_label[2:] != label[2:]:\n", + " corrected_labels.append('B-' + label[2:])\n", + " else:\n", + " corrected_labels.append(label)\n", + " else:\n", + " corrected_labels.append(label)\n", + " previous_label = corrected_labels[-1]\n", + " return corrected_labels\n", + "\n", "def save_predictions(tokens_ids, model, output_path, label_mapping):\n", " predictions = []\n", " for i in tqdm(range(len(tokens_ids))):\n", @@ -992,6 +1006,7 @@ " Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n", " Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)\n", " bio_labels = [label_mapping[label] for label in Y_batch_pred.numpy()[1:-1]]\n", + " bio_labels = validate_bio_sequence(bio_labels)\n", " predictions.append(\" \".join(bio_labels))\n", "\n", " with open(output_path, 'w') as f:\n",