102 lines
3.1 KiB
Python
102 lines
3.1 KiB
Python
|
import lzma
|
||
|
import pickle
|
||
|
from collections import Counter
|
||
|
|
||
|
|
||
|
def words(filename):
|
||
|
with lzma.open(filename, mode='rt', encoding='utf-8') as fid:
|
||
|
for line in fid:
|
||
|
separated = line.split('\t')
|
||
|
prefix = separated[6].replace(r'\n', ' ')
|
||
|
suffix = separated[7].replace(r'\n', ' ')
|
||
|
text = prefix + ' ' + suffix
|
||
|
for word in text.split():
|
||
|
yield word
|
||
|
|
||
|
|
||
|
def bigrams(filename, V):
|
||
|
V = [word for word, count in V]
|
||
|
with lzma.open(filename, mode='rt', encoding='utf-8') as fid:
|
||
|
for line in fid:
|
||
|
separated = line.split('\t')
|
||
|
prefix = separated[6].replace(r'\n', ' ')
|
||
|
suffix = separated[7].replace(r'\n', ' ')
|
||
|
text = prefix + ' ' + suffix
|
||
|
previous = ''
|
||
|
for word in text.split():
|
||
|
if word in V and previous in V:
|
||
|
yield previous, word
|
||
|
previous = word
|
||
|
|
||
|
|
||
|
def P(previous_word, word):
|
||
|
if word not in V:
|
||
|
return 0
|
||
|
if (previous_word, word) not in V2:
|
||
|
return 0
|
||
|
return V2[(previous_word, word)] / V[word]
|
||
|
|
||
|
|
||
|
def candidates(w1, w3):
|
||
|
cand = {}
|
||
|
for w2 in V:
|
||
|
cand[w2] = P(w1, w2) * P(w2, w3)
|
||
|
cand = sorted(list(cand.items()), key=lambda x: x[1], reverse=True)[:5]
|
||
|
try:
|
||
|
norm = [(x[0], float(x[1]) / sum([y[1] for y in cand])) for x in cand]
|
||
|
except ZeroDivisionError:
|
||
|
norm = [(x[0], 0.2) for x in cand]
|
||
|
norm[-1] = ('', norm[-1][1])
|
||
|
return ' '.join([f'{x[0]}:{x[1]}' for x in norm])
|
||
|
|
||
|
# WORD_LIMIT = 5000
|
||
|
#
|
||
|
# V = Counter(words('train/in.tsv.xz'))
|
||
|
# V = V.most_common(WORD_LIMIT)
|
||
|
# with open('V.pickle', 'wb') as handle:
|
||
|
# pickle.dump(V, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
||
|
# V2 = Counter(bigrams('train/in.tsv.xz', V))
|
||
|
# print(V2.most_common(100))
|
||
|
# with open('V2.pickle', 'wb') as handle:
|
||
|
# pickle.dump(V2, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
||
|
|
||
|
|
||
|
with open('V.pickle', 'rb') as handle:
|
||
|
V_tuple = pickle.load(handle)
|
||
|
V = {}
|
||
|
for key, value in V_tuple:
|
||
|
V[key] = value
|
||
|
|
||
|
with open('V2.pickle', 'rb') as handle:
|
||
|
V2 = pickle.load(handle)
|
||
|
|
||
|
|
||
|
|
||
|
with lzma.open('dev-0/in.tsv.xz', mode='rt', encoding='utf-8') as fid:
|
||
|
with open('dev-0/out.tsv', 'w', encoding='utf-8') as f:
|
||
|
for line in fid:
|
||
|
separated = line.split('\t')
|
||
|
prefix = separated[6].replace(r'\n', ' ')
|
||
|
suffix = separated[7].replace(r'\n', ' ')
|
||
|
w1 = prefix.split()[-1]
|
||
|
w3 = suffix.split()[0]
|
||
|
w2 = candidates(w1, w3)
|
||
|
print(w1)
|
||
|
print(w2)
|
||
|
print(w3)
|
||
|
f.write(w2 + '\n')
|
||
|
|
||
|
with lzma.open('test-A/in.tsv.xz', mode='rt', encoding='utf-8') as fid:
|
||
|
with open('test-A/out.tsv', 'w', encoding='utf-8') as f:
|
||
|
for line in fid:
|
||
|
separated = line.split('\t')
|
||
|
prefix = separated[6].replace(r'\n', ' ')
|
||
|
suffix = separated[7].replace(r'\n', ' ')
|
||
|
w1 = prefix.split()[-1]
|
||
|
w3 = suffix.split()[0]
|
||
|
w2 = candidates(w1, w3)
|
||
|
print(w1)
|
||
|
print(w2)
|
||
|
print(w3)
|
||
|
f.write(w2 + '\n')
|