FinTech_app/charts/algorithm_4.py

63 lines
1.9 KiB
Python
Raw Permalink Normal View History

2023-01-24 14:48:22 +01:00
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split
# data = pd.read_csv("static/obligacjePL.csv")
#
# months = data['Data']
# prices = data['Otwarcie']
#
# months = np.array(months).reshape(-1, 1)
#
# model = LinearRegression()
# model.fit(months, prices)
#
# future_months = np.array([12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]).reshape(-1, 1)
# predictions = model.predict(future_months)
# plt.plot(months, prices, label='actual')
# plt.plot(future_months, predictions, label='prediction')
# plt.legend()
# plt.xlabel('Months')
# plt.ylabel('Prices')
# plt.title('Stock Prices Prediction')
# plt.grid()
# plt.show()
data = pd.read_csv("static/AAPL2.csv")
months = data['Date']
prices = data['Open']
months = np.array(months).reshape(-1, 1)
poly_feat = PolynomialFeatures(degree=2)
months_poly = poly_feat.fit_transform(months)
model = LinearRegression()
model.fit(months_poly, prices)
future_months = np.array([12, 13, 14, 15, 16, 17, 18]).reshape(-1, 1)
future_months_poly = poly_feat.fit_transform(future_months)
predictions = model.predict(future_months_poly)
plt.plot(months, prices, color='cyan')
plt.plot(future_months, predictions, color='pink')
plt.title('Polynomial Regression')
plt.xlabel('Months')
plt.ylabel('Prices')
plt.xticks(range(1, 18))
plt.show()
X_train, X_test, y_train, y_test = train_test_split(months_poly, prices, test_size=0.2, random_state=0)
# Fit the model to the training data
model = LinearRegression()
model.fit(X_train, y_train)
# Make predictions on the test data
y_pred = model.predict(X_test)
# Calculate the MSE and R-Squared
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
# Print the results
print("Mean Squared Error:", mse)
print("R-Squared:", r2)