diff --git a/lm1.py b/lm1.py new file mode 100644 index 0000000..8ee9eda --- /dev/null +++ b/lm1.py @@ -0,0 +1,56 @@ +import sys +import random +from tqdm import tqdm +from collections import defaultdict +import pickle +import os + +corpus = [] + +with open('train/in.tsv', 'r') as f: + print('Reading corpus...') + for line in tqdm(f): + ctx = line.split('\t')[6:] + + corpus.append(ctx[0] + 'BLANK' + ctx[1]) + +corpus = ' '.join(corpus) +corpus = corpus.replace('-\n', '') +corpus = corpus.replace('\\n', ' ') +corpus = corpus.replace('\n', ' ') +corpus = corpus.split(' ') + +if (os.path.exists('distrib.pkl')): + print('Loading distribution...') + distrib = pickle.load(open('distrib.pkl', 'rb')) +else: + print('Generating distribution...') + distrib = defaultdict(lambda: defaultdict(int)) + for i in tqdm(range(len(corpus) - 1)): + distrib[corpus[i]][corpus[i+1]] += 1 + + with open('distrib.pkl', 'wb') as f: + print('Saving distribution...') + pickle.dump(dict(distrib), f) + +results = [] +with open('dev-0/in.tsv', 'r') as f: + print('Generating output...') + for line in tqdm(f): + ctx = line.split('\t')[6:] + last_word = ctx[0].split(' ')[-1] + try: + blank_word = max(distrib[last_word], key=distrib[last_word].get) + except: + blank_word = 'NONE' + results.append(blank_word) + +with open('dev-0/out.tsv', 'w') as f: + print('Writing output...') + for result in tqdm(results): + if result == 'NONE': + f.write('a:0.6 the:0.2 :0.2') + else: + f.write(f'{result}:0.9 :0.1') + + \ No newline at end of file