project_python_rynekNieruch.../home_pricing/main.py

47 lines
2.0 KiB
Python
Raw Permalink Normal View History

2024-02-26 16:54:44 +01:00
from DataCollectingScraper.DataScrapers.OtoDomDataScraperImpl import OtoDomDataScraperImpl
from DataCollectingScraper.DataCollectingScraper import DataCollectingScraper
from DataPreprocessor.helpers.OffersCSVReader import OffersCSVReader
from DataPreprocessor.DataPreprocessor import DataPreprocessor
from Prediction.Trainer.PredictionModelTrainer import PredictionModelTrainer
from sklearn.neural_network import MLPRegressor
from pandas.core.frame import DataFrame
import pandas as pd
import joblib
download_data = False
train_model = False
# Downloading raw data
if download_data:
offers_sublink = "wyniki/sprzedaz/mieszkanie/dolnoslaskie/wroclaw/wroclaw/wroclaw?viewType=listing"
scraper = DataCollectingScraper(OtoDomDataScraperImpl(offers_sublink))
scraper()
# Reading downloaded data
data_frame : DataFrame = OffersCSVReader.read_from_file("output.csv")
# Prepare data for neural network (data preprocessing)
data_preprocessor = DataPreprocessor(data_frame)
data_preprocessor.preprocess_data()
if train_model:
preprocessed_data : DataFrame = data_preprocessor.get_preprocessed_data()
# Train neural network with preprocessed data
trainer = PredictionModelTrainer(preprocessed_data)
trainer.train()
trained_model : MLPRegressor = trainer.get_trained_model()
joblib.dump(trained_model, 'trained_model.pkl')
trained_model = joblib.load('trained_model.pkl')
scaled_area = data_preprocessor.get_value('Area', pd.DataFrame({'Area': [56.0]}))
scaled_construction_year = data_preprocessor.get_value('Construction year', pd.DataFrame({'Construction year': [1980]}))
encoded_location = data_preprocessor.get_value("Location", ['Krzyki'])
sample_data = [[scaled_area, 3, 8, 0, 2, encoded_location, scaled_construction_year]]
sample = pd.DataFrame(sample_data, columns=['Area', 'Rooms', 'Floor', 'Property form' , 'State', 'Location', 'Construction year'])
prediction = trained_model.predict(sample)
print('Predicted price: ', round(float(prediction),0), '')