import lzma import pickle from collections import Counter def words(filename): with lzma.open(filename, mode='rt', encoding='utf-8') as fid: for line in fid: separated = line.split('\t') prefix = separated[6].replace(r'\n', ' ') suffix = separated[7].replace(r'\n', ' ') text = prefix + ' ' + suffix for word in text.split(): yield word def bigrams(filename, V): V = [word for word, count in V] with lzma.open(filename, mode='rt', encoding='utf-8') as fid: for line in fid: separated = line.split('\t') prefix = separated[6].replace(r'\n', ' ') suffix = separated[7].replace(r'\n', ' ') text = prefix + ' ' + suffix previous = '' for word in text.split(): if word in V and previous in V: yield previous, word previous = word def P(previous_word, word): if word not in V: return 0 if (previous_word, word) not in V2: return 0 return V2[(previous_word, word)] / V[word] def candidates(w1, w3): cand = {} for w2 in V: cand[w2] = P(w1, w2) * P(w2, w3) cand = sorted(list(cand.items()), key=lambda x: x[1], reverse=True)[:5] try: norm = [(x[0], float(x[1]) / sum([y[1] for y in cand])) for x in cand] except ZeroDivisionError: norm = [(x[0], 0.2) for x in cand] norm[-1] = ('', norm[-1][1]) return ' '.join([f'{x[0]}:{x[1]}' for x in norm]) # WORD_LIMIT = 5000 # # V = Counter(words('train/in.tsv.xz')) # V = V.most_common(WORD_LIMIT) # with open('V.pickle', 'wb') as handle: # pickle.dump(V, handle, protocol=pickle.HIGHEST_PROTOCOL) # V2 = Counter(bigrams('train/in.tsv.xz', V)) # print(V2.most_common(100)) # with open('V2.pickle', 'wb') as handle: # pickle.dump(V2, handle, protocol=pickle.HIGHEST_PROTOCOL) with open('V.pickle', 'rb') as handle: V_tuple = pickle.load(handle) V = {} for key, value in V_tuple: V[key] = value with open('V2.pickle', 'rb') as handle: V2 = pickle.load(handle) with lzma.open('dev-0/in.tsv.xz', mode='rt', encoding='utf-8') as fid: with open('dev-0/out.tsv', 'w', encoding='utf-8') as f: for line in fid: separated = line.split('\t') prefix = separated[6].replace(r'\n', ' ') suffix = separated[7].replace(r'\n', ' ') w1 = prefix.split()[-1] w3 = suffix.split()[0] w2 = candidates(w1, w3) print(w1) print(w2) print(w3) f.write(w2 + '\n') with lzma.open('test-A/in.tsv.xz', mode='rt', encoding='utf-8') as fid: with open('test-A/out.tsv', 'w', encoding='utf-8') as f: for line in fid: separated = line.split('\t') prefix = separated[6].replace(r'\n', ' ') suffix = separated[7].replace(r'\n', ' ') w1 = prefix.split()[-1] w3 = suffix.split()[0] w2 = candidates(w1, w3) print(w1) print(w2) print(w3) f.write(w2 + '\n')