{ "cells": [ { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": true, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "from sklearn.linear_model import LinearRegression\n", "from sklearn.feature_extraction.text import TfidfVectorizer\n", "from sklearn.metrics import mean_squared_error\n", "import lzma\n", "import numpy as np\n", "\n", "X_train_raw = lzma.open(\"train/train.tsv.xz\", mode='rt').readlines()\n", "X_dev0 = open(\"dev-0/in.tsv\", \"r\").readlines()\n", "y_expected_dev0 = open(\"dev-0/expected.tsv\", \"r\").readlines()\n", "X_dev1 = open(\"dev-1/in.tsv\", \"r\").readlines()\n", "y_expected_dev1 = open(\"dev-1/expected.tsv\", \"r\").readlines()\n", "X_test = open(\"test-A/in.tsv\", \"r\").readlines()\n", "X = [i.split('\\t') for i in X_train_raw]\n", "X_train = [x[4] for x in X]\n", "y_expected_train = [x[0] for x in X]" ] }, { "cell_type": "code", "execution_count": 8, "outputs": [], "source": [ "vectorizer = TfidfVectorizer(max_features=10000)\n", "X_train_tfidf = vectorizer.fit_transform(X_train)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 9, "outputs": [], "source": [ "X_dev0_tfidf = vectorizer.transform(X_dev0)\n", "X_dev1_tfidf = vectorizer.transform(X_dev1)\n", "X_test_tfidf = vectorizer.transform(X_test)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 10, "outputs": [], "source": [ "model = LinearRegression()\n", "model.fit(X_train_tfidf, y_expected_train)\n", "y_predicted_dev0 = model.predict(X_dev0_tfidf)\n", "y_predicted_dev1 = model.predict(X_dev1_tfidf)\n", "y_predicted_test = model.predict(X_test_tfidf)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 11, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/Dominik/anaconda3/lib/python3.9/site-packages/sklearn/utils/validation.py:63: FutureWarning: Arrays of bytes/strings is being converted to decimal numbers if dtype='numeric'. This behavior is deprecated in 0.24 and will be removed in 1.1 (renaming of 0.26). Please convert your data to numeric values explicitly instead.\n", " return f(*args, **kwargs)\n", "/Users/Dominik/anaconda3/lib/python3.9/site-packages/sklearn/utils/validation.py:63: FutureWarning: Arrays of bytes/strings is being converted to decimal numbers if dtype='numeric'. This behavior is deprecated in 0.24 and will be removed in 1.1 (renaming of 0.26). Please convert your data to numeric values explicitly instead.\n", " return f(*args, **kwargs)\n" ] } ], "source": [ "rmse_dev0 = mean_squared_error(y_expected_dev0, y_predicted_dev0, squared=False)\n", "rmse_dev1 = mean_squared_error(y_expected_dev1, y_predicted_dev1, squared=False)\n" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 12, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "24.077488520623103 22.447122551358966\n" ] } ], "source": [ "print(rmse_dev0, rmse_dev1)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 24, "outputs": [], "source": [ "open(\"dev-0/out.tsv\", mode='w').writelines([str(i)+'\\n' for i in y_predicted_dev0])\n", "open(\"dev-1/out.tsv\", mode='w').writelines([str(i)+'\\n' for i in y_predicted_dev1])\n", "open(\"test-A/out.tsv\", mode='w').writelines([str(i)+'\\n' for i in y_predicted_test])\n" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 0 }