diff --git a/retro.ipynb b/retro.ipynb new file mode 100644 index 0000000..d948448 --- /dev/null +++ b/retro.ipynb @@ -0,0 +1,161 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "import lzma\n", + "import csv\n", + "import re\n", + "\n", + "def readInput(dir):\n", + " X = []\n", + " if 'xz' in dir:\n", + " with lzma.open(dir) as f:\n", + " for line in f:\n", + " text = line.decode('utf-8')\n", + " text = text.split('\\t')\n", + " X.append(text)\n", + " else:\n", + " with open(dir, encoding='utf8', errors='ignore') as f:\n", + " for line in f:\n", + " X. append(line.replace('\\n',''))\n", + " return X\n", + "\n", + "def writeOutput(output, dir):\n", + " with open(dir, 'w', newline='') as f:\n", + " writer = csv.writer(f)\n", + " writer.writerows(output)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "train = pd.DataFrame(readInput('train/train.tsv.xz'), columns=['Beginning', 'End', 'Title', 'Source', 'X'])\n", + "train['Y'] = train.apply(lambda x: (float(x.Beginning) + float(x.End))/2, axis=1)\n", + "train = train.drop(columns=['Beginning', 'End', 'Title', 'Source'])" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn import linear_model\n", + "from sklearn.feature_extraction.text import TfidfVectorizer\n", + "from sklearn.pipeline import Pipeline\n", + "\n", + "estimators = [('tfidf', TfidfVectorizer()), ('linearRegression', linear_model.LinearRegression())]\n", + "model = Pipeline(estimators)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "model.fit(train.X, train.Y)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "dev0X = readInput('dev-0/in.tsv')\n", + "dev0Expected = readInput('dev-0/expected.tsv')\n", + "dev0Predicted = model.predict(dev0X)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print('RMSE = ', np.sqrt(sklearn.metrics.mean_squared_error(dev0Expected, dev0Predicted)))\n", + "print('Model score = ', model.score(dev0X, dev0Expected))" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\software\\python3\\lib\\site-packages\\sklearn\\utils\\validation.py:63: FutureWarning: Arrays of bytes/strings is being converted to decimal numbers if dtype='numeric'. This behavior is deprecated in 0.24 and will be removed in 1.1 (renaming of 0.26). Please convert your data to numeric values explicitly instead.\n", + " return f(*args, **kwargs)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RMSE = 21.716380888138996\n", + "Model score = 0.8585103501633741\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\software\\python3\\lib\\site-packages\\sklearn\\utils\\validation.py:63: FutureWarning: Arrays of bytes/strings is being converted to decimal numbers if dtype='numeric'. This behavior is deprecated in 0.24 and will be removed in 1.1 (renaming of 0.26). Please convert your data to numeric values explicitly instead.\n", + " return f(*args, **kwargs)\n" + ] + } + ], + "source": [ + "import sklearn.metrics\n", + "import numpy as np\n", + "\n", + "print('RMSE = ', np.sqrt(sklearn.metrics.mean_squared_error(dev0Expected, dev0Predicted)))\n", + "print('Model score = ', model.score(dev0X, dev0Expected))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "interpreter": { + "hash": "1b132c2ed43285dcf39f6d01712959169a14a721cf314fe69015adab49bb1fd1" + }, + "kernelspec": { + "display_name": "Python 3.8.10 64-bit", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +}