Regression

This commit is contained in:
Bartusiak 2020-04-06 19:11:16 +02:00
parent 72f56d6b42
commit 0e0d33afb4
4 changed files with 294932 additions and 294900 deletions

View File

@ -9,6 +9,8 @@ import random
vocabulary = [] vocabulary = []
file_to_save = open("test.tsv", "w", encoding='utf-8') file_to_save = open("test.tsv", "w", encoding='utf-8')
def define_vocabulary(file_to_learn_new_words): def define_vocabulary(file_to_learn_new_words):
word_counts = {'count': defaultdict(int)} word_counts = {'count': defaultdict(int)}
with open(file_to_learn_new_words, encoding='utf-8') as in_file: with open(file_to_learn_new_words, encoding='utf-8') as in_file:
@ -19,6 +21,7 @@ def define_vocabulary(file_to_learn_new_words):
word_counts['count'][token] += 1 word_counts['count'][token] += 1
return word_counts return word_counts
def read_input(file_path): def read_input(file_path):
read_word_counts = {'count': defaultdict(int)} read_word_counts = {'count': defaultdict(int)}
with open(file_path, encoding='utf-8') as in_file: with open(file_path, encoding='utf-8') as in_file:
@ -29,7 +32,9 @@ def read_input(file_path):
read_word_counts['count'][token] += 1 read_word_counts['count'][token] += 1
return read_word_counts return read_word_counts
def training(vocabulary, read_input, expected): def training(vocabulary, read_input, expected):
file_to_write = open(expected, 'w+', encoding='utf8')
learning_rate = 0.00001 learning_rate = 0.00001
learning_precision = 0.0000001 learning_precision = 0.0000001
weights = [] weights = []
@ -46,7 +51,7 @@ def training(vocabulary,read_input,expected):
# max_iteration=len(vocabulary['count'])+1 # max_iteration=len(vocabulary['count'])+1
max_iteration = 1000 max_iteration = 1000
delta = 1 delta = 1
while (delta>learning_precision and iteration<max_iteration): while delta > learning_precision and iteration < max_iteration:
d, y = random.choice(list(read_input['count'].items())) # d-word, y-value of d, y = random.choice(list(read_input['count'].items())) # d-word, y-value of
y_hat = weights[0] y_hat = weights[0]
i = 0 i = 0
@ -54,10 +59,14 @@ def training(vocabulary,read_input,expected):
if word_d in vocabulary['count'].keys(): if word_d in vocabulary['count'].keys():
# print(vocabulary['count'][d]) # print(vocabulary['count'][d])
y_hat += weights[vocabulary['count'][word_d]] * readed_words_values[i] y_hat += weights[vocabulary['count'][word_d]] * readed_words_values[i]
delta=abs(y_hat-y)*learning_rate i += 1
weights[0]=weights[0]-delta if y_hat > 0.0:
i+=i file_to_write.write('1\n')
else:
file_to_write.write('0\n')
i = 0 i = 0
delta = (y_hat - y) * learning_rate
weights[0] = weights[0] - delta
for word_w in d: for word_w in d:
if word_w in vocabulary['count'].keys(): if word_w in vocabulary['count'].keys():
weights[vocabulary['count'][word_w]] -= readed_words_values[i] * delta weights[vocabulary['count'][word_w]] -= readed_words_values[i] * delta
@ -73,10 +82,16 @@ def training(vocabulary,read_input,expected):
iteration = 0 iteration = 0
loss_sum = 0.0 loss_sum = 0.0
iteration += 1 iteration += 1
file_to_write.close
return weights, vocabulary
def main(): def main():
vocabulary = define_vocabulary('train/in.tsv') vocabulary = define_vocabulary('train/in.tsv')
readed_words = read_input('dev-0/in.tsv') readed_words = read_input('dev-0/in.tsv')
readed_words_test_a = read_input('test-A/in.tsv/in.tsv')
training(vocabulary, readed_words, 'test.tsv') training(vocabulary, readed_words, 'test.tsv')
training(vocabulary,readed_words_test_a, 'test_a.tsv')
# def cost_function(y_hat,y): # def cost_function(y_hat,y):
@ -88,7 +103,6 @@ def main():
# loss_sum=0.0 # loss_sum=0.0
# def main(): # def main():
# --------------- initialization --------------------------------- # --------------- initialization ---------------------------------
# vocabulary = define_vocabulary('train/in.tsv') # vocabulary = define_vocabulary('train/in.tsv')
@ -126,4 +140,3 @@ def main():
main() main()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,19 @@
1
1
1
1
1
1
1
1
1
1
0
0
0
0
0
0
0
0
0
1 1
1 1
2 1
3 1
4 1
5 1
6 1
7 1
8 1
9 1
10 1
11 0
12 0
13 0
14 0
15 0
16 0
17 0
18 0
19 0

File diff suppressed because it is too large Load Diff