61 lines
1.8 KiB
Plaintext
61 lines
1.8 KiB
Plaintext
|
import pickle
|
||
|
import re
|
||
|
|
||
|
|
||
|
def calculate_words(linetxt):
|
||
|
word_counts = {}
|
||
|
tokens = linetxt.split(' ')
|
||
|
for token in tokens:
|
||
|
if token in word_counts.keys():
|
||
|
word_counts[token]+=1
|
||
|
else:
|
||
|
word_counts[token]=1
|
||
|
word_counts[''] = 1
|
||
|
return word_counts
|
||
|
|
||
|
def tokenize_list(string_input):
|
||
|
string=string_input.replace('\\n',' ')
|
||
|
text = re.sub(r'\w+:\/{2}[\d\w-]+(\.[\d\w-]+)*(?:(?:\/[^\s/]*))*', '', string)
|
||
|
text = re.sub(r'\\n+', " ", text)
|
||
|
text = re.sub(r'http\S+', " ", text)
|
||
|
text = re.sub(r'\/[a-z]\/', " ", text)
|
||
|
text = re.sub(r'[^a-z]', " ", text)
|
||
|
text = re.sub(r'\s{2,}', " ", text)
|
||
|
text = re.sub(r'\W\w{1,3}\W|\A\w{1,3}\W', " ", text)
|
||
|
text = re.sub(r'^\s', "", text)
|
||
|
|
||
|
return text
|
||
|
|
||
|
def prediction(input,output):
|
||
|
loaded_model = pickle.load(open('model_linear_reg.pkl','rb'))
|
||
|
#print(loaded_model)
|
||
|
weights, word, vocabulary = loaded_model
|
||
|
#print("WORD: ")
|
||
|
#print(word)
|
||
|
#print(" WEIGHTS: ")
|
||
|
#print(weights)
|
||
|
output_f = open(output,'w')
|
||
|
with open(input, encoding='utf-8') as input_f:
|
||
|
for line in input_f:
|
||
|
text, timestamp = line.rstrip('\n').split('\t')
|
||
|
tokens = tokenize_list(text.lower())
|
||
|
line_vocabulary = calculate_words(tokens)
|
||
|
tokens = tokens.split(' ')
|
||
|
y_hat = weights[0]
|
||
|
for token in tokens:
|
||
|
if token in vocabulary.keys():
|
||
|
y_hat += weights[word[token]] * line_vocabulary[token]
|
||
|
if y_hat > 0.5:
|
||
|
output_f.write("1\n")
|
||
|
print(y_hat)
|
||
|
else:
|
||
|
output_f.write("0\n")
|
||
|
print(y_hat)
|
||
|
output_f.close()
|
||
|
|
||
|
|
||
|
def main():
|
||
|
prediction("dev-0/in.tsv","dev-0/out.tsv")
|
||
|
prediction("test-A/in.tsv","test-A/out.tsv")
|
||
|
|
||
|
main()
|