{ "cells": [ { "cell_type": "code", "execution_count": 3, "id": "57debdd5-4760-4524-9c77-409652cfb52e", "metadata": {}, "outputs": [], "source": [ "import vowpalwabbit\n", "import pandas as pd\n", "import re" ] }, { "cell_type": "code", "execution_count": 4, "id": "48357fb3-9d6f-48d0-9869-bce7a87e3ba1", "metadata": {}, "outputs": [], "source": [ "def prediction(path_in, path_out, model, categories):\n", " data = pd.read_csv(path_in, header=None, sep='\\t')\n", " data = data.drop(1, axis=1)\n", " data.columns = ['year', 'text']\n", "\n", " data['train_input'] = data.apply(lambda row: to_vowpalwabbit(row, categories), axis=1)\n", "\n", " with open(path_out, 'w', encoding='utf-8') as file:\n", " for example in data['train_input']:\n", " predicted = model.predict(example)\n", " text_predicted = dict((value, key) for key, value in categories.items()).get(predicted)\n", " file.write(str(text_predicted) + '\\n')" ] }, { "cell_type": "code", "execution_count": 5, "id": "f47c4ff9-2078-43f7-b06c-99c9fdd2022e", "metadata": {}, "outputs": [], "source": [ "def to_vowpalwabbit(row, categories):\n", " text = row['text'].replace('\\n', ' ').lower().strip()\n", " text = re.sub(\"[^a-zA-Z -']\", '', text)\n", " text = re.sub(\" +\", ' ', text)\n", " year = row['year']\n", " try:\n", " category = categories[row['category']]\n", " except KeyError:\n", " category = ''\n", "\n", " vw = f\"{category} | year:{year} text:{text}\\n\"\n", "\n", " return vw" ] }, { "cell_type": "code", "execution_count": 11, "id": "015a5ccb-8fe0-45cf-bf59-416ce9e59dad", "metadata": {}, "outputs": [], "source": [ "x_train = pd.read_csv('train/in.tsv', header=None, sep='\\t')\n", "x_train = x_train.drop(1, axis=1)\n", "x_train.columns = ['year', 'text']\n", "\n", "\n", "y_train = pd.read_csv('train/expected.tsv', header=None, sep='\\t')\n", "y_train.columns = ['category']\n", "\n", "x_train = x_train[0:800000]\n", "y_train = y_train[0:800000]\n", "\n", "data = pd.concat([x_train, y_train], axis=1)" ] }, { "cell_type": "code", "execution_count": 12, "id": "d8a2c80c-3b93-410d-98e4-ac651d0933a2", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
yeartext
02004.508197Sudan claims it is disarming militias
12008.442623Bluffer's guide to Euro 2008
22012.587432Ennis tallies her highest first day total
32009.071233Sri Lanka continues to battle Tamil Tigers
41997.345205Talks today to avert new health service strike
.........
7999952010.876712Top league stars among 135 listed online
7999962000.879452Cabinet to consider options for animal disposal
7999972004.915068Last orders for Bewley's this evening
7999982014.797260Toulon; Ospreys and Toulouse win Champions Cup...
7999991999.019178Volatile year in store for the markets
\n", "

800000 rows × 2 columns

\n", "
" ], "text/plain": [ " year text\n", "0 2004.508197 Sudan claims it is disarming militias\n", "1 2008.442623 Bluffer's guide to Euro 2008\n", "2 2012.587432 Ennis tallies her highest first day total\n", "3 2009.071233 Sri Lanka continues to battle Tamil Tigers\n", "4 1997.345205 Talks today to avert new health service strike\n", "... ... ...\n", "799995 2010.876712 Top league stars among 135 listed online\n", "799996 2000.879452 Cabinet to consider options for animal disposal\n", "799997 2004.915068 Last orders for Bewley's this evening\n", "799998 2014.797260 Toulon; Ospreys and Toulouse win Champions Cup...\n", "799999 1999.019178 Volatile year in store for the markets\n", "\n", "[800000 rows x 2 columns]" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x_train" ] }, { "cell_type": "code", "execution_count": 13, "id": "422673c2-4de6-446a-816c-1c35ba43c373", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'news': 1, 'sport': 2, 'opinion': 3, 'business': 4, 'culture': 5, 'lifestyle': 6, 'removed': 7}\n" ] } ], "source": [ "categories = {}\n", "\n", "for i, x in enumerate(data['category'].unique()):\n", " categories[x] = i+1\n", "\n", "print(categories)\n", " \n", "data['train_input'] = data.apply(lambda row: to_vowpalwabbit(row, categories), axis=1)\n", "\n", "model = vowpalwabbit.Workspace('--oaa 7 --learning_rate 0.99')\n", "\n", "for example in data['train_input']:\n", " model.learn(example)" ] }, { "cell_type": "code", "execution_count": 14, "id": "29f424f4-19fd-43f9-a8bf-39beb3fc408d", "metadata": {}, "outputs": [], "source": [ "prediction('dev-0/in.tsv', 'dev-0/out.tsv', model, categories)\n", "prediction('test-A/in.tsv', 'test-A/out.tsv', model, categories)\n", "prediction('test-B/in.tsv', 'test-B/out.tsv', model, categories)" ] }, { "cell_type": "code", "execution_count": 15, "id": "ecf1726c-56ee-4476-bf88-136fa588feec", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[NbConvertApp] Converting notebook run.ipynb to script\n", "[NbConvertApp] Writing 2030 bytes to run.py\n" ] } ], "source": [ "!jupyter nbconvert --to script run.ipynb" ] }, { "cell_type": "code", "execution_count": null, "id": "e00dd4c1-7d79-4b6d-9c59-7c71640e5230", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" } }, "nbformat": 4, "nbformat_minor": 5 }