This commit is contained in:
Krystian Wasilewski 2023-04-03 18:45:40 +02:00
parent 40e1c67559
commit 3aed881335
3 changed files with 18034 additions and 17933 deletions

101
cw4zad2.py Normal file
View File

@ -0,0 +1,101 @@
import lzma
import pickle
from collections import Counter
def words(filename):
with lzma.open(filename, mode='rt', encoding='utf-8') as fid:
for line in fid:
separated = line.split('\t')
prefix = separated[6].replace(r'\n', ' ')
suffix = separated[7].replace(r'\n', ' ')
text = prefix + ' ' + suffix
for word in text.split():
yield word
def bigrams(filename, V):
V = [word for word, count in V]
with lzma.open(filename, mode='rt', encoding='utf-8') as fid:
for line in fid:
separated = line.split('\t')
prefix = separated[6].replace(r'\n', ' ')
suffix = separated[7].replace(r'\n', ' ')
text = prefix + ' ' + suffix
previous = ''
for word in text.split():
if word in V and previous in V:
yield previous, word
previous = word
def P(previous_word, word):
if word not in V:
return 0
if (previous_word, word) not in V2:
return 0
return V2[(previous_word, word)] / V[word]
def candidates(w1, w3):
cand = {}
for w2 in V:
cand[w2] = P(w1, w2) * P(w2, w3)
cand = sorted(list(cand.items()), key=lambda x: x[1], reverse=True)[:5]
try:
norm = [(x[0], float(x[1]) / sum([y[1] for y in cand])) for x in cand]
except ZeroDivisionError:
norm = [(x[0], 0.2) for x in cand]
norm[-1] = ('', norm[-1][1])
return ' '.join([f'{x[0]}:{x[1]}' for x in norm])
# WORD_LIMIT = 5000
#
# V = Counter(words('train/in.tsv.xz'))
# V = V.most_common(WORD_LIMIT)
# with open('V.pickle', 'wb') as handle:
# pickle.dump(V, handle, protocol=pickle.HIGHEST_PROTOCOL)
# V2 = Counter(bigrams('train/in.tsv.xz', V))
# print(V2.most_common(100))
# with open('V2.pickle', 'wb') as handle:
# pickle.dump(V2, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open('V.pickle', 'rb') as handle:
V_tuple = pickle.load(handle)
V = {}
for key, value in V_tuple:
V[key] = value
with open('V2.pickle', 'rb') as handle:
V2 = pickle.load(handle)
with lzma.open('dev-0/in.tsv.xz', mode='rt', encoding='utf-8') as fid:
with open('dev-0/out.tsv', 'w', encoding='utf-8') as f:
for line in fid:
separated = line.split('\t')
prefix = separated[6].replace(r'\n', ' ')
suffix = separated[7].replace(r'\n', ' ')
w1 = prefix.split()[-1]
w3 = suffix.split()[0]
w2 = candidates(w1, w3)
print(w1)
print(w2)
print(w3)
f.write(w2 + '\n')
with lzma.open('test-A/in.tsv.xz', mode='rt', encoding='utf-8') as fid:
with open('test-A/out.tsv', 'w', encoding='utf-8') as f:
for line in fid:
separated = line.split('\t')
prefix = separated[6].replace(r'\n', ' ')
suffix = separated[7].replace(r'\n', ' ')
w1 = prefix.split()[-1]
w3 = suffix.split()[0]
w2 = candidates(w1, w3)
print(w1)
print(w2)
print(w3)
f.write(w2 + '\n')

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff