import sys import random from tqdm import tqdm from collections import defaultdict import pickle import os corpus = [] with open('train/in.tsv', 'r') as f: print('Reading corpus...') for line in tqdm(f): ctx = line.split('\t')[6:] corpus.append(ctx[0] + 'BLANK' + ctx[1]) corpus = ' '.join(corpus) corpus = corpus.replace('-\n', '') corpus = corpus.replace('\\n', ' ') corpus = corpus.replace('\n', ' ') corpus = corpus.split(' ') if (os.path.exists('distrib.pkl')): print('Loading distribution...') distrib = pickle.load(open('distrib.pkl', 'rb')) else: print('Generating distribution...') distrib = defaultdict(lambda: defaultdict(int)) for i in tqdm(range(len(corpus) - 1)): distrib[corpus[i]][corpus[i+1]] += 1 with open('distrib.pkl', 'wb') as f: print('Saving distribution...') pickle.dump(dict(distrib), f) results = [] with open('dev-0/in.tsv', 'r') as f: print('Generating output...') for line in tqdm(f): ctx = line.split('\t')[6:] last_word = ctx[0].split(' ')[-1] try: blank_word = max(distrib[last_word], key=distrib[last_word].get) except: blank_word = 'NONE' results.append(blank_word) with open('dev-0/out.tsv', 'w') as f: print('Writing output...') for result in tqdm(results): if result == 'NONE': f.write('a:0.6 the:0.2 :0.2') else: f.write(f'{result}:0.9 :0.1')