11 KiB
11 KiB
import lzma
import pickle
from tqdm import tqdm
from collections import Counter, defaultdict
def clean_text(line: str):
separated = line.split('\t')
prefix = separated[6].replace(r'\n', ' ').replace(' ', ' ').replace('.', '').replace(',', '').replace('?', '').replace('!', '')
suffix = separated[7].replace(r'\n', ' ').replace(' ', ' ').replace('.', '').replace(',', '').replace('?', '').replace('!', '')
return prefix + ' ' + suffix
def unigrams(filename):
with lzma.open(filename, mode='rt', encoding='utf-8') as fid:
with tqdm(total=432022) as pbar:
for line in fid:
text = clean_text(line)
for word in text.split():
yield word
pbar.update(1)
def bigrams(filename, V: dict):
with lzma.open(filename, mode='rt', encoding='utf-8') as fid:
pbar = tqdm(total=432022)
first_word = ''
for line in fid:
text = clean_text(line)
for second_word in text.split():
if V.get(second_word) is None:
second_word = 'UNK'
if second_word:
yield first_word, second_word
first_word = second_word
pbar.update(1)
pbar.close()
def trigrams(filename, V: dict):
with lzma.open(filename, mode='rt', encoding='utf-8') as fid:
print('Trigrams')
for line in tqdm(fid, total=432022):
text = clean_text(line)
words = text.split()
for i in range(len(words)-2):
trigram = tuple(V.get(word, 'UNK') for word in words[i:i+3])
yield trigram
def tetragrams(filename, V: dict):
with lzma.open(filename, mode='rt', encoding='utf-8') as fid:
print('Tetragrams')
for i, line in enumerate(tqdm(fid, total=432022)):
text = clean_text(line)
words = [V.get(word, 'UNK') for word in text.split()]
for first_word, second_word, third_word, fourth_word in zip(words, words[1:], words[2:], words[3:]):
yield first_word, second_word, third_word, fourth_word
def P(first_word, second_word=None, third_word=None, fourth_word=None):
if second_word is None:
return V_common_dict.get(first_word, 0) / total
elif third_word is None:
return V2_bigrams_dict.get((first_word, second_word), 0) / V_common_dict.get(first_word, 0)
elif fourth_word is None:
return V3_trigrams_dict.get((first_word, second_word, third_word), 0) / V2_bigrams_dict.get((first_word, second_word), 0)
else:
return V4_tatragrams_dict.get((first_word, second_word, third_word, fourth_word), 0) / V3__trigrams_dict.get((first_word, second_word, third_word), 0)
def compute_tetragram_probability(tetragram):
return 0.5 * P(*tetragram) + 0.35 * P(tetragram[1], *tetragram[2:]) + \
0.1 * P(tetragram[2], *tetragram[3:]) + 0.05 * P(tetragram[3])
def get_context(position, sentence):
context = []
for i in range(position-3, position):
if i < 0:
context.append('')
else:
context.append(sentence[i])
for i in range(position, position+4):
if i >= len(sentence):
context.append('')
else:
context.append(sentence[i])
return context
def compute_candidates(left_context, right_context):
candidate_probabilities = {}
for word in V_common_dict:
tetragram = left_context[-3:] + [word] + right_context[:3]
probability = compute_tetragram_probability(tetragram)
candidate_probabilities[word] = probability
sorted_candidates = sorted(candidate_probabilities.items(), key=lambda x: x[1], reverse=True)[:5]
total_probability = sum([c[1] for c in sorted_candidates])
normalized_candidates = [(c[0], c[1] / total_probability) for c in sorted_candidates]
for index, elem in enumerate(normalized_candidates):
if 'UNK' in elem:
normalized_candidates.pop(index)
normalized_candidates.append(('', elem[1]))
break
else:
normalized_candidates[-1] = ('', normalized_candidates[-1][1])
return ' '.join([f'{x[0]}:{x[1]}' for x in normalized_candidates])
def candidates(left_context, right_context):
left_context = [w if V_common_dict.get(w) else 'UNK' for w in left_context]
right_context = [w if V_common_dict.get(w) else 'UNK' for w in right_context]
return compute_candidates(left_context, right_context)
def create_vocab(filename, word_limit):
V = Counter(unigrams(filename))
V_common = V.most_common(word_limit)
UNK = sum(v for k, v in V.items() if k not in dict(V_common))
V_common_dict = dict(V_common)
V_common_dict['UNK'] = UNK
V_common_tuple = tuple((k, v) for k, v in V_common_dict.items())
with open('V.pickle', 'wb') as handle:
pickle.dump(V_common_tuple, handle, protocol=pickle.HIGHEST_PROTOCOL)
return V_common_dict
def load_pickle(filename):
with open(filename, 'rb') as handle:
return pickle.load(handle)
def save_pickle(obj, filename):
with open(filename, 'wb') as handle:
pickle.dump(obj, handle, protocol=pickle.HIGHEST_PROTOCOL)
create_vocab('train/in.tsv.xz', 1000)
with open('V.pickle', 'rb') as handle:
V_common_dict = pickle.load(handle)
total = sum(V_common_dict.values())
with open('V.pickle', 'rb') as handle:
V_common_tuple = pickle.load(handle)
V_common_dict = dict(V_common_tuple)
total = sum(V_common_dict.values())
V2_bigrams = Counter(bigrams('train/in.tsv.xz', V_common_tuple))
V2_bigrams_dict = dict(V2_bigrams)
save_pickle(V2_bigrams_dict,'V2_bigrams.pickle')
V2_bigrams_dict = load_pickle('V2_bigrams.pickle')
V2_bigrams = Counter(bigrams('train/in.tsv.xz', V_common_dict))
V2_bigrams_dict = dict(V2_bigrams)
save_pickle('V2_bigrams.pickle', V2_bigrams_dict)
V2_bigrams_dict = load_pickle('V2_bigrams.pickle')
V3_trigrams = Counter(trigrams('train/in.tsv.xz', V_common_dict))
V3_trigrams_dict = dict(V3_trigrams)
save_pickle(V3_trigrams_dict, 'V3_trigrams.pickle')
V3_trigrams_dict = load_pickle('V3_trigrams.pickle')
V4_tetragrams = Counter(tetragrams('train/in.tsv.xz', V_common_dict))
V4_tetragrams_dict = dict(V4_tetragrams)
save_pickle(V4_tetragrams_dict, 'V4_tetragrams.pickle')
V4_tetragrams_dict = load_pickle('V4_tetragrams.pickle')
def save_outs(folder_name):
with lzma.open(f'{folder_name}/in.tsv.xz', mode='rt', encoding='utf-8') as fid:
with open(f'{folder_name}/out.tsv', 'w', encoding='utf-8') as f:
for line in tqdm(fid):
separated = line.split('\t')
prefix = separated[6].replace(r'\n', ' ').replace(' ', ' ').replace('.', '').replace(',', '').replace('?', '').replace('!', '').split()
suffix = separated[7].replace(r'\n', ' ').replace(' ', ' ').replace('.', '').replace(',', '').replace('?', '').replace('!', '').split()
left_context = [x if V_common_dict.get(x) else 'UNK' for x in prefix[-3:]]
right_context = [x if V_common_dict.get(x) else 'UNK' for x in suffix[:3]]
w = candidates(left_context, right_context)
f.write(w + '\n')
save_outs('dev-0')
save_outs('test-A')