Fix sequence problem

This commit is contained in:
AWieczarek 2024-05-27 14:16:53 +02:00
parent 2a2b5b5bc4
commit ae653a40e3

View File

@ -985,6 +985,20 @@
}
],
"source": [
"def validate_bio_sequence(labels):\n",
" corrected_labels = []\n",
" previous_label = 'O'\n",
" for label in labels:\n",
" if label.startswith('I-'):\n",
" if previous_label == 'O' or previous_label[2:] != label[2:]:\n",
" corrected_labels.append('B-' + label[2:])\n",
" else:\n",
" corrected_labels.append(label)\n",
" else:\n",
" corrected_labels.append(label)\n",
" previous_label = corrected_labels[-1]\n",
" return corrected_labels\n",
"\n",
"def save_predictions(tokens_ids, model, output_path, label_mapping):\n",
" predictions = []\n",
" for i in tqdm(range(len(tokens_ids))):\n",
@ -992,6 +1006,7 @@
" Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n",
" Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)\n",
" bio_labels = [label_mapping[label] for label in Y_batch_pred.numpy()[1:-1]]\n",
" bio_labels = validate_bio_sequence(bio_labels)\n",
" predictions.append(\" \".join(bio_labels))\n",
"\n",
" with open(output_path, 'w') as f:\n",