79 lines
2.0 KiB
Python
79 lines
2.0 KiB
Python
import kenlm
|
|
import csv
|
|
|
|
def predict_probability(sentence):
|
|
return model.score(sentence)
|
|
|
|
def load_candidate_words(file_path):
|
|
with open(file_path, 'r', encoding='utf-8') as file:
|
|
candidate_words = {line.strip() for line in file}
|
|
return candidate_words
|
|
|
|
def predict_word_between(text1, text2, model, candidate_words):
|
|
max_prob = float("-inf")
|
|
best_word = None
|
|
|
|
for word in candidate_words:
|
|
sentence = f"{text1} {word} {text2}"
|
|
prob = model.score(sentence)
|
|
|
|
if prob > max_prob:
|
|
max_prob = prob
|
|
best_word = word
|
|
|
|
return best_word
|
|
|
|
dev = []
|
|
test = []
|
|
|
|
with open('dev-0/in_1.csv', 'r', newline='', encoding='utf-8') as file:
|
|
reader = csv.reader(file, delimiter=',')
|
|
|
|
for row in reader:
|
|
dev.append(row)
|
|
|
|
with open('test-A/in_1.csv', 'r', newline='', encoding='utf-8') as file:
|
|
reader = csv.reader(file, delimiter=',')
|
|
|
|
for row in reader:
|
|
test.append(row)
|
|
|
|
model_path = "model.binary"
|
|
model = kenlm.Model(model_path)
|
|
|
|
candidate_words_file = "words_3.txt"
|
|
candidate_words = load_candidate_words(candidate_words_file)
|
|
|
|
predicted_dev = []
|
|
predicted_test = []
|
|
|
|
i = 0
|
|
for row in dev:
|
|
text1 = row[0]
|
|
text2 = row[1]
|
|
predicted_word = predict_word_between(text1, text2, model, candidate_words)
|
|
predicted_dev.append(predicted_word)
|
|
if i % 500 == 0:
|
|
print(f'{i/len(dev)*100}%')
|
|
i += 1
|
|
|
|
with open('dev-0/out.tsv', 'w', newline='') as tsv_file:
|
|
tsv_writer = csv.writer(tsv_file, delimiter='\t')
|
|
for row in predicted_dev:
|
|
tsv_writer.writerow(row)
|
|
|
|
i = 0
|
|
for row in test:
|
|
text1 = row[0]
|
|
text2 = row[1]
|
|
predicted_word = predict_word_between(text1, text2, model, candidate_words)
|
|
predicted_test.append(predicted_word)
|
|
if i % 500 == 0:
|
|
print(f'{i/len(dev)*100}%')
|
|
i += 1
|
|
|
|
|
|
with open('test-A/out.tsv', 'w', newline='') as tsv_file:
|
|
tsv_writer = csv.writer(tsv_file, delimiter='\t')
|
|
for row in predicted_test:
|
|
tsv_writer.writerow(row) |