This commit is contained in:
s430705 2021-06-08 14:56:48 +02:00
parent 7e71e86b87
commit 31fab4d6b0

View File

@ -32,11 +32,8 @@
"id": "7d9d7e79" "id": "7d9d7e79"
}, },
"source": [ "source": [
"import os.path\n",
"import pandas as pd\n", "import pandas as pd\n",
"import numpy as np\n",
"import torch\n", "import torch\n",
"import csv\n",
"from collections import Counter\n", "from collections import Counter\n",
"from torchtext.vocab import Vocab\n" "from torchtext.vocab import Vocab\n"
], ],
@ -51,6 +48,7 @@
}, },
"source": [ "source": [
"def predict(input_tokens, labels):\n", "def predict(input_tokens, labels):\n",
"\n",
" results = []\n", " results = []\n",
" \n", " \n",
" for i in range(len(input_tokens)):\n", " for i in range(len(input_tokens)):\n",
@ -234,7 +232,7 @@
"outputId": "38dac368-b5dc-4ad8-ec5a-a8ae7abf11d2" "outputId": "38dac368-b5dc-4ad8-ec5a-a8ae7abf11d2"
}, },
"source": [ "source": [
"data = pd.read_csv('train.tsv', sep='\\t', names=['iob', 'tokens'])\n", "data = pd.read_csv('train/train.tsv', sep='\\t', names=['iob', 'tokens'])\n",
"data[\"iob\"] = data[\"iob\"].apply(lambda x: [labels.index(y) for y in x.split()])\n", "data[\"iob\"] = data[\"iob\"].apply(lambda x: [labels.index(y) for y in x.split()])\n",
"data[\"tokens\"] = data[\"tokens\"].apply(lambda x: x.split())\n", "data[\"tokens\"] = data[\"tokens\"].apply(lambda x: x.split())\n",
"\n", "\n",
@ -465,7 +463,7 @@
"id": "rXY6j7-qt7gU" "id": "rXY6j7-qt7gU"
}, },
"source": [ "source": [
"dev = pd.read_csv('in.tsv', sep='\\t', names=['tokens'])\n", "dev = pd.read_csv('dev-0/in.tsv', sep='\\t', names=['tokens'])\n",
"dev[\"tokens\"] = dev[\"tokens\"].apply(lambda x: x.split())\n", "dev[\"tokens\"] = dev[\"tokens\"].apply(lambda x: x.split())\n",
"\n", "\n",
"dev_tokens_ids = data_process(dev[\"tokens\"])\n", "dev_tokens_ids = data_process(dev[\"tokens\"])\n",
@ -476,7 +474,7 @@
"\n", "\n",
"results = predict(dev_tensors, labels)\n", "results = predict(dev_tensors, labels)\n",
"results_processed = process_output(results)\n", "results_processed = process_output(results)\n",
"save_to_file(\"out.tsv\", results_processed)" "save_to_file(\"dev-0/out.tsv\", results_processed)"
], ],
"id": "rXY6j7-qt7gU", "id": "rXY6j7-qt7gU",
"execution_count": 57, "execution_count": 57,
@ -488,7 +486,18 @@
"id": "1lGYlL6iliGM" "id": "1lGYlL6iliGM"
}, },
"source": [ "source": [
"" "test = pd.read_csv('test-A/in.tsv', sep='\\t', names=['tokens'])\n",
"test[\"tokens\"] = test[\"tokens\"].apply(lambda x: x.split())\n",
"\n",
"test_tokens_ids = data_process(test[\"tokens\"])\n",
"\n",
"test_extra_tensors = create_tensors_list(test)\n",
"\n",
"test_tensors = extra_features(test_tokens_ids, test_extra_tensors)\n",
"\n",
"results = predict(test_tensors, labels)\n",
"results_processed = process_output(results)\n",
"save_to_file(\"test-A/out.tsv\", results_processed)"
], ],
"id": "1lGYlL6iliGM", "id": "1lGYlL6iliGM",
"execution_count": null, "execution_count": null,