test 5 version
This commit is contained in:
parent
fa9858ae9a
commit
4e91ffc077
16882
dev-0/out.tsv
16882
dev-0/out.tsv
File diff suppressed because one or more lines are too long
12
model.py
12
model.py
@ -21,14 +21,16 @@ def preprocess(text):
|
|||||||
|
|
||||||
|
|
||||||
def predict(word_before, word_after):
|
def predict(word_before, word_after):
|
||||||
prob_list = model[(word_before, word_after)].items()
|
prob_list = dict(Counter(model[(word_before, word_after)]).most_common(5)).items()
|
||||||
predictions = []
|
predictions = []
|
||||||
total = 0.0
|
prob_sum = 0.0
|
||||||
for key, value in prob_list:
|
for key, value in prob_list:
|
||||||
total += value
|
prob_sum += value
|
||||||
predictions.append(f'{key}:{value}')
|
predictions.append(f'{key}:{value}')
|
||||||
if total == 0.0:
|
if prob_sum == 0.0:
|
||||||
return 'the:1.0'
|
return 'the:0:2 be:0.2 to:0.2 of:0.15 and:0.15 :0.1'
|
||||||
|
elif prob_sum < 1.0:
|
||||||
|
predictions.append(f':{1.0 - prob_sum}')
|
||||||
return ' '.join(predictions)
|
return ' '.join(predictions)
|
||||||
|
|
||||||
|
|
||||||
|
11928
test-A/out.tsv
11928
test-A/out.tsv
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user