paranormal-or-skeptic/train.py

111 lines
3.9 KiB
Python
Raw Normal View History

2020-03-22 10:15:36 +01:00
#!/usr/bin/python3
2020-04-06 10:41:14 +02:00
import re, sys, pickle, random
2020-03-29 23:29:19 +02:00
from nltk.corpus import stopwords
2020-03-22 10:15:36 +01:00
2020-03-29 13:39:47 +02:00
def clear_post(post):
post = post.replace('\\n', ' ')
2020-03-29 23:29:19 +02:00
post = post.lower()
2020-03-29 19:48:30 +02:00
post = re.sub(r'(\(|)(http|https|www)[a-zA-Z0-9\.\:\/\_\=\&\;\?\+\-\%]+(\)|)', ' internetlink ', post)
2020-03-29 14:28:07 +02:00
post = re.sub(r'[\.\,\/\~]+', ' ', post)
post = re.sub(r'(&lt|&gt|\@[a-zA-Z0-9]+)','',post)
2020-03-29 19:48:30 +02:00
post = re.sub(r'[\'\(\)\?\*\"\`\;0-9\[\]\:\%\|\\\!\=\^]+', '', post)
2020-03-29 14:28:07 +02:00
post = re.sub(r'( \- |\-\-+)', ' ', post)
2020-03-29 13:39:47 +02:00
post = re.sub(r' +', ' ', post)
post = post.rstrip(' ')
2020-03-29 23:29:19 +02:00
post = post.split(' ')
stop_words = set(stopwords.words('english'))
post_no_stop = [w for w in post if not w in stop_words]
return post_no_stop
2020-03-22 10:15:36 +01:00
2020-04-04 22:07:48 +02:00
# czy słowa musza byc setem?
def create_vocabulary_and_documents(in_file, expected_file):
vocabulary = set()
posts = {}
with open(in_file) as in_f, open(expected_file) as exp_f:
for line, exp in zip(in_f, exp_f):
2020-03-29 23:29:19 +02:00
text, timestap = line.rstrip('\n').split('\t')
2020-04-04 22:07:48 +02:00
post = clear_post(text)
posts[" ".join(post)] = int(exp)
for word in post:
vocabulary.add(word)
2020-04-06 10:41:14 +02:00
with open('data', 'wb') as f:
pickle.dump([vocabulary, posts], f)
print("data created")
2020-04-04 22:07:48 +02:00
return vocabulary, posts
def create_mappings(vocabulary):
word_to_index_mapping = {}
index_to_word_mapping = {}
xi = 1
for word in vocabulary:
word_to_index_mapping[word] = xi
index_to_word_mapping[xi] = word
xi += 1
return word_to_index_mapping, index_to_word_mapping
2020-03-29 23:29:19 +02:00
2020-03-22 10:15:36 +01:00
def main():
2020-03-29 13:39:47 +02:00
if len(sys.argv) != 4:
2020-04-04 22:07:48 +02:00
print("syntax ./train.py model expected_file in_file")
2020-03-29 13:39:47 +02:00
return
2020-04-04 22:07:48 +02:00
model = str(sys.argv[1])
expected_file = str(sys.argv[2])
in_file = str(sys.argv[3])
2020-04-06 10:41:14 +02:00
try:
with open("data", 'rb') as pos:
pickle_list = pickle.load(pos)
print("data loaded")
vocabulary = pickle_list[0]
posts = pickle_list[1]
except FileNotFoundError:
vocabulary, posts = create_vocabulary_and_documents(in_file, expected_file)
2020-04-04 22:07:48 +02:00
word_to_index_mapping, index_to_word_mapping = create_mappings(vocabulary)
weights = []
for xi in range(0, len(vocabulary) + 1):
weights.append(random.uniform(-0.01,0.01))
2020-04-06 10:41:14 +02:00
learning_rate = 0.000000001
2020-04-04 22:07:48 +02:00
loss_sum = 0.0
loss_sum_counter = 0
lowest_loss_sum_weights = []
lowest_loss_sum = 10000.0
print(f"len of vocabulary {len(vocabulary)}")
# mozna ustawić na bardzo bardzo duzo
2020-04-06 10:41:14 +02:00
while loss_sum_counter != 10000:
2020-04-04 22:07:48 +02:00
try:
d, y = random.choice(list(posts.items()))
y_hat = weights[0]
tokens = d.split(' ')
for word in tokens:
# mozna tez cos pomyslec z count aby lepiej dzialalo
#print(f"{d.count(word)} : {word}")
y_hat += weights[word_to_index_mapping[word]] * tokens.count(word)
2020-04-06 10:41:14 +02:00
#print(f"{weights[word_to_index_mapping[word]]} : {word}")
2020-04-04 22:07:48 +02:00
loss = (y_hat - y)**2
loss_sum += loss
delta = (y_hat - y) * learning_rate
if loss_sum_counter % 100 == 0:
2020-04-06 10:41:14 +02:00
print(f"{loss_sum_counter} : {loss_sum /1000} : {y_hat} : {delta} : {lowest_loss_sum}")
#loss_sum_counter = 0
2020-04-04 22:07:48 +02:00
loss_sum = 0
weights[0] -= delta
for word in tokens:
weights[word_to_index_mapping[word]] -= tokens.count(word) * delta
if lowest_loss_sum > loss_sum and loss_sum != 0:
2020-04-06 10:41:14 +02:00
print(f"it happened, new lowest_sum {loss_sum}")
2020-04-04 22:07:48 +02:00
lowest_loss_sum = loss_sum
lowest_loss_sum_weights = weights
loss_sum_counter +=1
except KeyboardInterrupt:
break
2020-04-06 10:41:14 +02:00
#print(lowest_loss_sum_weights)
with open(model, 'wb') as f:
pickle.dump([weights, lowest_loss_sum_weights, word_to_index_mapping], f)
2020-03-22 10:15:36 +01:00
main()