This commit is contained in:
Bartosz Karwacki 2022-04-03 19:30:55 +02:00
parent 2101b5f61c
commit a2064b7ed9
3 changed files with 17897 additions and 17879 deletions

File diff suppressed because it is too large Load Diff

30
run.py
View File

@ -3,6 +3,8 @@ import csv
import regex as re import regex as re
from nltk import bigrams, word_tokenize from nltk import bigrams, word_tokenize
from collections import Counter, defaultdict from collections import Counter, defaultdict
import string
import unicodedata
data = pd.read_csv( data = pd.read_csv(
"train/in.tsv.xz", "train/in.tsv.xz",
@ -10,7 +12,7 @@ data = pd.read_csv(
error_bad_lines=False, error_bad_lines=False,
header=None, header=None,
quoting=csv.QUOTE_NONE, quoting=csv.QUOTE_NONE,
nrows=200000, nrows=250000
) )
train_labels = pd.read_csv( train_labels = pd.read_csv(
"train/expected.tsv", "train/expected.tsv",
@ -18,7 +20,7 @@ train_labels = pd.read_csv(
error_bad_lines=False, error_bad_lines=False,
header=None, header=None,
quoting=csv.QUOTE_NONE, quoting=csv.QUOTE_NONE,
nrows=200000, nrows=250000
) )
train_data = data[[6, 7]] train_data = data[[6, 7]]
@ -30,8 +32,24 @@ model = defaultdict(lambda: defaultdict(lambda: 0))
def clean(text): def clean(text):
text = str(text).lower().replace("-\\n", "").replace("\\n", " ") text = str(text)
return re.sub(r"\p{P}", "", text) # normalize text
text = (
unicodedata.normalize('NFKD', text).encode('ascii', 'ignore').decode(
'utf-8', 'ignore'))
# replace html chars with ' '
text = re.sub('<.*?>', ' ', text)
# remove punctuation
text = text.translate(str.maketrans(' ', ' ', string.punctuation))
# only alphabets and numerics
text = re.sub('[^a-zA-Z]', ' ', text)
# replace newline with space
text = re.sub("\n", " ", text)
# lower case
text = text.lower()
# split and join the words
text = ' '.join(text.split())
return text
def train_model(data): def train_model(data):
@ -74,11 +92,11 @@ def predict_data(read_path, save_path):
) )
with open(save_path, "w") as file: with open(save_path, "w") as file:
for _, row in data.iterrows(): for _, row in data.iterrows():
words = word_tokenize(clean(row[7])) words = word_tokenize(clean(row[6]))
if len(words) < 3: if len(words) < 3:
prediction = "the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1" prediction = "the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1"
else: else:
prediction = predict(words[0]) prediction = predict(words[-1])
file.write(prediction + "\n") file.write(prediction + "\n")

File diff suppressed because it is too large Load Diff