{ "cells": [ { "cell_type": "code", "execution_count": 6, "id": "4206eb3f", "metadata": {}, "outputs": [], "source": [ "import vowpalwabbit\n", "import pandas as pd\n", "import re" ] }, { "cell_type": "code", "execution_count": 7, "id": "fde46276", "metadata": {}, "outputs": [], "source": [ "def prediction(path_in, path_out, model, categories):\n", " data = pd.read_csv(path_in, header=None, sep='\\t')\n", " data = data.drop(1, axis=1)\n", " data.columns = ['year', 'text']\n", "\n", " data['train_input'] = data.apply(lambda row: to_vowpalwabbit(row, categories), axis=1)\n", "\n", " with open(path_out, 'w', encoding='utf-8') as file:\n", " for example in data['train_input']:\n", " predicted = model.predict(example)\n", " text_predicted = dict((value, key) for key, value in map_dict.items()).get(predicted)\n", " file.write(str(text_predicted) + '\\n')\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "27e69709", "metadata": {}, "outputs": [], "source": [ "def to_vowpalwabbit(row, categories):\n", " text = row['text'].replace('\\n', ' ').lower().strip()\n", " text = re.sub(\"[^a-zA-Z -']\", '', text)\n", " text = re.sub(\" +\", ' ', text)\n", " year = row['year']\n", " try:\n", " category = categories[row['category']]\n", " except KeyError:\n", " category = ''\n", "\n", " vw = f\"{category} | year:{year} text:{text}\\n\"\n", "\n", " return vw" ] }, { "cell_type": "code", "execution_count": 9, "id": "c406b425", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'news': 1, 'sport': 2, 'opinion': 3, 'business': 4, 'culture': 5, 'lifestyle': 6, 'removed': 7}\n" ] } ], "source": [ "x_train = pd.read_csv('train/in.tsv', header=None, sep='\\t')\n", "x_train = x_train.drop(1, axis=1)\n", "x_train.columns = ['year', 'text']\n", "\n", "y_train = pd.read_csv('train/expected.tsv', header=None, sep='\\t')\n", "y_train.columns = ['category']\n", "\n", "data = pd.concat([x_train, y_train], axis=1)\n", "\n", "categories = {}\n", "\n", "for i, x in enumerate(data['category'].unique()):\n", " categories[x] = i+1\n", "\n", "print(categories)\n", " \n", "data['train_input'] = data.apply(lambda row: to_vowpalwabbit(row, categories), axis=1)\n", "\n", "model = vowpalwabbit.Workspace('--oaa 3 --quiet')\n", "\n", "for example in data['train_input']:\n", " model.learn(example)\n", "\n", "prediction('dev-0/in.tsv', 'dev-0/out.tsv', model, categories)\n", "prediction('test-A/in.tsv', 'test-A/out.tsv', model, categories)\n", "prediction('test-B/in.tsv', 'test-B/out.tsv', model, categories)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 5 }