{ "cells": [ { "cell_type": "code", "execution_count": 23, "id": "3312dc2a", "metadata": {}, "outputs": [], "source": [ "import vowpalwabbit\n", "import pandas as pd\n", "import re" ] }, { "cell_type": "code", "execution_count": 24, "id": "a5d2718d", "metadata": {}, "outputs": [], "source": [ "def prediction(path_in, path_out, model, categories):\n", " data = pd.read_csv(path_in, header=None, sep='\\t')\n", " data = data.drop(1, axis=1)\n", " data.columns = ['year', 'text']\n", "\n", " data['train_input'] = data.apply(lambda row: to_vowpalwabbit(row, categories), axis=1)\n", "\n", " with open(path_out, 'w', encoding='utf-8') as file:\n", " for example in data['train_input']:\n", " predicted = model.predict(example)\n", " text_predicted = dict((value, key) for key, value in categories.items()).get(predicted)\n", " file.write(str(text_predicted) + '\\n')\n" ] }, { "cell_type": "code", "execution_count": 25, "id": "2273a549", "metadata": {}, "outputs": [], "source": [ "def to_vowpalwabbit(row, categories):\n", " text = row['text'].replace('\\n', ' ').lower().strip()\n", " text = re.sub(\"[^a-zA-Z -']\", '', text)\n", " text = re.sub(\" +\", ' ', text)\n", " year = row['year']\n", " try:\n", " category = categories[row['category']]\n", " except KeyError:\n", " category = ''\n", "\n", " vw = f\"{category} | year:{year} text:{text}\\n\"\n", "\n", " return vw" ] }, { "cell_type": "code", "execution_count": 37, "id": "83f7c5b5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'news': 1, 'sport': 2, 'opinion': 3, 'business': 4, 'culture': 5, 'lifestyle': 6, 'removed': 7}\n" ] } ], "source": [ "x_train = pd.read_csv('train/in.tsv', header=None, sep='\\t')\n", "x_train = x_train.drop(1, axis=1)\n", "x_train.columns = ['year', 'text']\n", "\n", "y_train = pd.read_csv('train/expected.tsv', header=None, sep='\\t')\n", "y_train.columns = ['category']\n", "\n", "data = pd.concat([x_train, y_train], axis=1)\n", "\n", "categories = {}\n", "\n", "for i, x in enumerate(data['category'].unique()):\n", " categories[x] = i+1\n", "\n", "print(categories)\n", " \n", "data['train_input'] = data.apply(lambda row: to_vowpalwabbit(row, categories), axis=1)\n", "\n", "model = vowpalwabbit.Workspace('--oaa 7 --learning_rate 0.99')\n", "\n", "for example in data['train_input']:\n", " model.learn(example)\n", "\n", "prediction('dev-0/in.tsv', 'dev-0/out.tsv', model, categories)\n", "prediction('test-A/in.tsv', 'test-A/out.tsv', model, categories)\n", "prediction('test-B/in.tsv', 'test-B/out.tsv', model, categories)" ] }, { "cell_type": "code", "execution_count": 38, "id": "caa9bb3b", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[NbConvertApp] Converting notebook run.ipynb to script\n", "[NbConvertApp] Writing 1952 bytes to run.py\n" ] } ], "source": [ "!jupyter nbconvert --to script run.ipynb" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 5 }