bigram solution
This commit is contained in:
parent
21ed054e66
commit
a05b52d6d2
1
.gitignore
vendored
1
.gitignore
vendored
@ -9,3 +9,4 @@
|
||||
|
||||
dev-0/in.tsv
|
||||
test-A/in.tsv
|
||||
train/in.tsv
|
80
bigram.py
Normal file
80
bigram.py
Normal file
@ -0,0 +1,80 @@
|
||||
import collections
|
||||
import re
|
||||
import random
|
||||
import math
|
||||
|
||||
input_file_path = "train/in.tsv"
|
||||
bigrams = collections.defaultdict(lambda: collections.defaultdict(int))
|
||||
|
||||
|
||||
def clean_text(text: str):
|
||||
text = text.replace('\n', ' ')
|
||||
text = re.sub(r'[^a-zA-Z0-9\s]', '', text)
|
||||
text = text.lower()
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
with open('train/expected.tsv', 'r', encoding="utf-8") as f:
|
||||
expected = [line for line in f]
|
||||
|
||||
with open(input_file_path, 'r', encoding="utf-8") as f:
|
||||
data = [line.split('\t') for line in f]
|
||||
|
||||
#data = data[:200000] # total is over 400 000
|
||||
|
||||
combined = []
|
||||
|
||||
for idx, row in enumerate(data):
|
||||
line = clean_text(row[6]) + ' ' + expected[idx] + ' ' + clean_text(row[7])
|
||||
combined.append(line.lower())
|
||||
|
||||
|
||||
for line in combined:
|
||||
tokens = re.findall(r"\b\w+\b", line)
|
||||
|
||||
for i in range(len(tokens) - 1):
|
||||
bigrams[tokens[i]][tokens[i+1]] += 1
|
||||
|
||||
|
||||
|
||||
most_popular_words = [
|
||||
"be:0.5 and:0.2 of:0.1 :0.2",
|
||||
"a:0.5 in:0.2 to:0.1 :0.2",
|
||||
"have:0.5 too:0.2 it:0.1 :0.2",
|
||||
"I:0.5 that:0.2 for:0.1 :0.2",
|
||||
"you:0.5 he:0.2 with:0.1 :0.2",
|
||||
"on:0.5 do:0.2 say:0.1 :0.2",
|
||||
"this:0.5 they:0.2 at:0.1 :0.2",
|
||||
"but:0.5 we:0.2 his:0.1 :0.2"
|
||||
]
|
||||
|
||||
|
||||
with open('test-A/in.tsv', "r", encoding="utf-8") as input_file, open('test-A/out.tsv', "w", encoding="utf-8") as output_file:
|
||||
|
||||
lines = input_file.readlines()
|
||||
|
||||
for idx, line in enumerate(lines):
|
||||
tokens = re.findall(r"\b\w+\b", clean_text(line.split("\t")[6]))
|
||||
|
||||
probabilities = []
|
||||
denominator = sum(bigrams[tokens[-1]].values())
|
||||
|
||||
for possible_word in bigrams[tokens[-1]]:
|
||||
probability = bigrams[tokens[-1]][possible_word] / denominator
|
||||
probabilities.append((possible_word, probability))
|
||||
|
||||
probabilities.sort(key=lambda x: x[1], reverse=True)
|
||||
print(f'Line {idx} of {len(lines)}')
|
||||
|
||||
if len(probabilities) >= 3:
|
||||
out_line = ""
|
||||
out_line += probabilities[0][0] + ":0.6 "
|
||||
out_line += probabilities[1][0] + ":0.2 "
|
||||
out_line += probabilities[2][0] + ":0.1 "
|
||||
out_line += ":0.1"
|
||||
output_file.write(out_line + "\n")
|
||||
|
||||
else:
|
||||
output_file.write(random.choice(most_popular_words) + "\n")
|
||||
|
21038
dev-0/out.tsv
21038
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
14828
test-A/out.tsv
14828
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user