cw4 zad2
This commit is contained in:
parent
40e1c67559
commit
3aed881335
101
cw4zad2.py
Normal file
101
cw4zad2.py
Normal file
@ -0,0 +1,101 @@
|
||||
import lzma
|
||||
import pickle
|
||||
from collections import Counter
|
||||
|
||||
|
||||
def words(filename):
|
||||
with lzma.open(filename, mode='rt', encoding='utf-8') as fid:
|
||||
for line in fid:
|
||||
separated = line.split('\t')
|
||||
prefix = separated[6].replace(r'\n', ' ')
|
||||
suffix = separated[7].replace(r'\n', ' ')
|
||||
text = prefix + ' ' + suffix
|
||||
for word in text.split():
|
||||
yield word
|
||||
|
||||
|
||||
def bigrams(filename, V):
|
||||
V = [word for word, count in V]
|
||||
with lzma.open(filename, mode='rt', encoding='utf-8') as fid:
|
||||
for line in fid:
|
||||
separated = line.split('\t')
|
||||
prefix = separated[6].replace(r'\n', ' ')
|
||||
suffix = separated[7].replace(r'\n', ' ')
|
||||
text = prefix + ' ' + suffix
|
||||
previous = ''
|
||||
for word in text.split():
|
||||
if word in V and previous in V:
|
||||
yield previous, word
|
||||
previous = word
|
||||
|
||||
|
||||
def P(previous_word, word):
|
||||
if word not in V:
|
||||
return 0
|
||||
if (previous_word, word) not in V2:
|
||||
return 0
|
||||
return V2[(previous_word, word)] / V[word]
|
||||
|
||||
|
||||
def candidates(w1, w3):
|
||||
cand = {}
|
||||
for w2 in V:
|
||||
cand[w2] = P(w1, w2) * P(w2, w3)
|
||||
cand = sorted(list(cand.items()), key=lambda x: x[1], reverse=True)[:5]
|
||||
try:
|
||||
norm = [(x[0], float(x[1]) / sum([y[1] for y in cand])) for x in cand]
|
||||
except ZeroDivisionError:
|
||||
norm = [(x[0], 0.2) for x in cand]
|
||||
norm[-1] = ('', norm[-1][1])
|
||||
return ' '.join([f'{x[0]}:{x[1]}' for x in norm])
|
||||
|
||||
# WORD_LIMIT = 5000
|
||||
#
|
||||
# V = Counter(words('train/in.tsv.xz'))
|
||||
# V = V.most_common(WORD_LIMIT)
|
||||
# with open('V.pickle', 'wb') as handle:
|
||||
# pickle.dump(V, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
# V2 = Counter(bigrams('train/in.tsv.xz', V))
|
||||
# print(V2.most_common(100))
|
||||
# with open('V2.pickle', 'wb') as handle:
|
||||
# pickle.dump(V2, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
|
||||
with open('V.pickle', 'rb') as handle:
|
||||
V_tuple = pickle.load(handle)
|
||||
V = {}
|
||||
for key, value in V_tuple:
|
||||
V[key] = value
|
||||
|
||||
with open('V2.pickle', 'rb') as handle:
|
||||
V2 = pickle.load(handle)
|
||||
|
||||
|
||||
|
||||
with lzma.open('dev-0/in.tsv.xz', mode='rt', encoding='utf-8') as fid:
|
||||
with open('dev-0/out.tsv', 'w', encoding='utf-8') as f:
|
||||
for line in fid:
|
||||
separated = line.split('\t')
|
||||
prefix = separated[6].replace(r'\n', ' ')
|
||||
suffix = separated[7].replace(r'\n', ' ')
|
||||
w1 = prefix.split()[-1]
|
||||
w3 = suffix.split()[0]
|
||||
w2 = candidates(w1, w3)
|
||||
print(w1)
|
||||
print(w2)
|
||||
print(w3)
|
||||
f.write(w2 + '\n')
|
||||
|
||||
with lzma.open('test-A/in.tsv.xz', mode='rt', encoding='utf-8') as fid:
|
||||
with open('test-A/out.tsv', 'w', encoding='utf-8') as f:
|
||||
for line in fid:
|
||||
separated = line.split('\t')
|
||||
prefix = separated[6].replace(r'\n', ' ')
|
||||
suffix = separated[7].replace(r'\n', ' ')
|
||||
w1 = prefix.split()[-1]
|
||||
w3 = suffix.split()[0]
|
||||
w2 = candidates(w1, w3)
|
||||
print(w1)
|
||||
print(w2)
|
||||
print(w3)
|
||||
f.write(w2 + '\n')
|
21038
dev-0/out.tsv
21038
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
14828
test-A/out.tsv
14828
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user