en-ner-conll-2003/seq_labeling.py.ipynb
2021-06-07 12:41:08 +02:00

444 lines
15 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import os.path\n",
"import shutil\n",
"import torch\n",
"import pandas as pd\n",
"from sklearn.model_selection import train_test_split\n",
"from torchtext.vocab import Vocab\n",
"from collections import Counter"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"model_path = \"seq_labeling.model\"\n",
"if not os.path.isfile('train/train.tsv'):\n",
" import lzma\n",
" with lzma.open('train/train.tsv.xz', 'rb') as f_in:\n",
" with open('train/train.tsv', 'wb') as f_out:\n",
" shutil.copyfileobj(f_in, f_out)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>iob</th>\n",
" <th>tokens</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>[5, 0, 3, 0, 0, 0, 3, 0, 0, 0, 7, 8, 0, 1, 0, ...</td>\n",
" <td>[EU, rejects, German, call, to, boycott, Briti...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>[0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ...</td>\n",
" <td>[Rare, Hendrix, song, draft, sells, for, almos...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>[1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, ...</td>\n",
" <td>[China, says, Taiwan, spoils, atmosphere, for,...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>[1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, ...</td>\n",
" <td>[China, says, time, right, for, Taiwan, talks,...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>[3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ...</td>\n",
" <td>[German, July, car, registrations, up, 14.2, p...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>940</th>\n",
" <td>[0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 7, 8, 0, 1, 0, ...</td>\n",
" <td>[CYCLING, -, BALLANGER, KEEPS, SPRINT, TITLE, ...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>941</th>\n",
" <td>[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, ...</td>\n",
" <td>[CYCLING, -, WORLD, TRACK, CHAMPIONSHIP, RESUL...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>942</th>\n",
" <td>[0, 0, 3, 0, 7, 0, 5, 0, 0, 1, 0, 1, 0, 0, 3, ...</td>\n",
" <td>[SOCCER, -, FRENCH, DEFENDER, KOMBOUARE, JOINS...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>943</th>\n",
" <td>[0, 0, 1, 2, 3, 4, 0, 0, 0, 0, 1, 0, 1, 0, 0, ...</td>\n",
" <td>[MOTORCYCLING, -, SAN, MARINO, GRAND, PRIX, PR...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>944</th>\n",
" <td>[0, 0, 3, 4, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, ...</td>\n",
" <td>[GOLF, -, BRITISH, MASTERS, THIRD, ROUND, SCOR...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>945 rows × 2 columns</p>\n",
"</div>"
],
"text/plain": [
" iob \\\n",
"0 [5, 0, 3, 0, 0, 0, 3, 0, 0, 0, 7, 8, 0, 1, 0, ... \n",
"1 [0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ... \n",
"2 [1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, ... \n",
"3 [1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, ... \n",
"4 [3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ... \n",
".. ... \n",
"940 [0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 7, 8, 0, 1, 0, ... \n",
"941 [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, ... \n",
"942 [0, 0, 3, 0, 7, 0, 5, 0, 0, 1, 0, 1, 0, 0, 3, ... \n",
"943 [0, 0, 1, 2, 3, 4, 0, 0, 0, 0, 1, 0, 1, 0, 0, ... \n",
"944 [0, 0, 3, 4, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, ... \n",
"\n",
" tokens \n",
"0 [EU, rejects, German, call, to, boycott, Briti... \n",
"1 [Rare, Hendrix, song, draft, sells, for, almos... \n",
"2 [China, says, Taiwan, spoils, atmosphere, for,... \n",
"3 [China, says, time, right, for, Taiwan, talks,... \n",
"4 [German, July, car, registrations, up, 14.2, p... \n",
".. ... \n",
"940 [CYCLING, -, BALLANGER, KEEPS, SPRINT, TITLE, ... \n",
"941 [CYCLING, -, WORLD, TRACK, CHAMPIONSHIP, RESUL... \n",
"942 [SOCCER, -, FRENCH, DEFENDER, KOMBOUARE, JOINS... \n",
"943 [MOTORCYCLING, -, SAN, MARINO, GRAND, PRIX, PR... \n",
"944 [GOLF, -, BRITISH, MASTERS, THIRD, ROUND, SCOR... \n",
"\n",
"[945 rows x 2 columns]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"labels = ['O','B-LOC', 'I-LOC','B-MISC', 'I-MISC', 'B-ORG', 'I-ORG', 'B-PER', 'I-PER']\n",
"\n",
"data = pd.read_csv('train/train.tsv', sep='\\t', names=['iob', 'tokens'])\n",
"data[\"iob\"]=data[\"iob\"].apply(lambda x: [labels.index(y) for y in x.split()])\n",
"data[\"tokens\"]=data[\"tokens\"].apply(lambda x: x.split())\n",
"data"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def build_vocab(dataset):\n",
" counter = Counter()\n",
" for document in dataset:\n",
" counter.update(document)\n",
" return Vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>']) #, '<alpha>', '<notalpha>'])"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"vocab = build_vocab(data['tokens'])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def labels_process(dt):\n",
" return [ torch.tensor([0] + document + [0], dtype = torch.long) for document in dt]\n",
"\n",
"def data_process(dt):\n",
" return [ torch.tensor([vocab['<bos>']] +[vocab[token] for token in document ] + [vocab['<eos>']], dtype = torch.long) for document in dt]\n",
"\n",
"# def data_process(dt):\n",
"# result = []\n",
"# for document in dt:\n",
"# sentence = [vocab['<bos>'],vocab['<alpha>']]\n",
"# for token in document:\n",
"# sentence += [vocab[token]]\n",
"# sentence += [vocab['<alpha>'] if token.isalpha() else vocab['<notalpha>']]\n",
"# sentence += [vocab['<eos>'],vocab['<alpha>']]\n",
"# result.append(torch.tensor(sentence, dtype = torch.long))\n",
"# return result"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"23628"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(vocab.itos)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"class NERModel(torch.nn.Module):\n",
" def __init__(self,):\n",
" super(NERModel, self).__init__()\n",
" self.emb = torch.nn.Embedding(23629,200)\n",
" self.fc1 = torch.nn.Linear(1200,9) \n",
"\n",
" def forward(self, x):\n",
" x = self.emb(x)\n",
" x = x.reshape(1200) \n",
" x = self.fc1(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# class NERModel(torch.nn.Module):\n",
"# def __init__(self,):\n",
"# super(NERModel, self).__init__()\n",
"# #self.emb = torch.nn.Embedding(23629,200)\n",
"# self.emb = torch.nn.Embedding(23628,200)\n",
"# self.fc1 = torch.nn.Linear(600,9) \n",
"\n",
"# def forward(self, x):\n",
"# x = self.emb(x)\n",
"# x = x.reshape(600) \n",
"# x = self.fc1(x)\n",
"# return x"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"device_gpu = torch.device(\"cuda:0\")\n",
"device_cpu = torch.device(\"cpu\")\n",
"\n",
"ner_model = NERModel().to(device_gpu)\n",
"\n",
"criterion = torch.nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.Adam(ner_model.parameters())\n",
"\n",
"train_labels = labels_process(data['iob'])\n",
"train_tokens_ids = data_process(data['tokens'])"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"if not os.path.isfile(model_path):\n",
" for epoch in range(5):\n",
" acc_score = 0\n",
" prec_score = 0\n",
" selected_items = 0\n",
" recall_score = 0\n",
" relevant_items = 0\n",
" items_total = 0\n",
" ner_model.train()\n",
" for i in range(len(train_labels)):\n",
" for j in range(1, len(train_labels[i]) - 1):\n",
" #for j in range(2, len(train_labels[i]) - 2, 2):\n",
"\n",
" #X = train_tokens_ids[i][j-2: j+4].to(device_gpu)\n",
" X = train_tokens_ids[i][j-1: j+2].to(device_gpu)\n",
" \n",
" Y = train_labels[i][j: j+1].to(device_gpu)\n",
" Y_predictions = ner_model(X)\n",
" \n",
" acc_score += int(torch.argmax(Y_predictions) == Y)\n",
" if torch.argmax(Y_predictions) != 0:\n",
" selected_items +=1\n",
" if torch.argmax(Y_predictions) != 0 and torch.argmax(Y_predictions) == Y.item():\n",
" prec_score += 1\n",
" if Y.item() != 0:\n",
" relevant_items +=1\n",
" if Y.item() != 0 and torch.argmax(Y_predictions) == Y.item():\n",
" recall_score += 1\n",
"\n",
" items_total += 1\n",
" optimizer.zero_grad()\n",
" loss = criterion(Y_predictions.unsqueeze(0), Y)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" precision = prec_score / selected_items\n",
" recall = recall_score / relevant_items\n",
" f1_score = (2*precision * recall) / (precision + recall)\n",
" print(f'epoch: {epoch}')\n",
" print(f'f1: {f1_score}')\n",
" print(f'acc: {acc_score/ items_total}')\n",
" torch.save(ner_model.state_dict(), model_path)\n",
"else:\n",
" ner_model.load_state_dict(torch.load(model_path))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"def process(model, x):\n",
" predicted = model(x)\n",
" result = torch.argmax(predicted)\n",
" return labels[result]\n",
"\n",
"def process_dataset(model, path):\n",
" with open(path, 'r') as f:\n",
" lines = f.readlines()\n",
" X = [x.split() for x in lines]\n",
" data_tokens_ids = data_process(X)\n",
" results = []\n",
" for i in range(len(data_tokens_ids)):\n",
" line_results = []\n",
" #for j in range(1, len(data_tokens_ids[i]) - 1):\n",
" for j in range(2, len(data_tokens_ids[i]) - 3, 2):\n",
" x = data_tokens_ids[i][j-2: j+4].to(device_gpu)\n",
" # x = data_tokens_ids[i][j-1: j+2].to(device_gpu)\n",
" label = process(model, x)\n",
" line_results.append(label)\n",
" results.append(line_results)\n",
" return results\n",
"\n",
"# Przetwarzanie danych z wyjścia modelu (gdy B- i I- nie dotyczą tej samej etykiety)\n",
"def process_output(lines):\n",
" result = []\n",
" for line in lines:\n",
" last_label = None\n",
" new_line = []\n",
" for label in line:\n",
" if(label != \"O\" and label[0:2] == \"I-\"):\n",
" if last_label == None or last_label == \"O\":\n",
" label = label.replace('I-', 'B-')\n",
" else:\n",
" label = \"I-\" + last_label[2:]\n",
" last_label = label\n",
" new_line.append(label)\n",
" result.append(\" \".join(new_line))\n",
" return result\n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"results = process_dataset(ner_model,\"dev-0/in.tsv\")\n",
"file_content = process_output(results)\n",
"with open(\"dev-0/out.tsv\", \"w\") as f:\n",
" for line in file_content:\n",
" f.write(line + \"\\n\")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# results = process_dataset(ner_model,\"test-A/in.tsv\")\n",
"# file_content = [' '.join(x) for x in results]\n",
"# with open(\"test-A/out.tsv\", \"w\") as f:\n",
"# for line in file_content:\n",
"# print(line)\n",
"# #f.write(line + \"\\n\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}