This commit is contained in:
theta00 2022-05-27 21:33:35 +02:00
parent c9207fce50
commit f3f14088fb
3 changed files with 119446 additions and 190287 deletions

22
run.py
View File

@ -1,7 +1,7 @@
#!/usr/bin/env python #!/usr/bin/env python
# coding: utf-8 # coding: utf-8
# In[208]: # In[1]:
import pandas as pd import pandas as pd
@ -9,26 +9,26 @@ import vowpalwabbit
from sklearn import preprocessing from sklearn import preprocessing
# In[209]: # In[2]:
vw = vowpalwabbit.Workspace('--oaa 20') vw = vowpalwabbit.Workspace('--oaa 20')
# In[210]: # In[3]:
X_train = pd.read_csv('train\in.tsv', sep='\t', usecols=[2], names=['text']) X_train = pd.read_csv('train\in.tsv', sep='\t', usecols=[2], names=['text'])
Y_train = pd.read_csv('train\expected.tsv', sep='\t', usecols=[0], names=['class']) Y_train = pd.read_csv('train\expected.tsv', sep='\t', usecols=[0], names=['class'])
# In[211]: # In[4]:
Y_train['class'].unique() Y_train['class'].unique()
# In[212]: # In[5]:
le = preprocessing.LabelEncoder() le = preprocessing.LabelEncoder()
@ -36,20 +36,20 @@ le.fit(['business', 'culture', 'lifestyle', 'news', 'opinion', 'removed', 'sport
Y_train['class'] = le.fit_transform(Y_train['class']) Y_train['class'] = le.fit_transform(Y_train['class'])
# In[213]: # In[6]:
for x, y in zip(X_train['text'], Y_train['class']): for x, y in zip(X_train['text'], Y_train['class']):
vw.learn(f'{y} | text:{x}') vw.learn(f'{y} | text:{x}')
# In[216]: # In[16]:
def make_prediction(path_in, path_out): def make_prediction(path_in, path_out):
test_set = pd.read_csv(path_in, sep='\t', usecols=[2], names=['text']) test_set = pd.read_csv(path_in, sep='\t', usecols=[2], names=['text'])
predictions = [] predictions = []
for x in X_dev0['text']: for x in test_set['text']:
predictions.append(vw.predict(f'| text:{x}')) predictions.append(vw.predict(f'| text:{x}'))
predictions = le.inverse_transform(predictions) predictions = le.inverse_transform(predictions)
file = open(path_out, 'w') file = open(path_out, 'w')
@ -58,19 +58,19 @@ def make_prediction(path_in, path_out):
file.close() file.close()
# In[217]: # In[17]:
make_prediction('dev-0\in.tsv', 'dev-0\out.tsv') make_prediction('dev-0\in.tsv', 'dev-0\out.tsv')
# In[218]: # In[18]:
make_prediction('test-A\in.tsv', 'test-A\out.tsv') make_prediction('test-A\in.tsv', 'test-A\out.tsv')
# In[219]: # In[19]:
make_prediction('test-B\in.tsv', 'test-B\out.tsv') make_prediction('test-B\in.tsv', 'test-B\out.tsv')

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff