commit bb848936c7c85bc9581496146364289253f1abd7 Author: s464909 Date: Tue May 28 16:48:34 2024 +0200 Upload files to "/" diff --git a/RNN.ipynb b/RNN.ipynb new file mode 100644 index 0000000..aea2f4c --- /dev/null +++ b/RNN.ipynb @@ -0,0 +1,418 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "c80ac05e-c22e-4f7f-a48d-ca85173f0f86", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-05-28 15:20:02.976189: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", + "2024-05-28 15:20:04.436596: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" + ] + } + ], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "from tensorflow.keras.preprocessing.sequence import pad_sequences\n", + "from tensorflow.keras.utils import to_categorical\n", + "from sklearn.metrics import classification_report" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "84d7598c-6184-4583-8d86-2ecedf9b3117", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "05b3715f-de04-40fe-ae59-3918a665f78b", + "metadata": {}, + "outputs": [], + "source": [ + "def load_data(file_path):\n", + " with open(file_path, 'r', encoding='utf-8') as file:\n", + " lines = file.readlines()\n", + " sentences = [line.strip() for line in lines]\n", + " return sentences" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "b109e4c3-e781-44f3-bf4b-0b2ce7432195", + "metadata": {}, + "outputs": [], + "source": [ + "train_data = pd.read_csv('en-ner-conll-2003/train/train.tsv', sep='\\t', header=None, names=['label', 'sentence'], encoding='utf-8')\n", + "dev_sentences = load_data('en-ner-conll-2003/dev-0/in.tsv')\n", + "dev_labels = load_data('en-ner-conll-2003/dev-0/expected.tsv')\n", + "test_sentences = load_data('en-ner-conll-2003/test-A/in.tsv')" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "88fe9aa9-f3f1-4be3-aa93-83f45812c9bb", + "metadata": {}, + "outputs": [], + "source": [ + "def preprocess_data(sentences, labels=None):\n", + " tokenized_sentences = [sentence.split() for sentence in sentences]\n", + " if labels is not None:\n", + " tokenized_labels = [label.split() for label in labels]\n", + " return tokenized_sentences, tokenized_labels\n", + " return tokenized_sentences" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "466f1b4b-bf34-4cca-a2ee-3adbf7da9337", + "metadata": {}, + "outputs": [], + "source": [ + "train_sentences, train_labels = preprocess_data(train_data['sentence'].values, train_data['label'].values)\n", + "dev_sentences, dev_labels = preprocess_data(dev_sentences, dev_labels)\n", + "test_sentences = preprocess_data(test_sentences)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "3cb4de1c-d5e0-46c5-b80d-df69c0adbaa4", + "metadata": {}, + "outputs": [], + "source": [ + "special_tokens = ['', '', '', '']\n", + "word2idx = {w: i + len(special_tokens) for i, w in enumerate(set(word for sentence in train_sentences for word in sentence))}\n", + "for i, token in enumerate(special_tokens):\n", + " word2idx[token] = i\n", + "\n", + "idx2word = {i: w for w, i in word2idx.items()}\n", + "\n", + "label2idx = {\n", + " 'O': 0,\n", + " 'B-PER': 1, 'I-PER': 2,\n", + " 'B-ORG': 3, 'I-ORG': 4,\n", + " 'B-LOC': 5, 'I-LOC': 6,\n", + " 'B-MISC': 7, 'I-MISC': 8\n", + "}\n", + "idx2label = {i: l for l, i in label2idx.items()}" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "7c0b52af-7eeb-447b-a58e-b9f48da8f7c3", + "metadata": {}, + "outputs": [], + "source": [ + "def encode_data(sentences, labels=None):\n", + " encoded_sentences = [[word2idx.get(word, word2idx['']) for word in sentence] for sentence in sentences]\n", + " if labels is not None:\n", + " encoded_labels = [[label2idx[label] for label in label_list] for label_list in labels]\n", + " return encoded_sentences, encoded_labels\n", + " return encoded_sentences" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "37a9de80-bf7b-455c-9794-2d559b46734d", + "metadata": {}, + "outputs": [], + "source": [ + "X_train, y_train = encode_data(train_sentences, train_labels)\n", + "X_dev, y_dev = encode_data(dev_sentences, dev_labels)\n", + "X_test = encode_data(test_sentences)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "c06fe4fa-18f2-467c-9b5b-9ca36128fcac", + "metadata": {}, + "outputs": [], + "source": [ + "max_len = 1000 \n", + "X_train = pad_sequences(X_train, padding='post', maxlen=max_len)\n", + "y_train = pad_sequences(y_train, padding='post', maxlen=max_len)\n", + "\n", + "X_dev = pad_sequences(X_dev, padding='post', maxlen=max_len)\n", + "y_dev = pad_sequences(y_dev, padding='post', maxlen=max_len)\n", + "\n", + "X_test = pad_sequences(X_test, padding='post', maxlen=max_len)\n", + "\n", + "y_train = [to_categorical(i, num_classes=len(label2idx)) for i in y_train]\n", + "y_dev = [to_categorical(i, num_classes=len(label2idx)) for i in y_dev]" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "02bf884e-d80e-43e7-b60c-4d2a611d280a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/25\n", + "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m160s\u001b[0m 2s/step - accuracy: 0.8954 - loss: 0.5933 - val_accuracy: 0.9606 - val_loss: 0.1730\n", + "Epoch 2/25\n", + "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m135s\u001b[0m 2s/step - accuracy: 0.9642 - loss: 0.1531 - val_accuracy: 0.9606 - val_loss: 0.1563\n", + "Epoch 3/25\n", + "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m104s\u001b[0m 2s/step - accuracy: 0.9623 - loss: 0.1432 - val_accuracy: 0.9607 - val_loss: 0.1358\n", + "Epoch 4/25\n", + "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m103s\u001b[0m 2s/step - accuracy: 0.9630 - loss: 0.1177 - val_accuracy: 0.9648 - val_loss: 0.1104\n", + "Epoch 5/25\n", + "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m104s\u001b[0m 2s/step - accuracy: 0.9707 - loss: 0.0885 - val_accuracy: 0.9727 - val_loss: 0.0901\n", + "Epoch 6/25\n", + "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m141s\u001b[0m 2s/step - accuracy: 0.9790 - loss: 0.0684 - val_accuracy: 0.9779 - val_loss: 0.0751\n", + "Epoch 7/25\n", + "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m141s\u001b[0m 2s/step - accuracy: 0.9871 - loss: 0.0510 - val_accuracy: 0.9831 - val_loss: 0.0625\n", + "Epoch 8/25\n", + "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m104s\u001b[0m 2s/step - accuracy: 0.9919 - loss: 0.0377 - val_accuracy: 0.9857 - val_loss: 0.0540\n", + "Epoch 9/25\n", + "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m104s\u001b[0m 2s/step - accuracy: 0.9947 - loss: 0.0265 - val_accuracy: 0.9874 - val_loss: 0.0472\n", + "Epoch 10/25\n", + "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m104s\u001b[0m 2s/step - accuracy: 0.9963 - loss: 0.0209 - val_accuracy: 0.9885 - val_loss: 0.0431\n", + "Epoch 11/25\n", + "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m104s\u001b[0m 2s/step - accuracy: 0.9976 - loss: 0.0148 - val_accuracy: 0.9891 - val_loss: 0.0401\n", + "Epoch 12/25\n", + "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m104s\u001b[0m 2s/step - accuracy: 0.9983 - loss: 0.0121 - val_accuracy: 0.9895 - val_loss: 0.0386\n", + "Epoch 13/25\n", + "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m104s\u001b[0m 2s/step - accuracy: 0.9986 - loss: 0.0093 - val_accuracy: 0.9897 - val_loss: 0.0376\n", + "Epoch 14/25\n", + "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m142s\u001b[0m 2s/step - accuracy: 0.9989 - loss: 0.0077 - val_accuracy: 0.9896 - val_loss: 0.0385\n", + "Epoch 15/25\n", + "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m108s\u001b[0m 2s/step - accuracy: 0.9991 - loss: 0.0067 - val_accuracy: 0.9896 - val_loss: 0.0385\n", + "Epoch 16/25\n", + "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m142s\u001b[0m 2s/step - accuracy: 0.9992 - loss: 0.0057 - val_accuracy: 0.9899 - val_loss: 0.0371\n", + "Epoch 17/25\n", + "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m142s\u001b[0m 2s/step - accuracy: 0.9995 - loss: 0.0045 - val_accuracy: 0.9897 - val_loss: 0.0392\n", + "Epoch 18/25\n", + "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m113s\u001b[0m 2s/step - accuracy: 0.9995 - loss: 0.0040 - val_accuracy: 0.9899 - val_loss: 0.0385\n", + "Epoch 19/25\n", + "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m113s\u001b[0m 2s/step - accuracy: 0.9996 - loss: 0.0035 - val_accuracy: 0.9896 - val_loss: 0.0404\n", + "Epoch 20/25\n", + "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m114s\u001b[0m 2s/step - accuracy: 0.9997 - loss: 0.0030 - val_accuracy: 0.9898 - val_loss: 0.0391\n", + "Epoch 21/25\n", + "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m141s\u001b[0m 2s/step - accuracy: 0.9997 - loss: 0.0028 - val_accuracy: 0.9898 - val_loss: 0.0406\n", + "Epoch 22/25\n", + "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m142s\u001b[0m 2s/step - accuracy: 0.9998 - loss: 0.0022 - val_accuracy: 0.9896 - val_loss: 0.0421\n", + "Epoch 23/25\n", + "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m142s\u001b[0m 2s/step - accuracy: 0.9998 - loss: 0.0021 - val_accuracy: 0.9897 - val_loss: 0.0417\n", + "Epoch 24/25\n", + "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m113s\u001b[0m 2s/step - accuracy: 0.9998 - loss: 0.0019 - val_accuracy: 0.9898 - val_loss: 0.0415\n", + "Epoch 25/25\n", + "\u001b[1m68/68\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m113s\u001b[0m 2s/step - accuracy: 0.9998 - loss: 0.0019 - val_accuracy: 0.9897 - val_loss: 0.0434\n", + "\u001b[1m7/7\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 1s/step\n" + ] + } + ], + "source": [ + "model = tf.keras.models.Sequential([\n", + " tf.keras.layers.Embedding(input_dim=len(word2idx), output_dim=64),\n", + " tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(units=32, return_sequences=True)),\n", + " tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(len(label2idx), activation='softmax'))\n", + "])\n", + "\n", + "model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])\n", + "\n", + "history = model.fit(X_train, np.array(y_train), validation_data=(X_dev, np.array(y_dev)), epochs=25, batch_size=14)\n", + "\n", + "y_pred = model.predict(X_dev)\n", + "y_pred = np.argmax(y_pred, axis=-1)\n", + "y_true = np.argmax(np.array(y_dev), axis=-1)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "56af96ba-ee8e-4476-a17d-d9a66d03d136", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/michal/.local/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n", + "/home/michal/.local/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n", + "/home/michal/.local/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in labels with no true nor predicted samples. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n", + "/home/michal/.local/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n", + "/home/michal/.local/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 due to no true samples. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n", + "/home/michal/.local/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 due to no true nor predicted samples. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n", + "/home/michal/.local/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n", + "/home/michal/.local/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n", + "/home/michal/.local/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in labels with no true nor predicted samples. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " precision recall f1-score support\n", + "\n", + " B-PER 0.00 0.00 0.00 0\n", + " I-PER 0.00 0.00 0.00 0\n", + " B-ORG 0.00 0.00 0.00 0\n", + " I-ORG 0.00 0.00 0.00 0\n", + " B-LOC 0.00 0.00 0.00 0\n", + " I-LOC 0.00 0.00 0.00 0\n", + " B-MISC 0.00 0.00 0.00 0\n", + " I-MISC 0.00 0.00 0.00 0\n", + "\n", + " micro avg 0.00 0.00 0.00 0\n", + " macro avg 0.00 0.00 0.00 0\n", + "weighted avg 0.00 0.00 0.00 0\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/michal/.local/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n", + "/home/michal/.local/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n", + "/home/michal/.local/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in labels with no true nor predicted samples. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n" + ] + } + ], + "source": [ + "y_pred_tags = [[idx2label[i] for i in row] for row in y_pred]\n", + "y_true_tags = [[idx2label[i] for i in row] for row in y_true]\n", + "\n", + "print(classification_report(\n", + " [item for sublist in y_true_tags for item in sublist],\n", + " [item for sublist in y_pred_tags for item in sublist],\n", + " labels=list(label2idx.values())[1:], \n", + " target_names=[idx2label[i] for i in list(label2idx.values())[1:]]\n", + "))" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "be795033-3bca-4c89-b3f3-c6104dc305fd", + "metadata": {}, + "outputs": [], + "source": [ + "def correct_iob_labels(predictions):\n", + " corrected = []\n", + " for pred in predictions:\n", + " corrected_sentence = []\n", + " prev_label = 'O'\n", + " for label in pred:\n", + " if label.startswith('I-') and (prev_label == 'O' or prev_label[2:] != label[2:]):\n", + " corrected_sentence.append('B-' + label[2:])\n", + " else:\n", + " corrected_sentence.append(label)\n", + " prev_label = corrected_sentence[-1]\n", + " corrected.append(corrected_sentence)\n", + " return corrected" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "c75a75ad-3bef-470d-9882-13a97e5495e8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1m8/8\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 95ms/step\n" + ] + } + ], + "source": [ + "y_test_pred = model.predict(X_test)\n", + "y_test_pred = np.argmax(y_test_pred, axis=-1)\n", + "y_test_pred_tags = [[idx2label[i] for i in row] for row in y_test_pred]\n", + "\n", + "y_pred_tags_corrected = correct_iob_labels(y_pred_tags)\n", + "y_test_pred_tags_corrected = correct_iob_labels(y_test_pred_tags)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "0f48d0f4-a6f5-424b-a8cc-caed7b14e113", + "metadata": {}, + "outputs": [], + "source": [ + "dev_predictions = [' '.join(tags) for tags in y_pred_tags_corrected]\n", + "with open('en-ner-conll-2003/dev-0/out.tsv', 'w', encoding='utf-8') as f:\n", + " for prediction in dev_predictions:\n", + " f.write(\"%s\\n\" % prediction)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "8ac16ee5-cdf1-4b6a-9c94-003d598921e0", + "metadata": {}, + "outputs": [], + "source": [ + "test_predictions = [' '.join(tags) for tags in y_test_pred_tags_corrected]\n", + "with open('en-ner-conll-2003/test-A/out.tsv', 'w', encoding='utf-8') as f:\n", + " for prediction in test_predictions:\n", + " f.write(\"%s\\n\" % prediction)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "84079718-b955-4a64-8e0b-798b5706faca", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}