forked from kubapok/en-ner-conll-2003
ffix
This commit is contained in:
parent
7e71e86b87
commit
31fab4d6b0
993
seq.ipynb
993
seq.ipynb
@ -1,498 +1,507 @@
|
|||||||
{
|
{
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
"nbformat_minor": 5,
|
"nbformat_minor": 5,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3",
|
"display_name": "Python 3",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.8.5"
|
|
||||||
},
|
|
||||||
"colab": {
|
|
||||||
"name": "Copy of final.ipynb",
|
|
||||||
"provenance": []
|
|
||||||
},
|
|
||||||
"accelerator": "GPU"
|
|
||||||
},
|
},
|
||||||
"cells": [
|
"language_info": {
|
||||||
{
|
"codemirror_mode": {
|
||||||
"cell_type": "code",
|
"name": "ipython",
|
||||||
"metadata": {
|
"version": 3
|
||||||
"id": "7d9d7e79"
|
},
|
||||||
},
|
"file_extension": ".py",
|
||||||
"source": [
|
"mimetype": "text/x-python",
|
||||||
"import os.path\n",
|
"name": "python",
|
||||||
"import pandas as pd\n",
|
"nbconvert_exporter": "python",
|
||||||
"import numpy as np\n",
|
"pygments_lexer": "ipython3",
|
||||||
"import torch\n",
|
"version": "3.8.5"
|
||||||
"import csv\n",
|
},
|
||||||
"from collections import Counter\n",
|
"colab": {
|
||||||
"from torchtext.vocab import Vocab\n"
|
"name": "Copy of final.ipynb",
|
||||||
],
|
"provenance": []
|
||||||
"id": "7d9d7e79",
|
},
|
||||||
"execution_count": 1,
|
"accelerator": "GPU"
|
||||||
"outputs": []
|
},
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"id": "7d9d7e79"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"import torch\n",
|
||||||
|
"from collections import Counter\n",
|
||||||
|
"from torchtext.vocab import Vocab\n"
|
||||||
|
],
|
||||||
|
"id": "7d9d7e79",
|
||||||
|
"execution_count": 1,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"id": "8pcDKHIVkgAE"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"def predict(input_tokens, labels):\n",
|
||||||
|
"\n",
|
||||||
|
" results = []\n",
|
||||||
|
" \n",
|
||||||
|
" for i in range(len(input_tokens)):\n",
|
||||||
|
" line_results = []\n",
|
||||||
|
" for j in range(1, len(input_tokens[i]) - 1):\n",
|
||||||
|
" x = input_tokens[i][j-1: j+2].to(device_gpu)\n",
|
||||||
|
" predicted = ner_model(x.long())\n",
|
||||||
|
" result = torch.argmax(predicted)\n",
|
||||||
|
" label = labels[result]\n",
|
||||||
|
" line_results.append(label)\n",
|
||||||
|
" results.append(line_results)\n",
|
||||||
|
"\n",
|
||||||
|
" return results"
|
||||||
|
],
|
||||||
|
"id": "8pcDKHIVkgAE",
|
||||||
|
"execution_count": 49,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"id": "DqQGtDRyh3vd"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"def create_tensors_list(data):\n",
|
||||||
|
" tensors = []\n",
|
||||||
|
"\n",
|
||||||
|
" for sent in data[\"tokens\"]:\n",
|
||||||
|
" sent_tensor = torch.tensor(())\n",
|
||||||
|
" for word in sent:\n",
|
||||||
|
" temp = torch.tensor([word[0].isupper(), word[0].isdigit()])\n",
|
||||||
|
" sent_tensor = torch.cat((sent_tensor, temp))\n",
|
||||||
|
"\n",
|
||||||
|
" tensors.append(sent_tensor)\n",
|
||||||
|
"\n",
|
||||||
|
" return tensors"
|
||||||
|
],
|
||||||
|
"id": "DqQGtDRyh3vd",
|
||||||
|
"execution_count": 37,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"id": "yFR38yG4hk_B"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"def save_to_file(path, results):\n",
|
||||||
|
" with open(path, \"w\") as f:\n",
|
||||||
|
" for line in results:\n",
|
||||||
|
" f.write(line + \"\\n\")"
|
||||||
|
],
|
||||||
|
"id": "yFR38yG4hk_B",
|
||||||
|
"execution_count": 29,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"id": "mFSG7d15hHWJ"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"def extra_features(tokens_ids, tensors_list):\n",
|
||||||
|
" return [torch.cat((token, tensors_list[i])) for i, token in enumerate(tokens_ids)]"
|
||||||
|
],
|
||||||
|
"id": "mFSG7d15hHWJ",
|
||||||
|
"execution_count": 31,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"id": "pVQCgy3JhAiF"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"def process_output(lines):\n",
|
||||||
|
" result = []\n",
|
||||||
|
" for line in lines:\n",
|
||||||
|
" last_label = None\n",
|
||||||
|
" new_line = []\n",
|
||||||
|
" for label in line:\n",
|
||||||
|
" if(label != \"O\" and label[0:2] == \"I-\"):\n",
|
||||||
|
" if last_label == None or last_label == \"O\":\n",
|
||||||
|
" label = label.replace('I-', 'B-')\n",
|
||||||
|
" else:\n",
|
||||||
|
" label = \"I-\" + last_label[2:]\n",
|
||||||
|
" last_label = label\n",
|
||||||
|
" new_line.append(label)\n",
|
||||||
|
" x = (\" \".join(new_line))\n",
|
||||||
|
" result.append(\" \".join(new_line))\n",
|
||||||
|
" return result\n"
|
||||||
|
],
|
||||||
|
"id": "pVQCgy3JhAiF",
|
||||||
|
"execution_count": 28,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"id": "ea129277"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"def build_vocab(dataset):\n",
|
||||||
|
" counter = Counter()\n",
|
||||||
|
" for document in dataset:\n",
|
||||||
|
" counter.update(document)\n",
|
||||||
|
" return Vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])"
|
||||||
|
],
|
||||||
|
"id": "ea129277",
|
||||||
|
"execution_count": 2,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"id": "961b8a50"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"def labels_process(dt):\n",
|
||||||
|
" return [torch.tensor([0] + document + [0], dtype=torch.long) for document in dt]"
|
||||||
|
],
|
||||||
|
"id": "961b8a50",
|
||||||
|
"execution_count": 3,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"id": "afe14c32"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"def data_process(dt):\n",
|
||||||
|
" return [torch.tensor([vocab['<bos>']] + [vocab[token] for token in document] + [vocab['<eos>']], dtype=torch.long)\n",
|
||||||
|
" for document in dt]"
|
||||||
|
],
|
||||||
|
"id": "afe14c32",
|
||||||
|
"execution_count": 4,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"id": "63798281"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"class NERModel(torch.nn.Module):\n",
|
||||||
|
" def __init__(self, ):\n",
|
||||||
|
" super(NERModel, self).__init__()\n",
|
||||||
|
" self.emb = torch.nn.Embedding(23628, 200)\n",
|
||||||
|
" self.fc1 = torch.nn.Linear(600, 9)\n",
|
||||||
|
"\n",
|
||||||
|
" def forward(self, x):\n",
|
||||||
|
" x = self.emb(x)\n",
|
||||||
|
" x = x.reshape(600)\n",
|
||||||
|
" x = self.fc1(x)\n",
|
||||||
|
" return x\n"
|
||||||
|
],
|
||||||
|
"id": "63798281",
|
||||||
|
"execution_count": 6,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"id": "de6551ba"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"labels = ['O', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC', 'B-ORG', 'I-ORG', 'B-PER', 'I-PER']\n"
|
||||||
|
],
|
||||||
|
"id": "de6551ba",
|
||||||
|
"execution_count": 50,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"id": "9829ad04",
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/"
|
||||||
},
|
},
|
||||||
|
"outputId": "38dac368-b5dc-4ad8-ec5a-a8ae7abf11d2"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"data = pd.read_csv('train/train.tsv', sep='\\t', names=['iob', 'tokens'])\n",
|
||||||
|
"data[\"iob\"] = data[\"iob\"].apply(lambda x: [labels.index(y) for y in x.split()])\n",
|
||||||
|
"data[\"tokens\"] = data[\"tokens\"].apply(lambda x: x.split())\n",
|
||||||
|
"\n",
|
||||||
|
"extra_tensors = create_tensors_list(data)\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"vocab = build_vocab(data['tokens'])\n",
|
||||||
|
"\n",
|
||||||
|
"device_gpu = torch.device(\"cuda:0\")\n",
|
||||||
|
"ner_model = NERModel().to(device_gpu)\n",
|
||||||
|
"criterion = torch.nn.CrossEntropyLoss()\n",
|
||||||
|
"optimizer = torch.optim.Adam(ner_model.parameters())\n",
|
||||||
|
"\n",
|
||||||
|
"train_labels = labels_process(data['iob'])\n",
|
||||||
|
"train_tokens_ids = data_process(data['tokens'])\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"train_tensors = extra_features(train_tokens_ids, extra_tensors)\n",
|
||||||
|
"\n",
|
||||||
|
"print(train_tensors[0])\n"
|
||||||
|
],
|
||||||
|
"id": "9829ad04",
|
||||||
|
"execution_count": 55,
|
||||||
|
"outputs": [
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"output_type": "stream",
|
||||||
"metadata": {
|
"text": [
|
||||||
"id": "8pcDKHIVkgAE"
|
"tensor([2.0000e+00, 9.6700e+02, 2.2410e+04, ..., 0.0000e+00, 0.0000e+00,\n",
|
||||||
},
|
" 0.0000e+00])\n"
|
||||||
"source": [
|
],
|
||||||
"def predict(input_tokens, labels):\n",
|
"name": "stdout"
|
||||||
" results = []\n",
|
|
||||||
" \n",
|
|
||||||
" for i in range(len(input_tokens)):\n",
|
|
||||||
" line_results = []\n",
|
|
||||||
" for j in range(1, len(input_tokens[i]) - 1):\n",
|
|
||||||
" x = input_tokens[i][j-1: j+2].to(device_gpu)\n",
|
|
||||||
" predicted = ner_model(x.long())\n",
|
|
||||||
" result = torch.argmax(predicted)\n",
|
|
||||||
" label = labels[result]\n",
|
|
||||||
" line_results.append(label)\n",
|
|
||||||
" results.append(line_results)\n",
|
|
||||||
"\n",
|
|
||||||
" return results"
|
|
||||||
],
|
|
||||||
"id": "8pcDKHIVkgAE",
|
|
||||||
"execution_count": 49,
|
|
||||||
"outputs": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"metadata": {
|
|
||||||
"id": "DqQGtDRyh3vd"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"def create_tensors_list(data):\n",
|
|
||||||
" tensors = []\n",
|
|
||||||
"\n",
|
|
||||||
" for sent in data[\"tokens\"]:\n",
|
|
||||||
" sent_tensor = torch.tensor(())\n",
|
|
||||||
" for word in sent:\n",
|
|
||||||
" temp = torch.tensor([word[0].isupper(), word[0].isdigit()])\n",
|
|
||||||
" sent_tensor = torch.cat((sent_tensor, temp))\n",
|
|
||||||
"\n",
|
|
||||||
" tensors.append(sent_tensor)\n",
|
|
||||||
"\n",
|
|
||||||
" return tensors"
|
|
||||||
],
|
|
||||||
"id": "DqQGtDRyh3vd",
|
|
||||||
"execution_count": 37,
|
|
||||||
"outputs": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"metadata": {
|
|
||||||
"id": "yFR38yG4hk_B"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"def save_to_file(path, results):\n",
|
|
||||||
" with open(path, \"w\") as f:\n",
|
|
||||||
" for line in results:\n",
|
|
||||||
" f.write(line + \"\\n\")"
|
|
||||||
],
|
|
||||||
"id": "yFR38yG4hk_B",
|
|
||||||
"execution_count": 29,
|
|
||||||
"outputs": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"metadata": {
|
|
||||||
"id": "mFSG7d15hHWJ"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"def extra_features(tokens_ids, tensors_list):\n",
|
|
||||||
" return [torch.cat((token, tensors_list[i])) for i, token in enumerate(tokens_ids)]"
|
|
||||||
],
|
|
||||||
"id": "mFSG7d15hHWJ",
|
|
||||||
"execution_count": 31,
|
|
||||||
"outputs": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"metadata": {
|
|
||||||
"id": "pVQCgy3JhAiF"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"def process_output(lines):\n",
|
|
||||||
" result = []\n",
|
|
||||||
" for line in lines:\n",
|
|
||||||
" last_label = None\n",
|
|
||||||
" new_line = []\n",
|
|
||||||
" for label in line:\n",
|
|
||||||
" if(label != \"O\" and label[0:2] == \"I-\"):\n",
|
|
||||||
" if last_label == None or last_label == \"O\":\n",
|
|
||||||
" label = label.replace('I-', 'B-')\n",
|
|
||||||
" else:\n",
|
|
||||||
" label = \"I-\" + last_label[2:]\n",
|
|
||||||
" last_label = label\n",
|
|
||||||
" new_line.append(label)\n",
|
|
||||||
" x = (\" \".join(new_line))\n",
|
|
||||||
" result.append(\" \".join(new_line))\n",
|
|
||||||
" return result\n"
|
|
||||||
],
|
|
||||||
"id": "pVQCgy3JhAiF",
|
|
||||||
"execution_count": 28,
|
|
||||||
"outputs": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"metadata": {
|
|
||||||
"id": "ea129277"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"def build_vocab(dataset):\n",
|
|
||||||
" counter = Counter()\n",
|
|
||||||
" for document in dataset:\n",
|
|
||||||
" counter.update(document)\n",
|
|
||||||
" return Vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])"
|
|
||||||
],
|
|
||||||
"id": "ea129277",
|
|
||||||
"execution_count": 2,
|
|
||||||
"outputs": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"metadata": {
|
|
||||||
"id": "961b8a50"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"def labels_process(dt):\n",
|
|
||||||
" return [torch.tensor([0] + document + [0], dtype=torch.long) for document in dt]"
|
|
||||||
],
|
|
||||||
"id": "961b8a50",
|
|
||||||
"execution_count": 3,
|
|
||||||
"outputs": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"metadata": {
|
|
||||||
"id": "afe14c32"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"def data_process(dt):\n",
|
|
||||||
" return [torch.tensor([vocab['<bos>']] + [vocab[token] for token in document] + [vocab['<eos>']], dtype=torch.long)\n",
|
|
||||||
" for document in dt]"
|
|
||||||
],
|
|
||||||
"id": "afe14c32",
|
|
||||||
"execution_count": 4,
|
|
||||||
"outputs": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"metadata": {
|
|
||||||
"id": "63798281"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"class NERModel(torch.nn.Module):\n",
|
|
||||||
" def __init__(self, ):\n",
|
|
||||||
" super(NERModel, self).__init__()\n",
|
|
||||||
" self.emb = torch.nn.Embedding(23628, 200)\n",
|
|
||||||
" self.fc1 = torch.nn.Linear(600, 9)\n",
|
|
||||||
"\n",
|
|
||||||
" def forward(self, x):\n",
|
|
||||||
" x = self.emb(x)\n",
|
|
||||||
" x = x.reshape(600)\n",
|
|
||||||
" x = self.fc1(x)\n",
|
|
||||||
" return x\n"
|
|
||||||
],
|
|
||||||
"id": "63798281",
|
|
||||||
"execution_count": 6,
|
|
||||||
"outputs": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"metadata": {
|
|
||||||
"id": "de6551ba"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"labels = ['O', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC', 'B-ORG', 'I-ORG', 'B-PER', 'I-PER']\n"
|
|
||||||
],
|
|
||||||
"id": "de6551ba",
|
|
||||||
"execution_count": 50,
|
|
||||||
"outputs": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"metadata": {
|
|
||||||
"id": "9829ad04",
|
|
||||||
"colab": {
|
|
||||||
"base_uri": "https://localhost:8080/"
|
|
||||||
},
|
|
||||||
"outputId": "38dac368-b5dc-4ad8-ec5a-a8ae7abf11d2"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"data = pd.read_csv('train.tsv', sep='\\t', names=['iob', 'tokens'])\n",
|
|
||||||
"data[\"iob\"] = data[\"iob\"].apply(lambda x: [labels.index(y) for y in x.split()])\n",
|
|
||||||
"data[\"tokens\"] = data[\"tokens\"].apply(lambda x: x.split())\n",
|
|
||||||
"\n",
|
|
||||||
"extra_tensors = create_tensors_list(data)\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"vocab = build_vocab(data['tokens'])\n",
|
|
||||||
"\n",
|
|
||||||
"device_gpu = torch.device(\"cuda:0\")\n",
|
|
||||||
"ner_model = NERModel().to(device_gpu)\n",
|
|
||||||
"criterion = torch.nn.CrossEntropyLoss()\n",
|
|
||||||
"optimizer = torch.optim.Adam(ner_model.parameters())\n",
|
|
||||||
"\n",
|
|
||||||
"train_labels = labels_process(data['iob'])\n",
|
|
||||||
"train_tokens_ids = data_process(data['tokens'])\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"train_tensors = extra_features(train_tokens_ids, extra_tensors)\n",
|
|
||||||
"\n",
|
|
||||||
"print(train_tensors[0])\n"
|
|
||||||
],
|
|
||||||
"id": "9829ad04",
|
|
||||||
"execution_count": 55,
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"tensor([2.0000e+00, 9.6700e+02, 2.2410e+04, ..., 0.0000e+00, 0.0000e+00,\n",
|
|
||||||
" 0.0000e+00])\n"
|
|
||||||
],
|
|
||||||
"name": "stdout"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"metadata": {
|
|
||||||
"colab": {
|
|
||||||
"base_uri": "https://localhost:8080/"
|
|
||||||
},
|
|
||||||
"id": "bccaf9f7",
|
|
||||||
"outputId": "fff85047-6e9d-41d3-c8d9-2f8fe93046a3"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"type(train_labels)"
|
|
||||||
],
|
|
||||||
"id": "bccaf9f7",
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"output_type": "execute_result",
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"list"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {
|
|
||||||
"tags": []
|
|
||||||
},
|
|
||||||
"execution_count": 8
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"metadata": {
|
|
||||||
"colab": {
|
|
||||||
"base_uri": "https://localhost:8080/"
|
|
||||||
},
|
|
||||||
"id": "11d6a6d9",
|
|
||||||
"outputId": "f3e330ce-96f4-40b4-9846-c3ab602495e2"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"data[\"iob\"]\n"
|
|
||||||
],
|
|
||||||
"id": "11d6a6d9",
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"output_type": "execute_result",
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"0 [5, 0, 3, 0, 0, 0, 3, 0, 0, 0, 7, 8, 0, 1, 0, ...\n",
|
|
||||||
"1 [0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ...\n",
|
|
||||||
"2 [1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, ...\n",
|
|
||||||
"3 [1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, ...\n",
|
|
||||||
"4 [3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ...\n",
|
|
||||||
" ... \n",
|
|
||||||
"940 [0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 7, 8, 0, 1, 0, ...\n",
|
|
||||||
"941 [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, ...\n",
|
|
||||||
"942 [0, 0, 3, 0, 7, 0, 5, 0, 0, 1, 0, 1, 0, 0, 3, ...\n",
|
|
||||||
"943 [0, 0, 1, 2, 3, 4, 0, 0, 0, 0, 1, 0, 1, 0, 0, ...\n",
|
|
||||||
"944 [0, 0, 3, 4, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, ...\n",
|
|
||||||
"Name: iob, Length: 945, dtype: object"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {
|
|
||||||
"tags": []
|
|
||||||
},
|
|
||||||
"execution_count": 10
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"metadata": {
|
|
||||||
"colab": {
|
|
||||||
"base_uri": "https://localhost:8080/"
|
|
||||||
},
|
|
||||||
"id": "50e0cbaf",
|
|
||||||
"outputId": "e49f06bc-77ac-425c-ddf0-4aae1a1f6fb3"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"data[\"tokens\"]"
|
|
||||||
],
|
|
||||||
"id": "50e0cbaf",
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"output_type": "execute_result",
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"0 [EU, rejects, German, call, to, boycott, Briti...\n",
|
|
||||||
"1 [Rare, Hendrix, song, draft, sells, for, almos...\n",
|
|
||||||
"2 [China, says, Taiwan, spoils, atmosphere, for,...\n",
|
|
||||||
"3 [China, says, time, right, for, Taiwan, talks,...\n",
|
|
||||||
"4 [German, July, car, registrations, up, 14.2, p...\n",
|
|
||||||
" ... \n",
|
|
||||||
"940 [CYCLING, -, BALLANGER, KEEPS, SPRINT, TITLE, ...\n",
|
|
||||||
"941 [CYCLING, -, WORLD, TRACK, CHAMPIONSHIP, RESUL...\n",
|
|
||||||
"942 [SOCCER, -, FRENCH, DEFENDER, KOMBOUARE, JOINS...\n",
|
|
||||||
"943 [MOTORCYCLING, -, SAN, MARINO, GRAND, PRIX, PR...\n",
|
|
||||||
"944 [GOLF, -, BRITISH, MASTERS, THIRD, ROUND, SCOR...\n",
|
|
||||||
"Name: tokens, Length: 945, dtype: object"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {
|
|
||||||
"tags": []
|
|
||||||
},
|
|
||||||
"execution_count": 11
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"metadata": {
|
|
||||||
"colab": {
|
|
||||||
"base_uri": "https://localhost:8080/"
|
|
||||||
},
|
|
||||||
"id": "24c3f4f8",
|
|
||||||
"outputId": "0d3d550f-4eb1-41cf-889c-7dbb7ddfc62c"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"for epoch in range(5):\n",
|
|
||||||
" acc_score = 0\n",
|
|
||||||
" prec_score = 0\n",
|
|
||||||
" selected_items = 0\n",
|
|
||||||
" recall_score = 0\n",
|
|
||||||
" relevant_items = 0\n",
|
|
||||||
" items_total = 0\n",
|
|
||||||
" ner_model.train()\n",
|
|
||||||
" for i in range(len(train_labels)):\n",
|
|
||||||
" for j in range(1, len(train_labels[i]) - 1):\n",
|
|
||||||
" X = train_tensors[i][j - 1: j + 2].to(device_gpu)\n",
|
|
||||||
"\n",
|
|
||||||
" Y = train_labels[i][j: j + 1].to(device_gpu)\n",
|
|
||||||
"\n",
|
|
||||||
" # Had to add .long() to fit types\n",
|
|
||||||
" Y_predictions = ner_model(X.long())\n",
|
|
||||||
"\n",
|
|
||||||
" acc_score += int(torch.argmax(Y_predictions) == Y)\n",
|
|
||||||
" if torch.argmax(Y_predictions) != 0:\n",
|
|
||||||
" selected_items += 1\n",
|
|
||||||
" if torch.argmax(Y_predictions) != 0 and torch.argmax(Y_predictions) == Y.item():\n",
|
|
||||||
" prec_score += 1\n",
|
|
||||||
" if Y.item() != 0:\n",
|
|
||||||
" relevant_items += 1\n",
|
|
||||||
" if Y.item() != 0 and torch.argmax(Y_predictions) == Y.item():\n",
|
|
||||||
" recall_score += 1\n",
|
|
||||||
"\n",
|
|
||||||
" items_total += 1\n",
|
|
||||||
" optimizer.zero_grad()\n",
|
|
||||||
" loss = criterion(Y_predictions.unsqueeze(0), Y)\n",
|
|
||||||
" loss.backward()\n",
|
|
||||||
" optimizer.step()\n",
|
|
||||||
"\n",
|
|
||||||
" precision = prec_score / selected_items\n",
|
|
||||||
" recall = recall_score / relevant_items\n",
|
|
||||||
" f1_score = (2 * precision * recall) / (precision + recall)\n",
|
|
||||||
" print(f'epoch: {epoch}')\n",
|
|
||||||
" print(f'f1: {f1_score}')\n",
|
|
||||||
" print(f'acc: {acc_score / items_total}')"
|
|
||||||
],
|
|
||||||
"id": "24c3f4f8",
|
|
||||||
"execution_count": 20,
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"0\n",
|
|
||||||
"epoch: 0\n",
|
|
||||||
"f1: 0.6370749322900994\n",
|
|
||||||
"acc: 0.9114627847775542\n",
|
|
||||||
"1\n",
|
|
||||||
"epoch: 1\n",
|
|
||||||
"f1: 0.7994615623567001\n",
|
|
||||||
"acc: 0.954334500473289\n",
|
|
||||||
"2\n",
|
|
||||||
"epoch: 2\n",
|
|
||||||
"f1: 0.8643503374296407\n",
|
|
||||||
"acc: 0.9701919807375957\n",
|
|
||||||
"3\n",
|
|
||||||
"epoch: 3\n",
|
|
||||||
"f1: 0.9025574619618\n",
|
|
||||||
"acc: 0.9791431170907888\n",
|
|
||||||
"4\n",
|
|
||||||
"epoch: 4\n",
|
|
||||||
"f1: 0.9295360263614699\n",
|
|
||||||
"acc: 0.9851580233979396\n"
|
|
||||||
],
|
|
||||||
"name": "stdout"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"metadata": {
|
|
||||||
"id": "rXY6j7-qt7gU"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"dev = pd.read_csv('in.tsv', sep='\\t', names=['tokens'])\n",
|
|
||||||
"dev[\"tokens\"] = dev[\"tokens\"].apply(lambda x: x.split())\n",
|
|
||||||
"\n",
|
|
||||||
"dev_tokens_ids = data_process(dev[\"tokens\"])\n",
|
|
||||||
"\n",
|
|
||||||
"dev_extra_tensors = create_tensors_list(dev)\n",
|
|
||||||
"\n",
|
|
||||||
"dev_tensors = extra_features(dev_tokens_ids, dev_extra_tensors)\n",
|
|
||||||
"\n",
|
|
||||||
"results = predict(dev_tensors, labels)\n",
|
|
||||||
"results_processed = process_output(results)\n",
|
|
||||||
"save_to_file(\"out.tsv\", results_processed)"
|
|
||||||
],
|
|
||||||
"id": "rXY6j7-qt7gU",
|
|
||||||
"execution_count": 57,
|
|
||||||
"outputs": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"metadata": {
|
|
||||||
"id": "1lGYlL6iliGM"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
""
|
|
||||||
],
|
|
||||||
"id": "1lGYlL6iliGM",
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/"
|
||||||
|
},
|
||||||
|
"id": "bccaf9f7",
|
||||||
|
"outputId": "fff85047-6e9d-41d3-c8d9-2f8fe93046a3"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"type(train_labels)"
|
||||||
|
],
|
||||||
|
"id": "bccaf9f7",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"output_type": "execute_result",
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"list"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"execution_count": 8
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/"
|
||||||
|
},
|
||||||
|
"id": "11d6a6d9",
|
||||||
|
"outputId": "f3e330ce-96f4-40b4-9846-c3ab602495e2"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"data[\"iob\"]\n"
|
||||||
|
],
|
||||||
|
"id": "11d6a6d9",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"output_type": "execute_result",
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"0 [5, 0, 3, 0, 0, 0, 3, 0, 0, 0, 7, 8, 0, 1, 0, ...\n",
|
||||||
|
"1 [0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ...\n",
|
||||||
|
"2 [1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, ...\n",
|
||||||
|
"3 [1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, ...\n",
|
||||||
|
"4 [3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ...\n",
|
||||||
|
" ... \n",
|
||||||
|
"940 [0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 7, 8, 0, 1, 0, ...\n",
|
||||||
|
"941 [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, ...\n",
|
||||||
|
"942 [0, 0, 3, 0, 7, 0, 5, 0, 0, 1, 0, 1, 0, 0, 3, ...\n",
|
||||||
|
"943 [0, 0, 1, 2, 3, 4, 0, 0, 0, 0, 1, 0, 1, 0, 0, ...\n",
|
||||||
|
"944 [0, 0, 3, 4, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, ...\n",
|
||||||
|
"Name: iob, Length: 945, dtype: object"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"execution_count": 10
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/"
|
||||||
|
},
|
||||||
|
"id": "50e0cbaf",
|
||||||
|
"outputId": "e49f06bc-77ac-425c-ddf0-4aae1a1f6fb3"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"data[\"tokens\"]"
|
||||||
|
],
|
||||||
|
"id": "50e0cbaf",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"output_type": "execute_result",
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"0 [EU, rejects, German, call, to, boycott, Briti...\n",
|
||||||
|
"1 [Rare, Hendrix, song, draft, sells, for, almos...\n",
|
||||||
|
"2 [China, says, Taiwan, spoils, atmosphere, for,...\n",
|
||||||
|
"3 [China, says, time, right, for, Taiwan, talks,...\n",
|
||||||
|
"4 [German, July, car, registrations, up, 14.2, p...\n",
|
||||||
|
" ... \n",
|
||||||
|
"940 [CYCLING, -, BALLANGER, KEEPS, SPRINT, TITLE, ...\n",
|
||||||
|
"941 [CYCLING, -, WORLD, TRACK, CHAMPIONSHIP, RESUL...\n",
|
||||||
|
"942 [SOCCER, -, FRENCH, DEFENDER, KOMBOUARE, JOINS...\n",
|
||||||
|
"943 [MOTORCYCLING, -, SAN, MARINO, GRAND, PRIX, PR...\n",
|
||||||
|
"944 [GOLF, -, BRITISH, MASTERS, THIRD, ROUND, SCOR...\n",
|
||||||
|
"Name: tokens, Length: 945, dtype: object"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"execution_count": 11
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/"
|
||||||
|
},
|
||||||
|
"id": "24c3f4f8",
|
||||||
|
"outputId": "0d3d550f-4eb1-41cf-889c-7dbb7ddfc62c"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"for epoch in range(5):\n",
|
||||||
|
" acc_score = 0\n",
|
||||||
|
" prec_score = 0\n",
|
||||||
|
" selected_items = 0\n",
|
||||||
|
" recall_score = 0\n",
|
||||||
|
" relevant_items = 0\n",
|
||||||
|
" items_total = 0\n",
|
||||||
|
" ner_model.train()\n",
|
||||||
|
" for i in range(len(train_labels)):\n",
|
||||||
|
" for j in range(1, len(train_labels[i]) - 1):\n",
|
||||||
|
" X = train_tensors[i][j - 1: j + 2].to(device_gpu)\n",
|
||||||
|
"\n",
|
||||||
|
" Y = train_labels[i][j: j + 1].to(device_gpu)\n",
|
||||||
|
"\n",
|
||||||
|
" # Had to add .long() to fit types\n",
|
||||||
|
" Y_predictions = ner_model(X.long())\n",
|
||||||
|
"\n",
|
||||||
|
" acc_score += int(torch.argmax(Y_predictions) == Y)\n",
|
||||||
|
" if torch.argmax(Y_predictions) != 0:\n",
|
||||||
|
" selected_items += 1\n",
|
||||||
|
" if torch.argmax(Y_predictions) != 0 and torch.argmax(Y_predictions) == Y.item():\n",
|
||||||
|
" prec_score += 1\n",
|
||||||
|
" if Y.item() != 0:\n",
|
||||||
|
" relevant_items += 1\n",
|
||||||
|
" if Y.item() != 0 and torch.argmax(Y_predictions) == Y.item():\n",
|
||||||
|
" recall_score += 1\n",
|
||||||
|
"\n",
|
||||||
|
" items_total += 1\n",
|
||||||
|
" optimizer.zero_grad()\n",
|
||||||
|
" loss = criterion(Y_predictions.unsqueeze(0), Y)\n",
|
||||||
|
" loss.backward()\n",
|
||||||
|
" optimizer.step()\n",
|
||||||
|
"\n",
|
||||||
|
" precision = prec_score / selected_items\n",
|
||||||
|
" recall = recall_score / relevant_items\n",
|
||||||
|
" f1_score = (2 * precision * recall) / (precision + recall)\n",
|
||||||
|
" print(f'epoch: {epoch}')\n",
|
||||||
|
" print(f'f1: {f1_score}')\n",
|
||||||
|
" print(f'acc: {acc_score / items_total}')"
|
||||||
|
],
|
||||||
|
"id": "24c3f4f8",
|
||||||
|
"execution_count": 20,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"0\n",
|
||||||
|
"epoch: 0\n",
|
||||||
|
"f1: 0.6370749322900994\n",
|
||||||
|
"acc: 0.9114627847775542\n",
|
||||||
|
"1\n",
|
||||||
|
"epoch: 1\n",
|
||||||
|
"f1: 0.7994615623567001\n",
|
||||||
|
"acc: 0.954334500473289\n",
|
||||||
|
"2\n",
|
||||||
|
"epoch: 2\n",
|
||||||
|
"f1: 0.8643503374296407\n",
|
||||||
|
"acc: 0.9701919807375957\n",
|
||||||
|
"3\n",
|
||||||
|
"epoch: 3\n",
|
||||||
|
"f1: 0.9025574619618\n",
|
||||||
|
"acc: 0.9791431170907888\n",
|
||||||
|
"4\n",
|
||||||
|
"epoch: 4\n",
|
||||||
|
"f1: 0.9295360263614699\n",
|
||||||
|
"acc: 0.9851580233979396\n"
|
||||||
|
],
|
||||||
|
"name": "stdout"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"id": "rXY6j7-qt7gU"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"dev = pd.read_csv('dev-0/in.tsv', sep='\\t', names=['tokens'])\n",
|
||||||
|
"dev[\"tokens\"] = dev[\"tokens\"].apply(lambda x: x.split())\n",
|
||||||
|
"\n",
|
||||||
|
"dev_tokens_ids = data_process(dev[\"tokens\"])\n",
|
||||||
|
"\n",
|
||||||
|
"dev_extra_tensors = create_tensors_list(dev)\n",
|
||||||
|
"\n",
|
||||||
|
"dev_tensors = extra_features(dev_tokens_ids, dev_extra_tensors)\n",
|
||||||
|
"\n",
|
||||||
|
"results = predict(dev_tensors, labels)\n",
|
||||||
|
"results_processed = process_output(results)\n",
|
||||||
|
"save_to_file(\"dev-0/out.tsv\", results_processed)"
|
||||||
|
],
|
||||||
|
"id": "rXY6j7-qt7gU",
|
||||||
|
"execution_count": 57,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"id": "1lGYlL6iliGM"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"test = pd.read_csv('test-A/in.tsv', sep='\\t', names=['tokens'])\n",
|
||||||
|
"test[\"tokens\"] = test[\"tokens\"].apply(lambda x: x.split())\n",
|
||||||
|
"\n",
|
||||||
|
"test_tokens_ids = data_process(test[\"tokens\"])\n",
|
||||||
|
"\n",
|
||||||
|
"test_extra_tensors = create_tensors_list(test)\n",
|
||||||
|
"\n",
|
||||||
|
"test_tensors = extra_features(test_tokens_ids, test_extra_tensors)\n",
|
||||||
|
"\n",
|
||||||
|
"results = predict(test_tensors, labels)\n",
|
||||||
|
"results_processed = process_output(results)\n",
|
||||||
|
"save_to_file(\"test-A/out.tsv\", results_processed)"
|
||||||
|
],
|
||||||
|
"id": "1lGYlL6iliGM",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
|
}
|
||||||
|
]
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user