From d4ba49aa162c8454c07b5f6f98fee9c0275a4082 Mon Sep 17 00:00:00 2001 From: s45157 Date: Sun, 21 Jan 2018 19:11:37 +0100 Subject: [PATCH] Enhancement to predict function --- labs06/task02.py | 42 +++++++++++++----------------------------- 1 file changed, 13 insertions(+), 29 deletions(-) diff --git a/labs06/task02.py b/labs06/task02.py index b9ad400..d2ccacf 100755 --- a/labs06/task02.py +++ b/labs06/task02.py @@ -39,33 +39,18 @@ def find_13(dane): def find_best_flats(dane): return dane.loc[(dane['Borough'] == 'Winogrady') & (dane['Rooms'] == 3) & (dane['Floor'] == 1)] -def predict(dane, col_name): +def predict(dane, rooms, sqrMeters): from sklearn import linear_model - from sklearn.metrics import mean_squared_error, r2_score - d_X = pd.DataFrame(dane[col_name]) - d_X_train = d_X[4000:] - d_X_test = d_X[:4000] - d_y = pd.DataFrame(dane['Expected']) - d_y_train = d_y[4000:] - d_y_test = d_y[:4000] - regr = linear_model.LinearRegression() - regr.fit(d_X_train, d_y_train) - y_pred = regr.predict(d_X_test) - print('MODEL(%s): pred_y = %f * x + %f' % (col_name, regr.coef_[0], regr.intercept_) ) - print('Mean squared error: %.2f' % mean_squared_error(d_y_test, y_pred)) - - import matplotlib.pyplot as plt - - plt.clf() - dataLine, = plt.plot(d_X_test, d_y_test, 'ro', label='collected data') - predLine, = plt.plot(d_X_test, y_pred, color='blue', linestyle='--', linewidth = 2, label='predictions') - ax = plt.gca().add_artist(plt.legend(handles=[dataLine], loc=1)) - plt.legend(handles=[predLine], loc=4) - plt.xticks(()) - plt.yticks(()) - plt.xlabel(col_name) - plt.ylabel('Price') - plt.show() + import numpy as np + data = dane + df = pd.DataFrame(data, columns=np.array(['Rooms','SqrMeters'])) + target = pd.DataFrame(data, columns=["Expected"]) + X = df + y = target["Expected"] + lm = linear_model.LinearRegression() + model = lm.fit(X, y) + inData = pd.DataFrame.from_records([(rooms, sqrMeters)], columns=['Rooms', 'SqrMeters']) + return lm.predict(inData)[0] def main(): dane = wczytaj_dane() @@ -87,9 +72,8 @@ def main(): print('"Najlepsze" mieszkania: ') print(find_best_flats(dane)) - predict(dane, 'Rooms') - predict(dane, 'SqrMeters') + print('Predicted price(actual 146000): ', predict(dane,1,31.21)) if __name__ == "__main__": - main() + main() \ No newline at end of file