Updated outs
This commit is contained in:
parent
505e948550
commit
85b1e9509c
254
dev-0/out.tsv
254
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
26
predict.py
26
predict.py
@ -7,7 +7,7 @@ import sys
|
||||
import nltk
|
||||
from nltk.corpus import stopwords
|
||||
|
||||
def calc_post_class(post, paranormal_class_logprob, sceptic_class_logprob, bigrams_logprobs, words_logprobs):
|
||||
def calc_post_class(post, paranormal_class_logprob, sceptic_class_logprob, bigrams_logprobs, words_logprobs, total_sceptic_bigram, total_paranormal_bigram, total_sceptic_word, total_paranormal_word):
|
||||
text, timestap = post.rstrip('\n').split('\t')
|
||||
tokens = clear_post(text)
|
||||
#tokens = text.lower().split(' ')
|
||||
@ -22,11 +22,20 @@ def calc_post_class(post, paranormal_class_logprob, sceptic_class_logprob, bigra
|
||||
product += bigrams_logprobs[class_][bigram]
|
||||
except KeyError:
|
||||
product += 0
|
||||
# if class_ == 'sceptic':
|
||||
# product += math.log(1/total_sceptic_bigram)
|
||||
# elif class_ == 'paranormal':
|
||||
# product += math.log(1/total_paranormal_bigram)
|
||||
for token in tokens:
|
||||
try:
|
||||
product += words_logprobs[class_][token]
|
||||
product += words_logprobs[class_][token]/7
|
||||
except KeyError:
|
||||
product += 0
|
||||
product +=0
|
||||
#if class_ == 'sceptic':
|
||||
# product += math.log(1/total_sceptic_word)
|
||||
#elif class_ == 'paranormal':
|
||||
# product += math.log(1/total_paranormal_word)
|
||||
|
||||
if class_ == 'sceptic':
|
||||
product += sceptic_class_logprob
|
||||
elif class_ == 'paranormal':
|
||||
@ -65,10 +74,17 @@ def main():
|
||||
sceptic_class_logprob = pickle_list[1]
|
||||
bigrams_logprobs = pickle_list[2]
|
||||
words_logprobs = pickle_list[3]
|
||||
|
||||
total_sceptic_bigram = pickle_list[4]
|
||||
total_paranormal_bigram = pickle_list[5]
|
||||
total_sceptic_word = pickle_list[6]
|
||||
total_paranormal_word = pickle_list[7]
|
||||
print(math.log(1/total_sceptic_bigram))
|
||||
print(math.log(1/total_paranormal_bigram))
|
||||
print(math.log(1/total_sceptic_word))
|
||||
print(math.log(1/total_paranormal_word))
|
||||
with open(in_file) as in_f, open(out_file, 'w') as out_f:
|
||||
for line in in_f:
|
||||
hyp = calc_post_class(line, paranormal_class_logprob, sceptic_class_logprob, bigrams_logprobs, words_logprobs)
|
||||
hyp = calc_post_class(line, paranormal_class_logprob, sceptic_class_logprob, bigrams_logprobs, words_logprobs, total_sceptic_bigram, total_paranormal_bigram, total_sceptic_word, total_paranormal_word)
|
||||
if hyp == 'sceptic':
|
||||
out_f.write(' S\n')
|
||||
elif hyp == 'paranormal':
|
||||
|
264
test-A/out.tsv
264
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
6
train.py
6
train.py
@ -147,7 +147,11 @@ def main():
|
||||
bigrams_count, words_count = launch_bigrams_and_words(in_file, expected_file)
|
||||
bigram_logprobs = calc_bigram_logprobs(bigrams_count)
|
||||
word_logprobs = calc_word_logprobs(words_count)
|
||||
total_sceptic_bigram = sum(bigrams_count['sceptic'].values()) + len(bigrams_count['sceptic'].keys())
|
||||
total_paranormal_bigram = sum(bigrams_count['paranormal'].values()) + len(bigrams_count['paranormal'].keys())
|
||||
total_sceptic_word = sum(words_count['sceptic'].values()) + len(words_count['sceptic'].keys())
|
||||
total_paranormal_word = sum(words_count['paranormal'].values())+ len(words_count['paranormal'].keys())
|
||||
with open(model, 'wb') as f:
|
||||
pickle.dump([paranormal_class_logprob, sceptic_class_logprob, bigram_logprobs, word_logprobs],f)
|
||||
pickle.dump([paranormal_class_logprob, sceptic_class_logprob, bigram_logprobs, word_logprobs, total_sceptic_bigram, total_paranormal_bigram, total_sceptic_word, total_paranormal_word],f)
|
||||
main()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user