Fix sequence problem
This commit is contained in:
parent
2a2b5b5bc4
commit
ae653a40e3
15
RNN_6.ipynb
15
RNN_6.ipynb
@ -985,6 +985,20 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
|
"def validate_bio_sequence(labels):\n",
|
||||||
|
" corrected_labels = []\n",
|
||||||
|
" previous_label = 'O'\n",
|
||||||
|
" for label in labels:\n",
|
||||||
|
" if label.startswith('I-'):\n",
|
||||||
|
" if previous_label == 'O' or previous_label[2:] != label[2:]:\n",
|
||||||
|
" corrected_labels.append('B-' + label[2:])\n",
|
||||||
|
" else:\n",
|
||||||
|
" corrected_labels.append(label)\n",
|
||||||
|
" else:\n",
|
||||||
|
" corrected_labels.append(label)\n",
|
||||||
|
" previous_label = corrected_labels[-1]\n",
|
||||||
|
" return corrected_labels\n",
|
||||||
|
"\n",
|
||||||
"def save_predictions(tokens_ids, model, output_path, label_mapping):\n",
|
"def save_predictions(tokens_ids, model, output_path, label_mapping):\n",
|
||||||
" predictions = []\n",
|
" predictions = []\n",
|
||||||
" for i in tqdm(range(len(tokens_ids))):\n",
|
" for i in tqdm(range(len(tokens_ids))):\n",
|
||||||
@ -992,6 +1006,7 @@
|
|||||||
" Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n",
|
" Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n",
|
||||||
" Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)\n",
|
" Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)\n",
|
||||||
" bio_labels = [label_mapping[label] for label in Y_batch_pred.numpy()[1:-1]]\n",
|
" bio_labels = [label_mapping[label] for label in Y_batch_pred.numpy()[1:-1]]\n",
|
||||||
|
" bio_labels = validate_bio_sequence(bio_labels)\n",
|
||||||
" predictions.append(\" \".join(bio_labels))\n",
|
" predictions.append(\" \".join(bio_labels))\n",
|
||||||
"\n",
|
"\n",
|
||||||
" with open(output_path, 'w') as f:\n",
|
" with open(output_path, 'w') as f:\n",
|
||||||
|
Loading…
Reference in New Issue
Block a user