81 lines
2.2 KiB
Python
81 lines
2.2 KiB
Python
|
import collections
|
||
|
import re
|
||
|
import random
|
||
|
import math
|
||
|
|
||
|
input_file_path = "train/in.tsv"
|
||
|
bigrams = collections.defaultdict(lambda: collections.defaultdict(int))
|
||
|
|
||
|
|
||
|
def clean_text(text: str):
|
||
|
text = text.replace('\n', ' ')
|
||
|
text = re.sub(r'[^a-zA-Z0-9\s]', '', text)
|
||
|
text = text.lower()
|
||
|
text = text.strip()
|
||
|
return text
|
||
|
|
||
|
|
||
|
with open('train/expected.tsv', 'r', encoding="utf-8") as f:
|
||
|
expected = [line for line in f]
|
||
|
|
||
|
with open(input_file_path, 'r', encoding="utf-8") as f:
|
||
|
data = [line.split('\t') for line in f]
|
||
|
|
||
|
#data = data[:200000] # total is over 400 000
|
||
|
|
||
|
combined = []
|
||
|
|
||
|
for idx, row in enumerate(data):
|
||
|
line = clean_text(row[6]) + ' ' + expected[idx] + ' ' + clean_text(row[7])
|
||
|
combined.append(line.lower())
|
||
|
|
||
|
|
||
|
for line in combined:
|
||
|
tokens = re.findall(r"\b\w+\b", line)
|
||
|
|
||
|
for i in range(len(tokens) - 1):
|
||
|
bigrams[tokens[i]][tokens[i+1]] += 1
|
||
|
|
||
|
|
||
|
|
||
|
most_popular_words = [
|
||
|
"be:0.5 and:0.2 of:0.1 :0.2",
|
||
|
"a:0.5 in:0.2 to:0.1 :0.2",
|
||
|
"have:0.5 too:0.2 it:0.1 :0.2",
|
||
|
"I:0.5 that:0.2 for:0.1 :0.2",
|
||
|
"you:0.5 he:0.2 with:0.1 :0.2",
|
||
|
"on:0.5 do:0.2 say:0.1 :0.2",
|
||
|
"this:0.5 they:0.2 at:0.1 :0.2",
|
||
|
"but:0.5 we:0.2 his:0.1 :0.2"
|
||
|
]
|
||
|
|
||
|
|
||
|
with open('test-A/in.tsv', "r", encoding="utf-8") as input_file, open('test-A/out.tsv', "w", encoding="utf-8") as output_file:
|
||
|
|
||
|
lines = input_file.readlines()
|
||
|
|
||
|
for idx, line in enumerate(lines):
|
||
|
tokens = re.findall(r"\b\w+\b", clean_text(line.split("\t")[6]))
|
||
|
|
||
|
probabilities = []
|
||
|
denominator = sum(bigrams[tokens[-1]].values())
|
||
|
|
||
|
for possible_word in bigrams[tokens[-1]]:
|
||
|
probability = bigrams[tokens[-1]][possible_word] / denominator
|
||
|
probabilities.append((possible_word, probability))
|
||
|
|
||
|
probabilities.sort(key=lambda x: x[1], reverse=True)
|
||
|
print(f'Line {idx} of {len(lines)}')
|
||
|
|
||
|
if len(probabilities) >= 3:
|
||
|
out_line = ""
|
||
|
out_line += probabilities[0][0] + ":0.6 "
|
||
|
out_line += probabilities[1][0] + ":0.2 "
|
||
|
out_line += probabilities[2][0] + ":0.1 "
|
||
|
out_line += ":0.1"
|
||
|
output_file.write(out_line + "\n")
|
||
|
|
||
|
else:
|
||
|
output_file.write(random.choice(most_popular_words) + "\n")
|
||
|
|