Enhancement to predict function
This commit is contained in:
parent
e20007b2df
commit
d4ba49aa16
@ -39,33 +39,18 @@ def find_13(dane):
|
|||||||
def find_best_flats(dane):
|
def find_best_flats(dane):
|
||||||
return dane.loc[(dane['Borough'] == 'Winogrady') & (dane['Rooms'] == 3) & (dane['Floor'] == 1)]
|
return dane.loc[(dane['Borough'] == 'Winogrady') & (dane['Rooms'] == 3) & (dane['Floor'] == 1)]
|
||||||
|
|
||||||
def predict(dane, col_name):
|
def predict(dane, rooms, sqrMeters):
|
||||||
from sklearn import linear_model
|
from sklearn import linear_model
|
||||||
from sklearn.metrics import mean_squared_error, r2_score
|
import numpy as np
|
||||||
d_X = pd.DataFrame(dane[col_name])
|
data = dane
|
||||||
d_X_train = d_X[4000:]
|
df = pd.DataFrame(data, columns=np.array(['Rooms','SqrMeters']))
|
||||||
d_X_test = d_X[:4000]
|
target = pd.DataFrame(data, columns=["Expected"])
|
||||||
d_y = pd.DataFrame(dane['Expected'])
|
X = df
|
||||||
d_y_train = d_y[4000:]
|
y = target["Expected"]
|
||||||
d_y_test = d_y[:4000]
|
lm = linear_model.LinearRegression()
|
||||||
regr = linear_model.LinearRegression()
|
model = lm.fit(X, y)
|
||||||
regr.fit(d_X_train, d_y_train)
|
inData = pd.DataFrame.from_records([(rooms, sqrMeters)], columns=['Rooms', 'SqrMeters'])
|
||||||
y_pred = regr.predict(d_X_test)
|
return lm.predict(inData)[0]
|
||||||
print('MODEL(%s): pred_y = %f * x + %f' % (col_name, regr.coef_[0], regr.intercept_) )
|
|
||||||
print('Mean squared error: %.2f' % mean_squared_error(d_y_test, y_pred))
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
plt.clf()
|
|
||||||
dataLine, = plt.plot(d_X_test, d_y_test, 'ro', label='collected data')
|
|
||||||
predLine, = plt.plot(d_X_test, y_pred, color='blue', linestyle='--', linewidth = 2, label='predictions')
|
|
||||||
ax = plt.gca().add_artist(plt.legend(handles=[dataLine], loc=1))
|
|
||||||
plt.legend(handles=[predLine], loc=4)
|
|
||||||
plt.xticks(())
|
|
||||||
plt.yticks(())
|
|
||||||
plt.xlabel(col_name)
|
|
||||||
plt.ylabel('Price')
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
dane = wczytaj_dane()
|
dane = wczytaj_dane()
|
||||||
@ -87,9 +72,8 @@ def main():
|
|||||||
print('"Najlepsze" mieszkania: ')
|
print('"Najlepsze" mieszkania: ')
|
||||||
print(find_best_flats(dane))
|
print(find_best_flats(dane))
|
||||||
|
|
||||||
predict(dane, 'Rooms')
|
print('Predicted price(actual 146000): ', predict(dane,1,31.21))
|
||||||
predict(dane, 'SqrMeters')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
Loading…
Reference in New Issue
Block a user