diff --git a/RNN.ipynb b/RNN.ipynb deleted file mode 100644 index 07afcbb..0000000 --- a/RNN.ipynb +++ /dev/null @@ -1,531 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "c80ac05e-c22e-4f7f-a48d-ca85173f0f86", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import torch\n", - "import torch.nn as nn\n", - "import torch.optim as optim\n", - "from torch.utils.data import DataLoader, TensorDataset\n", - "from sklearn.model_selection import train_test_split\n", - "from sklearn.metrics import classification_report\n", - "from tqdm.notebook import tqdm\n", - "import numpy as np\n", - "from collections import defaultdict" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "c8a9db2f-9e34-4246-8a71-5b7c504cfcb5", - "metadata": {}, - "outputs": [], - "source": [ - "train_path = 'en-ner-conll-2003-master/en-ner-conll-2003/train/train.tsv/train.tsv'\n", - "dev_in_path = 'en-ner-conll-2003-master/en-ner-conll-2003/dev-0/in.tsv'\n", - "test_in_path = 'en-ner-conll-2003-master/en-ner-conll-2003/test-A/in.tsv'\n", - "dev_out_path = 'en-ner-conll-2003-master/en-ner-conll-2003/dev-0/out.tsv'\n", - "test_out_path = 'en-ner-conll-2003-master/en-ner-conll-2003/test-A/out.tsv'" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "06e936cf-3022-4d9e-a6d8-03c736bd7586", - "metadata": {}, - "outputs": [], - "source": [ - "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "f1c3b706-0e38-4b3d-814e-79449ba696f6", - "metadata": {}, - "outputs": [], - "source": [ - "def load_data(file_path):\n", - " sentences, labels = [], []\n", - " with open(file_path, 'r') as file:\n", - " sentence, label = [], []\n", - " for line in file:\n", - " if line.startswith('-DOCSTART-') or line == \"\\n\":\n", - " if sentence:\n", - " sentences.append(sentence)\n", - " labels.append(label)\n", - " sentence, label = [], []\n", - " else:\n", - " splits = line.split()\n", - " sentence.append(splits[0])\n", - " label.append(splits[-1])\n", - " if sentence:\n", - " sentences.append(sentence)\n", - " labels.append(label)\n", - " return sentences, labels" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "e9aa2f41-8625-44d9-815e-f8d010d7ea6b", - "metadata": {}, - "outputs": [], - "source": [ - "def build_vocab(sentences):\n", - " vocab = defaultdict(lambda: len(vocab))\n", - " vocab[\"\"]\n", - " for sentence in sentences:\n", - " for word in sentence:\n", - " vocab[word]\n", - " return vocab" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "011e8daa-8bd0-4fba-8785-cf25b31c3343", - "metadata": {}, - "outputs": [], - "source": [ - "def encode_sentences(sentences, vocab):\n", - " return [[vocab[word] for word in sentence] for sentence in sentences]\n", - "\n", - "def pad_sequences(sequences, max_len):\n", - " padded_sequences = np.zeros((len(sequences), max_len), dtype=int)\n", - " for i, seq in enumerate(sequences):\n", - " padded_sequences[i, :len(seq)] = seq\n", - " return padded_sequences" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "e1c4c06b-fb61-48aa-b6bd-53b8b6b0b7b1", - "metadata": {}, - "outputs": [], - "source": [ - "def create_dataloader(inputs, labels, batch_size):\n", - " inputs_tensor = torch.tensor(inputs, dtype=torch.long)\n", - " labels_tensor = torch.tensor(labels, dtype=torch.long)\n", - " dataset = TensorDataset(inputs_tensor, labels_tensor)\n", - " dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n", - " return dataloader" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "c1012f38-c925-4b3a-97b9-ec2ff3a66bae", - "metadata": {}, - "outputs": [], - "source": [ - "class NER_RNN(nn.Module):\n", - " def __init__(self, vocab_size, tagset_size, embedding_dim=128, hidden_dim=256):\n", - " super(NER_RNN, self).__init__()\n", - " self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)\n", - " self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)\n", - " self.fc = nn.Linear(hidden_dim, tagset_size)\n", - " \n", - " def forward(self, x):\n", - " x = self.embedding(x)\n", - " x, _ = self.lstm(x)\n", - " x = self.fc(x)\n", - " return x" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "f2f09d01-11c7-4cc0-91f7-af10c992be9a", - "metadata": {}, - "outputs": [], - "source": [ - "def correct_iob_labels(predictions): \n", - " corrected = []\n", - " for pred in predictions:\n", - " corrected_sentence = []\n", - " prev_label = 'O'\n", - " for label in pred:\n", - " if label.startswith('I-') and (prev_label == 'O' or prev_label[2:] != label[2:]): \n", - " corrected_sentence.append('B-' + label[2:])\n", - " else:\n", - " corrected_sentence.append(label)\n", - " prev_label = corrected_sentence[-1]\n", - " corrected.append(corrected_sentence)\n", - " return corrected" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "50971cde-d176-4fce-af14-1a7530c1c453", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "train_sentences, train_labels = load_data(train_path)\n", - "\n", - "all_labels = [label for sentence_labels in train_labels for label in sentence_labels]\n", - "unique_labels = list(set(all_labels))\n", - "label_mapping = {label: i for i, label in enumerate(unique_labels)}\n", - "\n", - "inverse_label_mapping = {i: label for label, i in label_mapping.items()}\n", - "\n", - "def predict_and_save(model, input_tokens_ids, output_file):\n", - " model.eval() \n", - " predictions = []\n", - " for tokens in tqdm(input_tokens_ids):\n", - " tokens = tokens.unsqueeze(0)\n", - " with torch.no_grad():\n", - " predicted_tags = model(tokens)\n", - " predicted_tags = torch.argmax(predicted_tags.squeeze(0), 1).tolist() \n", - " predicted_labels = [inverse_label_mapping[tag] for tag in predicted_tags]\n", - " predictions.append(predicted_labels[1:-1]) \n", - "\n", - " with open(output_file, 'w') as f:\n", - " for sentence in correct_iob_labels(predictions):\n", - " f.write(\" \".join(sentence) + \"\\n\")" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "2542fb92-9c15-4abe-8b4d-d002df57fe99", - "metadata": {}, - "outputs": [], - "source": [ - "train_sentences, train_labels = load_data(train_path)\n", - "dev_sentences, dev_labels = load_data(dev_in_path)\n", - "test_sentences, test_labels = load_data(test_in_path)\n", - "word_vocab = build_vocab(train_sentences)\n", - "tag_vocab = build_vocab(train_labels)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "ff64590b-08cd-42ac-8e32-363f16038c0b", - "metadata": {}, - "outputs": [], - "source": [ - "train_inputs = encode_sentences(train_sentences, word_vocab)\n", - "dev_inputs = encode_sentences(dev_sentences, word_vocab)\n", - "test_inputs = encode_sentences(test_sentences, word_vocab)\n", - "\n", - "train_labels_encoded = encode_sentences(train_labels, tag_vocab)\n", - "dev_labels_encoded = encode_sentences(dev_labels, tag_vocab)\n", - "test_labels_encoded = encode_sentences(test_labels, tag_vocab)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "4c591b46-6cda-4f72-9574-e284faa58312", - "metadata": {}, - "outputs": [], - "source": [ - "max_len = max(max(len(seq) for seq in train_inputs), max(len(seq) for seq in dev_inputs), max(len(seq) for seq in test_inputs))\n", - "\n", - "train_inputs_padded = pad_sequences(train_inputs, max_len)\n", - "dev_inputs_padded = pad_sequences(dev_inputs, max_len)\n", - "test_inputs_padded = pad_sequences(test_inputs, max_len)\n", - "\n", - "train_labels_padded = pad_sequences(train_labels_encoded, max_len)\n", - "dev_labels_padded = pad_sequences(dev_labels_encoded, max_len)\n", - "test_labels_padded = pad_sequences(test_labels_encoded, max_len)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "228cf2c6-6fa0-47fd-97cc-3a0495b5059e", - "metadata": {}, - "outputs": [], - "source": [ - "batch_size = 32\n", - "\n", - "train_dataloader = create_dataloader(train_inputs_padded, train_labels_padded, batch_size)\n", - "dev_dataloader = create_dataloader(dev_inputs_padded, dev_labels_padded, batch_size)\n", - "test_dataloader = create_dataloader(test_inputs_padded, test_labels_padded, batch_size)\n", - "\n", - "vocab_size = len(word_vocab)\n", - "tagset_size = len(tag_vocab)\n", - "\n", - "model = NER_RNN(vocab_size, tagset_size).to(device)" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "eafe2cf8-547d-4245-b32a-de24792f33ec", - "metadata": {}, - "outputs": [], - "source": [ - "criterion = nn.CrossEntropyLoss(ignore_index=0)\n", - "optimizer = optim.Adam(model.parameters(), lr=0.001)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "a178c476-b566-42ec-b2ca-fb59ed85f6ee", - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "8f7c956b746a454d82306b90bb3d3abf", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/1 [00:00