This commit is contained in:
s440054 2022-04-04 15:07:07 +02:00
parent 42d25a2e0f
commit 5af6e29a07
5 changed files with 6497 additions and 356 deletions

File diff suppressed because it is too large Load Diff

65
run.py
View File

@ -1,64 +1,38 @@
import pandas as pd import pandas as pd
import csv
import regex as re
import nltk import nltk
from collections import Counter, defaultdict from collections import Counter, defaultdict
import string from utils import get_csv, check_prerequisites, ENCODING, clean_text
import unicodedata
def main(): def main():
try: check_prerequisites()
nltk.data.find('tokenizers/punkt')
except LookupError:
nltk.download('punkt')
with open("in-header.tsv") as f: data = get_csv("train/in.tsv.xz")
in_cols = f.read().strip().split("\t")
with open("out-header.tsv") as f: train_words = get_csv("train/expected.tsv")
out_cols = f.read().strip().split("\t")
data = pd.read_csv(
"train/in.tsv.xz",
sep="\t",
on_bad_lines='skip',
header=None,
# names=in_cols,
quoting=csv.QUOTE_NONE,
)
train_labels = pd.read_csv(
"train/expected.tsv",
sep="\t",
on_bad_lines='skip',
header=None,
# names=out_cols,
quoting=csv.QUOTE_NONE,
)
train_data = data[[7, 6]] train_data = data[[7, 6]]
train_data = pd.concat([train_data, train_labels], axis=1) train_data = pd.concat([train_data, train_words], axis=1)
train_data["final"] = train_data[7] + train_data[0] + train_data[6] train_data[760] = train_data[7] + train_data[0] + train_data[6]
model = defaultdict(lambda: defaultdict(lambda: 0)) model = defaultdict(lambda: defaultdict(lambda: 0))
train_model(train_data, model) train_model(train_data, model)
predict_data("dev-0/in.tsv.xz", "dev-0/out.tsv", model) predict_data("dev-0/in.tsv.xz", "dev-0/out.tsv", model)
predict_data("test-A/in.tsv.xz", "test-A/out.tsv", model) predict_data("test-A/in.tsv.xz", "test-A/out.tsv", model)
def clean_text(text):
return re.sub(r"\p{P}", "", str(text).lower().replace("-\\n", "").replace("\\n", " "))
def train_model(data, model): def train_model(data, model):
for _, row in data.iterrows(): for _, row in data.iterrows():
words = nltk.word_tokenize(clean_text(row["final"])) words = nltk.word_tokenize(clean_text(row[760]))
for w1, w2 in nltk.bigrams(words, pad_left=True, pad_right=True): for w1, w2 in nltk.bigrams(words, pad_left=True, pad_right=True):
if w1 and w2: if w1 and w2:
model[w2][w1] += 1 model[w2][w1] += 1
for w1 in model: for w2 in model:
total_count = float(sum(model[w1].values())) total_count = float(sum(model[w2].values()))
for w2 in model[w1]: for w1 in model[w2]:
model[w2][w1] /= total_count model[w2][w1] /= total_count
@ -85,21 +59,16 @@ def predict(word, model):
def predict_data(read_path, save_path, model): def predict_data(read_path, save_path, model):
data = pd.read_csv( data = get_csv(read_path)
read_path,
sep="\t", with open(save_path, "w", encoding=ENCODING) as f:
error_bad_lines=False,
header=None,
quoting=csv.QUOTE_NONE
)
with open(save_path, "w") as file:
for _, row in data.iterrows(): for _, row in data.iterrows():
words = nltk.word_tokenize(clean_text(row[6])) words = nltk.word_tokenize(clean_text(row[7]))
if len(words) < 3: if len(words) < 3:
prediction = "the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1" prediction = "the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1"
else: else:
prediction = predict(words[-1], model) prediction = predict(words[-1], model)
file.write(prediction + "\n") f.write(prediction + "\n")
if __name__ == "__main__": if __name__ == "__main__":

File diff suppressed because it is too large Load Diff

View File

@ -2,7 +2,7 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 18,
"id": "21c9b695", "id": "21c9b695",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -59,16 +59,16 @@
" error_bad_lines=False,\n", " error_bad_lines=False,\n",
" header=None,\n", " header=None,\n",
" quoting=csv.QUOTE_NONE,\n", " quoting=csv.QUOTE_NONE,\n",
" encoding=\"utf8\"\n", " encoding=\"utf-8\"\n",
" )\n", " )\n",
" with open(save_path, \"w\") as file:\n", " with open(save_path, \"w\", encoding=\"utf-8\") as f:\n",
" for _, row in data.iterrows():\n", " for _, row in data.iterrows():\n",
" words = nltk.word_tokenize(clean_text(row[7]))\n", " words = nltk.word_tokenize(clean_text(row[7]))\n",
" if len(words) < 3:\n", " if len(words) < 3:\n",
" prediction = \"the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1\"\n", " prediction = \"the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1\"\n",
" else:\n", " else:\n",
" prediction = predict(words[-1], model)\n", " prediction = predict(words[-1], model)\n",
" file.write(prediction + \"\\n\")\n" " f.write(prediction + \"\\n\")\n"
] ]
}, },
{ {
@ -141,6 +141,7 @@
" header=None,\n", " header=None,\n",
" # names=in_cols,\n", " # names=in_cols,\n",
" quoting=csv.QUOTE_NONE,\n", " quoting=csv.QUOTE_NONE,\n",
" encoding=\"utf-8\"\n",
")\n", ")\n",
"\n", "\n",
"train_words = pd.read_csv(\n", "train_words = pd.read_csv(\n",
@ -149,7 +150,8 @@
" on_bad_lines='skip',\n", " on_bad_lines='skip',\n",
" header=None,\n", " header=None,\n",
" # names=out_cols,\n", " # names=out_cols,\n",
" quoting=csv.QUOTE_NONE,\n", " quoting=csv.QUOTE_NONE,,\n",
" encoding=\"utf-8\"\n",
")\n", ")\n",
"\n", "\n",
"train_data = data[[7, 6]]\n", "train_data = data[[7, 6]]\n",
@ -390,10 +392,21 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 19,
"id": "195cb6cf", "id": "195cb6cf",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\Norbert\\AppData\\Local\\Temp\\ipykernel_15436\\751703071.py:47: FutureWarning: The error_bad_lines argument has been deprecated and will be removed in a future version. Use on_bad_lines in the future.\n",
"\n",
"\n",
" data = pd.read_csv(\n"
]
}
],
"source": [ "source": [
"predict_data(\"test-A/in.tsv.xz\", \"test-A/out.tsv\", model)" "predict_data(\"test-A/in.tsv.xz\", \"test-A/out.tsv\", model)"
] ]

28
utils.py Normal file
View File

@ -0,0 +1,28 @@
import nltk
import pandas as pd
import regex as re
from csv import QUOTE_NONE
ENCODING = "utf-8"
def clean_text(text):
return re.sub(r"\p{P}", "", str(text).lower().replace("-\\n", "").replace("\\n", " "))
def get_csv(fname):
return pd.read_csv(
fname,
sep="\t",
on_bad_lines='skip',
header=None,
quoting=QUOTE_NONE,
encoding=ENCODING
)
def check_prerequisites():
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
nltk.download('punkt')