bigram solution
This commit is contained in:
parent
21ed054e66
commit
a05b52d6d2
3
.gitignore
vendored
3
.gitignore
vendored
@ -8,4 +8,5 @@
|
|||||||
.token
|
.token
|
||||||
|
|
||||||
dev-0/in.tsv
|
dev-0/in.tsv
|
||||||
test-A/in.tsv
|
test-A/in.tsv
|
||||||
|
train/in.tsv
|
80
bigram.py
Normal file
80
bigram.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
import collections
|
||||||
|
import re
|
||||||
|
import random
|
||||||
|
import math
|
||||||
|
|
||||||
|
input_file_path = "train/in.tsv"
|
||||||
|
bigrams = collections.defaultdict(lambda: collections.defaultdict(int))
|
||||||
|
|
||||||
|
|
||||||
|
def clean_text(text: str):
|
||||||
|
text = text.replace('\n', ' ')
|
||||||
|
text = re.sub(r'[^a-zA-Z0-9\s]', '', text)
|
||||||
|
text = text.lower()
|
||||||
|
text = text.strip()
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
with open('train/expected.tsv', 'r', encoding="utf-8") as f:
|
||||||
|
expected = [line for line in f]
|
||||||
|
|
||||||
|
with open(input_file_path, 'r', encoding="utf-8") as f:
|
||||||
|
data = [line.split('\t') for line in f]
|
||||||
|
|
||||||
|
#data = data[:200000] # total is over 400 000
|
||||||
|
|
||||||
|
combined = []
|
||||||
|
|
||||||
|
for idx, row in enumerate(data):
|
||||||
|
line = clean_text(row[6]) + ' ' + expected[idx] + ' ' + clean_text(row[7])
|
||||||
|
combined.append(line.lower())
|
||||||
|
|
||||||
|
|
||||||
|
for line in combined:
|
||||||
|
tokens = re.findall(r"\b\w+\b", line)
|
||||||
|
|
||||||
|
for i in range(len(tokens) - 1):
|
||||||
|
bigrams[tokens[i]][tokens[i+1]] += 1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
most_popular_words = [
|
||||||
|
"be:0.5 and:0.2 of:0.1 :0.2",
|
||||||
|
"a:0.5 in:0.2 to:0.1 :0.2",
|
||||||
|
"have:0.5 too:0.2 it:0.1 :0.2",
|
||||||
|
"I:0.5 that:0.2 for:0.1 :0.2",
|
||||||
|
"you:0.5 he:0.2 with:0.1 :0.2",
|
||||||
|
"on:0.5 do:0.2 say:0.1 :0.2",
|
||||||
|
"this:0.5 they:0.2 at:0.1 :0.2",
|
||||||
|
"but:0.5 we:0.2 his:0.1 :0.2"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
with open('test-A/in.tsv', "r", encoding="utf-8") as input_file, open('test-A/out.tsv', "w", encoding="utf-8") as output_file:
|
||||||
|
|
||||||
|
lines = input_file.readlines()
|
||||||
|
|
||||||
|
for idx, line in enumerate(lines):
|
||||||
|
tokens = re.findall(r"\b\w+\b", clean_text(line.split("\t")[6]))
|
||||||
|
|
||||||
|
probabilities = []
|
||||||
|
denominator = sum(bigrams[tokens[-1]].values())
|
||||||
|
|
||||||
|
for possible_word in bigrams[tokens[-1]]:
|
||||||
|
probability = bigrams[tokens[-1]][possible_word] / denominator
|
||||||
|
probabilities.append((possible_word, probability))
|
||||||
|
|
||||||
|
probabilities.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
print(f'Line {idx} of {len(lines)}')
|
||||||
|
|
||||||
|
if len(probabilities) >= 3:
|
||||||
|
out_line = ""
|
||||||
|
out_line += probabilities[0][0] + ":0.6 "
|
||||||
|
out_line += probabilities[1][0] + ":0.2 "
|
||||||
|
out_line += probabilities[2][0] + ":0.1 "
|
||||||
|
out_line += ":0.1"
|
||||||
|
output_file.write(out_line + "\n")
|
||||||
|
|
||||||
|
else:
|
||||||
|
output_file.write(random.choice(most_popular_words) + "\n")
|
||||||
|
|
21038
dev-0/out.tsv
21038
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
14828
test-A/out.tsv
14828
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user