From ca339fcfcc658d8c0728e7e9c4e8e0ab398007a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20J=C4=99dyk?= Date: Sat, 9 Apr 2022 14:54:19 +0200 Subject: [PATCH] change script for fine-tuning alpha --- run.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/run.py b/run.py index 02711de..e8ceefb 100644 --- a/run.py +++ b/run.py @@ -1,5 +1,6 @@ import pandas as pd import csv +import sys import regex as re from collections import Counter, defaultdict from nltk import trigrams, word_tokenize @@ -17,7 +18,7 @@ class Model(): self.vocab = set() def train(self, data): - for _, row in data.iterrows(): + for index, row in data.iterrows(): text = clean_text(str(row['text'])) words = word_tokenize(text) for w1, w2, w3 in trigrams(words, pad_right=True, pad_left=True): @@ -26,6 +27,9 @@ class Model(): self.vocab.add(w2) self.vocab.add(w3) self.probs[(w1, w3)][w2] += 1 + # limit number of data rows used for training + if index > 10000: + break for w1_w3 in self.probs: total_count = float(sum(self.probs[w1_w3].values())) @@ -46,15 +50,19 @@ class Model(): str_prediction += f'{word}:{prob} ' remaining_prob = 1 - total_prob - - if remaining_prob == 0: - remaining_prob = 0.01 str_prediction += f':{remaining_prob}' return str_prediction +# check arguments +if len(sys.argv) != 2: + print('Wrong number of arguments. Expected 1 - alpha smoothing parameter.') + quit() +else: + alpha = sys.argv[1] + # load training data train_data = pd.read_csv('train/in.tsv.xz', sep='\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE) train_labels = pd.read_csv('train/expected.tsv', sep='\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE) @@ -66,7 +74,7 @@ train_data['text'] = train_data[6] + train_data[0] + train_data[7] train_data = train_data[['text']] # init model with given aplha -model = Model(0.01) +model = Model(alpha) # train model probs model.train(train_data)