Add retro.ipynb
This commit is contained in:
parent
33b70ce7b1
commit
4797f14e20
161
retro.ipynb
Normal file
161
retro.ipynb
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 12,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import lzma\n",
|
||||||
|
"import csv\n",
|
||||||
|
"import re\n",
|
||||||
|
"\n",
|
||||||
|
"def readInput(dir):\n",
|
||||||
|
" X = []\n",
|
||||||
|
" if 'xz' in dir:\n",
|
||||||
|
" with lzma.open(dir) as f:\n",
|
||||||
|
" for line in f:\n",
|
||||||
|
" text = line.decode('utf-8')\n",
|
||||||
|
" text = text.split('\\t')\n",
|
||||||
|
" X.append(text)\n",
|
||||||
|
" else:\n",
|
||||||
|
" with open(dir, encoding='utf8', errors='ignore') as f:\n",
|
||||||
|
" for line in f:\n",
|
||||||
|
" X. append(line.replace('\\n',''))\n",
|
||||||
|
" return X\n",
|
||||||
|
"\n",
|
||||||
|
"def writeOutput(output, dir):\n",
|
||||||
|
" with open(dir, 'w', newline='') as f:\n",
|
||||||
|
" writer = csv.writer(f)\n",
|
||||||
|
" writer.writerows(output)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"\n",
|
||||||
|
"train = pd.DataFrame(readInput('train/train.tsv.xz'), columns=['Beginning', 'End', 'Title', 'Source', 'X'])\n",
|
||||||
|
"train['Y'] = train.apply(lambda x: (float(x.Beginning) + float(x.End))/2, axis=1)\n",
|
||||||
|
"train = train.drop(columns=['Beginning', 'End', 'Title', 'Source'])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 17,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from sklearn import linear_model\n",
|
||||||
|
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
|
||||||
|
"from sklearn.pipeline import Pipeline\n",
|
||||||
|
"\n",
|
||||||
|
"estimators = [('tfidf', TfidfVectorizer()), ('linearRegression', linear_model.LinearRegression())]\n",
|
||||||
|
"model = Pipeline(estimators)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 18,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"model.fit(train.X, train.Y)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 13,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"dev0X = readInput('dev-0/in.tsv')\n",
|
||||||
|
"dev0Expected = readInput('dev-0/expected.tsv')\n",
|
||||||
|
"dev0Predicted = model.predict(dev0X)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"print('RMSE = ', np.sqrt(sklearn.metrics.mean_squared_error(dev0Expected, dev0Predicted)))\n",
|
||||||
|
"print('Model score = ', model.score(dev0X, dev0Expected))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 16,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"c:\\software\\python3\\lib\\site-packages\\sklearn\\utils\\validation.py:63: FutureWarning: Arrays of bytes/strings is being converted to decimal numbers if dtype='numeric'. This behavior is deprecated in 0.24 and will be removed in 1.1 (renaming of 0.26). Please convert your data to numeric values explicitly instead.\n",
|
||||||
|
" return f(*args, **kwargs)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"RMSE = 21.716380888138996\n",
|
||||||
|
"Model score = 0.8585103501633741\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"c:\\software\\python3\\lib\\site-packages\\sklearn\\utils\\validation.py:63: FutureWarning: Arrays of bytes/strings is being converted to decimal numbers if dtype='numeric'. This behavior is deprecated in 0.24 and will be removed in 1.1 (renaming of 0.26). Please convert your data to numeric values explicitly instead.\n",
|
||||||
|
" return f(*args, **kwargs)\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import sklearn.metrics\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"\n",
|
||||||
|
"print('RMSE = ', np.sqrt(sklearn.metrics.mean_squared_error(dev0Expected, dev0Predicted)))\n",
|
||||||
|
"print('Model score = ', model.score(dev0X, dev0Expected))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"interpreter": {
|
||||||
|
"hash": "1b132c2ed43285dcf39f6d01712959169a14a721cf314fe69015adab49bb1fd1"
|
||||||
|
},
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3.8.10 64-bit",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.8.10"
|
||||||
|
},
|
||||||
|
"orig_nbformat": 4
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user