Word gap prediction - solution
This commit is contained in:
parent
61e88a9c8c
commit
be8f7cc880
1
.gitignore
vendored
1
.gitignore
vendored
@ -6,3 +6,4 @@
|
|||||||
*.o
|
*.o
|
||||||
.DS_Store
|
.DS_Store
|
||||||
.token
|
.token
|
||||||
|
env
|
10519
dev-0/out.tsv
Normal file
10519
dev-0/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
65
run.py
Normal file
65
run.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import pandas as pd
|
||||||
|
import csv
|
||||||
|
from collections import Counter, defaultdict
|
||||||
|
from nltk.tokenize import RegexpTokenizer
|
||||||
|
from nltk import trigrams
|
||||||
|
|
||||||
|
|
||||||
|
class WordGapPrediction:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.tokenizer = RegexpTokenizer(r"\w+")
|
||||||
|
self.model = defaultdict(lambda: defaultdict(lambda: 0))
|
||||||
|
|
||||||
|
def read_train_data(self, file):
|
||||||
|
data = pd.read_csv(file, sep="\t", error_bad_lines=False, index_col=0, header=None)
|
||||||
|
for index, row in data[:90000].iterrows():
|
||||||
|
text = str(row[6]) + ' ' + str(row[7])
|
||||||
|
tokens = self.tokenizer.tokenize(text)
|
||||||
|
for w1, w2, w3 in trigrams(tokens, pad_right=True, pad_left=True):
|
||||||
|
if w1 and w2 and w3:
|
||||||
|
self.model[(w2, w3)][w1] += 1
|
||||||
|
self.model[(w1, w2)][w3] += 1
|
||||||
|
|
||||||
|
for word_pair in self.model:
|
||||||
|
num_n_grams = float(sum(self.model[word_pair].values()))
|
||||||
|
for word in self.model[word_pair]:
|
||||||
|
self.model[word_pair][word] /= num_n_grams
|
||||||
|
|
||||||
|
def generate_outputs(self, input_file, output_file):
|
||||||
|
data = pd.read_csv(input_file, sep='\t', error_bad_lines=False, index_col=0, header=None, quoting=csv.QUOTE_NONE)
|
||||||
|
with open(output_file, 'w') as f:
|
||||||
|
for index, row in data.iterrows():
|
||||||
|
text = str(row[7])
|
||||||
|
tokens = self.tokenizer.tokenize(text)
|
||||||
|
if len(tokens) < 4:
|
||||||
|
prediction = 'the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1'
|
||||||
|
else:
|
||||||
|
prediction = word_gap_prediction.predict_probs(tokens[0], tokens[1])
|
||||||
|
f.write(prediction + '\n')
|
||||||
|
|
||||||
|
def predict_probs(self, word1, word2):
|
||||||
|
predictions = dict(self.model[word1, word2])
|
||||||
|
most_common = dict(Counter(predictions).most_common(6))
|
||||||
|
|
||||||
|
total_prob = 0.0
|
||||||
|
str_prediction = ''
|
||||||
|
|
||||||
|
for word, prob in most_common.items():
|
||||||
|
total_prob += prob
|
||||||
|
str_prediction += f'{word}:{prob} '
|
||||||
|
|
||||||
|
if total_prob == 0.0:
|
||||||
|
return 'the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1'
|
||||||
|
|
||||||
|
if 1 - total_prob >= 0.01:
|
||||||
|
str_prediction += f":{1-total_prob}"
|
||||||
|
else:
|
||||||
|
str_prediction += f":0.01"
|
||||||
|
|
||||||
|
return str_prediction
|
||||||
|
|
||||||
|
word_gap_prediction = WordGapPrediction()
|
||||||
|
word_gap_prediction.read_train_data('./train/in.tsv')
|
||||||
|
word_gap_prediction.generate_outputs('dev-0/in.tsv', 'dev-0/out.tsv')
|
||||||
|
word_gap_prediction.generate_outputs('test-A/in.tsv', 'test-A/out.tsv')
|
7414
test-A/out.tsv
Normal file
7414
test-A/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user