{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "coral-camera", "metadata": {}, "outputs": [], "source": [ "def read_data(path):\n", " with open(path, 'r') as f:\n", " dataset = [line.strip().split() for line in f]\n", " return dataset" ] }, { "cell_type": "code", "execution_count": 3, "id": "cc201a16", "metadata": {}, "outputs": [], "source": [ "dataset = read_data('train/train.tsv')" ] }, { "cell_type": "code", "execution_count": 4, "id": "sharing-employment", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['B-ORG',\n", " 'O',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'B-MISC',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'O',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'O',\n", " 'B-MISC',\n", " 'B-LOC',\n", " 'B-PER',\n", " 'B-MISC',\n", " 'B-LOC',\n", " 'O',\n", " 'O',\n", " 'B-ORG',\n", " 'O',\n", " 'O',\n", " 'B-ORG',\n", " 'B-MISC',\n", " 'O',\n", " 'O',\n", " 'B-MISC',\n", " 'B-MISC',\n", " 'B-MISC',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'O',\n", " 'O',\n", " 'B-LOC',\n", " 'O',\n", " 'B-LOC',\n", " 'O',\n", " 'B-ORG',\n", " 'O',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'B-MISC',\n", " 'O',\n", " 'O',\n", " 'B-MISC',\n", " 'B-MISC',\n", " 'B-MISC',\n", " 'B-PER',\n", " 'B-PER',\n", " 'O',\n", " 'B-LOC',\n", " 'O',\n", " 'B-MISC',\n", " 'B-LOC',\n", " 'O',\n", " 'O',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'B-MISC',\n", " 'B-PER',\n", " 'B-LOC',\n", " 'O',\n", " 'B-ORG',\n", " 'O',\n", " 'B-MISC',\n", " 'B-ORG',\n", " 'B-PER',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'O',\n", " 'B-PER',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'O',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'O',\n", " 'B-LOC',\n", " 'O',\n", " 'B-LOC',\n", " 'B-ORG',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'B-ORG',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'B-ORG',\n", " 'B-MISC',\n", " 'O',\n", " 'O',\n", " 'B-PER',\n", " 'B-LOC',\n", " 'O',\n", " 'O',\n", " 'B-ORG',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'O',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'O',\n", " 'B-PER',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'O',\n", " 'B-PER',\n", " 'B-LOC',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'O',\n", " 'B-ORG',\n", " 'O',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'O',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'O',\n", " 'B-ORG',\n", " 'B-MISC',\n", " 'B-MISC',\n", " 'B-ORG',\n", " 'O',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-MISC',\n", " 'B-MISC',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-LOC',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'O',\n", " 'B-ORG',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'B-LOC',\n", " 'O',\n", " 'B-MISC',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'B-MISC',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'O',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'B-MISC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'O',\n", " 'B-MISC',\n", " 'B-LOC',\n", " 'O',\n", " 'B-PER',\n", " 'O',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'B-MISC',\n", " 'B-MISC',\n", " 'B-LOC',\n", " 'O',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'O',\n", " 'O',\n", " 'B-ORG',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'B-MISC',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'B-MISC',\n", " 'B-MISC',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'B-LOC',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'B-MISC',\n", " 'O',\n", " 'B-MISC',\n", " 'B-ORG',\n", " 'B-MISC',\n", " 'O',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'O',\n", " 'B-MISC',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'O',\n", " 'B-PER',\n", " 'O',\n", " 'O',\n", " 'B-MISC',\n", " 'O',\n", " 'B-MISC',\n", " 'B-MISC',\n", " 'O',\n", " 'B-MISC',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-MISC',\n", " 'B-MISC',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'B-MISC',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'B-ORG',\n", " 'O',\n", " 'B-ORG',\n", " 'O',\n", " 'B-PER',\n", " 'B-ORG',\n", " 'B-MISC',\n", " 'B-ORG',\n", " 'B-PER',\n", " 'B-ORG',\n", " 'O',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-MISC',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'O',\n", " 'B-ORG',\n", " 'B-MISC',\n", " 'B-ORG',\n", " 'O',\n", " 'O',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-MISC',\n", " 'B-ORG',\n", " 'O',\n", " 'B-MISC',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'O',\n", " 'B-PER',\n", " 'B-LOC',\n", " 'B-PER',\n", " 'O',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'O',\n", " 'B-LOC',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'B-PER',\n", " 'B-MISC',\n", " 'B-PER',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'O',\n", " 'B-ORG',\n", " 'B-MISC',\n", " 'B-MISC',\n", " 'B-ORG',\n", " 'B-LOC',\n", " 'O',\n", " 'O',\n", " 'B-ORG',\n", " 'O',\n", " 'O',\n", " 'B-MISC',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'B-MISC',\n", " 'B-ORG',\n", " 'B-LOC',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'B-ORG',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'B-ORG',\n", " 'O',\n", " 'O',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-PER',\n", " 'B-MISC',\n", " 'B-MISC',\n", " 'B-ORG',\n", " 'O',\n", " 'B-LOC',\n", " 'O',\n", " 'B-MISC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'O',\n", " 'B-MISC',\n", " 'B-MISC',\n", " 'B-MISC',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'O',\n", " 'O',\n", " 'B-ORG',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'O',\n", " 'B-LOC',\n", " 'O',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-PER',\n", " 'O',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-ORG',\n", " 'B-LOC',\n", " 'B-ORG',\n", " 'B-MISC',\n", " 'B-MISC',\n", " 'O',\n", " 'B-PER',\n", " 'B-MISC',\n", " 'B-MISC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'O',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'B-ORG',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'O',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'B-ORG',\n", " 'B-MISC',\n", " 'O',\n", " 'B-ORG',\n", " 'O',\n", " 'B-LOC',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'B-ORG',\n", " 'O',\n", " 'B-ORG',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'B-MISC',\n", " 'B-MISC',\n", " 'B-ORG',\n", " 'B-LOC',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-LOC',\n", " 'B-ORG',\n", " 'B-LOC',\n", " 'B-ORG',\n", " 'B-LOC',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'O',\n", " 'O',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'O',\n", " 'B-LOC',\n", " 'B-PER',\n", " 'O',\n", " 'B-MISC',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'B-MISC',\n", " 'O',\n", " 'B-LOC',\n", " 'B-PER',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'B-ORG',\n", " 'B-MISC',\n", " 'B-ORG',\n", " 'B-MISC',\n", " 'B-MISC',\n", " 'B-LOC',\n", " 'B-PER',\n", " 'B-MISC',\n", " 'B-ORG',\n", " 'O',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'O',\n", " 'B-MISC',\n", " 'B-MISC',\n", " 'B-MISC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'B-ORG',\n", " 'B-PER',\n", " 'B-MISC',\n", " 'B-ORG',\n", " 'B-PER',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-ORG',\n", " 'B-LOC',\n", " 'B-PER',\n", " 'O',\n", " 'B-ORG',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'O',\n", " 'B-MISC',\n", " 'B-LOC',\n", " 'B-ORG',\n", " 'B-PER',\n", " 'O',\n", " 'B-MISC',\n", " 'O',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-LOC',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'O',\n", " 'B-ORG',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-ORG',\n", " 'O',\n", " 'B-LOC',\n", " 'B-PER',\n", " 'B-LOC',\n", " 'O',\n", " 'B-ORG',\n", " 'O',\n", " 'B-LOC',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-LOC',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-MISC',\n", " 'B-LOC',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'B-PER',\n", " 'O',\n", " 'O',\n", " 'B-ORG',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-ORG',\n", " 'O',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'O',\n", " 'B-ORG',\n", " 'O',\n", " 'B-PER',\n", " 'B-ORG',\n", " 'B-MISC',\n", " 'O',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'O',\n", " 'B-LOC',\n", " 'O',\n", " 'B-MISC',\n", " 'B-MISC',\n", " 'B-ORG',\n", " 'B-MISC',\n", " 'O',\n", " 'B-LOC',\n", " 'O',\n", " 'O',\n", " 'B-ORG',\n", " 'O',\n", " 'O',\n", " 'B-PER',\n", " 'B-PER',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-ORG',\n", " 'O',\n", " 'B-LOC',\n", " 'B-ORG',\n", " 'B-MISC',\n", " 'O',\n", " 'B-LOC',\n", " 'O',\n", " 'O',\n", " 'B-MISC',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'O',\n", " 'B-MISC',\n", " 'B-MISC',\n", " 'B-LOC',\n", " 'B-ORG',\n", " 'O',\n", " 'B-PER',\n", " 'B-PER',\n", " 'O',\n", " 'O',\n", " 'B-PER',\n", " 'O',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'B-LOC',\n", " 'O',\n", " 'B-MISC',\n", " 'O',\n", " 'O',\n", " 'B-ORG',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-MISC',\n", " 'O',\n", " 'B-LOC',\n", " 'B-LOC',\n", " 'B-MISC',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'B-LOC',\n", " 'B-ORG',\n", " 'B-LOC',\n", " 'B-ORG',\n", " 'B-LOC',\n", " 'B-ORG',\n", " 'B-ORG',\n", " 'O',\n", " 'O',\n", " 'B-ORG',\n", " 'B-LOC',\n", " 'B-ORG',\n", " 'B-LOC',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O',\n", " 'O']" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_x = [x[1] for x in dataset]\n", "train_y = [y[0] for y in dataset]\n", "train_y" ] }, { "cell_type": "code", "execution_count": 5, "id": "elder-trauma", "metadata": {}, "outputs": [], "source": [ "import torchtext.vocab\n", "from collections import Counter" ] }, { "cell_type": "code", "execution_count": 6, "id": "material-timothy", "metadata": {}, "outputs": [], "source": [ "def build_vocab(dataset):\n", " counter = Counter()\n", " for document in dataset:\n", " counter.update(document)\n", " \n", " vocab = torchtext.vocab.vocab(counter, specials=['', '', '', ''])\n", " vocab.set_default_index(0)\n", " return vocab" ] }, { "cell_type": "code", "execution_count": 7, "id": "provincial-reader", "metadata": {}, "outputs": [], "source": [ "train_x = [x.split() for x in train_x]" ] }, { "cell_type": "code", "execution_count": 8, "id": "invalid-nursing", "metadata": {}, "outputs": [], "source": [ "vocab = build_vocab(train_x)" ] }, { "cell_type": "code", "execution_count": 9, "id": "accredited-observation", "metadata": {}, "outputs": [], "source": [ "def data_process(dt):\n", " return [ torch.tensor([vocab['']] +[vocab[token] for token in document ] + [vocab['']], dtype = torch.long) for document in dt]\n", "\n", "def labels_process(dt):\n", " labels = []\n", " for document in dt:\n", " temp = []\n", " temp.append(0)\n", " temp.append(document)\n", " temp.append(0)\n", " labels.append(torch.tensor(temp, dtype = torch.long))\n", " return labels\n", " \n", " \n", " #return [ torch.tensor([0] + document + [0], dtype = torch.long) for document in dt]\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "united-local", "metadata": {}, "outputs": [], "source": [ "ner_tags = {'O': 0, 'B-ORG': 1, 'I-ORG': 2, 'B-PER': 3, 'I-PER': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8}" ] }, { "cell_type": "code", "execution_count": 11, "id": "reported-afghanistan", "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "train_tokens_ids = data_process(train_x)" ] }, { "cell_type": "code", "execution_count": 12, "id": "southern-nirvana", "metadata": {}, "outputs": [], "source": [ "dev_x = read_data('dev-0/in.tsv')\n", "dev_y = read_data('dev-0/expected.tsv')\n", "\n", "test_x = read_data('test-A/in.tsv')\n", "\n", "dev_x = [x[0].split() for x in dev_x]\n", "dev_y = [y[0].split() for y in dev_y]\n", "test_x = [x[0].split() for x in test_x]" ] }, { "cell_type": "code", "execution_count": 13, "id": "played-transparency", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['B-ORG', 'O', 'B-LOC', 'B-LOC', 'B-MISC']" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "[1, 0, 5, 5, 7]" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_y = [y[0] for y in dataset]\n", "display(train_y[:5])\n", "train_y = [ner_tags.get(tag) for tag in train_y]\n", "train_y[:5]" ] }, { "cell_type": "code", "execution_count": 14, "id": "assured-colonial", "metadata": {}, "outputs": [], "source": [ "dev_y = [ner_tags.get(tag) for y in dev_y for tag in y]" ] }, { "cell_type": "code", "execution_count": 15, "id": "identical-subsection", "metadata": { "tags": [] }, "outputs": [], "source": [ "test_tokens_ids = data_process(dev_x)\n", "train_labels = labels_process(train_y)\n", "test_labels = labels_process(dev_y)" ] }, { "cell_type": "code", "execution_count": 31, "id": "demanding-bonus", "metadata": {}, "outputs": [], "source": [ "class NERModel(torch.nn.Module):\n", "\n", " def __init__(self,):\n", " super(NERModel, self).__init__()\n", " self.emb = torch.nn.Embedding(23627, 200)\n", " self.fc1 = torch.nn.Linear(2400, 9)\n", " #self.softmax = torch.nn.Softmax(dim=1)\n", " # nie trzeba, bo używamy https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html\n", " # jako kryterium\n", " \n", "\n", " def forward(self, x):\n", " x = self.emb(x)\n", " x = x.reshape(2400) \n", " x = self.fc1(x)\n", " #x = self.softmax(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 32, "id": "statistical-barbados", "metadata": {}, "outputs": [], "source": [ "ner_model = NERModel()" ] }, { "cell_type": "code", "execution_count": 33, "id": "impressive-insert", "metadata": {}, "outputs": [], "source": [ "criterion = torch.nn.CrossEntropyLoss()" ] }, { "cell_type": "code", "execution_count": 34, "id": "speaking-seeking", "metadata": {}, "outputs": [], "source": [ "optimizer = torch.optim.Adam(ner_model.parameters())" ] }, { "cell_type": "code", "execution_count": 35, "id": "8161d438", "metadata": {}, "outputs": [], "source": [ "import string\n", "def add_features(tens, tokens):\n", " array = [0, 0, 0, 0, 0, 0, 0, 0, 0]\n", " if len(tokens) >= 2:\n", " if len(tokens[1]) >= 1:\n", " word = tokens[1]\n", " if word[0].isupper():\n", " array[0] = 1\n", " if word.isalnum():\n", " array[1] = 1\n", " for i in word:\n", " # checking whether the char is punctuation.\n", " if i in string.punctuation:\n", " # Printing the punctuation values\n", " array[2] = 1\n", " if word.isnumeric():\n", " array[3] = 1\n", " if word.isupper():\n", " array[4] = 1\n", " if '-' in word:\n", " array[5] = 1\n", " if '/' in word:\n", " array[6] = 1\n", " if len(word) > 3:\n", " array[7] = 1\n", " if len(word) > 6:\n", " array[8] = 1\n", " x = torch.tensor(array)\n", " new_tensor = torch.cat((tens, x), 0)\n", " return new_tensor" ] }, { "cell_type": "code", "execution_count": 36, "id": "sized-mobile", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "2.811322446731947" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.48" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.18604651162790697" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.1702127659574468" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "2.5642633876085097" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.43" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.1702127659574468" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.1702127659574468" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "2" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "2.2745147155076166" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.54" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.3023255813953488" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2765957446808511" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "3" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "2.3077734905840908" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.56" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.3023255813953488" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2765957446808511" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "4" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "2.2327055485211984" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.51" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.24489795918367346" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2553191489361702" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "5" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "2.032022957816762" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.58" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2978723404255319" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2978723404255319" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "6" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.9094040171859614" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.57" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2765957446808511" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2765957446808511" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "7" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.8801336237322357" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.54" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2391304347826087" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.23404255319148937" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "8" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.853852765722122" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.53" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2222222222222222" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2127659574468085" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "9" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.8288560365608282" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.53" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2222222222222222" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2127659574468085" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "10" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.8022360281114742" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.53" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2222222222222222" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2127659574468085" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "11" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.775143896874324" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.53" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.22727272727272727" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2127659574468085" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "12" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.748205496848568" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.53" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.22727272727272727" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2127659574468085" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "13" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.7217520299459284" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.54" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.24444444444444444" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.23404255319148937" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "14" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.6958234204566542" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.55" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2608695652173913" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2553191489361702" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "15" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.6712835076041666" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.55" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2608695652173913" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2553191489361702" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "16" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.6480589298788255" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.55" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2608695652173913" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2553191489361702" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "17" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.6257991834238783" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.55" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2608695652173913" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2553191489361702" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "18" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.604242644636688" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.55" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.26666666666666666" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2553191489361702" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "19" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.5837320431148691" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.55" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.26666666666666666" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2553191489361702" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "20" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.5641135577083332" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.54" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.24444444444444444" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.23404255319148937" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "21" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.5454202877922216" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.54" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.24444444444444444" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.23404255319148937" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "22" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.5275404900076683" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.54" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.24444444444444444" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.23404255319148937" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "23" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.5099791488426855" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.53" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.22727272727272727" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2127659574468085" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "24" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.4932281806698302" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.53" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.22727272727272727" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2127659574468085" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "25" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.4772486361171469" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.53" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.22727272727272727" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2127659574468085" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "26" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.4617241937015206" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.53" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.22727272727272727" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2127659574468085" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "27" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.447145789535134" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.53" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.22727272727272727" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2127659574468085" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "28" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.433001351452549" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.53" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.22727272727272727" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2127659574468085" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "29" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.4193171267636353" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.54" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.25" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.23404255319148937" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "30" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.4062099850556116" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.55" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2727272727272727" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2553191489361702" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "31" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.3941219550243114" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.55" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2727272727272727" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2553191489361702" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "32" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.3825843345944304" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.56" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.29545454545454547" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2765957446808511" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "33" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.3714176365407185" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.56" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.3023255813953488" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2765957446808511" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "34" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.3609581479639745" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.56" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.3023255813953488" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2765957446808511" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "35" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.3509947879862738" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.56" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.3023255813953488" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2765957446808511" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "36" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.3424826521927025" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.56" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.3023255813953488" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2765957446808511" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "37" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.3336302372731734" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.56" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.3023255813953488" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2765957446808511" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "38" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.3246490387670928" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.56" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.3023255813953488" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2765957446808511" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "39" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.316349835752626" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.57" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.30952380952380953" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2765957446808511" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "40" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.3090153592341813" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.57" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.30952380952380953" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2765957446808511" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "41" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.3016801220795606" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.56" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.3023255813953488" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2765957446808511" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "42" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.2947906140016858" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.56" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.3023255813953488" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2765957446808511" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "43" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.2887717709777644" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.56" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.3023255813953488" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2765957446808511" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "44" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.2825759449476026" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.56" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.3023255813953488" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2765957446808511" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "45" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.2770079451325" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.56" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.3023255813953488" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2765957446808511" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "46" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.2715366566940793" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.56" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.30952380952380953" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2765957446808511" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "47" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.266929566776089" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.56" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.30952380952380953" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2765957446808511" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "48" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.2626329537964194" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.56" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.30952380952380953" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2765957446808511" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'epoch: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "49" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'loss: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "1.2600517650827532" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'acc: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.56" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'prec: '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.30952380952380953" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'recall: : '" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "0.2765957446808511" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "for epoch in range(50):\n", " loss_score = 0\n", " acc_score = 0\n", " prec_score = 0\n", " selected_items = 0\n", " recall_score = 0\n", " relevant_items = 0\n", " items_total = 0\n", " ner_model.train()\n", " #for i in range(len(train_labels)):\n", " for i in range(100):\n", " for j in range(1, len(train_labels[i]) - 1):\n", " \n", " X_base = train_tokens_ids[i][j-1: j+2]\n", " X_add = train_x[i][j-1: j+2]\n", " X_final = add_features(X_base, X_add)\n", " \n", " Y = train_labels[i][j: j+1]\n", "\n", " Y_predictions = ner_model(X_final)\n", " \n", " \n", " acc_score += int(torch.argmax(Y_predictions) == Y)\n", " \n", " if torch.argmax(Y_predictions) != 0:\n", " selected_items +=1\n", " if torch.argmax(Y_predictions) != 0 and torch.argmax(Y_predictions) == Y.item():\n", " prec_score += 1\n", " \n", " if Y.item() != 0:\n", " relevant_items +=1\n", " if Y.item() != 0 and torch.argmax(Y_predictions) == Y.item():\n", " recall_score += 1\n", " \n", " items_total += 1\n", "\n", " \n", " optimizer.zero_grad()\n", " loss = criterion(Y_predictions.unsqueeze(0), Y)\n", " loss.backward()\n", " optimizer.step()\n", "\n", "\n", " loss_score += loss.item() \n", " \n", " precision = prec_score / selected_items\n", " recall = recall_score / relevant_items\n", " #f1_score = (2*precision * recall) / (precision + recall)\n", " display('epoch: ', epoch)\n", " display('loss: ', loss_score / items_total)\n", " display('acc: ', acc_score / items_total)\n", " display('prec: ', precision)\n", " display('recall: : ', recall)\n", " #display('f1: ', f1_score)" ] }, { "cell_type": "code", "execution_count": 37, "id": "defensive-discretion", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.29213483146067415" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(2*precision * recall) / (precision + recall)" ] }, { "cell_type": "code", "execution_count": 38, "id": "common-national", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([ -0.5326, -1.1218, -10.0297, -0.1610, -10.9741, 1.3533, -11.9781,\n", " 0.6097, -9.1263], grad_fn=)" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Y_predictions" ] }, { "cell_type": "code", "execution_count": 45, "id": "isolated-excess", "metadata": {}, "outputs": [], "source": [ "ner_tags_re = {\n", " 0: 'O',\n", " 1: 'B-PER',\n", " 2: 'B-LOC',\n", " 3: 'I-PER',\n", " 4: 'B-MISC',\n", " 5: 'I-MISC',\n", " 6: 'I-LOC',\n", " 7: 'B-ORG',\n", " 8: 'I-ORG'\n", "}\n", "\n", "def generate_out(folder_path):\n", " ner_model.eval()\n", " ner_model.cpu()\n", " print('Generating out')\n", " X_dev = []\n", " with open(f\"{folder_path}/in.tsv\", 'r') as file:\n", " for line in file:\n", " line = line.strip()\n", " X_dev.append(line.split(' '))\n", " test_tokens_ids = data_process(X_dev)\n", "\n", " predicted_values = []\n", " # for i in range(100):\n", " for i in range(len(test_tokens_ids)):\n", " pred_string = ''\n", " for j in range(1, len(test_tokens_ids[i]) - 1):\n", " X = test_tokens_ids[i][j - 1: j + 2]\n", " X_raw_single = X_dev[i][j - 1: j + 2]\n", " X = add_features(X, X_raw_single)\n", " \n", " # X = X.to(device)\n", " # print('train is cuda?', X.is_cuda)\n", "\n", " try:\n", " Y_predictions = ner_model(X)\n", " id = torch.argmax(Y_predictions)\n", " val = ner_tags_re[int(id)]\n", " pred_string += val + ' '\n", " except Exception as e:\n", " print('Error', e)\n", " predicted_values.append(pred_string[:-1])\n", " lines = []\n", " for line in predicted_values:\n", " last_label = None\n", " line = line.split(' ')\n", " new_line = []\n", " for label in line:\n", " if (label != \"O\" and label[0:2] == \"I-\"):\n", " if last_label == None or last_label == \"O\":\n", " label = label.replace('I-', 'B-')\n", " else:\n", " label = \"I-\" + last_label[2:]\n", " last_label = label\n", " new_line.append(label)\n", " lines.append(\" \".join(new_line))\n", " with open(f\"{folder_path}/out.tsv\", \"w\") as f:\n", " for line in lines:\n", " f.write(str(line) + \"\\n\")\n", "\n", " f.close()" ] }, { "cell_type": "code", "execution_count": 46, "id": "d362007a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Generating out\n", "step 6\n", "Generating out\n", "step 6\n" ] } ], "source": [ "generate_out('dev-0')\n", "generate_out('test-A')" ] }, { "cell_type": "code", "execution_count": null, "id": "c7bdb256", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" } }, "nbformat": 4, "nbformat_minor": 5 }