Add retro.ipynb

This commit is contained in:
s443930 2022-05-08 23:30:43 +02:00
parent 33b70ce7b1
commit 4797f14e20
1 changed files with 161 additions and 0 deletions

161
retro.ipynb Normal file
View File

@ -0,0 +1,161 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"import lzma\n",
"import csv\n",
"import re\n",
"\n",
"def readInput(dir):\n",
" X = []\n",
" if 'xz' in dir:\n",
" with lzma.open(dir) as f:\n",
" for line in f:\n",
" text = line.decode('utf-8')\n",
" text = text.split('\\t')\n",
" X.append(text)\n",
" else:\n",
" with open(dir, encoding='utf8', errors='ignore') as f:\n",
" for line in f:\n",
" X. append(line.replace('\\n',''))\n",
" return X\n",
"\n",
"def writeOutput(output, dir):\n",
" with open(dir, 'w', newline='') as f:\n",
" writer = csv.writer(f)\n",
" writer.writerows(output)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"train = pd.DataFrame(readInput('train/train.tsv.xz'), columns=['Beginning', 'End', 'Title', 'Source', 'X'])\n",
"train['Y'] = train.apply(lambda x: (float(x.Beginning) + float(x.End))/2, axis=1)\n",
"train = train.drop(columns=['Beginning', 'End', 'Title', 'Source'])"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"from sklearn import linear_model\n",
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.pipeline import Pipeline\n",
"\n",
"estimators = [('tfidf', TfidfVectorizer()), ('linearRegression', linear_model.LinearRegression())]\n",
"model = Pipeline(estimators)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"model.fit(train.X, train.Y)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"dev0X = readInput('dev-0/in.tsv')\n",
"dev0Expected = readInput('dev-0/expected.tsv')\n",
"dev0Predicted = model.predict(dev0X)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print('RMSE = ', np.sqrt(sklearn.metrics.mean_squared_error(dev0Expected, dev0Predicted)))\n",
"print('Model score = ', model.score(dev0X, dev0Expected))"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\software\\python3\\lib\\site-packages\\sklearn\\utils\\validation.py:63: FutureWarning: Arrays of bytes/strings is being converted to decimal numbers if dtype='numeric'. This behavior is deprecated in 0.24 and will be removed in 1.1 (renaming of 0.26). Please convert your data to numeric values explicitly instead.\n",
" return f(*args, **kwargs)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"RMSE = 21.716380888138996\n",
"Model score = 0.8585103501633741\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\software\\python3\\lib\\site-packages\\sklearn\\utils\\validation.py:63: FutureWarning: Arrays of bytes/strings is being converted to decimal numbers if dtype='numeric'. This behavior is deprecated in 0.24 and will be removed in 1.1 (renaming of 0.26). Please convert your data to numeric values explicitly instead.\n",
" return f(*args, **kwargs)\n"
]
}
],
"source": [
"import sklearn.metrics\n",
"import numpy as np\n",
"\n",
"print('RMSE = ', np.sqrt(sklearn.metrics.mean_squared_error(dev0Expected, dev0Predicted)))\n",
"print('Model score = ', model.score(dev0X, dev0Expected))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"interpreter": {
"hash": "1b132c2ed43285dcf39f6d01712959169a14a721cf314fe69015adab49bb1fd1"
},
"kernelspec": {
"display_name": "Python 3.8.10 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}