RNN/RNN_NER.ipynb

259 lines
9.3 KiB
Plaintext
Raw Normal View History

2024-05-27 18:44:12 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import torch\n",
"from torchtext.vocab import vocab\n",
"from sklearn.model_selection import train_test_split\n",
"from tqdm.notebook import tqdm\n",
"from collections import Counter"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"def load_datasets():\n",
" train_data = pd.read_csv(\n",
" \"train/train.tsv.xz\", compression=\"xz\", sep=\"\\t\", names=[\"Tag\", \"Sentence\"]\n",
" )\n",
" dev_data = pd.read_csv(\"dev-0/in.tsv\", sep=\"\\t\", names=[\"Sentence\"])\n",
" dev_labels = pd.read_csv(\"dev-0/expected.tsv\", sep=\"\\t\", names=[\"Tag\"])\n",
" test_data = pd.read_csv(\"test-A/in.tsv\", sep=\"\\t\", names=[\"Sentence\"])\n",
"\n",
" return train_data, dev_data, dev_labels, test_data\n",
"\n",
"train_data, dev_data, dev_labels, test_data = load_datasets()\n",
"\n",
"train_sentences, val_sentences, train_tags, val_tags = train_test_split(\n",
" train_data[\"Sentence\"], train_data[\"Tag\"], test_size=0.1, random_state=42\n",
")\n",
"\n",
"train_data = pd.DataFrame({\"Sentence\": train_sentences, \"Tag\": train_tags})\n",
"val_data = pd.DataFrame({\"Sentence\": val_sentences, \"Tag\": val_tags})\n",
"\n",
"def tokenize_column(dataframe, column):\n",
" return dataframe[column].apply(lambda x: x.split())\n",
"\n",
"train_data[\"tokens\"] = tokenize_column(train_data, \"Sentence\")\n",
"train_data[\"tag_tokens\"] = tokenize_column(train_data, \"Tag\")\n",
"val_data[\"tokens\"] = tokenize_column(val_data, \"Sentence\")\n",
"val_data[\"tag_tokens\"] = tokenize_column(val_data, \"Tag\")\n",
"dev_data[\"tokens\"] = tokenize_column(dev_data, \"Sentence\")\n",
"dev_labels[\"tag_tokens\"] = tokenize_column(dev_labels, \"Tag\")\n",
"test_data[\"tokens\"] = tokenize_column(test_data, \"Sentence\")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"def create_vocab(token_list):\n",
" token_counter = Counter()\n",
" for tokens in token_list:\n",
" token_counter.update(tokens)\n",
" return vocab(token_counter, specials=[\"<unk>\", \"<pad>\", \"<bos>\", \"<eos>\"])\n",
"\n",
"vocab_obj = create_vocab(train_data[\"tokens\"])\n",
"\n",
"vocab_obj.set_default_index(vocab_obj[\"<unk>\"])\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"def convert_to_tensor(token_lists, vocab_obj, device):\n",
" return [\n",
" torch.tensor(\n",
" [vocab_obj[\"<bos>\"]] + [vocab_obj[token] for token in tokens] + [vocab_obj[\"<eos>\"]],\n",
" dtype=torch.long,\n",
" device=device,\n",
" )\n",
" for tokens in token_lists\n",
" ]\n",
"\n",
"train_tensor = convert_to_tensor(train_data[\"tokens\"], vocab_obj, device)\n",
"val_tensor = convert_to_tensor(val_data[\"tokens\"], vocab_obj, device)\n",
"dev_tensor = convert_to_tensor(dev_data[\"tokens\"], vocab_obj, device)\n",
"test_tensor = convert_to_tensor(test_data[\"tokens\"], vocab_obj, device)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"tag_list = [\"O\", \"B-PER\", \"I-PER\", \"B-ORG\", \"I-ORG\", \"B-LOC\", \"I-LOC\", \"B-MISC\", \"I-MISC\"]\n",
"\n",
"tag_to_index = {tag: idx for idx, tag in enumerate(tag_list)}\n",
"\n",
"def convert_tags_to_tensor(tag_tokens, tag_to_index, device):\n",
" return [\n",
" torch.tensor(\n",
" [0] + [tag_to_index[tag] for tag in tags] + [0],\n",
" dtype=torch.long,\n",
" device=device,\n",
" )\n",
" for tags in tag_tokens\n",
" ]\n",
"\n",
"train_tag_tensor = convert_tags_to_tensor(train_data[\"tag_tokens\"], tag_to_index, device)\n",
"val_tag_tensor = convert_tags_to_tensor(val_data[\"tag_tokens\"], tag_to_index, device)\n",
"dev_tag_tensor = convert_tags_to_tensor(dev_labels[\"tag_tokens\"], tag_to_index, device)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"def calculate_metrics(true_labels, predicted_labels):\n",
" true_positives = 0\n",
" total_selected = 0\n",
" total_relevant = 0\n",
"\n",
" for pred, true in zip(predicted_labels, true_labels):\n",
" if pred == true:\n",
" true_positives += 1\n",
" if pred > 0:\n",
" total_selected += 1\n",
" if true > 0:\n",
" total_relevant += 1\n",
"\n",
" precision = true_positives / total_selected if total_selected > 0 else 1.0\n",
" recall = true_positives / total_relevant if total_relevant > 0 else 1.0\n",
" f1_score = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0\n",
"\n",
" return precision, recall, f1_score\n",
"\n",
"max_tag_index = max(tag_to_index.values()) + 1"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"class BiLSTMModel(torch.nn.Module):\n",
" def __init__(self, vocab_size, embed_size, hidden_size, num_layers, output_size):\n",
" super(BiLSTMModel, self).__init__()\n",
" self.embedding = torch.nn.Embedding(vocab_size, embed_size)\n",
" self.lstm = torch.nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True, bidirectional=True)\n",
" self.fc = torch.nn.Linear(hidden_size * 2, output_size)\n",
"\n",
" def forward(self, x):\n",
" embedded = torch.relu(self.embedding(x))\n",
" lstm_out, _ = self.lstm(embedded)\n",
" logits = self.fc(lstm_out)\n",
" return logits\n",
"\n",
"model = BiLSTMModel(len(vocab_obj.get_itos()), 100, 100, 1, max_tag_index).to(device)\n",
"loss_fn = torch.nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.Adam(model.parameters())\n",
"\n",
"def evaluate_model(tokens, labels, model):\n",
" true_labels = []\n",
" predicted_labels = []\n",
" for i in tqdm(range(len(labels))):\n",
" inputs = tokens[i].unsqueeze(0)\n",
" true = list(labels[i].cpu().numpy())\n",
" true_labels += true\n",
"\n",
" with torch.no_grad():\n",
" logits = model(inputs).squeeze(0)\n",
" predicted = torch.argmax(logits, dim=1)\n",
" predicted_labels += list(predicted.cpu().numpy())\n",
"\n",
" return calculate_metrics(true_labels, predicted_labels)\n",
"\n",
"def predict_labels(tokens, model, tag_to_index):\n",
" predictions = []\n",
" index_to_tag = {v: k for k, v in tag_to_index.items()}\n",
"\n",
" for i in tqdm(range(len(tokens))):\n",
" inputs = tokens[i].unsqueeze(0)\n",
" with torch.no_grad():\n",
" logits = model(inputs).squeeze(0)\n",
" predicted = torch.argmax(logits, dim=1)\n",
" tags = [index_to_tag[label.item()] for label in predicted[1:-1]]\n",
" predictions.append(\" \".join(tags))\n",
"\n",
" return predictions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"EPOCHS = 10\n",
"\n",
"for epoch in range(EPOCHS):\n",
" model.train()\n",
" for i in tqdm(range(len(train_tag_tensor))):\n",
" inputs = train_tensor[i].unsqueeze(0)\n",
" targets = train_tag_tensor[i].unsqueeze(1)\n",
"\n",
" optimizer.zero_grad()\n",
" outputs = model(inputs)\n",
" loss = loss_fn(outputs.squeeze(0), targets.squeeze(1))\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" model.eval()\n",
" print(evaluate_model(val_tensor, val_tag_tensor, model))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"evaluate_model(val_tensor, val_tag_tensor, model)\n",
"evaluate_model(dev_tensor, dev_tag_tensor, model)\n",
"\n",
"dev_predictions = predict_labels(dev_tensor, model, tag_to_index)\n",
"dev_predictions_df = pd.DataFrame(dev_predictions, columns=[\"Tag\"])\n",
"dev_predictions_df.to_csv(\"dev-0/out.tsv\", index=False, header=False)\n",
"\n",
"test_predictions = predict_labels(test_tensor, model, tag_to_index)\n",
"test_predictions_df = pd.DataFrame(test_predictions, columns=[\"Tag\"])\n",
"test_predictions_df.to_csv(\"test-A/out.tsv\", index=False, header=False)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.19"
}
},
"nbformat": 4,
"nbformat_minor": 2
}