forked from kubapok/en-ner-conll-2003
444 lines
15 KiB
Plaintext
444 lines
15 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import pandas as pd\n",
|
||
"import numpy as np\n",
|
||
"import os.path\n",
|
||
"import shutil\n",
|
||
"import torch\n",
|
||
"import pandas as pd\n",
|
||
"from sklearn.model_selection import train_test_split\n",
|
||
"from torchtext.vocab import Vocab\n",
|
||
"from collections import Counter"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"model_path = \"seq_labeling.model\"\n",
|
||
"if not os.path.isfile('train/train.tsv'):\n",
|
||
" import lzma\n",
|
||
" with lzma.open('train/train.tsv.xz', 'rb') as f_in:\n",
|
||
" with open('train/train.tsv', 'wb') as f_out:\n",
|
||
" shutil.copyfileobj(f_in, f_out)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>iob</th>\n",
|
||
" <th>tokens</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>0</th>\n",
|
||
" <td>[5, 0, 3, 0, 0, 0, 3, 0, 0, 0, 7, 8, 0, 1, 0, ...</td>\n",
|
||
" <td>[EU, rejects, German, call, to, boycott, Briti...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1</th>\n",
|
||
" <td>[0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ...</td>\n",
|
||
" <td>[Rare, Hendrix, song, draft, sells, for, almos...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2</th>\n",
|
||
" <td>[1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, ...</td>\n",
|
||
" <td>[China, says, Taiwan, spoils, atmosphere, for,...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3</th>\n",
|
||
" <td>[1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, ...</td>\n",
|
||
" <td>[China, says, time, right, for, Taiwan, talks,...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4</th>\n",
|
||
" <td>[3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ...</td>\n",
|
||
" <td>[German, July, car, registrations, up, 14.2, p...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>940</th>\n",
|
||
" <td>[0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 7, 8, 0, 1, 0, ...</td>\n",
|
||
" <td>[CYCLING, -, BALLANGER, KEEPS, SPRINT, TITLE, ...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>941</th>\n",
|
||
" <td>[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, ...</td>\n",
|
||
" <td>[CYCLING, -, WORLD, TRACK, CHAMPIONSHIP, RESUL...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>942</th>\n",
|
||
" <td>[0, 0, 3, 0, 7, 0, 5, 0, 0, 1, 0, 1, 0, 0, 3, ...</td>\n",
|
||
" <td>[SOCCER, -, FRENCH, DEFENDER, KOMBOUARE, JOINS...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>943</th>\n",
|
||
" <td>[0, 0, 1, 2, 3, 4, 0, 0, 0, 0, 1, 0, 1, 0, 0, ...</td>\n",
|
||
" <td>[MOTORCYCLING, -, SAN, MARINO, GRAND, PRIX, PR...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>944</th>\n",
|
||
" <td>[0, 0, 3, 4, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, ...</td>\n",
|
||
" <td>[GOLF, -, BRITISH, MASTERS, THIRD, ROUND, SCOR...</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>945 rows × 2 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" iob \\\n",
|
||
"0 [5, 0, 3, 0, 0, 0, 3, 0, 0, 0, 7, 8, 0, 1, 0, ... \n",
|
||
"1 [0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ... \n",
|
||
"2 [1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, ... \n",
|
||
"3 [1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, ... \n",
|
||
"4 [3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ... \n",
|
||
".. ... \n",
|
||
"940 [0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 7, 8, 0, 1, 0, ... \n",
|
||
"941 [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, ... \n",
|
||
"942 [0, 0, 3, 0, 7, 0, 5, 0, 0, 1, 0, 1, 0, 0, 3, ... \n",
|
||
"943 [0, 0, 1, 2, 3, 4, 0, 0, 0, 0, 1, 0, 1, 0, 0, ... \n",
|
||
"944 [0, 0, 3, 4, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, ... \n",
|
||
"\n",
|
||
" tokens \n",
|
||
"0 [EU, rejects, German, call, to, boycott, Briti... \n",
|
||
"1 [Rare, Hendrix, song, draft, sells, for, almos... \n",
|
||
"2 [China, says, Taiwan, spoils, atmosphere, for,... \n",
|
||
"3 [China, says, time, right, for, Taiwan, talks,... \n",
|
||
"4 [German, July, car, registrations, up, 14.2, p... \n",
|
||
".. ... \n",
|
||
"940 [CYCLING, -, BALLANGER, KEEPS, SPRINT, TITLE, ... \n",
|
||
"941 [CYCLING, -, WORLD, TRACK, CHAMPIONSHIP, RESUL... \n",
|
||
"942 [SOCCER, -, FRENCH, DEFENDER, KOMBOUARE, JOINS... \n",
|
||
"943 [MOTORCYCLING, -, SAN, MARINO, GRAND, PRIX, PR... \n",
|
||
"944 [GOLF, -, BRITISH, MASTERS, THIRD, ROUND, SCOR... \n",
|
||
"\n",
|
||
"[945 rows x 2 columns]"
|
||
]
|
||
},
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"labels = ['O','B-LOC', 'I-LOC','B-MISC', 'I-MISC', 'B-ORG', 'I-ORG', 'B-PER', 'I-PER']\n",
|
||
"\n",
|
||
"data = pd.read_csv('train/train.tsv', sep='\\t', names=['iob', 'tokens'])\n",
|
||
"data[\"iob\"]=data[\"iob\"].apply(lambda x: [labels.index(y) for y in x.split()])\n",
|
||
"data[\"tokens\"]=data[\"tokens\"].apply(lambda x: x.split())\n",
|
||
"data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def build_vocab(dataset):\n",
|
||
" counter = Counter()\n",
|
||
" for document in dataset:\n",
|
||
" counter.update(document)\n",
|
||
" return Vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>']) #, '<alpha>', '<notalpha>'])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"vocab = build_vocab(data['tokens'])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def labels_process(dt):\n",
|
||
" return [ torch.tensor([0] + document + [0], dtype = torch.long) for document in dt]\n",
|
||
"\n",
|
||
"def data_process(dt):\n",
|
||
" return [ torch.tensor([vocab['<bos>']] +[vocab[token] for token in document ] + [vocab['<eos>']], dtype = torch.long) for document in dt]\n",
|
||
"\n",
|
||
"# def data_process(dt):\n",
|
||
"# result = []\n",
|
||
"# for document in dt:\n",
|
||
"# sentence = [vocab['<bos>'],vocab['<alpha>']]\n",
|
||
"# for token in document:\n",
|
||
"# sentence += [vocab[token]]\n",
|
||
"# sentence += [vocab['<alpha>'] if token.isalpha() else vocab['<notalpha>']]\n",
|
||
"# sentence += [vocab['<eos>'],vocab['<alpha>']]\n",
|
||
"# result.append(torch.tensor(sentence, dtype = torch.long))\n",
|
||
"# return result"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"23628"
|
||
]
|
||
},
|
||
"execution_count": 7,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"len(vocab.itos)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class NERModel(torch.nn.Module):\n",
|
||
" def __init__(self,):\n",
|
||
" super(NERModel, self).__init__()\n",
|
||
" self.emb = torch.nn.Embedding(23629,200)\n",
|
||
" self.fc1 = torch.nn.Linear(1200,9) \n",
|
||
"\n",
|
||
" def forward(self, x):\n",
|
||
" x = self.emb(x)\n",
|
||
" x = x.reshape(1200) \n",
|
||
" x = self.fc1(x)\n",
|
||
" return x"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# class NERModel(torch.nn.Module):\n",
|
||
"# def __init__(self,):\n",
|
||
"# super(NERModel, self).__init__()\n",
|
||
"# #self.emb = torch.nn.Embedding(23629,200)\n",
|
||
"# self.emb = torch.nn.Embedding(23628,200)\n",
|
||
"# self.fc1 = torch.nn.Linear(600,9) \n",
|
||
"\n",
|
||
"# def forward(self, x):\n",
|
||
"# x = self.emb(x)\n",
|
||
"# x = x.reshape(600) \n",
|
||
"# x = self.fc1(x)\n",
|
||
"# return x"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"device_gpu = torch.device(\"cuda:0\")\n",
|
||
"device_cpu = torch.device(\"cpu\")\n",
|
||
"\n",
|
||
"ner_model = NERModel().to(device_gpu)\n",
|
||
"\n",
|
||
"criterion = torch.nn.CrossEntropyLoss()\n",
|
||
"optimizer = torch.optim.Adam(ner_model.parameters())\n",
|
||
"\n",
|
||
"train_labels = labels_process(data['iob'])\n",
|
||
"train_tokens_ids = data_process(data['tokens'])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"metadata": {
|
||
"scrolled": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"if not os.path.isfile(model_path):\n",
|
||
" for epoch in range(5):\n",
|
||
" acc_score = 0\n",
|
||
" prec_score = 0\n",
|
||
" selected_items = 0\n",
|
||
" recall_score = 0\n",
|
||
" relevant_items = 0\n",
|
||
" items_total = 0\n",
|
||
" ner_model.train()\n",
|
||
" for i in range(len(train_labels)):\n",
|
||
" for j in range(1, len(train_labels[i]) - 1):\n",
|
||
" #for j in range(2, len(train_labels[i]) - 2, 2):\n",
|
||
"\n",
|
||
" #X = train_tokens_ids[i][j-2: j+4].to(device_gpu)\n",
|
||
" X = train_tokens_ids[i][j-1: j+2].to(device_gpu)\n",
|
||
" \n",
|
||
" Y = train_labels[i][j: j+1].to(device_gpu)\n",
|
||
" Y_predictions = ner_model(X)\n",
|
||
" \n",
|
||
" acc_score += int(torch.argmax(Y_predictions) == Y)\n",
|
||
" if torch.argmax(Y_predictions) != 0:\n",
|
||
" selected_items +=1\n",
|
||
" if torch.argmax(Y_predictions) != 0 and torch.argmax(Y_predictions) == Y.item():\n",
|
||
" prec_score += 1\n",
|
||
" if Y.item() != 0:\n",
|
||
" relevant_items +=1\n",
|
||
" if Y.item() != 0 and torch.argmax(Y_predictions) == Y.item():\n",
|
||
" recall_score += 1\n",
|
||
"\n",
|
||
" items_total += 1\n",
|
||
" optimizer.zero_grad()\n",
|
||
" loss = criterion(Y_predictions.unsqueeze(0), Y)\n",
|
||
" loss.backward()\n",
|
||
" optimizer.step()\n",
|
||
"\n",
|
||
" precision = prec_score / selected_items\n",
|
||
" recall = recall_score / relevant_items\n",
|
||
" f1_score = (2*precision * recall) / (precision + recall)\n",
|
||
" print(f'epoch: {epoch}')\n",
|
||
" print(f'f1: {f1_score}')\n",
|
||
" print(f'acc: {acc_score/ items_total}')\n",
|
||
" torch.save(ner_model.state_dict(), model_path)\n",
|
||
"else:\n",
|
||
" ner_model.load_state_dict(torch.load(model_path))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def process(model, x):\n",
|
||
" predicted = model(x)\n",
|
||
" result = torch.argmax(predicted)\n",
|
||
" return labels[result]\n",
|
||
"\n",
|
||
"def process_dataset(model, path):\n",
|
||
" with open(path, 'r') as f:\n",
|
||
" lines = f.readlines()\n",
|
||
" X = [x.split() for x in lines]\n",
|
||
" data_tokens_ids = data_process(X)\n",
|
||
" results = []\n",
|
||
" for i in range(len(data_tokens_ids)):\n",
|
||
" line_results = []\n",
|
||
" #for j in range(1, len(data_tokens_ids[i]) - 1):\n",
|
||
" for j in range(2, len(data_tokens_ids[i]) - 3, 2):\n",
|
||
" x = data_tokens_ids[i][j-2: j+4].to(device_gpu)\n",
|
||
" # x = data_tokens_ids[i][j-1: j+2].to(device_gpu)\n",
|
||
" label = process(model, x)\n",
|
||
" line_results.append(label)\n",
|
||
" results.append(line_results)\n",
|
||
" return results\n",
|
||
"\n",
|
||
"# Przetwarzanie danych z wyjścia modelu (gdy B- i I- nie dotyczą tej samej etykiety)\n",
|
||
"def process_output(lines):\n",
|
||
" result = []\n",
|
||
" for line in lines:\n",
|
||
" last_label = None\n",
|
||
" new_line = []\n",
|
||
" for label in line:\n",
|
||
" if(label != \"O\" and label[0:2] == \"I-\"):\n",
|
||
" if last_label == None or last_label == \"O\":\n",
|
||
" label = label.replace('I-', 'B-')\n",
|
||
" else:\n",
|
||
" label = \"I-\" + last_label[2:]\n",
|
||
" last_label = label\n",
|
||
" new_line.append(label)\n",
|
||
" result.append(\" \".join(new_line))\n",
|
||
" return result\n",
|
||
" \n",
|
||
" "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"results = process_dataset(ner_model,\"dev-0/in.tsv\")\n",
|
||
"file_content = process_output(results)\n",
|
||
"with open(\"dev-0/out.tsv\", \"w\") as f:\n",
|
||
" for line in file_content:\n",
|
||
" f.write(line + \"\\n\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# results = process_dataset(ner_model,\"test-A/in.tsv\")\n",
|
||
"# file_content = [' '.join(x) for x in results]\n",
|
||
"# with open(\"test-A/out.tsv\", \"w\") as f:\n",
|
||
"# for line in file_content:\n",
|
||
"# print(line)\n",
|
||
"# #f.write(line + \"\\n\")"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.8.5"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 4
|
||
}
|