ireland-news-headlines/run.py

55 lines
1.7 KiB
Python

import vowpalwabbit
import pandas as pd
import re
import sys
file_to_predict = sys.argv[1]
def to_vw_format(row):
mapping = {'business': 1, 'culture': 2, 'lifestyle': 3, 'news': 4, 'opinion': 5, 'removed': 6, 'sport': 7}
try:
res = f'{mapping[row.category]} |'
rows = row.drop(['category']).iteritems()
except:
res = '0 |'
rows = row.iteritems()
for idx, value in rows:
if idx == 'text':
value = value.lower()
value = re.sub("[^a-zA-Z '-]", '', value)
value = re.sub("( ')|(' )|('$)|(^')", ' ', value)
value = value.strip()
value = re.sub(' +', ' ', value)
res += f" {idx}:{value}"
return res
in_df = pd.read_csv('/home/ked/Desktop/aitech/eks/ireland-news-headlines/train/in.tsv', sep='\t', header=None)
exp_df = pd.read_csv('/home/ked/Desktop/aitech/eks/ireland-news-headlines/train/expected.tsv', sep='\t', header=None)
pred_df = pd.read_csv(file_to_predict, sep='\t', header=None)
in_df[3] = exp_df[0]
in_df = in_df.drop(1, 1)
in_df.columns = ['year', 'text', 'category']
pred_df[3] = exp_df[0] # Leniwe rozwiązanie, żeby funkcja się nie sypała
pred_df = pred_df.drop(1, 1)
pred_df.columns = ['year', 'text', 'category']
vw = vowpalwabbit.Workspace('--oaa 7')
for example in in_df.apply(to_vw_format, axis=1):
vw.learn(example)
predictions = []
for example in pred_df.apply(to_vw_format, axis=1):
predicted_class = vw.predict(example)
predictions.append(predicted_class)
mapping = {1: 'business', 2: 'culture', 3: 'lifestyle', 4: 'news', 5: 'opinion', 6: 'removed', 7: 'sport'}
with open('out.tsv', 'w', encoding='utf-8') as f:
for prediction in predictions:
f.write(mapping[prediction] + '\n')