This commit is contained in:
Maciej(Linux) 2022-04-11 00:56:48 +02:00
parent 58127c0cf0
commit b20466a23f

36
run.py
View File

@ -1,11 +1,11 @@
from nltk import trigrams, word_tokenize from nltk import tris, word_tokenize
import pandas as pd import pandas as pd
import csv import csv
import regex as re import regex as re
from collections import Counter, defaultdict from collections import Counter, defaultdict
train_set = pd.read_csv( train = pd.read_csv(
'train/in.tsv.xz', 'train/in.tsv.xz',
sep='\t', sep='\t',
on_bad_lines='skip', on_bad_lines='skip',
@ -14,7 +14,7 @@ train_set = pd.read_csv(
nrows=50000) nrows=50000)
train_labels = pd.read_csv( labels = pd.read_csv(
'train/expected.tsv', 'train/expected.tsv',
sep='\t', sep='\t',
on_bad_lines='skip', on_bad_lines='skip',
@ -28,7 +28,7 @@ def data_preprocessing(text):
def predict(before, after): def predict(before, after):
prediction = dict(Counter(dict(trigram[before, after])).most_common(5)) prediction = dict(Counter(dict(tri[before, after])).most_common(5))
result = '' result = ''
prob = 0.0 prob = 0.0
for key, value in prediction.items(): for key, value in prediction.items():
@ -52,28 +52,28 @@ def make_prediction(file):
file_out.write(prediction + '\n') file_out.write(prediction + '\n')
train_set = train_set[[6, 7]] train = train[[6, 7]]
train_set = pd.concat([train_set, train_labels], axis=1) train = pd.concat([train, labels], axis=1)
train_set['line'] = train_set[6] + train_set[0] + train_set[7] train['line'] = train[6] + train[0] + train[7]
trigram = defaultdict(lambda: defaultdict(lambda: 0)) tri = defaultdict(lambda: defaultdict(lambda: 0))
rows = train_set.iterrows() rows = train.iterrows()
rows_len = len(train_set) rows_len = len(train)
for index, (_, row) in enumerate(rows): for index, (_, row) in enumerate(rows):
text = data_preprocessing(str(row['line'])) text = data_preprocessing(str(row['line']))
words = word_tokenize(text) words = word_tokenize(text)
for word_1, word_2, word_3 in trigrams(words, pad_right=True, pad_left=True): for word_1, word_2, word_3 in tris(words, pad_right=True, pad_left=True):
if word_1 and word_2 and word_3: if word_1 and word_2 and word_3:
trigram[(word_1, word_3)][word_2] += 1 tri[(word_1, word_3)][word_2] += 1
model_len = len(trigram) model_len = len(tri)
for index, words_1_3 in enumerate(trigram): for index, words_1_3 in enumerate(tri):
count = sum(trigram[words_1_3].values()) count = sum(tri[words_1_3].values())
for word_2 in trigram[words_1_3]: for word_2 in tri[words_1_3]:
trigram[words_1_3][word_2] += 0.25 tri[words_1_3][word_2] += 0.25
trigram[words_1_3][word_2] /= float(count + 0.25 + len(word_2)) tri[words_1_3][word_2] /= float(count + 0.25 + len(word_2))
make_prediction('test-A') make_prediction('test-A')