paranormal-or-skeptic/linearregression.py
2020-04-05 16:28:00 +02:00

88 lines
2.8 KiB
Python

import csv
import re
import random
import json
# Prints ['Hey', 'you', 'what', 'are', 'you', 'doing', 'here']
def make_dict(path):
dict = {}
with open(path) as in_file:
for line in in_file:
post = (line.split('\t')[0])
for word in re.findall(r"[\w']+", post):
if not word in dict:
weight = round(random.random()%0.2-0.1,2)
dict[word] = weight
return dict
def train_model(in_path, exp_path):
dict = make_dict(in_path)
w0 = 0.1
lr = 0.0001
with open(in_path) as in_file, open(exp_path) as exp_file:
for in_line, exp_line in zip(in_file, exp_file):
print("new post" + str(random.randint(0,10)))
post = (in_line.split('\t')[0])
delta = 1
y=0
y_plus = 0
y_minus = 0
while delta > 0.5:
for word in re.findall(r"[\w']+", post):
y += dict[word]
y_plus += dict[word] + lr
y_minus += dict[word] - lr
delta = abs(int(exp_line) - y+w0)
delta_minus = abs(int(exp_line) - y_minus+w0)
delta_plus = abs(int(exp_line) - y_plus+w0)
if delta_minus < delta:
delta = delta_minus
for word in re.findall(r"[\w']+", post):
dict[word] = dict[word] - lr
elif delta_plus < delta:
delta = delta_plus
for word in re.findall(r"[\w']+", post):
dict[word] = dict[word] + lr
else:
break
with open('dict.txt', 'w') as file:
json.dump(dict, file)
def predict(path):
results = []
with open('dict.txt', 'r') as file:
dict = json.load(file)
with open(path+"/in.tsv") as in_file:
for in_line in in_file:
print("new post" + str(random.randint(0,10)))
post = (in_line.split('\t')[0])
y=0
for word in re.findall(r"[\w']+", post):
if word in dict:
y += dict[word]
if y > 0.5:
results.append("1")
else:
results.append("0")
with open(path+"/out.tsv", 'wt') as tsvfile:
tsv_writer = csv.writer(tsvfile, delimiter='\t')
for i in results:
tsv_writer.writerow(i)
predict("test-A")
def check_dev():
with open("dev-0/out.tsv") as out_file, open("dev-0/expected.tsv") as exp_file:
counter = 0
positive = 0
for out_line, exp_line in zip(out_file, exp_file):
counter+=1
if out_line == exp_line:
positive += 1
print(positive/counter)