cw5 tetragram
This commit is contained in:
parent
3aed881335
commit
b7e27a1f1d
182
cw5tetragram.py
Normal file
182
cw5tetragram.py
Normal file
@ -0,0 +1,182 @@
|
||||
import lzma
|
||||
import pickle
|
||||
|
||||
|
||||
def clean_line(line: str):
|
||||
separated = line.split('\t')
|
||||
prefix = separated[6].replace(r'\n', ' ')
|
||||
suffix = separated[7].replace(r'\n', ' ')
|
||||
return prefix + ' ' + suffix
|
||||
|
||||
|
||||
def words(filename):
|
||||
with lzma.open(filename, mode='rt', encoding='utf-8') as fid:
|
||||
index = 1
|
||||
print('Words')
|
||||
for line in fid:
|
||||
print(f'\rProgress: {(index / 432022 * 100):2f}%', end='')
|
||||
text = clean_line(line)
|
||||
for word in text.split():
|
||||
yield word
|
||||
index += 1
|
||||
print()
|
||||
|
||||
|
||||
def bigrams(filename, V: dict):
|
||||
with lzma.open(filename, mode='rt', encoding='utf-8') as fid:
|
||||
index = 1
|
||||
print('Bigrams')
|
||||
for line in fid:
|
||||
print(f'\rProgress: {(index / 432022 * 100):2f}%', end='')
|
||||
text = clean_line(line)
|
||||
first_word = ''
|
||||
for second_word in text.split():
|
||||
if V.get(second_word) is None:
|
||||
second_word = 'UNK'
|
||||
if second_word:
|
||||
yield first_word, second_word
|
||||
first_word = second_word
|
||||
index += 1
|
||||
print()
|
||||
|
||||
|
||||
def trigrams(filename, V: dict):
|
||||
with lzma.open(filename, mode='rt', encoding='utf-8') as fid:
|
||||
index = 1
|
||||
print('Trigrams')
|
||||
for line in fid:
|
||||
print(f'\rProgress: {(index / 432022 * 100):2f}%', end='')
|
||||
text = clean_line(line)
|
||||
first_word = ''
|
||||
second_word = ''
|
||||
for third_word in text.split():
|
||||
if V.get(third_word) is None:
|
||||
third_word = 'UNK'
|
||||
if first_word:
|
||||
yield first_word, second_word, third_word
|
||||
first_word = second_word
|
||||
second_word = third_word
|
||||
index += 1
|
||||
print()
|
||||
|
||||
|
||||
def tetragrams(filename, V: dict):
|
||||
with lzma.open(filename, mode='rt', encoding='utf-8') as fid:
|
||||
index = 1
|
||||
print('Tetragrams')
|
||||
for line in fid:
|
||||
print(f'\rProgress: {(index / 432022 * 100):2f}%', end='')
|
||||
text = clean_line(line)
|
||||
first_word = ''
|
||||
second_word = ''
|
||||
third_word = ''
|
||||
for fourth_word in text.split():
|
||||
if V.get(fourth_word) is None:
|
||||
fourth_word = 'UNK'
|
||||
if first_word:
|
||||
yield first_word, second_word, third_word, fourth_word
|
||||
first_word = second_word
|
||||
second_word = third_word
|
||||
third_word = fourth_word
|
||||
index += 1
|
||||
print()
|
||||
|
||||
|
||||
def P(first_word, second_word=None, third_word=None, fourth_word=None):
|
||||
try:
|
||||
if second_word is None:
|
||||
return V_common_dict[first_word] / total
|
||||
if third_word is None:
|
||||
return V2_dict[(first_word, second_word)] / V_common_dict[first_word]
|
||||
if fourth_word is None:
|
||||
return V3_dict[(first_word, second_word, third_word)] / V2_dict[(first_word, second_word)]
|
||||
else:
|
||||
return V4_dict[(first_word, second_word, third_word, fourth_word)] / V3_dict[
|
||||
(first_word, second_word, third_word)]
|
||||
except KeyError:
|
||||
return 0
|
||||
|
||||
|
||||
def smoothed(tetragram):
|
||||
first, second, third, fourth = tetragram
|
||||
return 0.45 * P(first, second, third, fourth) + 0.30 * P(second, third, fourth) + 0.15 * P(third, fourth) + 0.1 * P(
|
||||
fourth)
|
||||
|
||||
|
||||
def candidates(left_context, right_context):
|
||||
cand = {}
|
||||
first, second, third = left_context
|
||||
fifth, sixth, seventh = right_context
|
||||
for word in V_common_dict:
|
||||
p1 = smoothed((first, second, third, word))
|
||||
p2 = smoothed((second, third, word, fifth))
|
||||
p3 = smoothed((third, word, fifth, sixth))
|
||||
p4 = smoothed((word, fifth, sixth, seventh))
|
||||
cand[word] = p1 * p2 * p3 * p4
|
||||
cand = sorted(list(cand.items()), key=lambda x: x[1], reverse=True)[:5]
|
||||
norm = [(x[0], float(x[1]) / sum([y[1] for y in cand])) for x in cand]
|
||||
for index, elem in enumerate(norm):
|
||||
if 'UNK' in elem:
|
||||
unk = norm.pop(index)
|
||||
norm.append(('', unk[1]))
|
||||
break
|
||||
return ' '.join([f'{x[0]}:{x[1]}' for x in norm])
|
||||
|
||||
|
||||
def create_outputs(folder_name):
|
||||
print(f'Creating outputs in {folder_name}')
|
||||
with lzma.open(f'{folder_name}/in.tsv.xz', mode='rt', encoding='utf-8') as fid:
|
||||
with open(f'{folder_name}/out.tsv', 'w', encoding='utf-8') as f:
|
||||
for line in fid:
|
||||
separated = line.split('\t')
|
||||
prefix = separated[6].replace(r'\n', ' ').split()
|
||||
suffix = separated[7].replace(r'\n', ' ').split()
|
||||
left_context = [x if V_common_dict.get(x) else 'UNK' for x in prefix[-3:]]
|
||||
right_context = [x if V_common_dict.get(x) else 'UNK' for x in suffix[:3]]
|
||||
w = candidates(left_context, right_context)
|
||||
f.write(w + '\n')
|
||||
|
||||
|
||||
# WORD_LIMIT = 3000
|
||||
# V = Counter(words('train/in.tsv.xz'))
|
||||
# V_common_dict = dict(V.most_common(WORD_LIMIT))
|
||||
# UNK = 0
|
||||
# for key, value in V.items():
|
||||
# if V_common_dict.get(key) is None:
|
||||
# UNK += value
|
||||
# V_common_dict['UNK'] = UNK
|
||||
# with open('V.pickle', 'wb') as handle:
|
||||
# pickle.dump(V_common_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
|
||||
with open('V.pickle', 'rb') as handle:
|
||||
V_common_dict = pickle.load(handle)
|
||||
|
||||
total = sum(V_common_dict.values())
|
||||
|
||||
# V2 = Counter(bigrams('train/in.tsv.xz', V_common_dict))
|
||||
# V2_dict = dict(V2)
|
||||
# with open('V2.pickle', 'wb') as handle:
|
||||
# pickle.dump(V2_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
with open('V2.pickle', 'rb') as handle:
|
||||
V2_dict = pickle.load(handle)
|
||||
|
||||
# V3 = Counter(trigrams('train/in.tsv.xz', V_common_dict))
|
||||
# V3_dict = dict(V3)
|
||||
# with open('V3.pickle', 'wb') as handle:
|
||||
# pickle.dump(V3_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
with open('V3.pickle', 'rb') as handle:
|
||||
V3_dict = pickle.load(handle)
|
||||
|
||||
# V4 = Counter(tetragrams('train/in.tsv.xz', V_common_dict))
|
||||
# V4_dict = dict(V4)
|
||||
# with open('V4.pickle', 'wb') as handle:
|
||||
# pickle.dump(V4_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
with open('V4.pickle', 'rb') as handle:
|
||||
V4_dict = pickle.load(handle)
|
||||
|
||||
create_outputs('dev-0')
|
||||
create_outputs('test-A')
|
21038
dev-0/out.tsv
21038
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
14828
test-A/out.tsv
14828
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user