rnn_dl/rnn.ipynb

605 lines
253 KiB
Plaintext
Raw Permalink Normal View History

2024-09-27 18:22:30 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 98,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import torch\n",
"from torchtext.vocab import vocab\n",
"from sklearn.model_selection import train_test_split\n",
"from tqdm.notebook import tqdm\n",
"from collections import Counter"
]
},
{
"cell_type": "code",
"execution_count": 99,
"metadata": {},
"outputs": [],
"source": [
"dev_0_in = \"./en-ner-conll-2003/dev-0/in.tsv\"\n",
"test_A_in = \"./en-ner-conll-2003/test-A/in.tsv\"\n",
"\n",
"dev_0_out = \"./en-ner-conll-2003/dev-0/out.tsv\"\n",
"test_A_out = \"./en-ner-conll-2003/test-A/out.tsv\"\n",
"\n",
"train = \"./en-ner-conll-2003/train/train.tsv\"\n",
"expected = \"./en-ner-conll-2003/dev-0/expected.tsv\""
]
},
{
"cell_type": "code",
"execution_count": 100,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": 101,
"metadata": {},
"outputs": [],
"source": [
"def load_datasets():\n",
" train_data = pd.read_csv(\n",
" \"./en-ner-conll-2003/train/train.tsv.xz\", compression=\"xz\", sep=\"\\t\", names=[\"Tag\", \"Sentence\"]\n",
" )\n",
" dev_data = pd.read_csv(dev_0_in, sep=\"\\t\", names=[\"Sentence\"])\n",
" dev_labels = pd.read_csv(expected, sep=\"\\t\", names=[\"Tag\"])\n",
" test_data = pd.read_csv(test_A_in, sep=\"\\t\", names=[\"Sentence\"])\n",
"\n",
" return train_data, dev_data, dev_labels, test_data\n",
"\n",
"train_data, dev_data, dev_labels, test_data = load_datasets()\n",
"\n",
"train_sentences, val_sentences, train_tags, val_tags = train_test_split(\n",
" train_data[\"Sentence\"], train_data[\"Tag\"], test_size=0.1, random_state=42\n",
")\n",
"\n",
"train_data = pd.DataFrame({\"Sentence\": train_sentences, \"Tag\": train_tags})\n",
"val_data = pd.DataFrame({\"Sentence\": val_sentences, \"Tag\": val_tags})\n",
"\n",
"def tokenize_column(dataframe, column):\n",
" return dataframe[column].apply(lambda x: x.split())\n",
"\n",
"train_data[\"tokens\"] = tokenize_column(train_data, \"Sentence\")\n",
"train_data[\"tag_tokens\"] = tokenize_column(train_data, \"Tag\")\n",
"val_data[\"tokens\"] = tokenize_column(val_data, \"Sentence\")\n",
"val_data[\"tag_tokens\"] = tokenize_column(val_data, \"Tag\")\n",
"dev_data[\"tokens\"] = tokenize_column(dev_data, \"Sentence\")\n",
"dev_labels[\"tag_tokens\"] = tokenize_column(dev_labels, \"Tag\")\n",
"test_data[\"tokens\"] = tokenize_column(test_data, \"Sentence\")"
]
},
{
"cell_type": "code",
"execution_count": 102,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Index(['Sentence', 'Tag', 'tokens', 'tag_tokens'], dtype='object')"
]
},
"execution_count": 102,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_data.columns"
]
},
{
"cell_type": "code",
"execution_count": 103,
"metadata": {},
"outputs": [],
"source": [
"def build_vocab(dataset):\n",
" counter = Counter()\n",
" for document in dataset:\n",
" counter.update(document)\n",
" return vocab(counter, specials=[\"<unk>\", \"<pad>\", \"<bos>\", \"<eos>\"])"
]
},
{
"cell_type": "code",
"execution_count": 104,
"metadata": {},
"outputs": [],
"source": [
"def data_process(token_lists, v, device):\n",
" return [\n",
" torch.tensor(\n",
" [v[\"<bos>\"]] + [v[token] for token in tokens] + [v[\"<eos>\"]],\n",
" dtype=torch.long,\n",
" device=device,\n",
" )\n",
" for tokens in token_lists\n",
" ]"
]
},
{
"cell_type": "code",
"execution_count": 105,
"metadata": {},
"outputs": [],
"source": [
"v = build_vocab(train_data[\"tokens\"])\n",
"v.set_default_index(v[\"<unk>\"])"
]
},
{
"cell_type": "code",
"execution_count": 106,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['<unk>', '<pad>', '<bos>', '<eos>', 'SOCCER', '-', 'POLISH', 'FIRST', 'DIVISION', 'RESULTS', '.', '</S>', 'WARSAW', '1996-08-24', 'Results', 'of', 'Polish', 'first', 'division', 'soccer', 'matches', 'on', 'Saturday', ':', 'Amica', 'Wronki', '3', 'Hutnik', 'Krakow', '0', 'Sokol', 'Tychy', '5', 'Lech', 'Poznan', 'Rakow', 'Czestochowa', '1', 'Stomil', 'Olsztyn', '4', 'Wisla', 'Gornik', 'Zabrze', 'Slask', 'Wroclaw', 'Odra', 'Wodzislaw', 'GKS', 'Katowice', 'Polonia', 'Warsaw', 'Zaglebie', 'Lubin', '2', 'LKS', 'Lodz', 'Legia', 'Belchatow', 'CRICKET', 'POLLOCK', 'CONCLUDES', 'WARWICKSHIRE', 'CAREER', 'WITH', 'FLOURISH', 'LONDON', '1996-08-25', 'South', 'African', 'fast', 'bowler', 'Shaun', 'Pollock', 'concluded', 'his', 'Warwickshire', 'career', 'with', 'a', 'flourish', 'Sunday', 'by', 'taking', 'the', 'final', 'three', 'wickets', 'during', 'county', \"'s\", 'league', 'victory', 'over', 'Worcestershire', ',', 'who', 'returns', 'home', 'Tuesday', 'for', 'an', 'ankle', 'operation', 'took', 'last', 'in', 'nine', 'balls', 'as', 'were', 'dismissed', '154', 'After', 'hour', 'interruption', 'rain', 'then', 'reached', 'adjusted', 'target', '109', '13', 'to', 'spare', 'and', 'record', 'their', 'fifth', 'win', 'six', 'games', 'are', 'currently', 'fourth', 'position', 'behind', 'Yorkshire', 'Nottinghamshire', 'Surrey', 'captain', 'David', 'Byas', 'completed', 'third', 'century', 'side', 'swept', 'clear', 'at', 'top', 'table', 'reaching', 'best', '111', 'not', 'out', 'against', 'Lancashire', 'total', '205', 'eight', 'from', '40', 'overs', 'looked', 'reasonable', 'before', 'put', 'attack', 'sword', 'collecting', 'runs', 'just', '100', 'sixes', 'fours', 'eventually', 'only', 'four', 'down', '7.5', 'CYCLING', 'BALLANGER', 'KEEPS', 'SPRINT', 'TITLE', 'IN', 'STYLE', 'Martin', 'Ayres', 'MANCHESTER', 'England', '1996-08-30', 'Felicia', 'Ballanger', 'France', 'confirmed', 'her', 'status', 'world', 'number', 'one', 'woman', 'sprinter', 'when', 'she', 'retained', 'title', 'cycling', 'championships', 'Friday', 'beat', 'Germany', 'Annett', 'Neumann', '2-0', 'best-of-three', 'add', 'Olympic', 'gold', 'medal', 'won', 'July', 'also', 'place', 'sprint', 'Magali', 'Faure', 'defeating', 'ex-world', 'champion', 'Tanya', 'Dubnicoff', 'Canada', '25', 'will', 'be', 'aiming', 'complete', 'track', 'double', 'defends', '500', 'metres', 'time', 'trial', 'The', 'other', 'night', 'women', '24-kms', 'points', 'race', 'ended', 'success', 'reigning', 'Russia', 'Svetlana', 'Samokhalova', 'fought', 'off', 'spirited', 'challenge', 'American', 'Jane', 'Quigley', 'take', 'second', 'year', 'nation', 'have', 'two', 'riders', 'field', 'made', 'full', 'use', 'numerical', 'superiority', 'Goulnara', 'Fatkoullina', 'helped', 'build', 'unbeatable', 'lead', 'snatching', 'bronze', 'former', 'medallist', 'event', 'led', 'half', 'distance', '\"', 'I', 'went', 'so', 'close', 'this', 'but', 'having', 'certainly', 'gave', 'Russians', 'advantage', 'said', 'lapped', 'which', 'left', 'Ingrid', 'Haringa', 'Netherlands', 'seventh', 'despite', 'highest', 'score', 'Nathalie', 'Lancien', 'missed', 'winning', 'finished', 'disappointing', '10th', 'RUGBY', 'LEAGUE', 'Australian', 'rugby', 'standings', 'SYDNEY', '1996-08-26', 'premiership', 'after', 'played', 'weekend', '(', 'tabulate', 'under', 'drawn', 'lost', ')', 'Manly', '21', '17', '501', '181', '34', 'Brisbane', '16', '569', '257', '32', 'North', 'Sydney', '14', '560', '317', '30', 'City', '20', '487', '293', '29', 'Cronulla', '12', '6', '359', '258', '26', 'Canberra', '8', '502', '374', 'St', 'George', '421', '344', 'Newcastle', '11', '9', '416', '366', '23', 'Western', 'Suburbs', '382', '426', 'Auckland', '10', '406', '389', '22', 'Tigers', '309', '435', 'Parramatta', '388', '391', 'Bulldogs', '325', '356', 'Illawarra', '395', '432', 'Reds', '297', '398', 'Penrith', '339', '448', 'Queensland', '15', '266', '593', 'Gold', 'Coast', '351', '483', '304', '586', '210', '460', '--', 'Newsroom', '61-2', '9373-1800', 'TENNIS', 'AT', 'HAMLET', 'CUP', 'COMMACK', 'New', 'York', 'Hamlet', 'Cup', 'tennis', 'tournament', 'prefix', 'denotes', 'seed
]
}
],
"source": [
"itos = v.get_itos()\n",
"print(itos)"
]
},
{
"cell_type": "code",
"execution_count": 107,
"metadata": {},
"outputs": [],
"source": [
"train_tensor = data_process(train_data[\"tokens\"], v, device)\n",
"val_tensor = data_process(val_data[\"tokens\"], v, device)\n",
"dev_tensor = data_process(dev_data[\"tokens\"], v, device)\n",
"test_tensor = data_process(test_data[\"tokens\"], v, device)"
]
},
{
"cell_type": "code",
"execution_count": 108,
"metadata": {},
"outputs": [],
"source": [
"tag_list = [\"O\", \"B-PER\", \"I-PER\", \"B-ORG\", \"I-ORG\", \"B-LOC\", \"I-LOC\", \"B-MISC\", \"I-MISC\"]\n",
"\n",
"tag_to_index = {tag: idx for idx, tag in enumerate(tag_list)}\n",
"\n",
"def convert_tags_to_tensor(tag_tokens, tag_to_index, device):\n",
" return [\n",
" torch.tensor(\n",
" [0] + [tag_to_index[tag] for tag in tags] + [0],\n",
" dtype=torch.long,\n",
" device=device,\n",
" )\n",
" for tags in tag_tokens\n",
" ]\n",
"\n",
"train_tag_tensor = convert_tags_to_tensor(train_data[\"tag_tokens\"], tag_to_index, device)\n",
"val_tag_tensor = convert_tags_to_tensor(val_data[\"tag_tokens\"], tag_to_index, device)\n",
"dev_tag_tensor = convert_tags_to_tensor(dev_labels[\"tag_tokens\"], tag_to_index, device)"
]
},
{
"cell_type": "code",
"execution_count": 109,
"metadata": {},
"outputs": [],
"source": [
"tag_to_index = {tag: idx for idx, tag in enumerate(tag_list)}\n",
"max_tag_index = max(tag_to_index.values()) + 1"
]
},
{
"cell_type": "code",
"execution_count": 110,
"metadata": {},
"outputs": [],
"source": [
"class LSTMModel(torch.nn.Module):\n",
" def __init__(self, vocab_size, embed_size, hidden_size, num_layers, output_size):\n",
" super(LSTMModel, self).__init__()\n",
" self.embedding = torch.nn.Embedding(vocab_size, embed_size)\n",
" self.lstm = torch.nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)\n",
" self.fc = torch.nn.Linear(hidden_size, output_size)\n",
"\n",
" def forward(self, x):\n",
" embedded = torch.relu(self.embedding(x))\n",
" lstm_out, _ = self.lstm(embedded)\n",
" logits = self.fc(lstm_out)\n",
" return logits\n"
]
},
{
"cell_type": "code",
"execution_count": 111,
"metadata": {},
"outputs": [],
"source": [
"model = LSTMModel(len(itos), 100, 100, 1, max_tag_index).to(device)\n",
"loss_fn = torch.nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.Adam(model.parameters())"
]
},
{
"cell_type": "code",
"execution_count": 112,
"metadata": {},
"outputs": [],
"source": [
"def eval_model(dataset_tokens, dataset_labels, model, tag_to_index):\n",
" Y_true = []\n",
" Y_pred = []\n",
" \n",
" index_to_tag = {v: k for k, v in tag_to_index.items()}\n",
"\n",
" for i in range(len(dataset_labels)):\n",
" inputs = dataset_tokens[i].unsqueeze(0).to('cuda')\n",
" true_labels = dataset_labels[i].cpu().numpy() \n",
" \n",
" with torch.no_grad():\n",
" logits = model(inputs).squeeze(0)\n",
" predicted = torch.argmax(logits, dim=1)\n",
" predicted_labels = predicted.cpu().numpy()\n",
"\n",
" true_tags = [index_to_tag[label] for label in true_labels[1:-1]]\n",
" pred_tags = [index_to_tag[label] for label in predicted_labels[1:-1]]\n",
"\n",
" if len(true_tags) == len(pred_tags):\n",
" Y_true.extend(true_tags)\n",
" Y_pred.extend(pred_tags)\n",
" else:\n",
" continue\n",
"\n",
" return Y_true, Y_pred"
]
},
{
"cell_type": "code",
"execution_count": 113,
"metadata": {},
"outputs": [],
"source": [
"def calculate_accuracy(true_labels, pred_labels):\n",
" correct = sum(t == p for t, p in zip(true_labels, pred_labels))\n",
" total = len(true_labels)\n",
" accuracy = correct / total if total > 0 else 0\n",
" return accuracy"
]
},
{
"cell_type": "code",
"execution_count": 114,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoka 1, Dokładność: 85.41\n",
"Epoka 2, Dokładność: 88.92\n",
"Epoka 3, Dokładność: 90.99\n",
"Epoka 4, Dokładność: 92.06\n",
"Epoka 5, Dokładność: 92.69\n",
"Epoka 6, Dokładność: 93.19\n",
"Epoka 7, Dokładność: 93.50\n",
"Epoka 8, Dokładność: 93.78\n",
"Epoka 9, Dokładność: 93.94\n",
"Epoka 10, Dokładność: 94.08\n"
]
}
],
"source": [
"epochs = 10\n",
"\n",
"for epoch in range(epochs):\n",
" model.train()\n",
" \n",
" for i in range(len(train_tag_tensor)):\n",
" inputs = train_tensor[i].unsqueeze(0).to('cuda')\n",
" targets = train_tag_tensor[i].unsqueeze(1).to('cuda')\n",
"\n",
" optimizer.zero_grad()\n",
" outputs = model(inputs)\n",
" loss = loss_fn(outputs.squeeze(0), targets.squeeze(1))\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" model.eval()\n",
"\n",
" val_tensor_cuda = [tensor.to('cuda') for tensor in val_tensor]\n",
" val_tag_tensor_cuda = [tensor.to('cuda') for tensor in val_tag_tensor]\n",
"\n",
" Y_true, Y_pred = eval_model(val_tensor_cuda, val_tag_tensor_cuda, model, tag_to_index)\n",
" accuracy = calculate_accuracy(Y_true, Y_pred)\n",
" print(f\"Epoka {epoch + 1}, Dokładność: {accuracy * 100:.2f}\")\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": 115,
"metadata": {},
"outputs": [],
"source": [
"def predict_labels(tokens, model, tag_to_index):\n",
" predictions = []\n",
" index_to_tag = {v: k for k, v in tag_to_index.items()}\n",
"\n",
" for i in range(len(tokens)):\n",
" inputs = tokens[i].unsqueeze(0)\n",
" with torch.no_grad():\n",
" logits = model(inputs).squeeze(0)\n",
" predicted = torch.argmax(logits, dim=1)\n",
" tags = [index_to_tag[label.item()] for label in predicted[1:-1]]\n",
" predictions.append(\" \".join(tags))\n",
"\n",
" return predictions"
]
},
{
"cell_type": "code",
"execution_count": 116,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dokładność: 94.08%\n"
]
}
],
"source": [
"val_tensor_cuda = [tensor.to('cuda') for tensor in val_tensor]\n",
"val_tag_tensor_cuda = [tensor.to('cuda') for tensor in val_tag_tensor] \n",
"\n",
"Y_true, Y_pred = eval_model(val_tensor_cuda, val_tag_tensor_cuda, model, tag_to_index)\n",
"accuracy = calculate_accuracy(Y_true, Y_pred)\n",
"print(f\"Dokładność: {accuracy * 100:.2f}%\")"
]
},
{
"cell_type": "code",
"execution_count": 117,
"metadata": {},
"outputs": [],
"source": [
"dev_predictions = predict_labels(dev_tensor, model, tag_to_index)\n",
"dev_predictions_df = pd.DataFrame(dev_predictions, columns=[\"Tag\"])\n",
"dev_predictions_df.to_csv(dev_0_out, index=False, header=False)\n",
"\n",
"test_predictions = predict_labels(test_tensor, model, tag_to_index)\n",
"test_predictions_df = pd.DataFrame(test_predictions, columns=[\"Tag\"])\n",
"test_predictions_df.to_csv(test_A_out, index=False, header=False)"
]
},
{
"cell_type": "code",
"execution_count": 118,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Tag</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>O O O O O O O O O O O O B-LOC O O B-LOC I-LOC ...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>O O B-MISC B-MISC I-MISC O O O B-LOC O O O O O...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>O O O O O O O B-LOC O O B-LOC O O O O O O O O ...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>O O O O O B-LOC O O O B-LOC O O O O O O O B-MI...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>O O O O O O O B-LOC O O O O O O O O O O O O O ...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>O O B-ORG O O O O O O O B-LOC O O B-LOC O B-LO...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>O O O O O O O O O O O B-LOC O O O O O O O O O ...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>O O O O O O O O O O O O O O B-LOC O O O O O B-...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>O O B-MISC I-MISC I-MISC O O O O B-LOC I-LOC O...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>O O O O B-MISC I-MISC O O O O B-LOC O O O O B-...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Tag\n",
"0 O O O O O O O O O O O O B-LOC O O B-LOC I-LOC ...\n",
"1 O O B-MISC B-MISC I-MISC O O O B-LOC O O O O O...\n",
"2 O O O O O O O B-LOC O O B-LOC O O O O O O O O ...\n",
"3 O O O O O B-LOC O O O B-LOC O O O O O O O B-MI...\n",
"4 O O O O O O O B-LOC O O O O O O O O O O O O O ...\n",
"5 O O B-ORG O O O O O O O B-LOC O O B-LOC O B-LO...\n",
"6 O O O O O O O O O O O B-LOC O O O O O O O O O ...\n",
"7 O O O O O O O O O O O O O O B-LOC O O O O O B-...\n",
"8 O O B-MISC I-MISC I-MISC O O O O B-LOC I-LOC O...\n",
"9 O O O O B-MISC I-MISC O O O O B-LOC O O O O B-..."
]
},
"execution_count": 118,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dev_predictions_df.head(10)"
]
},
{
"cell_type": "code",
"execution_count": 119,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Tag</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>O O O O O O O B-ORG O O O O O O O O O O B-LOC ...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>B-ORG I-ORG O O O O O O O O O O B-LOC O O B-LO...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>O O O O O O O O O O O O O B-LOC I-ORG I-LOC O ...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>O O B-MISC O O O O O O B-LOC O O O O O B-MISC ...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>O O O B-MISC O O O O O O O B-LOC I-ORG I-LOC O...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>O O B-LOC O B-LOC I-LOC O O O O O O O B-MISC O...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>O O B-MISC O I-MISC O O O O O B-LOC O O O O O ...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>O O O B-ORG O O O B-LOC O O B-MISC O O O O O O...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>O O B-LOC O O O O O O O O O O B-LOC O O B-ORG ...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>O O O O O O B-LOC O O O O O O O O O O O O O O ...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Tag\n",
"0 O O O O O O O B-ORG O O O O O O O O O O B-LOC ...\n",
"1 B-ORG I-ORG O O O O O O O O O O B-LOC O O B-LO...\n",
"2 O O O O O O O O O O O O O B-LOC I-ORG I-LOC O ...\n",
"3 O O B-MISC O O O O O O B-LOC O O O O O B-MISC ...\n",
"4 O O O B-MISC O O O O O O O B-LOC I-ORG I-LOC O...\n",
"5 O O B-LOC O B-LOC I-LOC O O O O O O O B-MISC O...\n",
"6 O O B-MISC O I-MISC O O O O O B-LOC O O O O O ...\n",
"7 O O O B-ORG O O O B-LOC O O B-MISC O O O O O O...\n",
"8 O O B-LOC O O O O O O O O O O B-LOC O O B-ORG ...\n",
"9 O O O O O O B-LOC O O O O O O O O O O O O O O ..."
]
},
"execution_count": 119,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_predictions_df.head(10)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}