ireland-news-headlines/run.py

78 lines
1.2 KiB
Python
Raw Normal View History

2022-05-27 21:21:57 +02:00
#!/usr/bin/env python
# coding: utf-8
2022-05-27 21:33:35 +02:00
# In[1]:
2022-05-27 21:21:57 +02:00
import pandas as pd
import vowpalwabbit
from sklearn import preprocessing
2022-05-27 21:33:35 +02:00
# In[2]:
2022-05-27 21:21:57 +02:00
vw = vowpalwabbit.Workspace('--oaa 20')
2022-05-27 21:33:35 +02:00
# In[3]:
2022-05-27 21:21:57 +02:00
X_train = pd.read_csv('train\in.tsv', sep='\t', usecols=[2], names=['text'])
Y_train = pd.read_csv('train\expected.tsv', sep='\t', usecols=[0], names=['class'])
2022-05-27 21:33:35 +02:00
# In[4]:
2022-05-27 21:21:57 +02:00
Y_train['class'].unique()
2022-05-27 21:33:35 +02:00
# In[5]:
2022-05-27 21:21:57 +02:00
le = preprocessing.LabelEncoder()
le.fit(['business', 'culture', 'lifestyle', 'news', 'opinion', 'removed', 'sport'])
Y_train['class'] = le.fit_transform(Y_train['class'])
2022-05-27 21:33:35 +02:00
# In[6]:
2022-05-27 21:21:57 +02:00
for x, y in zip(X_train['text'], Y_train['class']):
vw.learn(f'{y} | text:{x}')
2022-05-27 21:33:35 +02:00
# In[16]:
2022-05-27 21:21:57 +02:00
def make_prediction(path_in, path_out):
test_set = pd.read_csv(path_in, sep='\t', usecols=[2], names=['text'])
predictions = []
2022-05-27 21:33:35 +02:00
for x in test_set['text']:
2022-05-27 21:21:57 +02:00
predictions.append(vw.predict(f'| text:{x}'))
predictions = le.inverse_transform(predictions)
file = open(path_out, 'w')
for pred in predictions:
file.write(f'{pred}\n')
file.close()
2022-05-27 21:33:35 +02:00
# In[17]:
2022-05-27 21:21:57 +02:00
make_prediction('dev-0\in.tsv', 'dev-0\out.tsv')
2022-05-27 21:33:35 +02:00
# In[18]:
2022-05-27 21:21:57 +02:00
make_prediction('test-A\in.tsv', 'test-A\out.tsv')
2022-05-27 21:33:35 +02:00
# In[19]:
2022-05-27 21:21:57 +02:00
make_prediction('test-B\in.tsv', 'test-B\out.tsv')