From 58246f2c12be4d2c02cd0a4db66926da7fa68d8e Mon Sep 17 00:00:00 2001 From: s463046 Date: Sun, 26 May 2024 23:48:30 +0200 Subject: [PATCH] Upload files to "/" --- RNN.ipynb | 531 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 531 insertions(+) create mode 100644 RNN.ipynb diff --git a/RNN.ipynb b/RNN.ipynb new file mode 100644 index 0000000..07afcbb --- /dev/null +++ b/RNN.ipynb @@ -0,0 +1,531 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "c80ac05e-c22e-4f7f-a48d-ca85173f0f86", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from torch.utils.data import DataLoader, TensorDataset\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.metrics import classification_report\n", + "from tqdm.notebook import tqdm\n", + "import numpy as np\n", + "from collections import defaultdict" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "c8a9db2f-9e34-4246-8a71-5b7c504cfcb5", + "metadata": {}, + "outputs": [], + "source": [ + "train_path = 'en-ner-conll-2003-master/en-ner-conll-2003/train/train.tsv/train.tsv'\n", + "dev_in_path = 'en-ner-conll-2003-master/en-ner-conll-2003/dev-0/in.tsv'\n", + "test_in_path = 'en-ner-conll-2003-master/en-ner-conll-2003/test-A/in.tsv'\n", + "dev_out_path = 'en-ner-conll-2003-master/en-ner-conll-2003/dev-0/out.tsv'\n", + "test_out_path = 'en-ner-conll-2003-master/en-ner-conll-2003/test-A/out.tsv'" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "06e936cf-3022-4d9e-a6d8-03c736bd7586", + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "f1c3b706-0e38-4b3d-814e-79449ba696f6", + "metadata": {}, + "outputs": [], + "source": [ + "def load_data(file_path):\n", + " sentences, labels = [], []\n", + " with open(file_path, 'r') as file:\n", + " sentence, label = [], []\n", + " for line in file:\n", + " if line.startswith('-DOCSTART-') or line == \"\\n\":\n", + " if sentence:\n", + " sentences.append(sentence)\n", + " labels.append(label)\n", + " sentence, label = [], []\n", + " else:\n", + " splits = line.split()\n", + " sentence.append(splits[0])\n", + " label.append(splits[-1])\n", + " if sentence:\n", + " sentences.append(sentence)\n", + " labels.append(label)\n", + " return sentences, labels" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "e9aa2f41-8625-44d9-815e-f8d010d7ea6b", + "metadata": {}, + "outputs": [], + "source": [ + "def build_vocab(sentences):\n", + " vocab = defaultdict(lambda: len(vocab))\n", + " vocab[\"\"]\n", + " for sentence in sentences:\n", + " for word in sentence:\n", + " vocab[word]\n", + " return vocab" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "011e8daa-8bd0-4fba-8785-cf25b31c3343", + "metadata": {}, + "outputs": [], + "source": [ + "def encode_sentences(sentences, vocab):\n", + " return [[vocab[word] for word in sentence] for sentence in sentences]\n", + "\n", + "def pad_sequences(sequences, max_len):\n", + " padded_sequences = np.zeros((len(sequences), max_len), dtype=int)\n", + " for i, seq in enumerate(sequences):\n", + " padded_sequences[i, :len(seq)] = seq\n", + " return padded_sequences" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "e1c4c06b-fb61-48aa-b6bd-53b8b6b0b7b1", + "metadata": {}, + "outputs": [], + "source": [ + "def create_dataloader(inputs, labels, batch_size):\n", + " inputs_tensor = torch.tensor(inputs, dtype=torch.long)\n", + " labels_tensor = torch.tensor(labels, dtype=torch.long)\n", + " dataset = TensorDataset(inputs_tensor, labels_tensor)\n", + " dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n", + " return dataloader" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "c1012f38-c925-4b3a-97b9-ec2ff3a66bae", + "metadata": {}, + "outputs": [], + "source": [ + "class NER_RNN(nn.Module):\n", + " def __init__(self, vocab_size, tagset_size, embedding_dim=128, hidden_dim=256):\n", + " super(NER_RNN, self).__init__()\n", + " self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)\n", + " self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)\n", + " self.fc = nn.Linear(hidden_dim, tagset_size)\n", + " \n", + " def forward(self, x):\n", + " x = self.embedding(x)\n", + " x, _ = self.lstm(x)\n", + " x = self.fc(x)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "f2f09d01-11c7-4cc0-91f7-af10c992be9a", + "metadata": {}, + "outputs": [], + "source": [ + "def correct_iob_labels(predictions): \n", + " corrected = []\n", + " for pred in predictions:\n", + " corrected_sentence = []\n", + " prev_label = 'O'\n", + " for label in pred:\n", + " if label.startswith('I-') and (prev_label == 'O' or prev_label[2:] != label[2:]): \n", + " corrected_sentence.append('B-' + label[2:])\n", + " else:\n", + " corrected_sentence.append(label)\n", + " prev_label = corrected_sentence[-1]\n", + " corrected.append(corrected_sentence)\n", + " return corrected" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "50971cde-d176-4fce-af14-1a7530c1c453", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "train_sentences, train_labels = load_data(train_path)\n", + "\n", + "all_labels = [label for sentence_labels in train_labels for label in sentence_labels]\n", + "unique_labels = list(set(all_labels))\n", + "label_mapping = {label: i for i, label in enumerate(unique_labels)}\n", + "\n", + "inverse_label_mapping = {i: label for label, i in label_mapping.items()}\n", + "\n", + "def predict_and_save(model, input_tokens_ids, output_file):\n", + " model.eval() \n", + " predictions = []\n", + " for tokens in tqdm(input_tokens_ids):\n", + " tokens = tokens.unsqueeze(0)\n", + " with torch.no_grad():\n", + " predicted_tags = model(tokens)\n", + " predicted_tags = torch.argmax(predicted_tags.squeeze(0), 1).tolist() \n", + " predicted_labels = [inverse_label_mapping[tag] for tag in predicted_tags]\n", + " predictions.append(predicted_labels[1:-1]) \n", + "\n", + " with open(output_file, 'w') as f:\n", + " for sentence in correct_iob_labels(predictions):\n", + " f.write(\" \".join(sentence) + \"\\n\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "2542fb92-9c15-4abe-8b4d-d002df57fe99", + "metadata": {}, + "outputs": [], + "source": [ + "train_sentences, train_labels = load_data(train_path)\n", + "dev_sentences, dev_labels = load_data(dev_in_path)\n", + "test_sentences, test_labels = load_data(test_in_path)\n", + "word_vocab = build_vocab(train_sentences)\n", + "tag_vocab = build_vocab(train_labels)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "ff64590b-08cd-42ac-8e32-363f16038c0b", + "metadata": {}, + "outputs": [], + "source": [ + "train_inputs = encode_sentences(train_sentences, word_vocab)\n", + "dev_inputs = encode_sentences(dev_sentences, word_vocab)\n", + "test_inputs = encode_sentences(test_sentences, word_vocab)\n", + "\n", + "train_labels_encoded = encode_sentences(train_labels, tag_vocab)\n", + "dev_labels_encoded = encode_sentences(dev_labels, tag_vocab)\n", + "test_labels_encoded = encode_sentences(test_labels, tag_vocab)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "4c591b46-6cda-4f72-9574-e284faa58312", + "metadata": {}, + "outputs": [], + "source": [ + "max_len = max(max(len(seq) for seq in train_inputs), max(len(seq) for seq in dev_inputs), max(len(seq) for seq in test_inputs))\n", + "\n", + "train_inputs_padded = pad_sequences(train_inputs, max_len)\n", + "dev_inputs_padded = pad_sequences(dev_inputs, max_len)\n", + "test_inputs_padded = pad_sequences(test_inputs, max_len)\n", + "\n", + "train_labels_padded = pad_sequences(train_labels_encoded, max_len)\n", + "dev_labels_padded = pad_sequences(dev_labels_encoded, max_len)\n", + "test_labels_padded = pad_sequences(test_labels_encoded, max_len)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "228cf2c6-6fa0-47fd-97cc-3a0495b5059e", + "metadata": {}, + "outputs": [], + "source": [ + "batch_size = 32\n", + "\n", + "train_dataloader = create_dataloader(train_inputs_padded, train_labels_padded, batch_size)\n", + "dev_dataloader = create_dataloader(dev_inputs_padded, dev_labels_padded, batch_size)\n", + "test_dataloader = create_dataloader(test_inputs_padded, test_labels_padded, batch_size)\n", + "\n", + "vocab_size = len(word_vocab)\n", + "tagset_size = len(tag_vocab)\n", + "\n", + "model = NER_RNN(vocab_size, tagset_size).to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "eafe2cf8-547d-4245-b32a-de24792f33ec", + "metadata": {}, + "outputs": [], + "source": [ + "criterion = nn.CrossEntropyLoss(ignore_index=0)\n", + "optimizer = optim.Adam(model.parameters(), lr=0.001)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "a178c476-b566-42ec-b2ca-fb59ed85f6ee", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8f7c956b746a454d82306b90bb3d3abf", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1 [00:00