From 499702ff9c400060131a92479239c2076ec6b06b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zofia=20Fra=C5=9B?= Date: Tue, 22 Jun 2021 20:21:17 +0200 Subject: [PATCH] add script --- rnn_fras.ipynb | 1403 ++++++++++++++++++++++++++++++++++++++++++++++++ rnn_fras.py | 340 ++++++++++++ 2 files changed, 1743 insertions(+) create mode 100644 rnn_fras.ipynb create mode 100644 rnn_fras.py diff --git a/rnn_fras.ipynb b/rnn_fras.ipynb new file mode 100644 index 0000000..320b3f0 --- /dev/null +++ b/rnn_fras.ipynb @@ -0,0 +1,1403 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Zadanie domowe\n", + "\n", + "\n", + "- sklonować repozytorium https://git.wmi.amu.edu.pl/kubapok/en-ner-conll-2003\n", + "- stworzyć model seq labelling bazujący na sieci neuronowej opisanej w punkcie niżej (można bazować na tym jupyterze lub nie).\n", + "- model sieci to GRU (o dowolnych parametrach) + CRF w pytorchu korzystając z modułu CRF z poprzednich zajęć- - stworzyć predykcje w plikach dev-0/out.tsv oraz test-A/out.tsv\n", + "- wynik fscore sprawdzony za pomocą narzędzia geval (patrz poprzednie zadanie) powinien wynosić conajmniej 0.65\n", + "- proszę umieścić predykcję oraz skrypty generujące (w postaci tekstowej a nie jupyter) w repo, a w MS TEAMS umieścić link do swojego repo\n", + "termin 22.06, 60 punktów, za najlepszy wynik- 100 punktów\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import torch\n", + "from torchtext.vocab import Vocab\n", + "from collections import Counter\n", + "from tqdm.notebook import tqdm\n", + "import lzma\n", + "import itertools\n", + "from torchcrf import CRF" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def read_data(filename):\n", + " all_data = lzma.open(filename).read().decode('UTF-8').split('\\n')\n", + " return [line.split('\\t') for line in all_data][:-1]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "def data_process(dt):\n", + " return [torch.tensor([vocab['']] + [vocab[token] for token in document] + [vocab['']], dtype = torch.long) for document in dt]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "def labels_process(dt):\n", + " return [ torch.tensor([0] + document + [0], dtype = torch.long) for document in dt]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "def build_vocab(dataset):\n", + " counter = Counter()\n", + " for document in dataset:\n", + " counter.update(document)\n", + " return Vocab(counter, specials=['', '', '', ''])" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "train_data = read_data('train/train.tsv.xz')\n", + "\n", + "tokens, ner_tags = [], []\n", + "for i in train_data:\n", + " ner_tags.append(i[0].split())\n", + " tokens.append(i[1].split())" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "vocab = build_vocab(tokens)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "train_tokens_ids = data_process(tokens)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['B-LOC', 'B-MISC', 'B-ORG', 'B-PER', 'I-LOC', 'I-MISC', 'I-ORG', 'I-PER', 'O']\n" + ] + } + ], + "source": [ + "ner_tags_set = list(set(itertools.chain(*ner_tags)))\n", + "ner_tags_set.sort()\n", + "print(ner_tags_set)\n", + "train_labels = labels_process([[ner_tags_set.index(token) for token in doc] for doc in ner_tags])" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "num_tags = max([max(x) for x in train_labels]) + 1 " + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "class GRU(torch.nn.Module):\n", + "\n", + " def __init__(self):\n", + " super(GRU, self).__init__()\n", + " self.emb = torch.nn.Embedding(len(vocab.itos),100)\n", + " self.dropout = torch.nn.Dropout(0.2)\n", + " self.rec = torch.nn.GRU(100, 256, 2, batch_first = True, bidirectional = True)\n", + " self.fc1 = torch.nn.Linear(2* 256 , 9)\n", + " \n", + " def forward(self, x):\n", + " emb = torch.relu(self.emb(x))\n", + " emb = self.dropout(emb)\n", + " gru_output, h_n = self.rec(emb)\n", + " out_weights = self.fc1(gru_output)\n", + " return out_weights" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "def get_scores(y_true, y_pred):\n", + " acc_score = 0\n", + " tp = 0\n", + " fp = 0\n", + " selected_items = 0\n", + " relevant_items = 0 \n", + "\n", + " for p,t in zip(y_pred, y_true):\n", + " if p == t:\n", + " acc_score +=1\n", + "\n", + " if p > 0 and p == t:\n", + " tp +=1\n", + "\n", + " if p > 0:\n", + " selected_items += 1\n", + "\n", + " if t > 0 :\n", + " relevant_items +=1\n", + " \n", + " if selected_items == 0:\n", + " precision = 1.0\n", + " else:\n", + " precision = tp / selected_items\n", + " \n", + " if relevant_items == 0:\n", + " recall = 1.0\n", + " else:\n", + " recall = tp / relevant_items\n", + " \n", + " if precision + recall == 0.0 :\n", + " f1 = 0.0\n", + " else:\n", + " f1 = 2* precision * recall / (precision + recall)\n", + "\n", + " return precision, recall, f1" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "def eval_model(dataset_tokens, dataset_labels, model):\n", + " Y_true = []\n", + " Y_pred = []\n", + " for i in tqdm(range(len(dataset_labels))):\n", + " batch_tokens = dataset_tokens[i].unsqueeze(1)\n", + " tags = list(dataset_labels[i].numpy())\n", + " emissions = gru(batch_tokens).squeeze(0)\n", + " Y_pred += crf.decode(emissions)[0]\n", + " Y_true += tags\n", + " return get_scores(Y_true, Y_pred)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "gru = GRU()\n", + "crf = CRF(num_tags)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "params = list(gru.parameters()) + list(crf.parameters())\n", + "optimizer = torch.optim.Adam(params)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "NUM_EPOCHS = 20" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "criterion = torch.nn.CrossEntropyLoss()" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c5da680182d74dbe8a6e6e515f39c304", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/zosia/.local/lib/python3.8/site-packages/torchcrf/__init__.py:249: UserWarning: where received a uint8 condition tensor. This behavior is deprecated and will be removed in a future version of PyTorch. Use a boolean condition instead. (Triggered internally at /pytorch/aten/src/ATen/native/TensorCompare.cpp:255.)\n", + " score = torch.where(mask[i].unsqueeze(1), next_score, score)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "3ca55e4b508d4fc9b2d0720e1def2a58", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "(0.8601941656899232, 0.8751514345303986, 0.8676083403589915)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "31afa4456a9240789283af09788a3ed9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a6b3ba1f3b474cf092a826c87a0345be", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "(0.8815602436292092, 0.8897984198549079, 0.8856601748234387)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0d18da57114b4e0ab646fcb52860dabd", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9cfb2facab9c4b56924c27e287ba05d4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "(0.9144309250302297, 0.919752763828645, 0.9170841238373373)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a0b5a4064324446c8a1a0c70d07cda59", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "dea1cbf55a0c43fa84167e376b309125", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "(0.9361905528132853, 0.9398110097060626, 0.9379972877369673)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0524c9827f294852a5cd271bfbdbfcd2", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "04b306aaa0604677aa251727feacd2ba", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "(0.9519541852390448, 0.9547763044748607, 0.9533631563717097)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1404f55a3ce546c2b99ddc12679b5d97", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c848b31135b14155b98aeaea7b8ac2be", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "(0.960722713444972, 0.9632376346282668, 0.961978530336279)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8713c93530b94398a96354138783e326", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "57d75f0d65be4ca4a5401cb7ed3d5fe0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "(0.9697570414352719, 0.9714709221947199, 0.9706132252353172)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "47614d60ab4b4f0abd56c395420103a6", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2b88316f13234cdeb1486f66cf08d5b6", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "(0.9760554565110192, 0.9779891394717963, 0.9770213412246582)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9df02fe711de4ccbac63cbcc77b9185b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d45a98359f90443b9de4017908729121", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "(0.9811127302761178, 0.9819703829690195, 0.9815413692723396)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "85e862f466b049089dd1ae4d4b3b25b4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "3c38cdbf365448a8bec573e7f9bf3831", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "(0.984655071665091, 0.9846831395763159, 0.9846691054206851)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "fc3a0e186cb94c649f6d489d8afe0b28", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "aa122bd240074ffb9ad1a8e8df787497", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "(0.9871442343767067, 0.9875194192515452, 0.9873317911716786)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "bc95fc4c8eb84aa99bc9c09e449edc53", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "36dc7b6450a24bac82ced0efb0d9c4a0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "(0.9893908786272786, 0.9889114292094049, 0.9891510958201069)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ec3fe5ad630e42b181096b84f985428f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "3b1bcac29f3f4f5d949ee0860c345329", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "(0.9911312527046112, 0.9901989196482444, 0.9906648668174991)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "4dd0d5d81d5943f1ad7d7451014727d9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "98ef88bd0dcf45fd9ebeaf2ab77b2dbf", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "(0.9924332083291745, 0.9919900041332719, 0.9922115567382627)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "99089ca7a97a4168a1cc46d1eba49a62", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "294ec84096cd474a80ebff9c36ea0644", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "(0.9930640069977942, 0.9924270857582653, 0.9927454442197611)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "130e89d10ba54246a65645381c69538a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "11ae3075d74443a79bbd4789a5bdd9b7", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "(0.9739162872556146, 0.9674801769230403, 0.9706875636048171)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "cfaa4824a18046d898818cbed675450e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ba9b8454320c4d53b041f6f424b773ab", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "(0.9848088502477955, 0.9837187094689933, 0.9842634780066597)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a5c8d9a82a9b4b4f8a58b83dcf50dfc8", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "286d0b6ad83146e4911b967e9cbde195", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "(0.9808100926458495, 0.9802695653413275, 0.9805397545015183)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b6342f18da15402e9c391550235a7ded", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "5fc5c4feee604ddea5ead7ff203f07e8", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "(0.9668917478143436, 0.9694090371376854, 0.968148756174055)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "927b69c7183442aeaf3dc08ae3e20cbe", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "15f90bd1520b4773bcbb699f985ea031", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "(0.9793555195345366, 0.9788157938495013, 0.979085582310423)\n" + ] + } + ], + "source": [ + "for i in range(NUM_EPOCHS):\n", + " gru.train()\n", + " crf.train()\n", + " for i in tqdm(range(len(train_labels))):\n", + " batch_tokens = train_tokens_ids[i].unsqueeze(1)\n", + " tags = train_labels[i].unsqueeze(1)\n", + " emissions = gru(batch_tokens).squeeze(0)\n", + " optimizer.zero_grad()\n", + " loss = -crf(emissions,tags.squeeze(0))\n", + " loss.backward()\n", + " optimizer.step()\n", + " gru.eval()\n", + " crf.eval()\n", + " print(eval_model(train_tokens_ids, train_labels, gru))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## dev-0 i test-A" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "def predict_labels(dataset_tokens, dataset_labels, model):\n", + " print(len(dataset_tokens[0]), len(dataset_labels[0]))\n", + " Y_true = []\n", + " Y_pred = []\n", + " result = []\n", + " for i in tqdm(range(len(dataset_labels))):\n", + " batch_tokens = dataset_tokens[i].unsqueeze(1)\n", + " tags = list(dataset_labels[i].numpy())\n", + " emissions = gru(batch_tokens).squeeze(0)\n", + " tmp = crf.decode(emissions)[0]\n", + " Y_pred += tmp\n", + " result += [tmp]\n", + " Y_true += tags\n", + " print(get_scores(Y_true, Y_pred))\n", + " return result" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "with open('dev-0/in.tsv', \"r\", encoding=\"utf-8\") as f:\n", + " dev_0_data = [line.rstrip() for line in f]\n", + " \n", + "dev_0_data = [i.split() for i in dev_0_data]\n", + "dev_0_tokens_ids = data_process(dev_0_data)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "with open('dev-0/expected.tsv', \"r\", encoding=\"utf-8\") as f:\n", + " dev_0_labels = [line.rstrip() for line in f]\n", + " \n", + "dev_0_labels = [i.split() for i in dev_0_labels]\n", + "dev_0_labels = labels_process([[ner_tags_set.index(token) for token in doc] for doc in dev_0_labels])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "458 458\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e48f16faacc043ac8237af22f32b0af1", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=215.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "(0.9501477944520237, 0.9535808009736432, 0.9518612023310112)\n" + ] + } + ], + "source": [ + "tmp = predict_labels(dev_0_tokens_ids, dev_0_labels, gru)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "r = [[ner_tags_set[i] for i in tmp2] for tmp2 in tmp]\n", + "r = [i[1:-1] for i in r]" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "for doc in r:\n", + " if doc[0] != 'O':\n", + " doc[0] = 'B' + doc[0][1:]\n", + " for i in range(len(doc))[:-1]:\n", + " if doc[i] == 'O':\n", + " if doc[i + 1] != 'O':\n", + " doc[i + 1] = 'B' + doc[i + 1][1:]\n", + " elif doc[i + 1] != 'O':\n", + " if doc[i][1:] == doc[i + 1][1:]:\n", + " doc[i + 1] = 'I' + doc[i + 1][1:]\n", + " else:\n", + " doc[i + 1] = 'B' + doc[i + 1][1:]" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "f = open(\"dev-0/out.tsv\", \"a\")\n", + "for i in r:\n", + " f.write(' '.join(i) + '\\n')\n", + "f.close()" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9cce1860765e420f9b0bfaa23b651f58", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=215.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "42e2565e95db4efb9343d93f195212d5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=230.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "def predict(path, model):\n", + " with open(path + '/in.tsv', \"r\", encoding=\"utf-8\") as f:\n", + " data = [line.rstrip() for line in f]\n", + " data = [i.split() for i in data]\n", + " tokens_ids = data_process(data)\n", + " \n", + " Y_true = []\n", + " Y_pred = []\n", + " result = []\n", + " for i in tqdm(range(len(tokens_ids))):\n", + " batch_tokens = tokens_ids[i].unsqueeze(1)\n", + " emissions = gru(batch_tokens).squeeze(0)\n", + " tmp = crf.decode(emissions)[0]\n", + " Y_pred += tmp\n", + " result += [tmp]\n", + " r = [[ner_tags_set[i] for i in tmp] for tmp in result]\n", + " r = [i[1:-1] for i in r]\n", + " for doc in r:\n", + " if doc[0] != 'O':\n", + " doc[0] = 'B' + doc[0][1:]\n", + " for i in range(len(doc))[:-1]:\n", + " if doc[i] == 'O':\n", + " if doc[i + 1] != 'O':\n", + " doc[i + 1] = 'B' + doc[i + 1][1:]\n", + " elif doc[i + 1] != 'O':\n", + " if doc[i][1:] == doc[i + 1][1:]:\n", + " doc[i + 1] = 'I' + doc[i + 1][1:]\n", + " else:\n", + " doc[i + 1] = 'B' + doc[i + 1][1:]\n", + " f = open(path + \"/out.tsv\", \"a\")\n", + " for i in r:\n", + " f.write(' '.join(i) + '\\n')\n", + " f.close()\n", + " return result\n", + "\n", + "result = predict('dev-0', gru)\n", + "result = predict('test-A', gru)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/rnn_fras.py b/rnn_fras.py new file mode 100644 index 0000000..a4fe179 --- /dev/null +++ b/rnn_fras.py @@ -0,0 +1,340 @@ +#!/usr/bin/env python +# coding: utf-8 + +# ## Zadanie domowe +# +# +# - sklonować repozytorium https://git.wmi.amu.edu.pl/kubapok/en-ner-conll-2003 +# - stworzyć model seq labelling bazujący na sieci neuronowej opisanej w punkcie niżej (można bazować na tym jupyterze lub nie). +# - model sieci to GRU (o dowolnych parametrach) + CRF w pytorchu korzystając z modułu CRF z poprzednich zajęć- - stworzyć predykcje w plikach dev-0/out.tsv oraz test-A/out.tsv +# - wynik fscore sprawdzony za pomocą narzędzia geval (patrz poprzednie zadanie) powinien wynosić conajmniej 0.65 +# - proszę umieścić predykcję oraz skrypty generujące (w postaci tekstowej a nie jupyter) w repo, a w MS TEAMS umieścić link do swojego repo +# termin 22.06, 60 punktów, za najlepszy wynik- 100 punktów +# + +# In[2]: + + +import numpy as np +import torch +from torchtext.vocab import Vocab +from collections import Counter +from tqdm.notebook import tqdm +import lzma +import itertools +from torchcrf import CRF + + +# In[3]: + + +def read_data(filename): + all_data = lzma.open(filename).read().decode('UTF-8').split('\n') + return [line.split('\t') for line in all_data][:-1] + + +# In[4]: + + +def data_process(dt): + return [torch.tensor([vocab['']] + [vocab[token] for token in document] + [vocab['']], dtype = torch.long) for document in dt] + + +# In[5]: + + +def labels_process(dt): + return [ torch.tensor([0] + document + [0], dtype = torch.long) for document in dt] + + +# In[6]: + + +def build_vocab(dataset): + counter = Counter() + for document in dataset: + counter.update(document) + return Vocab(counter, specials=['', '', '', '']) + + +# In[7]: + + +train_data = read_data('train/train.tsv.xz') + +tokens, ner_tags = [], [] +for i in train_data: + ner_tags.append(i[0].split()) + tokens.append(i[1].split()) + + +# In[8]: + + +vocab = build_vocab(tokens) + + +# In[9]: + + +train_tokens_ids = data_process(tokens) + + +# In[10]: + + +ner_tags_set = list(set(itertools.chain(*ner_tags))) +ner_tags_set.sort() +print(ner_tags_set) +train_labels = labels_process([[ner_tags_set.index(token) for token in doc] for doc in ner_tags]) + + +# In[11]: + + +num_tags = max([max(x) for x in train_labels]) + 1 + + +# In[12]: + + +class GRU(torch.nn.Module): + + def __init__(self): + super(GRU, self).__init__() + self.emb = torch.nn.Embedding(len(vocab.itos),100) + self.dropout = torch.nn.Dropout(0.2) + self.rec = torch.nn.GRU(100, 256, 2, batch_first = True, bidirectional = True) + self.fc1 = torch.nn.Linear(2* 256 , 9) + + def forward(self, x): + emb = torch.relu(self.emb(x)) + emb = self.dropout(emb) + gru_output, h_n = self.rec(emb) + out_weights = self.fc1(gru_output) + return out_weights + + +# In[13]: + + +def get_scores(y_true, y_pred): + acc_score = 0 + tp = 0 + fp = 0 + selected_items = 0 + relevant_items = 0 + + for p,t in zip(y_pred, y_true): + if p == t: + acc_score +=1 + + if p > 0 and p == t: + tp +=1 + + if p > 0: + selected_items += 1 + + if t > 0 : + relevant_items +=1 + + if selected_items == 0: + precision = 1.0 + else: + precision = tp / selected_items + + if relevant_items == 0: + recall = 1.0 + else: + recall = tp / relevant_items + + if precision + recall == 0.0 : + f1 = 0.0 + else: + f1 = 2* precision * recall / (precision + recall) + + return precision, recall, f1 + + +# In[14]: + + +def eval_model(dataset_tokens, dataset_labels, model): + Y_true = [] + Y_pred = [] + for i in tqdm(range(len(dataset_labels))): + batch_tokens = dataset_tokens[i].unsqueeze(1) + tags = list(dataset_labels[i].numpy()) + emissions = gru(batch_tokens).squeeze(0) + Y_pred += crf.decode(emissions)[0] + Y_true += tags + return get_scores(Y_true, Y_pred) + + +# In[15]: + + +gru = GRU() +crf = CRF(num_tags) + + +# In[16]: + + +params = list(gru.parameters()) + list(crf.parameters()) +optimizer = torch.optim.Adam(params) + + +# In[17]: + + +NUM_EPOCHS = 20 + + +# In[18]: + + +criterion = torch.nn.CrossEntropyLoss() + + +# In[19]: + + +for i in range(NUM_EPOCHS): + gru.train() + crf.train() + for i in tqdm(range(len(train_labels))): + batch_tokens = train_tokens_ids[i].unsqueeze(1) + tags = train_labels[i].unsqueeze(1) + emissions = gru(batch_tokens).squeeze(0) + optimizer.zero_grad() + loss = -crf(emissions,tags.squeeze(0)) + loss.backward() + optimizer.step() + gru.eval() + crf.eval() + print(eval_model(train_tokens_ids, train_labels, gru)) + + +# ## dev-0 i test-A + +# In[20]: + + +def predict_labels(dataset_tokens, dataset_labels, model): + print(len(dataset_tokens[0]), len(dataset_labels[0])) + Y_true = [] + Y_pred = [] + result = [] + for i in tqdm(range(len(dataset_labels))): + batch_tokens = dataset_tokens[i].unsqueeze(1) + tags = list(dataset_labels[i].numpy()) + emissions = gru(batch_tokens).squeeze(0) + tmp = crf.decode(emissions)[0] + Y_pred += tmp + result += [tmp] + Y_true += tags + print(get_scores(Y_true, Y_pred)) + return result + + +# In[21]: + + +with open('dev-0/in.tsv', "r", encoding="utf-8") as f: + dev_0_data = [line.rstrip() for line in f] + +dev_0_data = [i.split() for i in dev_0_data] +dev_0_tokens_ids = data_process(dev_0_data) + + +# In[22]: + + +with open('dev-0/expected.tsv', "r", encoding="utf-8") as f: + dev_0_labels = [line.rstrip() for line in f] + +dev_0_labels = [i.split() for i in dev_0_labels] +dev_0_labels = labels_process([[ner_tags_set.index(token) for token in doc] for doc in dev_0_labels]) + + +# In[23]: + + +tmp = predict_labels(dev_0_tokens_ids, dev_0_labels, gru) + + +# In[24]: + + +r = [[ner_tags_set[i] for i in tmp2] for tmp2 in tmp] +r = [i[1:-1] for i in r] + + +# In[25]: + + +for doc in r: + if doc[0] != 'O': + doc[0] = 'B' + doc[0][1:] + for i in range(len(doc))[:-1]: + if doc[i] == 'O': + if doc[i + 1] != 'O': + doc[i + 1] = 'B' + doc[i + 1][1:] + elif doc[i + 1] != 'O': + if doc[i][1:] == doc[i + 1][1:]: + doc[i + 1] = 'I' + doc[i + 1][1:] + else: + doc[i + 1] = 'B' + doc[i + 1][1:] + + +# In[26]: + + +f = open("dev-0/out.tsv", "a") +for i in r: + f.write(' '.join(i) + '\n') +f.close() + + +# In[27]: + + +def predict(path, model): + with open(path + '/in.tsv', "r", encoding="utf-8") as f: + data = [line.rstrip() for line in f] + data = [i.split() for i in data] + tokens_ids = data_process(data) + + Y_true = [] + Y_pred = [] + result = [] + for i in tqdm(range(len(tokens_ids))): + batch_tokens = tokens_ids[i].unsqueeze(1) + emissions = gru(batch_tokens).squeeze(0) + tmp = crf.decode(emissions)[0] + Y_pred += tmp + result += [tmp] + r = [[ner_tags_set[i] for i in tmp] for tmp in result] + r = [i[1:-1] for i in r] + for doc in r: + if doc[0] != 'O': + doc[0] = 'B' + doc[0][1:] + for i in range(len(doc))[:-1]: + if doc[i] == 'O': + if doc[i + 1] != 'O': + doc[i + 1] = 'B' + doc[i + 1][1:] + elif doc[i + 1] != 'O': + if doc[i][1:] == doc[i + 1][1:]: + doc[i + 1] = 'I' + doc[i + 1][1:] + else: + doc[i + 1] = 'B' + doc[i + 1][1:] + f = open(path + "/out.tsv", "a") + for i in r: + f.write(' '.join(i) + '\n') + f.close() + return result + +result = predict('dev-0', gru) +result = predict('test-A', gru) +