2020-03-22 10:15:36 +01:00
|
|
|
|
#!/usr/bin/python3
|
2020-04-04 22:07:48 +02:00
|
|
|
|
import re, sys, pickle, nltk, math, random
|
2020-03-29 23:29:19 +02:00
|
|
|
|
from nltk.corpus import stopwords
|
2020-03-22 10:15:36 +01:00
|
|
|
|
|
2020-03-29 13:39:47 +02:00
|
|
|
|
def clear_post(post):
|
|
|
|
|
post = post.replace('\\n', ' ')
|
2020-03-29 23:29:19 +02:00
|
|
|
|
post = post.lower()
|
2020-03-29 19:48:30 +02:00
|
|
|
|
post = re.sub(r'(\(|)(http|https|www)[a-zA-Z0-9\.\:\/\_\=\&\;\?\+\-\%]+(\)|)', ' internetlink ', post)
|
2020-03-29 14:28:07 +02:00
|
|
|
|
post = re.sub(r'[\.\,\/\~]+', ' ', post)
|
|
|
|
|
post = re.sub(r'(<|>|\@[a-zA-Z0-9]+)','',post)
|
2020-03-29 19:48:30 +02:00
|
|
|
|
post = re.sub(r'[\'\(\)\?\*\"\`\;0-9\[\]\:\%\|\–\”\!\=\^]+', '', post)
|
2020-03-29 14:28:07 +02:00
|
|
|
|
post = re.sub(r'( \- |\-\-+)', ' ', post)
|
2020-03-29 13:39:47 +02:00
|
|
|
|
post = re.sub(r' +', ' ', post)
|
|
|
|
|
post = post.rstrip(' ')
|
2020-03-29 23:29:19 +02:00
|
|
|
|
post = post.split(' ')
|
|
|
|
|
stop_words = set(stopwords.words('english'))
|
|
|
|
|
post_no_stop = [w for w in post if not w in stop_words]
|
|
|
|
|
return post_no_stop
|
2020-03-22 10:15:36 +01:00
|
|
|
|
|
2020-04-04 22:07:48 +02:00
|
|
|
|
# czy słowa musza byc setem?
|
|
|
|
|
def create_vocabulary_and_documents(in_file, expected_file):
|
|
|
|
|
vocabulary = set()
|
|
|
|
|
posts = {}
|
|
|
|
|
with open(in_file) as in_f, open(expected_file) as exp_f:
|
|
|
|
|
for line, exp in zip(in_f, exp_f):
|
2020-03-29 23:29:19 +02:00
|
|
|
|
text, timestap = line.rstrip('\n').split('\t')
|
2020-04-04 22:07:48 +02:00
|
|
|
|
post = clear_post(text)
|
|
|
|
|
posts[" ".join(post)] = int(exp)
|
|
|
|
|
for word in post:
|
|
|
|
|
vocabulary.add(word)
|
|
|
|
|
return vocabulary, posts
|
|
|
|
|
|
|
|
|
|
def create_mappings(vocabulary):
|
|
|
|
|
word_to_index_mapping = {}
|
|
|
|
|
index_to_word_mapping = {}
|
|
|
|
|
xi = 1
|
|
|
|
|
for word in vocabulary:
|
|
|
|
|
word_to_index_mapping[word] = xi
|
|
|
|
|
index_to_word_mapping[xi] = word
|
|
|
|
|
xi += 1
|
|
|
|
|
return word_to_index_mapping, index_to_word_mapping
|
2020-03-29 23:29:19 +02:00
|
|
|
|
|
2020-03-22 10:15:36 +01:00
|
|
|
|
def main():
|
2020-03-29 13:39:47 +02:00
|
|
|
|
if len(sys.argv) != 4:
|
2020-04-04 22:07:48 +02:00
|
|
|
|
print("syntax ./train.py model expected_file in_file")
|
2020-03-29 13:39:47 +02:00
|
|
|
|
return
|
2020-04-04 22:07:48 +02:00
|
|
|
|
model = str(sys.argv[1])
|
|
|
|
|
expected_file = str(sys.argv[2])
|
|
|
|
|
in_file = str(sys.argv[3])
|
|
|
|
|
vocabulary, posts = create_vocabulary_and_documents(in_file, expected_file)
|
|
|
|
|
word_to_index_mapping, index_to_word_mapping = create_mappings(vocabulary)
|
|
|
|
|
|
|
|
|
|
weights = []
|
|
|
|
|
for xi in range(0, len(vocabulary) + 1):
|
|
|
|
|
weights.append(random.uniform(-0.01,0.01))
|
|
|
|
|
|
|
|
|
|
learning_rate = 0.000001
|
|
|
|
|
loss_sum = 0.0
|
|
|
|
|
loss_sum_counter = 0
|
|
|
|
|
lowest_loss_sum_weights = []
|
|
|
|
|
lowest_loss_sum = 10000.0
|
|
|
|
|
|
|
|
|
|
print(f"len of vocabulary {len(vocabulary)}")
|
|
|
|
|
# mozna ustawić na bardzo bardzo duzo
|
|
|
|
|
while True: #loss_sum_counter != 10:
|
|
|
|
|
try:
|
|
|
|
|
d, y = random.choice(list(posts.items()))
|
|
|
|
|
y_hat = weights[0]
|
|
|
|
|
tokens = d.split(' ')
|
|
|
|
|
for word in tokens:
|
|
|
|
|
# mozna tez cos pomyslec z count aby lepiej dzialalo
|
|
|
|
|
#print(f"{d.count(word)} : {word}")
|
|
|
|
|
y_hat += weights[word_to_index_mapping[word]] * tokens.count(word)
|
|
|
|
|
|
|
|
|
|
loss = (y_hat - y)**2
|
|
|
|
|
loss_sum += loss
|
|
|
|
|
delta = (y_hat - y) * learning_rate
|
|
|
|
|
if loss_sum_counter % 100 == 0:
|
|
|
|
|
print(f"{loss_sum /1000} : {loss_sum_counter} : {y_hat} : {delta}")
|
|
|
|
|
loss_sum_counter = 0
|
|
|
|
|
loss_sum = 0
|
|
|
|
|
|
|
|
|
|
weights[0] -= delta
|
|
|
|
|
for word in tokens:
|
|
|
|
|
weights[word_to_index_mapping[word]] -= tokens.count(word) * delta
|
|
|
|
|
|
|
|
|
|
if lowest_loss_sum > loss_sum and loss_sum != 0:
|
|
|
|
|
print("it happened")
|
|
|
|
|
lowest_loss_sum = loss_sum
|
|
|
|
|
lowest_loss_sum_weights = weights
|
|
|
|
|
|
|
|
|
|
loss_sum_counter +=1
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
|
break
|
|
|
|
|
print(lowest_loss_sum_weights)
|
2020-03-22 10:15:36 +01:00
|
|
|
|
main()
|