This commit is contained in:
s430705 2021-06-08 14:56:48 +02:00
parent 7e71e86b87
commit 31fab4d6b0

View File

@ -32,11 +32,8 @@
"id": "7d9d7e79"
},
"source": [
"import os.path\n",
"import pandas as pd\n",
"import numpy as np\n",
"import torch\n",
"import csv\n",
"from collections import Counter\n",
"from torchtext.vocab import Vocab\n"
],
@ -51,6 +48,7 @@
},
"source": [
"def predict(input_tokens, labels):\n",
"\n",
" results = []\n",
" \n",
" for i in range(len(input_tokens)):\n",
@ -234,7 +232,7 @@
"outputId": "38dac368-b5dc-4ad8-ec5a-a8ae7abf11d2"
},
"source": [
"data = pd.read_csv('train.tsv', sep='\\t', names=['iob', 'tokens'])\n",
"data = pd.read_csv('train/train.tsv', sep='\\t', names=['iob', 'tokens'])\n",
"data[\"iob\"] = data[\"iob\"].apply(lambda x: [labels.index(y) for y in x.split()])\n",
"data[\"tokens\"] = data[\"tokens\"].apply(lambda x: x.split())\n",
"\n",
@ -465,7 +463,7 @@
"id": "rXY6j7-qt7gU"
},
"source": [
"dev = pd.read_csv('in.tsv', sep='\\t', names=['tokens'])\n",
"dev = pd.read_csv('dev-0/in.tsv', sep='\\t', names=['tokens'])\n",
"dev[\"tokens\"] = dev[\"tokens\"].apply(lambda x: x.split())\n",
"\n",
"dev_tokens_ids = data_process(dev[\"tokens\"])\n",
@ -476,7 +474,7 @@
"\n",
"results = predict(dev_tensors, labels)\n",
"results_processed = process_output(results)\n",
"save_to_file(\"out.tsv\", results_processed)"
"save_to_file(\"dev-0/out.tsv\", results_processed)"
],
"id": "rXY6j7-qt7gU",
"execution_count": 57,
@ -488,7 +486,18 @@
"id": "1lGYlL6iliGM"
},
"source": [
""
"test = pd.read_csv('test-A/in.tsv', sep='\\t', names=['tokens'])\n",
"test[\"tokens\"] = test[\"tokens\"].apply(lambda x: x.split())\n",
"\n",
"test_tokens_ids = data_process(test[\"tokens\"])\n",
"\n",
"test_extra_tensors = create_tensors_list(test)\n",
"\n",
"test_tensors = extra_features(test_tokens_ids, test_extra_tensors)\n",
"\n",
"results = predict(test_tensors, labels)\n",
"results_processed = process_output(results)\n",
"save_to_file(\"test-A/out.tsv\", results_processed)"
],
"id": "1lGYlL6iliGM",
"execution_count": null,