Delete RNN.ipynb
This commit is contained in:
parent
58246f2c12
commit
d51744a41f
531
RNN.ipynb
531
RNN.ipynb
@ -1,531 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 1,
|
|
||||||
"id": "c80ac05e-c22e-4f7f-a48d-ca85173f0f86",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import os\n",
|
|
||||||
"import torch\n",
|
|
||||||
"import torch.nn as nn\n",
|
|
||||||
"import torch.optim as optim\n",
|
|
||||||
"from torch.utils.data import DataLoader, TensorDataset\n",
|
|
||||||
"from sklearn.model_selection import train_test_split\n",
|
|
||||||
"from sklearn.metrics import classification_report\n",
|
|
||||||
"from tqdm.notebook import tqdm\n",
|
|
||||||
"import numpy as np\n",
|
|
||||||
"from collections import defaultdict"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 2,
|
|
||||||
"id": "c8a9db2f-9e34-4246-8a71-5b7c504cfcb5",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"train_path = 'en-ner-conll-2003-master/en-ner-conll-2003/train/train.tsv/train.tsv'\n",
|
|
||||||
"dev_in_path = 'en-ner-conll-2003-master/en-ner-conll-2003/dev-0/in.tsv'\n",
|
|
||||||
"test_in_path = 'en-ner-conll-2003-master/en-ner-conll-2003/test-A/in.tsv'\n",
|
|
||||||
"dev_out_path = 'en-ner-conll-2003-master/en-ner-conll-2003/dev-0/out.tsv'\n",
|
|
||||||
"test_out_path = 'en-ner-conll-2003-master/en-ner-conll-2003/test-A/out.tsv'"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 3,
|
|
||||||
"id": "06e936cf-3022-4d9e-a6d8-03c736bd7586",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 4,
|
|
||||||
"id": "f1c3b706-0e38-4b3d-814e-79449ba696f6",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def load_data(file_path):\n",
|
|
||||||
" sentences, labels = [], []\n",
|
|
||||||
" with open(file_path, 'r') as file:\n",
|
|
||||||
" sentence, label = [], []\n",
|
|
||||||
" for line in file:\n",
|
|
||||||
" if line.startswith('-DOCSTART-') or line == \"\\n\":\n",
|
|
||||||
" if sentence:\n",
|
|
||||||
" sentences.append(sentence)\n",
|
|
||||||
" labels.append(label)\n",
|
|
||||||
" sentence, label = [], []\n",
|
|
||||||
" else:\n",
|
|
||||||
" splits = line.split()\n",
|
|
||||||
" sentence.append(splits[0])\n",
|
|
||||||
" label.append(splits[-1])\n",
|
|
||||||
" if sentence:\n",
|
|
||||||
" sentences.append(sentence)\n",
|
|
||||||
" labels.append(label)\n",
|
|
||||||
" return sentences, labels"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 5,
|
|
||||||
"id": "e9aa2f41-8625-44d9-815e-f8d010d7ea6b",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def build_vocab(sentences):\n",
|
|
||||||
" vocab = defaultdict(lambda: len(vocab))\n",
|
|
||||||
" vocab[\"<PAD>\"]\n",
|
|
||||||
" for sentence in sentences:\n",
|
|
||||||
" for word in sentence:\n",
|
|
||||||
" vocab[word]\n",
|
|
||||||
" return vocab"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 6,
|
|
||||||
"id": "011e8daa-8bd0-4fba-8785-cf25b31c3343",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def encode_sentences(sentences, vocab):\n",
|
|
||||||
" return [[vocab[word] for word in sentence] for sentence in sentences]\n",
|
|
||||||
"\n",
|
|
||||||
"def pad_sequences(sequences, max_len):\n",
|
|
||||||
" padded_sequences = np.zeros((len(sequences), max_len), dtype=int)\n",
|
|
||||||
" for i, seq in enumerate(sequences):\n",
|
|
||||||
" padded_sequences[i, :len(seq)] = seq\n",
|
|
||||||
" return padded_sequences"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 7,
|
|
||||||
"id": "e1c4c06b-fb61-48aa-b6bd-53b8b6b0b7b1",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def create_dataloader(inputs, labels, batch_size):\n",
|
|
||||||
" inputs_tensor = torch.tensor(inputs, dtype=torch.long)\n",
|
|
||||||
" labels_tensor = torch.tensor(labels, dtype=torch.long)\n",
|
|
||||||
" dataset = TensorDataset(inputs_tensor, labels_tensor)\n",
|
|
||||||
" dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n",
|
|
||||||
" return dataloader"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 8,
|
|
||||||
"id": "c1012f38-c925-4b3a-97b9-ec2ff3a66bae",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"class NER_RNN(nn.Module):\n",
|
|
||||||
" def __init__(self, vocab_size, tagset_size, embedding_dim=128, hidden_dim=256):\n",
|
|
||||||
" super(NER_RNN, self).__init__()\n",
|
|
||||||
" self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)\n",
|
|
||||||
" self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)\n",
|
|
||||||
" self.fc = nn.Linear(hidden_dim, tagset_size)\n",
|
|
||||||
" \n",
|
|
||||||
" def forward(self, x):\n",
|
|
||||||
" x = self.embedding(x)\n",
|
|
||||||
" x, _ = self.lstm(x)\n",
|
|
||||||
" x = self.fc(x)\n",
|
|
||||||
" return x"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 9,
|
|
||||||
"id": "f2f09d01-11c7-4cc0-91f7-af10c992be9a",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def correct_iob_labels(predictions): \n",
|
|
||||||
" corrected = []\n",
|
|
||||||
" for pred in predictions:\n",
|
|
||||||
" corrected_sentence = []\n",
|
|
||||||
" prev_label = 'O'\n",
|
|
||||||
" for label in pred:\n",
|
|
||||||
" if label.startswith('I-') and (prev_label == 'O' or prev_label[2:] != label[2:]): \n",
|
|
||||||
" corrected_sentence.append('B-' + label[2:])\n",
|
|
||||||
" else:\n",
|
|
||||||
" corrected_sentence.append(label)\n",
|
|
||||||
" prev_label = corrected_sentence[-1]\n",
|
|
||||||
" corrected.append(corrected_sentence)\n",
|
|
||||||
" return corrected"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 10,
|
|
||||||
"id": "50971cde-d176-4fce-af14-1a7530c1c453",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"\n",
|
|
||||||
"train_sentences, train_labels = load_data(train_path)\n",
|
|
||||||
"\n",
|
|
||||||
"all_labels = [label for sentence_labels in train_labels for label in sentence_labels]\n",
|
|
||||||
"unique_labels = list(set(all_labels))\n",
|
|
||||||
"label_mapping = {label: i for i, label in enumerate(unique_labels)}\n",
|
|
||||||
"\n",
|
|
||||||
"inverse_label_mapping = {i: label for label, i in label_mapping.items()}\n",
|
|
||||||
"\n",
|
|
||||||
"def predict_and_save(model, input_tokens_ids, output_file):\n",
|
|
||||||
" model.eval() \n",
|
|
||||||
" predictions = []\n",
|
|
||||||
" for tokens in tqdm(input_tokens_ids):\n",
|
|
||||||
" tokens = tokens.unsqueeze(0)\n",
|
|
||||||
" with torch.no_grad():\n",
|
|
||||||
" predicted_tags = model(tokens)\n",
|
|
||||||
" predicted_tags = torch.argmax(predicted_tags.squeeze(0), 1).tolist() \n",
|
|
||||||
" predicted_labels = [inverse_label_mapping[tag] for tag in predicted_tags]\n",
|
|
||||||
" predictions.append(predicted_labels[1:-1]) \n",
|
|
||||||
"\n",
|
|
||||||
" with open(output_file, 'w') as f:\n",
|
|
||||||
" for sentence in correct_iob_labels(predictions):\n",
|
|
||||||
" f.write(\" \".join(sentence) + \"\\n\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 11,
|
|
||||||
"id": "2542fb92-9c15-4abe-8b4d-d002df57fe99",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"train_sentences, train_labels = load_data(train_path)\n",
|
|
||||||
"dev_sentences, dev_labels = load_data(dev_in_path)\n",
|
|
||||||
"test_sentences, test_labels = load_data(test_in_path)\n",
|
|
||||||
"word_vocab = build_vocab(train_sentences)\n",
|
|
||||||
"tag_vocab = build_vocab(train_labels)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 12,
|
|
||||||
"id": "ff64590b-08cd-42ac-8e32-363f16038c0b",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"train_inputs = encode_sentences(train_sentences, word_vocab)\n",
|
|
||||||
"dev_inputs = encode_sentences(dev_sentences, word_vocab)\n",
|
|
||||||
"test_inputs = encode_sentences(test_sentences, word_vocab)\n",
|
|
||||||
"\n",
|
|
||||||
"train_labels_encoded = encode_sentences(train_labels, tag_vocab)\n",
|
|
||||||
"dev_labels_encoded = encode_sentences(dev_labels, tag_vocab)\n",
|
|
||||||
"test_labels_encoded = encode_sentences(test_labels, tag_vocab)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 13,
|
|
||||||
"id": "4c591b46-6cda-4f72-9574-e284faa58312",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"max_len = max(max(len(seq) for seq in train_inputs), max(len(seq) for seq in dev_inputs), max(len(seq) for seq in test_inputs))\n",
|
|
||||||
"\n",
|
|
||||||
"train_inputs_padded = pad_sequences(train_inputs, max_len)\n",
|
|
||||||
"dev_inputs_padded = pad_sequences(dev_inputs, max_len)\n",
|
|
||||||
"test_inputs_padded = pad_sequences(test_inputs, max_len)\n",
|
|
||||||
"\n",
|
|
||||||
"train_labels_padded = pad_sequences(train_labels_encoded, max_len)\n",
|
|
||||||
"dev_labels_padded = pad_sequences(dev_labels_encoded, max_len)\n",
|
|
||||||
"test_labels_padded = pad_sequences(test_labels_encoded, max_len)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 14,
|
|
||||||
"id": "228cf2c6-6fa0-47fd-97cc-3a0495b5059e",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"batch_size = 32\n",
|
|
||||||
"\n",
|
|
||||||
"train_dataloader = create_dataloader(train_inputs_padded, train_labels_padded, batch_size)\n",
|
|
||||||
"dev_dataloader = create_dataloader(dev_inputs_padded, dev_labels_padded, batch_size)\n",
|
|
||||||
"test_dataloader = create_dataloader(test_inputs_padded, test_labels_padded, batch_size)\n",
|
|
||||||
"\n",
|
|
||||||
"vocab_size = len(word_vocab)\n",
|
|
||||||
"tagset_size = len(tag_vocab)\n",
|
|
||||||
"\n",
|
|
||||||
"model = NER_RNN(vocab_size, tagset_size).to(device)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 15,
|
|
||||||
"id": "eafe2cf8-547d-4245-b32a-de24792f33ec",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"criterion = nn.CrossEntropyLoss(ignore_index=0)\n",
|
|
||||||
"optimizer = optim.Adam(model.parameters(), lr=0.001)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 16,
|
|
||||||
"id": "a178c476-b566-42ec-b2ca-fb59ed85f6ee",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "8f7c956b746a454d82306b90bb3d3abf",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
" 0%| | 0/1 [00:00<?, ?it/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Epoch 1/5, Loss: 0.6613934636116028\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "f9c3700d6d6343bc882e082504745ec3",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
" 0%| | 0/1 [00:00<?, ?it/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Epoch 2/5, Loss: 0.4284511208534241\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "539a905707ef48b3a33406616d163695",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
" 0%| | 0/1 [00:00<?, ?it/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Epoch 3/5, Loss: 0.26125839352607727\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "0d9f7098cd9348d8964a083f91b5684b",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
" 0%| | 0/1 [00:00<?, ?it/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Epoch 4/5, Loss: 0.14203576743602753\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "eb7df70c5e1e441e9afb8032301a3808",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
" 0%| | 0/1 [00:00<?, ?it/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Epoch 5/5, Loss: 0.06412970274686813\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"num_epochs = 5\n",
|
|
||||||
"\n",
|
|
||||||
"for epoch in range(num_epochs):\n",
|
|
||||||
" model.train()\n",
|
|
||||||
" total_loss = 0\n",
|
|
||||||
" for inputs, labels in tqdm(train_dataloader):\n",
|
|
||||||
" inputs, labels = inputs.to(device), labels.to(device)\n",
|
|
||||||
" optimizer.zero_grad()\n",
|
|
||||||
" outputs = model(inputs)\n",
|
|
||||||
" outputs = outputs.view(-1, tagset_size)\n",
|
|
||||||
" labels = labels.view(-1)\n",
|
|
||||||
" \n",
|
|
||||||
" loss = criterion(outputs, labels)\n",
|
|
||||||
" loss.backward()\n",
|
|
||||||
" optimizer.step()\n",
|
|
||||||
" \n",
|
|
||||||
" total_loss += loss.item()\n",
|
|
||||||
" print(f\"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_dataloader)}\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 17,
|
|
||||||
"id": "bdd4c839-9df1-4f85-a1f4-2f5ca6436f89",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def evaluate(dataloader):\n",
|
|
||||||
" model.eval()\n",
|
|
||||||
" all_preds = []\n",
|
|
||||||
" all_labels = []\n",
|
|
||||||
" with torch.no_grad():\n",
|
|
||||||
" for inputs, labels in dataloader:\n",
|
|
||||||
" inputs, labels = inputs.to(device), labels.to(device)\n",
|
|
||||||
" outputs = model(inputs)\n",
|
|
||||||
" _, preds = torch.max(outputs, dim=2)\n",
|
|
||||||
" \n",
|
|
||||||
" all_preds.extend(preds.cpu().numpy())\n",
|
|
||||||
" all_labels.extend(labels.cpu().numpy())\n",
|
|
||||||
" \n",
|
|
||||||
" return all_preds, all_labels"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 18,
|
|
||||||
"id": "be8eee4f-4f28-433c-9e23-a3730ac11780",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"dev_preds, dev_labels_true = evaluate(dev_dataloader)\n",
|
|
||||||
"test_preds, test_labels_true = evaluate(test_dataloader)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 19,
|
|
||||||
"id": "961dcea7-d873-4aac-82c6-435953c449c5",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def calculate_accuracy(predictions, labels):\n",
|
|
||||||
" correct = 0\n",
|
|
||||||
" total = 0\n",
|
|
||||||
" for preds, true_labels in zip(predictions, labels):\n",
|
|
||||||
" for pred, true_label in zip(preds, true_labels):\n",
|
|
||||||
" if pred == true_label:\n",
|
|
||||||
" correct += 1\n",
|
|
||||||
" total += 1\n",
|
|
||||||
" accuracy = correct / total\n",
|
|
||||||
" return accuracy"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 20,
|
|
||||||
"id": "dc8791ef-ddfa-45d3-8b81-9e1eec38ad1a",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Accuracy on Dev Set: 0.14708994708994708\n",
|
|
||||||
"Accuracy on Test Set: 0.17566137566137566\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"dev_accuracy = calculate_accuracy(dev_preds, dev_labels_true)\n",
|
|
||||||
"test_accuracy = calculate_accuracy(test_preds, test_labels_true)\n",
|
|
||||||
"\n",
|
|
||||||
"print(\"Accuracy on Dev Set:\", dev_accuracy)\n",
|
|
||||||
"print(\"Accuracy on Test Set:\", test_accuracy)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 21,
|
|
||||||
"id": "4965d2be-eb8e-437d-b53b-3a038ea77e91",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def save_predictions(predictions, file_path):\n",
|
|
||||||
" with open(file_path, 'w') as file:\n",
|
|
||||||
" for preds in predictions:\n",
|
|
||||||
" for pred in preds:\n",
|
|
||||||
" file.write(f\"{pred}\\n\")\n",
|
|
||||||
" file.write(\"\\n\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 22,
|
|
||||||
"id": "9fefb237-ee84-430a-9e72-9f606c2bc7d3",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"save_predictions(dev_preds, dev_out_path)\n",
|
|
||||||
"save_predictions(test_preds, test_out_path)"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python 3 (ipykernel)",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.12.3"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 5
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user