paranormal-or-skeptic/train.py
2020-04-06 10:41:14 +02:00

111 lines
3.9 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/python3
import re, sys, pickle, random
from nltk.corpus import stopwords
def clear_post(post):
post = post.replace('\\n', ' ')
post = post.lower()
post = re.sub(r'(\(|)(http|https|www)[a-zA-Z0-9\.\:\/\_\=\&\;\?\+\-\%]+(\)|)', ' internetlink ', post)
post = re.sub(r'[\.\,\/\~]+', ' ', post)
post = re.sub(r'(&lt|&gt|\@[a-zA-Z0-9]+)','',post)
post = re.sub(r'[\'\(\)\?\*\"\`\;0-9\[\]\:\%\|\\\!\=\^]+', '', post)
post = re.sub(r'( \- |\-\-+)', ' ', post)
post = re.sub(r' +', ' ', post)
post = post.rstrip(' ')
post = post.split(' ')
stop_words = set(stopwords.words('english'))
post_no_stop = [w for w in post if not w in stop_words]
return post_no_stop
# czy słowa musza byc setem?
def create_vocabulary_and_documents(in_file, expected_file):
vocabulary = set()
posts = {}
with open(in_file) as in_f, open(expected_file) as exp_f:
for line, exp in zip(in_f, exp_f):
text, timestap = line.rstrip('\n').split('\t')
post = clear_post(text)
posts[" ".join(post)] = int(exp)
for word in post:
vocabulary.add(word)
with open('data', 'wb') as f:
pickle.dump([vocabulary, posts], f)
print("data created")
return vocabulary, posts
def create_mappings(vocabulary):
word_to_index_mapping = {}
index_to_word_mapping = {}
xi = 1
for word in vocabulary:
word_to_index_mapping[word] = xi
index_to_word_mapping[xi] = word
xi += 1
return word_to_index_mapping, index_to_word_mapping
def main():
if len(sys.argv) != 4:
print("syntax ./train.py model expected_file in_file")
return
model = str(sys.argv[1])
expected_file = str(sys.argv[2])
in_file = str(sys.argv[3])
try:
with open("data", 'rb') as pos:
pickle_list = pickle.load(pos)
print("data loaded")
vocabulary = pickle_list[0]
posts = pickle_list[1]
except FileNotFoundError:
vocabulary, posts = create_vocabulary_and_documents(in_file, expected_file)
word_to_index_mapping, index_to_word_mapping = create_mappings(vocabulary)
weights = []
for xi in range(0, len(vocabulary) + 1):
weights.append(random.uniform(-0.01,0.01))
learning_rate = 0.000000001
loss_sum = 0.0
loss_sum_counter = 0
lowest_loss_sum_weights = []
lowest_loss_sum = 10000.0
print(f"len of vocabulary {len(vocabulary)}")
# mozna ustawić na bardzo bardzo duzo
while loss_sum_counter != 10000:
try:
d, y = random.choice(list(posts.items()))
y_hat = weights[0]
tokens = d.split(' ')
for word in tokens:
# mozna tez cos pomyslec z count aby lepiej dzialalo
#print(f"{d.count(word)} : {word}")
y_hat += weights[word_to_index_mapping[word]] * tokens.count(word)
#print(f"{weights[word_to_index_mapping[word]]} : {word}")
loss = (y_hat - y)**2
loss_sum += loss
delta = (y_hat - y) * learning_rate
if loss_sum_counter % 100 == 0:
print(f"{loss_sum_counter} : {loss_sum /1000} : {y_hat} : {delta} : {lowest_loss_sum}")
#loss_sum_counter = 0
loss_sum = 0
weights[0] -= delta
for word in tokens:
weights[word_to_index_mapping[word]] -= tokens.count(word) * delta
if lowest_loss_sum > loss_sum and loss_sum != 0:
print(f"it happened, new lowest_sum {loss_sum}")
lowest_loss_sum = loss_sum
lowest_loss_sum_weights = weights
loss_sum_counter +=1
except KeyboardInterrupt:
break
#print(lowest_loss_sum_weights)
with open(model, 'wb') as f:
pickle.dump([weights, lowest_loss_sum_weights, word_to_index_mapping], f)
main()