RNN/rnn_2.ipynb

461 lines
249 KiB
Plaintext
Raw Normal View History

2024-05-27 14:39:47 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['<unk>', '<pad>', '<bos>', '<eos>', 'SOCCER', '-', 'POLISH', 'FIRST', 'DIVISION', 'RESULTS', '.', '</S>', 'WARSAW', '1996-08-24', 'Results', 'of', 'Polish', 'first', 'division', 'soccer', 'matches', 'on', 'Saturday', ':', 'Amica', 'Wronki', '3', 'Hutnik', 'Krakow', '0', 'Sokol', 'Tychy', '5', 'Lech', 'Poznan', 'Rakow', 'Czestochowa', '1', 'Stomil', 'Olsztyn', '4', 'Wisla', 'Gornik', 'Zabrze', 'Slask', 'Wroclaw', 'Odra', 'Wodzislaw', 'GKS', 'Katowice', 'Polonia', 'Warsaw', 'Zaglebie', 'Lubin', '2', 'LKS', 'Lodz', 'Legia', 'Belchatow', 'CRICKET', 'POLLOCK', 'CONCLUDES', 'WARWICKSHIRE', 'CAREER', 'WITH', 'FLOURISH', 'LONDON', '1996-08-25', 'South', 'African', 'fast', 'bowler', 'Shaun', 'Pollock', 'concluded', 'his', 'Warwickshire', 'career', 'with', 'a', 'flourish', 'Sunday', 'by', 'taking', 'the', 'final', 'three', 'wickets', 'during', 'county', \"'s\", 'league', 'victory', 'over', 'Worcestershire', ',', 'who', 'returns', 'home', 'Tuesday', 'for', 'an', 'ankle', 'operation', 'took', 'last', 'in', 'nine', 'balls', 'as', 'were', 'dismissed', '154', 'After', 'hour', 'interruption', 'rain', 'then', 'reached', 'adjusted', 'target', '109', '13', 'to', 'spare', 'and', 'record', 'their', 'fifth', 'win', 'six', 'games', 'are', 'currently', 'fourth', 'position', 'behind', 'Yorkshire', 'Nottinghamshire', 'Surrey', 'captain', 'David', 'Byas', 'completed', 'third', 'century', 'side', 'swept', 'clear', 'at', 'top', 'table', 'reaching', 'best', '111', 'not', 'out', 'against', 'Lancashire', 'total', '205', 'eight', 'from', '40', 'overs', 'looked', 'reasonable', 'before', 'put', 'attack', 'sword', 'collecting', 'runs', 'just', '100', 'sixes', 'fours', 'eventually', 'only', 'four', 'down', '7.5', 'CYCLING', 'BALLANGER', 'KEEPS', 'SPRINT', 'TITLE', 'IN', 'STYLE', 'Martin', 'Ayres', 'MANCHESTER', 'England', '1996-08-30', 'Felicia', 'Ballanger', 'France', 'confirmed', 'her', 'status', 'world', 'number', 'one', 'woman', 'sprinter', 'when', 'she', 'retained', 'title', 'cycling', 'championships', 'Friday', 'beat', 'Germany', 'Annett', 'Neumann', '2-0', 'best-of-three', 'add', 'Olympic', 'gold', 'medal', 'won', 'July', 'also', 'place', 'sprint', 'Magali', 'Faure', 'defeating', 'ex-world', 'champion', 'Tanya', 'Dubnicoff', 'Canada', '25', 'will', 'be', 'aiming', 'complete', 'track', 'double', 'defends', '500', 'metres', 'time', 'trial', 'The', 'other', 'night', 'women', '24-kms', 'points', 'race', 'ended', 'success', 'reigning', 'Russia', 'Svetlana', 'Samokhalova', 'fought', 'off', 'spirited', 'challenge', 'American', 'Jane', 'Quigley', 'take', 'second', 'year', 'nation', 'have', 'two', 'riders', 'field', 'made', 'full', 'use', 'numerical', 'superiority', 'Goulnara', 'Fatkoullina', 'helped', 'build', 'unbeatable', 'lead', 'snatching', 'bronze', 'former', 'medallist', 'event', 'led', 'half', 'distance', '\"', 'I', 'went', 'so', 'close', 'this', 'but', 'having', 'certainly', 'gave', 'Russians', 'advantage', 'said', 'lapped', 'which', 'left', 'Ingrid', 'Haringa', 'Netherlands', 'seventh', 'despite', 'highest', 'score', 'Nathalie', 'Lancien', 'missed', 'winning', 'finished', 'disappointing', '10th', 'RUGBY', 'LEAGUE', 'Australian', 'rugby', 'standings', 'SYDNEY', '1996-08-26', 'premiership', 'after', 'played', 'weekend', '(', 'tabulate', 'under', 'drawn', 'lost', ')', 'Manly', '21', '17', '501', '181', '34', 'Brisbane', '16', '569', '257', '32', 'North', 'Sydney', '14', '560', '317', '30', 'City', '20', '487', '293', '29', 'Cronulla', '12', '6', '359', '258', '26', 'Canberra', '8', '502', '374', 'St', 'George', '421', '344', 'Newcastle', '11', '9', '416', '366', '23', 'Western', 'Suburbs', '382', '426', 'Auckland', '10', '406', '389', '22', 'Tigers', '309', '435', 'Parramatta', '388', '391', 'Bulldogs', '325', '356', 'Illawarra', '395', '432', 'Reds', '297', '398', 'Penrith', '339', '448', 'Queensland', '15', '266', '593', 'Gold', 'Coast', '351', '483', '304', '586', '210', '460', '--', 'Newsroom', '61-2', '9373-1800', 'TENNIS', 'AT', 'HAMLET', 'CUP', 'COMMACK', 'New', 'York', 'Hamlet', 'Cup', 'tennis', 'tournament', 'prefix', 'denotes', 'seed
"22154\n",
"9\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a9e31a5d56f947bab11440afba13f2b2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/850 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2c08aba575f4435ca98af63d5dc96fc0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/95 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.5546171171171171, 0.2832901926948519, 0.37502379592613744)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a042bbe87f404a0e8463a641335693b6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/850 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "11115dc17f3e4f9195a1b598a5cc1feb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/95 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.5892077354959451, 0.543284440609721, 0.5653149783031572)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d0734c011bb341cd9352338fbc3be59c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/850 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c5af3a6e70d44efc9e36bce4874e18d9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/95 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.5969561794804513, 0.6542996836353178, 0.6243139407244785)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "adbd98d4c11647f3bd348bca22b8e98d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/850 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "804fdc7cbd5c4d64877813c63c5339e4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/95 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.5954920019389239, 0.7066436583261432, 0.6463238195449165)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "adb5363844f74ef8b792704b71366b9b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/850 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "165c7d86d6614f29a6fa4ca4ce8b9c0f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/95 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.5790436970944864, 0.727926373310325, 0.6450050968399592)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b58d5293d2624cceb1e71bb7aebbb915",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/95 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ca543149052943ac9dc1bb054ab08ec3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/215 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "064d06a3ae394f2698deffd4727a7ec0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/215 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "dc5d4d0379a9489db22c8d3b727403ad",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/230 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"from torch.utils.data import DataLoader, Dataset\n",
"from collections import Counter\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"from torchtext.vocab import vocab\n",
"from tqdm.notebook import tqdm\n",
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"# Funkcja do wczytywania danych\n",
"def load_datasets():\n",
" train_df = pd.read_csv(\n",
" \"train/train.tsv.xz\", compression=\"xz\", sep=\"\\t\", names=[\"Category\", \"Sentence\"]\n",
" )\n",
" dev_df = pd.read_csv(\"dev-0/in.tsv\", sep=\"\\t\", names=[\"Sentence\"])\n",
" dev_labels = pd.read_csv(\"dev-0/expected.tsv\", sep=\"\\t\", names=[\"Category\"])\n",
" test_df = pd.read_csv(\"test-A/in.tsv\", sep=\"\\t\", names=[\"Sentence\"])\n",
" return train_df, dev_df, dev_labels, test_df\n",
"\n",
"train_df, dev_df, dev_labels, test_df = load_datasets()\n",
"train_texts, val_texts, train_labels, val_labels = train_test_split(\n",
" train_df[\"Sentence\"], train_df[\"Category\"], test_size=0.1, random_state=42\n",
")\n",
"train_df = pd.DataFrame({\"Sentence\": train_texts, \"Category\": train_labels})\n",
"val_df = pd.DataFrame({\"Sentence\": val_texts, \"Category\": val_labels})\n",
"\n",
"# Tokenizacja danych\n",
"train_df[\"tokens\"] = train_df[\"Sentence\"].apply(lambda x: x.split())\n",
"train_df[\"label_tokens\"] = train_df[\"Category\"].apply(lambda x: x.split())\n",
"test_df[\"tokens\"] = test_df[\"Sentence\"].apply(lambda x: x.split())\n",
"val_df[\"tokens\"] = val_df[\"Sentence\"].apply(lambda x: x.split())\n",
"val_df[\"label_tokens\"] = val_df[\"Category\"].apply(lambda x: x.split())\n",
"dev_df[\"tokens\"] = dev_df[\"Sentence\"].apply(lambda x: x.split())\n",
"dev_df[\"label_tokens\"] = dev_labels[\"Category\"].apply(lambda x: x.split())\n",
"\n",
"# Budowanie słownika\n",
"def create_vocab(token_list):\n",
" count = Counter()\n",
" for tokens in token_list:\n",
" count.update(tokens)\n",
" return vocab(count, specials=[\"<unk>\", \"<pad>\", \"<bos>\", \"<eos>\"])\n",
"\n",
"vocabulary = create_vocab(train_df[\"tokens\"])\n",
"index_to_string = vocabulary.get_itos()\n",
"print(index_to_string)\n",
"print(len(index_to_string))\n",
"\n",
"vocabulary.set_default_index(vocabulary[\"<unk>\"])\n",
"\n",
"# Przetwarzanie danych na wektory\n",
"def vectorize_data(data_tokens):\n",
" return [\n",
" torch.tensor([vocabulary[\"<bos>\"]] + [vocabulary[token] for token in tokens] + [vocabulary[\"<eos>\"]], dtype=torch.long)\n",
" for tokens in data_tokens\n",
" ]\n",
"\n",
"def vectorize_labels(data_labels, label_map):\n",
" return [\n",
" torch.tensor([0] + [label_map[label] for label in labels] + [0], dtype=torch.long, device=device)\n",
" for labels in data_labels\n",
" ]\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"train_tokens = vectorize_data(train_df[\"tokens\"])\n",
"test_tokens = vectorize_data(test_df[\"tokens\"])\n",
"val_tokens = vectorize_data(val_df[\"tokens\"])\n",
"dev_tokens = vectorize_data(dev_df[\"tokens\"])\n",
"\n",
"label_list = [\"O\", \"B-PER\", \"I-PER\", \"B-ORG\", \"I-ORG\", \"B-LOC\", \"I-LOC\", \"B-MISC\", \"I-MISC\"]\n",
"label_map = {label: idx for idx, label in enumerate(label_list)}\n",
"\n",
"train_label_tokens = vectorize_labels(train_df[\"label_tokens\"], label_map)\n",
"val_label_tokens = vectorize_labels(val_df[\"label_tokens\"], label_map)\n",
"dev_label_tokens = vectorize_labels(dev_df[\"label_tokens\"], label_map)\n",
"\n",
"# Funkcja do obliczania metryk\n",
"def calculate_metrics(true_labels, pred_labels):\n",
" accuracy = 0\n",
" true_positive = 0\n",
" false_positive = 0\n",
" total_selected = 0\n",
" total_relevant = 0\n",
"\n",
" for pred, true in zip(pred_labels, true_labels):\n",
" if pred == true:\n",
" accuracy += 1\n",
"\n",
" if pred > 0 and pred == true:\n",
" true_positive += 1\n",
"\n",
" if pred > 0:\n",
" total_selected += 1\n",
"\n",
" if true > 0:\n",
" total_relevant += 1\n",
"\n",
" precision = true_positive / total_selected if total_selected > 0 else 1.0\n",
" recall = true_positive / total_relevant if total_relevant > 0 else 1.0\n",
" f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0\n",
"\n",
" return precision, recall, f1_score\n",
"\n",
"label_indices = [label_map[label] for labels in train_df[\"label_tokens\"] for label in labels]\n",
"num_classes = max(label_indices) + 1\n",
"print(num_classes)\n",
"\n",
"# Definicja modelu LSTM\n",
"class BiLSTM(nn.Module):\n",
" def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, num_classes):\n",
" super(BiLSTM, self).__init__()\n",
" self.embedding = nn.Embedding(vocab_size, embed_dim)\n",
" self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True, bidirectional=True)\n",
" self.fc = nn.Linear(hidden_dim * 2, num_classes)\n",
"\n",
" def forward(self, x):\n",
" x_embed = torch.relu(self.embedding(x))\n",
" lstm_out, _ = self.lstm(x_embed)\n",
" logits = self.fc(lstm_out)\n",
" return logits\n",
"\n",
"model = BiLSTM(len(index_to_string), 100, 100, 1, num_classes).to(device)\n",
"loss_function = nn.CrossEntropyLoss()\n",
"optimizer = optim.Adam(model.parameters())\n",
"\n",
"# Funkcja do ewaluacji modelu\n",
"def evaluate_model(data_tokens, data_labels, model):\n",
" true_labels = []\n",
" pred_labels = []\n",
" for i in tqdm(range(len(data_labels))):\n",
" tokens = data_tokens[i].unsqueeze(0)\n",
" true_labels_batch = list(data_labels[i].cpu().numpy())\n",
" true_labels += true_labels_batch\n",
"\n",
" pred_logits = model(tokens).squeeze(0)\n",
" pred_batch = torch.argmax(pred_logits, 1)\n",
" pred_labels += list(pred_batch.cpu().numpy())\n",
"\n",
" return calculate_metrics(true_labels, pred_labels)\n",
"\n",
"# Funkcja do predykcji etykiet\n",
"def predict_labels(data_tokens, model, label_map):\n",
" pred_labels = []\n",
" inv_label_map = {v: k for k, v in label_map.items()}\n",
"\n",
" for i in tqdm(range(len(data_tokens))):\n",
" tokens = data_tokens[i].unsqueeze(0)\n",
" pred_logits = model(tokens).squeeze(0)\n",
" pred_batch = torch.argmax(pred_logits, 1)\n",
" pred_label_list = [inv_label_map[label.item()] for label in pred_batch]\n",
" pred_label_list = pred_label_list[1:-1]\n",
" pred_labels.append(\" \".join(pred_label_list))\n",
"\n",
" return pred_labels\n",
"\n",
"# Trening modelu\n",
"EPOCHS = 5\n",
"for epoch in range(EPOCHS):\n",
" model.train()\n",
" for i in tqdm(range(len(train_label_tokens))):\n",
" tokens = train_tokens[i].unsqueeze(0)\n",
" true_labels = train_label_tokens[i].unsqueeze(1)\n",
"\n",
" pred_labels = model(tokens)\n",
" optimizer.zero_grad()\n",
" loss = loss_function(pred_labels.squeeze(0), true_labels.squeeze(1))\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" model.eval()\n",
" print(evaluate_model(val_tokens, val_label_tokens, model))\n",
"\n",
"# Ewaluacja na zbiorze walidacyjnym i deweloperskim\n",
"evaluate_model(val_tokens, val_label_tokens, model)\n",
"evaluate_model(dev_tokens, dev_label_tokens, model)\n",
"\n",
"# Generowanie predykcji dla zbiorów testowych\n",
"dev_predictions = predict_labels(dev_tokens, model, label_map)\n",
"dev_predictions_df = pd.DataFrame(dev_predictions, columns=[\"Category\"])\n",
"dev_predictions_df.to_csv(\"dev-0/out.tsv\", index=False, header=False)\n",
"\n",
"test_predictions = predict_labels(test_tokens, model, label_map)\n",
"test_predictions_df = pd.DataFrame(test_predictions, columns=[\"Category\"])\n",
"test_predictions_df.to_csv(\"test-A/out.tsv\", index=False, header=False)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "myenv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}