Upload files to "/"
This commit is contained in:
parent
58a78647d2
commit
96dfb2fa54
406
RNN.ipynb
Normal file
406
RNN.ipynb
Normal file
@ -0,0 +1,406 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "c80ac05e-c22e-4f7f-a48d-ca85173f0f86",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"import numpy as np\n",
|
||||
"import tensorflow as tf\n",
|
||||
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
|
||||
"from tensorflow.keras.utils import to_categorical\n",
|
||||
"from sklearn.metrics import classification_report"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "05b3715f-de04-40fe-ae59-3918a665f78b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def load_data(file_path):\n",
|
||||
" with open(file_path, 'r', encoding='utf-8') as file:\n",
|
||||
" lines = file.readlines()\n",
|
||||
" sentences = [line.strip() for line in lines]\n",
|
||||
" return sentences"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "b109e4c3-e781-44f3-bf4b-0b2ce7432195",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_data = pd.read_csv('en-ner-conll-2003-master/en-ner-conll-2003/train/train.tsv/train.tsv', sep='\\t', header=None, names=['label', 'sentence'], encoding='utf-8')\n",
|
||||
"dev_sentences = load_data('en-ner-conll-2003-master/en-ner-conll-2003/dev-0/in.tsv')\n",
|
||||
"dev_labels = load_data('en-ner-conll-2003-master/en-ner-conll-2003/dev-0/expected.tsv')\n",
|
||||
"test_sentences = load_data('en-ner-conll-2003-master/en-ner-conll-2003/test-A/in.tsv')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "88fe9aa9-f3f1-4be3-aa93-83f45812c9bb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def preprocess_data(sentences, labels=None):\n",
|
||||
" tokenized_sentences = [sentence.split() for sentence in sentences]\n",
|
||||
" if labels is not None:\n",
|
||||
" tokenized_labels = [label.split() for label in labels]\n",
|
||||
" return tokenized_sentences, tokenized_labels\n",
|
||||
" return tokenized_sentences"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "466f1b4b-bf34-4cca-a2ee-3adbf7da9337",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_sentences, train_labels = preprocess_data(train_data['sentence'].values, train_data['label'].values)\n",
|
||||
"dev_sentences, dev_labels = preprocess_data(dev_sentences, dev_labels)\n",
|
||||
"test_sentences = preprocess_data(test_sentences)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "3cb4de1c-d5e0-46c5-b80d-df69c0adbaa4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"special_tokens = ['<PAD>', '<UNK>', '<BOS>', '<EOS>']\n",
|
||||
"word2idx = {w: i + len(special_tokens) for i, w in enumerate(set(word for sentence in train_sentences for word in sentence))}\n",
|
||||
"for i, token in enumerate(special_tokens):\n",
|
||||
" word2idx[token] = i\n",
|
||||
"\n",
|
||||
"idx2word = {i: w for w, i in word2idx.items()}\n",
|
||||
"\n",
|
||||
"label2idx = {\n",
|
||||
" 'O': 0,\n",
|
||||
" 'B-PER': 1, 'I-PER': 2,\n",
|
||||
" 'B-ORG': 3, 'I-ORG': 4,\n",
|
||||
" 'B-LOC': 5, 'I-LOC': 6,\n",
|
||||
" 'B-MISC': 7, 'I-MISC': 8\n",
|
||||
"}\n",
|
||||
"idx2label = {i: l for l, i in label2idx.items()}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "7c0b52af-7eeb-447b-a58e-b9f48da8f7c3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def encode_data(sentences, labels=None):\n",
|
||||
" encoded_sentences = [[word2idx.get(word, word2idx['<UNK>']) for word in sentence] for sentence in sentences]\n",
|
||||
" if labels is not None:\n",
|
||||
" encoded_labels = [[label2idx[label] for label in label_list] for label_list in labels]\n",
|
||||
" return encoded_sentences, encoded_labels\n",
|
||||
" return encoded_sentences"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "37a9de80-bf7b-455c-9794-2d559b46734d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"X_train, y_train = encode_data(train_sentences, train_labels)\n",
|
||||
"X_dev, y_dev = encode_data(dev_sentences, dev_labels)\n",
|
||||
"X_test = encode_data(test_sentences)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "c06fe4fa-18f2-467c-9b5b-9ca36128fcac",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"max_len = 1000 \n",
|
||||
"X_train = pad_sequences(X_train, padding='post', maxlen=max_len)\n",
|
||||
"y_train = pad_sequences(y_train, padding='post', maxlen=max_len)\n",
|
||||
"\n",
|
||||
"X_dev = pad_sequences(X_dev, padding='post', maxlen=max_len)\n",
|
||||
"y_dev = pad_sequences(y_dev, padding='post', maxlen=max_len)\n",
|
||||
"\n",
|
||||
"X_test = pad_sequences(X_test, padding='post', maxlen=max_len)\n",
|
||||
"\n",
|
||||
"y_train = [to_categorical(i, num_classes=len(label2idx)) for i in y_train]\n",
|
||||
"y_dev = [to_categorical(i, num_classes=len(label2idx)) for i in y_dev]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "02bf884e-d80e-43e7-b60c-4d2a611d280a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 1/25\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"C:\\Users\\Kosmitos\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\keras\\src\\layers\\core\\embedding.py:90: UserWarning: Argument `input_length` is deprecated. Just remove it.\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[1m60/60\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m79s\u001b[0m 773ms/step - accuracy: 0.8902 - loss: 0.6219 - val_accuracy: 0.9606 - val_loss: 0.1716\n",
|
||||
"Epoch 2/25\n",
|
||||
"\u001b[1m60/60\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m50s\u001b[0m 825ms/step - accuracy: 0.9656 - loss: 0.1477 - val_accuracy: 0.9606 - val_loss: 0.1553\n",
|
||||
"Epoch 3/25\n",
|
||||
"\u001b[1m60/60\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m47s\u001b[0m 788ms/step - accuracy: 0.9650 - loss: 0.1328 - val_accuracy: 0.9607 - val_loss: 0.1352\n",
|
||||
"Epoch 4/25\n",
|
||||
"\u001b[1m60/60\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m47s\u001b[0m 780ms/step - accuracy: 0.9646 - loss: 0.1136 - val_accuracy: 0.9648 - val_loss: 0.1134\n",
|
||||
"Epoch 5/25\n",
|
||||
"\u001b[1m60/60\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m46s\u001b[0m 772ms/step - accuracy: 0.9705 - loss: 0.0910 - val_accuracy: 0.9716 - val_loss: 0.0963\n",
|
||||
"Epoch 6/25\n",
|
||||
"\u001b[1m60/60\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m45s\u001b[0m 746ms/step - accuracy: 0.9776 - loss: 0.0753 - val_accuracy: 0.9761 - val_loss: 0.0827\n",
|
||||
"Epoch 7/25\n",
|
||||
"\u001b[1m60/60\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m47s\u001b[0m 778ms/step - accuracy: 0.9837 - loss: 0.0590 - val_accuracy: 0.9804 - val_loss: 0.0707\n",
|
||||
"Epoch 8/25\n",
|
||||
"\u001b[1m60/60\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m47s\u001b[0m 777ms/step - accuracy: 0.9880 - loss: 0.0484 - val_accuracy: 0.9833 - val_loss: 0.0614\n",
|
||||
"Epoch 9/25\n",
|
||||
"\u001b[1m60/60\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m45s\u001b[0m 751ms/step - accuracy: 0.9914 - loss: 0.0391 - val_accuracy: 0.9852 - val_loss: 0.0538\n",
|
||||
"Epoch 10/25\n",
|
||||
"\u001b[1m60/60\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m46s\u001b[0m 772ms/step - accuracy: 0.9937 - loss: 0.0292 - val_accuracy: 0.9865 - val_loss: 0.0486\n",
|
||||
"Epoch 11/25\n",
|
||||
"\u001b[1m60/60\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m46s\u001b[0m 766ms/step - accuracy: 0.9950 - loss: 0.0234 - val_accuracy: 0.9878 - val_loss: 0.0447\n",
|
||||
"Epoch 12/25\n",
|
||||
"\u001b[1m60/60\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m45s\u001b[0m 750ms/step - accuracy: 0.9960 - loss: 0.0198 - val_accuracy: 0.9885 - val_loss: 0.0420\n",
|
||||
"Epoch 13/25\n",
|
||||
"\u001b[1m60/60\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m45s\u001b[0m 743ms/step - accuracy: 0.9971 - loss: 0.0160 - val_accuracy: 0.9890 - val_loss: 0.0397\n",
|
||||
"Epoch 14/25\n",
|
||||
"\u001b[1m60/60\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m45s\u001b[0m 754ms/step - accuracy: 0.9975 - loss: 0.0137 - val_accuracy: 0.9892 - val_loss: 0.0386\n",
|
||||
"Epoch 15/25\n",
|
||||
"\u001b[1m60/60\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m45s\u001b[0m 745ms/step - accuracy: 0.9981 - loss: 0.0110 - val_accuracy: 0.9896 - val_loss: 0.0369\n",
|
||||
"Epoch 16/25\n",
|
||||
"\u001b[1m60/60\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m46s\u001b[0m 759ms/step - accuracy: 0.9983 - loss: 0.0097 - val_accuracy: 0.9897 - val_loss: 0.0365\n",
|
||||
"Epoch 17/25\n",
|
||||
"\u001b[1m60/60\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m44s\u001b[0m 731ms/step - accuracy: 0.9986 - loss: 0.0082 - val_accuracy: 0.9897 - val_loss: 0.0365\n",
|
||||
"Epoch 18/25\n",
|
||||
"\u001b[1m60/60\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m43s\u001b[0m 714ms/step - accuracy: 0.9989 - loss: 0.0070 - val_accuracy: 0.9900 - val_loss: 0.0355\n",
|
||||
"Epoch 19/25\n",
|
||||
"\u001b[1m60/60\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m43s\u001b[0m 711ms/step - accuracy: 0.9991 - loss: 0.0059 - val_accuracy: 0.9899 - val_loss: 0.0351\n",
|
||||
"Epoch 20/25\n",
|
||||
"\u001b[1m60/60\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m44s\u001b[0m 729ms/step - accuracy: 0.9992 - loss: 0.0054 - val_accuracy: 0.9900 - val_loss: 0.0353\n",
|
||||
"Epoch 21/25\n",
|
||||
"\u001b[1m60/60\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m44s\u001b[0m 731ms/step - accuracy: 0.9993 - loss: 0.0048 - val_accuracy: 0.9900 - val_loss: 0.0350\n",
|
||||
"Epoch 22/25\n",
|
||||
"\u001b[1m60/60\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m43s\u001b[0m 714ms/step - accuracy: 0.9993 - loss: 0.0044 - val_accuracy: 0.9901 - val_loss: 0.0353\n",
|
||||
"Epoch 23/25\n",
|
||||
"\u001b[1m60/60\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m44s\u001b[0m 738ms/step - accuracy: 0.9994 - loss: 0.0040 - val_accuracy: 0.9900 - val_loss: 0.0359\n",
|
||||
"Epoch 24/25\n",
|
||||
"\u001b[1m60/60\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m44s\u001b[0m 725ms/step - accuracy: 0.9995 - loss: 0.0036 - val_accuracy: 0.9900 - val_loss: 0.0353\n",
|
||||
"Epoch 25/25\n",
|
||||
"\u001b[1m60/60\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m43s\u001b[0m 722ms/step - accuracy: 0.9995 - loss: 0.0033 - val_accuracy: 0.9899 - val_loss: 0.0366\n",
|
||||
"\u001b[1m7/7\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 1s/step \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model = tf.keras.models.Sequential([\n",
|
||||
" tf.keras.layers.Embedding(input_dim=len(word2idx), output_dim=64, input_length=max_len),\n",
|
||||
" tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(units=32, return_sequences=True)),\n",
|
||||
" tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(len(label2idx), activation='softmax'))\n",
|
||||
"])\n",
|
||||
"\n",
|
||||
"model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])\n",
|
||||
"\n",
|
||||
"history = model.fit(X_train, np.array(y_train), validation_data=(X_dev, np.array(y_dev)), epochs=25, batch_size=16)\n",
|
||||
"\n",
|
||||
"y_pred = model.predict(X_dev)\n",
|
||||
"y_pred = np.argmax(y_pred, axis=-1)\n",
|
||||
"y_true = np.argmax(np.array(y_dev), axis=-1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "56af96ba-ee8e-4476-a17d-d9a66d03d136",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"C:\\Users\\Kosmitos\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
|
||||
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
|
||||
"C:\\Users\\Kosmitos\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1517: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n",
|
||||
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
|
||||
"C:\\Users\\Kosmitos\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1517: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in labels with no true nor predicted samples. Use `zero_division` parameter to control this behavior.\n",
|
||||
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
|
||||
"C:\\Users\\Kosmitos\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
|
||||
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
|
||||
"C:\\Users\\Kosmitos\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1517: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 due to no true samples. Use `zero_division` parameter to control this behavior.\n",
|
||||
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
|
||||
"C:\\Users\\Kosmitos\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1517: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 due to no true nor predicted samples. Use `zero_division` parameter to control this behavior.\n",
|
||||
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
|
||||
"C:\\Users\\Kosmitos\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
|
||||
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
|
||||
"C:\\Users\\Kosmitos\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1517: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n",
|
||||
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
|
||||
"C:\\Users\\Kosmitos\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1517: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in labels with no true nor predicted samples. Use `zero_division` parameter to control this behavior.\n",
|
||||
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" B-PER 0.00 0.00 0.00 0\n",
|
||||
" I-PER 0.00 0.00 0.00 0\n",
|
||||
" B-ORG 0.00 0.00 0.00 0\n",
|
||||
" I-ORG 0.00 0.00 0.00 0\n",
|
||||
" B-LOC 0.00 0.00 0.00 0\n",
|
||||
" I-LOC 0.00 0.00 0.00 0\n",
|
||||
" B-MISC 0.00 0.00 0.00 0\n",
|
||||
" I-MISC 0.00 0.00 0.00 0\n",
|
||||
"\n",
|
||||
" micro avg 0.00 0.00 0.00 0\n",
|
||||
" macro avg 0.00 0.00 0.00 0\n",
|
||||
"weighted avg 0.00 0.00 0.00 0\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"C:\\Users\\Kosmitos\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
|
||||
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
|
||||
"C:\\Users\\Kosmitos\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1517: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n",
|
||||
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
|
||||
"C:\\Users\\Kosmitos\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1517: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in labels with no true nor predicted samples. Use `zero_division` parameter to control this behavior.\n",
|
||||
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"y_pred_tags = [[idx2label[i] for i in row] for row in y_pred]\n",
|
||||
"y_true_tags = [[idx2label[i] for i in row] for row in y_true]\n",
|
||||
"\n",
|
||||
"print(classification_report(\n",
|
||||
" [item for sublist in y_true_tags for item in sublist],\n",
|
||||
" [item for sublist in y_pred_tags for item in sublist],\n",
|
||||
" labels=list(label2idx.values())[1:], \n",
|
||||
" target_names=[idx2label[i] for i in list(label2idx.values())[1:]]\n",
|
||||
"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "be795033-3bca-4c89-b3f3-c6104dc305fd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def correct_iob_labels(predictions):\n",
|
||||
" corrected = []\n",
|
||||
" for pred in predictions:\n",
|
||||
" corrected_sentence = []\n",
|
||||
" prev_label = 'O'\n",
|
||||
" for label in pred:\n",
|
||||
" if label.startswith('I-') and (prev_label == 'O' or prev_label[2:] != label[2:]):\n",
|
||||
" corrected_sentence.append('B-' + label[2:])\n",
|
||||
" else:\n",
|
||||
" corrected_sentence.append(label)\n",
|
||||
" prev_label = corrected_sentence[-1]\n",
|
||||
" corrected.append(corrected_sentence)\n",
|
||||
" return corrected"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "c75a75ad-3bef-470d-9882-13a97e5495e8",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[1m8/8\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 99ms/step\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"y_test_pred = model.predict(X_test)\n",
|
||||
"y_test_pred = np.argmax(y_test_pred, axis=-1)\n",
|
||||
"y_test_pred_tags = [[idx2label[i] for i in row] for row in y_test_pred]\n",
|
||||
"\n",
|
||||
"y_pred_tags_corrected = correct_iob_labels(y_pred_tags)\n",
|
||||
"y_test_pred_tags_corrected = correct_iob_labels(y_test_pred_tags)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"id": "0f48d0f4-a6f5-424b-a8cc-caed7b14e113",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dev_predictions = [' '.join(tags) for tags in y_pred_tags_corrected]\n",
|
||||
"with open('en-ner-conll-2003-master/en-ner-conll-2003/dev-0/out.tsv', 'w', encoding='utf-8') as f:\n",
|
||||
" for prediction in dev_predictions:\n",
|
||||
" f.write(\"%s\\n\" % prediction)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"id": "8ac16ee5-cdf1-4b6a-9c94-003d598921e0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"test_predictions = [' '.join(tags) for tags in y_test_pred_tags_corrected]\n",
|
||||
"with open('en-ner-conll-2003-master/en-ner-conll-2003/test-A/out.tsv', 'w', encoding='utf-8') as f:\n",
|
||||
" for prediction in test_predictions:\n",
|
||||
" f.write(\"%s\\n\" % prediction)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
Loading…
Reference in New Issue
Block a user