648 KiB
648 KiB
This dataset contains information on patients with lung cancer, including their age, gender, air pollution exposure, alcohol use, dust allergy, occupational hazards, genetic risk, chronic lung disease, balanced diet, obesity, smoking, passive smoker, chest pain, coughing of blood, fatigue, weight loss ,shortness of breath ,wheezing ,swallowing difficulty ,clubbing of finger nails and snoring
https://www.kaggle.com/datasets/thedevastator/cancer-patients-and-air-pollution-a-new-link/data
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import plotly.figure_factory as ff
import seaborn as sns
sns.set()
import plotly.express as px
import numpy as np
import sklearn
pip install plotnine
import plotnine
dane = pd.read_csv(r'C:\Users\HP\Desktop\podyplomówka\cancer_patient_data_sets.csv', index_col = 0)
dane.head()
Patient Id | Age | Gender | Air Pollution | Alcohol use | Dust Allergy | OccuPational Hazards | Genetic Risk | chronic Lung Disease | Balanced Diet | ... | Fatigue | Weight Loss | Shortness of Breath | Wheezing | Swallowing Difficulty | Clubbing of Finger Nails | Frequent Cold | Dry Cough | Snoring | Level | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
index | |||||||||||||||||||||
0 | P1 | 33 | 1 | 2 | 4 | 5 | 4 | 3 | 2 | 2 | ... | 3 | 4 | 2 | 2 | 3 | 1 | 2 | 3 | 4 | Low |
1 | P10 | 17 | 1 | 3 | 1 | 5 | 3 | 4 | 2 | 2 | ... | 1 | 3 | 7 | 8 | 6 | 2 | 1 | 7 | 2 | Medium |
2 | P100 | 35 | 1 | 4 | 5 | 6 | 5 | 5 | 4 | 6 | ... | 8 | 7 | 9 | 2 | 1 | 4 | 6 | 7 | 2 | High |
3 | P1000 | 37 | 1 | 7 | 7 | 7 | 7 | 6 | 7 | 7 | ... | 4 | 2 | 3 | 1 | 4 | 5 | 6 | 7 | 5 | High |
4 | P101 | 46 | 1 | 6 | 8 | 7 | 7 | 7 | 6 | 7 | ... | 3 | 2 | 4 | 1 | 4 | 2 | 4 | 2 | 3 | High |
5 rows × 25 columns
dane.info()
<class 'pandas.core.frame.DataFrame'> Index: 1000 entries, 0 to 999 Data columns (total 25 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 Patient Id 1000 non-null object 1 Age 1000 non-null int64 2 Gender 1000 non-null int64 3 Air Pollution 1000 non-null int64 4 Alcohol use 1000 non-null int64 5 Dust Allergy 1000 non-null int64 6 OccuPational Hazards 1000 non-null int64 7 Genetic Risk 1000 non-null int64 8 chronic Lung Disease 1000 non-null int64 9 Balanced Diet 1000 non-null int64 10 Obesity 1000 non-null int64 11 Smoking 1000 non-null int64 12 Passive Smoker 1000 non-null int64 13 Chest Pain 1000 non-null int64 14 Coughing of Blood 1000 non-null int64 15 Fatigue 1000 non-null int64 16 Weight Loss 1000 non-null int64 17 Shortness of Breath 1000 non-null int64 18 Wheezing 1000 non-null int64 19 Swallowing Difficulty 1000 non-null int64 20 Clubbing of Finger Nails 1000 non-null int64 21 Frequent Cold 1000 non-null int64 22 Dry Cough 1000 non-null int64 23 Snoring 1000 non-null int64 24 Level 1000 non-null object dtypes: int64(23), object(2) memory usage: 203.1+ KB
dane.describe().T
count | mean | std | min | 25% | 50% | 75% | max | |
---|---|---|---|---|---|---|---|---|
Age | 1000.0 | 37.174 | 12.005493 | 14.0 | 27.75 | 36.0 | 45.0 | 73.0 |
Gender | 1000.0 | 1.402 | 0.490547 | 1.0 | 1.00 | 1.0 | 2.0 | 2.0 |
Air Pollution | 1000.0 | 3.840 | 2.030400 | 1.0 | 2.00 | 3.0 | 6.0 | 8.0 |
Alcohol use | 1000.0 | 4.563 | 2.620477 | 1.0 | 2.00 | 5.0 | 7.0 | 8.0 |
Dust Allergy | 1000.0 | 5.165 | 1.980833 | 1.0 | 4.00 | 6.0 | 7.0 | 8.0 |
OccuPational Hazards | 1000.0 | 4.840 | 2.107805 | 1.0 | 3.00 | 5.0 | 7.0 | 8.0 |
Genetic Risk | 1000.0 | 4.580 | 2.126999 | 1.0 | 2.00 | 5.0 | 7.0 | 7.0 |
chronic Lung Disease | 1000.0 | 4.380 | 1.848518 | 1.0 | 3.00 | 4.0 | 6.0 | 7.0 |
Balanced Diet | 1000.0 | 4.491 | 2.135528 | 1.0 | 2.00 | 4.0 | 7.0 | 7.0 |
Obesity | 1000.0 | 4.465 | 2.124921 | 1.0 | 3.00 | 4.0 | 7.0 | 7.0 |
Smoking | 1000.0 | 3.948 | 2.495902 | 1.0 | 2.00 | 3.0 | 7.0 | 8.0 |
Passive Smoker | 1000.0 | 4.195 | 2.311778 | 1.0 | 2.00 | 4.0 | 7.0 | 8.0 |
Chest Pain | 1000.0 | 4.438 | 2.280209 | 1.0 | 2.00 | 4.0 | 7.0 | 9.0 |
Coughing of Blood | 1000.0 | 4.859 | 2.427965 | 1.0 | 3.00 | 4.0 | 7.0 | 9.0 |
Fatigue | 1000.0 | 3.856 | 2.244616 | 1.0 | 2.00 | 3.0 | 5.0 | 9.0 |
Weight Loss | 1000.0 | 3.855 | 2.206546 | 1.0 | 2.00 | 3.0 | 6.0 | 8.0 |
Shortness of Breath | 1000.0 | 4.240 | 2.285087 | 1.0 | 2.00 | 4.0 | 6.0 | 9.0 |
Wheezing | 1000.0 | 3.777 | 2.041921 | 1.0 | 2.00 | 4.0 | 5.0 | 8.0 |
Swallowing Difficulty | 1000.0 | 3.746 | 2.270383 | 1.0 | 2.00 | 4.0 | 5.0 | 8.0 |
Clubbing of Finger Nails | 1000.0 | 3.923 | 2.388048 | 1.0 | 2.00 | 4.0 | 5.0 | 9.0 |
Frequent Cold | 1000.0 | 3.536 | 1.832502 | 1.0 | 2.00 | 3.0 | 5.0 | 7.0 |
Dry Cough | 1000.0 | 3.853 | 2.039007 | 1.0 | 2.00 | 4.0 | 6.0 | 7.0 |
Snoring | 1000.0 | 2.926 | 1.474686 | 1.0 | 2.00 | 3.0 | 4.0 | 7.0 |
dane.columns
Index(['Patient Id', 'Age', 'Gender', 'Air Pollution', 'Alcohol use', 'Dust Allergy', 'OccuPational Hazards', 'Genetic Risk', 'chronic Lung Disease', 'Balanced Diet', 'Obesity', 'Smoking', 'Passive Smoker', 'Chest Pain', 'Coughing of Blood', 'Fatigue', 'Weight Loss', 'Shortness of Breath', 'Wheezing', 'Swallowing Difficulty', 'Clubbing of Finger Nails', 'Frequent Cold', 'Dry Cough', 'Snoring', 'Level'], dtype='object')
level_counts = dane['Level'].value_counts()
ax = level_counts.plot(kind = 'pie', autopct='%1.1f%%', startangle=90)
ax = plt.title('Risk distribiution')
#stworzenie 'binow' dla pokazania wieku pacjentów
bins = []
for i in range (0, 101, 10):
bins.append(i)
plt.hist(dane['Age'], bins, histtype='bar', rwidth=0.8)
for i in range(len(bins) - 1):
count = ((dane['Age'] >= bins[i]) & (dane['Age'] < bins[i+1])).sum()
plt.text(bins[i] + 5, count, str(count), ha='center', va='bottom')
plt.xlabel('Age')
plt.ylabel('Number of patients')
plt.title('Age of the patients')
plt.show()
mean_age = dane['Age'].mean()
mean_age
37.174
gender_counts = dane['Gender'].value_counts()
ax = gender_counts.plot(kind='bar', color=['blue', 'pink'])
# Dodawanie wartości do słupków
for i, value in enumerate(gender_counts):
ax.text(i, value + 0.1, str(value), ha='center', va='bottom')
# Zmiana etykiet osi x
ax.set_xticks([0, 1])
ax.set_xticklabels(['Man', 'Woman'])
ax.set_title ("Distribution of patients' gender")
# Dodanie legendy
plt.legend()
# Wyświetlenie wykresu
plt.show()
# Grupowanie danych
grouped_data = dane.groupby(['Gender', 'Level']).size().unstack()
# Ustawienia kategorii i szerokości słupków
categories = grouped_data.columns
bar_width = 0.35
bar_positions_man = np.arange(len(categories))
bar_positions_woman = [pos + bar_width for pos in bar_positions_man]
# Wygenerowanie wykresu słupkowego
fig, ax = plt.subplots()
ax.bar(bar_positions_man, grouped_data.loc[1], width=bar_width, label='Man')
ax.bar(bar_positions_woman, grouped_data.loc[2], width=bar_width, label='Woman')
# Dodanie wartości procentowych do słupków
for i, column in enumerate(categories):
for j, value in enumerate(grouped_data.index):
total = grouped_data[column].sum()
percent = grouped_data.loc[value, column] / total
height = grouped_data.loc[value, column]
ax.text(i + j * bar_width, height + 0.2, f'{percent:.0%}', ha='center', va='bottom')
# Ustawienia etykiet i tytułów
plt.xlabel('Gender')
plt.ylabel('Count')
plt.title('Distribution of level by gender')
# Dodanie legendy
plt.legend(title='Level')
# Zmiana etykiet osi x
ax.set_xticks([pos + bar_width / 2 for pos in bar_positions_man])
ax.set_xticklabels(categories)
# Wyświetlenie wykresu
plt.show()
dane3 = [dane.groupby('Smoking').size()]
dane3
[Smoking 1 181 2 222 3 172 4 59 5 10 6 60 7 207 8 89 dtype: int64]
dane['Gender'] = dane['Gender'].replace({1: 'Man', 2: 'Woman'})
smoking_counts = dane.groupby(['Smoking', 'Gender']).size()
# Zamiana liczby na procent
smoking_percentages = smoking_counts / smoking_counts.groupby('Gender').sum() * 100
# Sortowanie danych według stopnia 'Smoking'
smoking_percentages_sorted = smoking_percentages.sort_index(level='Smoking', sort_remaining=False)
plt.figure(figsize=(14, 6))
# Tworzenie wykresu słupkowego poziomego
ax = smoking_percentages_sorted.plot(kind='barh')
# Dodawanie wartości procentowych do słupków
for i, value in enumerate(smoking_percentages_sorted):
ax.text(value + 0.1, i, f'{value:.2f}%', ha='left', va='center')
# Dodanie legendy
plt.legend()
# Ustawienia etykiet i tytułów
plt.xlabel('Percentage')
plt.ylabel('Smoking, Gender')
plt.title('Distribution of smoking by gender (%)')
plt.show()
No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
dane3 = [dane.groupby('Passive Smoker').size()]
dane3
[Passive Smoker 1 60 2 284 3 140 4 161 5 30 6 30 7 187 8 108 dtype: int64]
dane['Gender'] = dane['Gender'].replace({1: 'Man', 2: 'Woman'})
smoking_counts = dane.groupby(['Passive Smoker', 'Gender']).size()
# Zamiana liczby na procent
smoking_percentages = smoking_counts / smoking_counts.groupby('Gender').sum() * 100
# Sortowanie danych według stopnia 'Passive smoker'
smoking_percentages_sorted = smoking_percentages.sort_index(level='Passive Smoker', sort_remaining=False)
plt.figure(figsize=(15, 6))
# Tworzenie wykresu słupkowego poziomego
ax = smoking_percentages_sorted.plot(kind='barh')
# Dodawanie wartości procentowych do słupków
for i, value in enumerate(smoking_percentages_sorted):
ax.text(value + 0.1, i, f'{value:.2f}%', ha='left', va='center')
# Dodanie legendy
plt.legend()
# Ustawienia etykiet i tytułów
plt.xlabel('Percentage')
plt.ylabel('Passive Smoker, Gender')
plt.title('Distribution of passive smokers by gender (%)')
plt.show()
No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
dane['Gender'] = dane['Gender'].replace({1: 'Man', 2: 'Woman'})
Genetic_risk_counts = dane.groupby(['Genetic Risk', 'Gender']).size()
Genetic_risk_percentages= Genetic_risk_counts / Genetic_risk_counts.groupby('Gender').sum() * 100
Genetic_risk_percentages_sorted = Genetic_risk_percentages.sort_index(level='Genetic Risk', sort_remaining=False)
plt.figure(figsize=(15, 6))
# Tworzenie wykresu słupkowego poziomego
ax = Genetic_risk_percentages_sorted.plot(kind='barh')
# Dodawanie wartości procentowych do słupków
for i, value in enumerate(Genetic_risk_percentages_sorted):
ax.text(value + 0.1, i, f'{value:.2f}%', ha='left', va='center')
# Dodanie legendy
plt.legend()
# Ustawienia etykiet i tytułów
plt.xlabel('Percentage')
plt.ylabel('Genetic risk, Gender')
plt.title('Distribution of genetic risk by gender (%)')
# Wyświetlenie wykresu
plt.show()
No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
Genetic_risk_counts = dane.groupby(['Genetic Risk', 'Level']).size()
# Sortowanie danych według ryzyka genetycznego i liczby w odwrotnej kolejności
Genetic_risk_counts_sorted = Genetic_risk_counts.sort_index(level=['Genetic Risk', 'Level'], key=lambda x: x.map({'High': 1, 'Medium': 2, 'Low': 3}))
plt.figure(figsize=(10, 6))
# Tworzenie wykresu słupkowego horyzontalnego
ax = Genetic_risk_counts_sorted.plot(kind='barh')
# Dodawanie wartości do słupków
for i, value in enumerate(Genetic_risk_counts_sorted):
ax.text(value + 0.1, i, str(value), ha='left', va='center')
# Ustawienia etykiet i tytułów
plt.xlabel('Count')
plt.ylabel('Genetic Risk, Level')
plt.title('Distribution of level by genetic risk')
# Wyświetlenie wykresu
plt.show()
x = dane[(dane['Smoking'] >= 7) & (dane['Alcohol use'] >= 7)]
x_sorted = x[['Smoking', 'Alcohol use', 'Level']].sort_values(by=['Smoking', 'Alcohol use'], ascending=False)
x_sorted
Smoking | Alcohol use | Level | |
---|---|---|---|
index | |||
4 | 8 | 8 | High |
20 | 8 | 8 | High |
22 | 8 | 8 | High |
46 | 8 | 8 | High |
68 | 8 | 8 | High |
... | ... | ... | ... |
989 | 7 | 7 | High |
992 | 7 | 7 | High |
993 | 7 | 7 | High |
994 | 7 | 7 | High |
995 | 7 | 7 | High |
256 rows × 3 columns
dane7 = dane['Air Pollution'].value_counts()
dane7.sort_values()
Air Pollution 8 19 5 20 7 30 4 90 1 141 3 173 2 201 6 326 Name: count, dtype: int64
air_pollution = dane.groupby(['Air Pollution', 'Level']).size()
# Sortowanie danych według ryzyka genetycznego i liczby w odwrotnej kolejności
air_pollution_sorted = air_pollution.sort_index(level=['Air Pollution', 'Level'], key=lambda x: x.map({'High': 1, 'Medium': 2, 'Low': 3}))
plt.figure(figsize=(10, 6))
# Tworzenie wykresu słupkowego horyzontalnego
ax = air_pollution_sorted.plot(kind='barh')
# Dodawanie wartości do słupków
for i, value in enumerate(air_pollution_sorted):
ax.text(value + 0.1, i, str(value), ha='left', va='center')
# Ustawienia etykiet i tytułów
plt.xlabel('Count')
plt.ylabel('Air Pollution, Level')
plt.title('Distribution of level by air pollution')
# Wyświetlenie wykresu
plt.show()
data = dane.replace({'Level':{'High' : 3, 'Medium' : 2, 'Low' : 1}})
data = data.drop(['Patient Id', 'Gender', 'Age'], axis=1)
corr_matrix = data.corr()
corr_df = corr_matrix.stack().reset_index()
corr_df.columns = ['x', 'y', 'value']
fig = px.imshow(corr_matrix, x=corr_matrix.columns.tolist(), y=corr_matrix.columns.tolist(), zmin=-1, zmax=1, color_continuous_scale=['blue', 'white', 'red'])
fig.update_layout(
width=800,
height=800,
coloraxis_colorbar=dict(
title='Correlation',
tickvals=[-1, -0.5, 0, 0.5, 1],
ticktext=['-1', '-0.5', '0', '0.5', '1'],
),
coloraxis=dict(
colorscale=[
[0, 'rgb(58, 89, 156)'], # Niebieski (cool)
[0.5, 'rgb(255, 255, 255)'], # Biały
[1, 'rgb(179, 35, 26)'] # Czerwony (warm)
],
),
font=dict(
family="Arial",
size=12,
color="black",
))
# Wyświetlenie wykresu
fig.show()
import numpy as np
import sklearn
data = dane.replace({'Level':{'High' : 3, 'Medium' : 2, 'Low' : 1}})
data['Gender'] = data['Gender'].replace({'Man' : 1, 'Woman' : 2})
data.head()
Patient Id | Age | Gender | Air Pollution | Alcohol use | Dust Allergy | OccuPational Hazards | Genetic Risk | chronic Lung Disease | Balanced Diet | ... | Fatigue | Weight Loss | Shortness of Breath | Wheezing | Swallowing Difficulty | Clubbing of Finger Nails | Frequent Cold | Dry Cough | Snoring | Level | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
index | |||||||||||||||||||||
0 | P1 | 33 | 1 | 2 | 4 | 5 | 4 | 3 | 2 | 2 | ... | 3 | 4 | 2 | 2 | 3 | 1 | 2 | 3 | 4 | 1 |
1 | P10 | 17 | 1 | 3 | 1 | 5 | 3 | 4 | 2 | 2 | ... | 1 | 3 | 7 | 8 | 6 | 2 | 1 | 7 | 2 | 2 |
2 | P100 | 35 | 1 | 4 | 5 | 6 | 5 | 5 | 4 | 6 | ... | 8 | 7 | 9 | 2 | 1 | 4 | 6 | 7 | 2 | 3 |
3 | P1000 | 37 | 1 | 7 | 7 | 7 | 7 | 6 | 7 | 7 | ... | 4 | 2 | 3 | 1 | 4 | 5 | 6 | 7 | 5 | 3 |
4 | P101 | 46 | 1 | 6 | 8 | 7 | 7 | 7 | 6 | 7 | ... | 3 | 2 | 4 | 1 | 4 | 2 | 4 | 2 | 3 | 3 |
5 rows × 25 columns
np.random.seed(10)
np.set_printoptions(precision=6, suppress=True)
X = data.drop(['Level', 'Patient Id'], axis=1)
y = data['Level']
print("Y shape:", y.shape)
print("X shape:", X.shape)
Y shape: (1000,) X shape: (1000, 23)
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split (X, y)
print("X_train shape:", X_train.shape)
print("y_train shape:", y_train.shape)
print("X_test shape:", X_test.shape)
print("y_test shape:", y_test.shape)
X_train shape: (750, 23) y_train shape: (750,) X_test shape: (250, 23) y_test shape: (250,)
from sklearn.linear_model import LogisticRegression
#classifier = LogisticRegression(multi_class='multinomial', solver='lbfgs')
classifier = LogisticRegression(max_iter = 200)
classifier.fit(X_train, y_train)
c:\Users\HP\anaconda3\lib\site-packages\sklearn\linear_model\_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1): STOP: TOTAL NO. of ITERATIONS REACHED LIMIT. Increase the number of iterations (max_iter) or scale the data as shown in: https://scikit-learn.org/stable/modules/preprocessing.html Please also refer to the documentation for alternative solver options: https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
LogisticRegression(max_iter=200)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression(max_iter=200)
y_prob = classifier.predict_proba(X_test)
y_prob
array([[0.01354 , 0.98631 , 0.00015 ], [0. , 0.000005, 0.999995], [0.000005, 0.00068 , 0.999315], [0.999999, 0.000001, 0. ], [0.000003, 0.000947, 0.99905 ], [0.001809, 0.982182, 0.016009], [0. , 0.012048, 0.987952], [0.059547, 0.831086, 0.109367], [0. , 0.075388, 0.924612], [0.039136, 0.960851, 0.000013], [0.000022, 0.003654, 0.996323], [0.006091, 0.993801, 0.000108], [0.017895, 0.982104, 0.000001], [0.998827, 0.001173, 0. ], [0.928696, 0.070643, 0.000662], [0.033735, 0.966263, 0.000003], [0. , 0.000113, 0.999887], [0.934007, 0.009674, 0.056319], [0. , 0.000031, 0.999969], [0.999648, 0.000306, 0.000047], [0.000004, 0.00404 , 0.995956], [0.999251, 0.000216, 0.000533], [0. , 0.00512 , 0.99488 ], [0. , 0.004695, 0.995305], [0. , 0.000124, 0.999876], [0.1259 , 0.859667, 0.014433], [0.059547, 0.831086, 0.109367], [0.991172, 0.008807, 0.000021], [0.999909, 0.00009 , 0.000001], [0.1259 , 0.859667, 0.014433], [0.976724, 0.023247, 0.000028], [0.008947, 0.990345, 0.000708], [0. , 0.004695, 0.995305], [0.999251, 0.000216, 0.000533], [0.005118, 0.994787, 0.000095], [0. , 0.00008 , 0.99992 ], [0.015368, 0.98458 , 0.000052], [0. , 0.000003, 0.999997], [0.00583 , 0.994147, 0.000023], [0.993283, 0.00006 , 0.006656], [0.1259 , 0.859667, 0.014433], [0.004655, 0.989938, 0.005407], [0. , 0.000134, 0.999866], [0.000007, 0.999991, 0.000002], [0. , 0.012048, 0.987952], [0.000022, 0.003654, 0.996323], [0.931524, 0.068283, 0.000192], [0. , 0.000031, 0.999969], [0. , 0.000165, 0.999835], [0.009961, 0.988549, 0.00149 ], [0.006484, 0.974177, 0.019339], [0.00523 , 0.994198, 0.000572], [0.021549, 0.975853, 0.002598], [0. , 0.000012, 0.999988], [0.009961, 0.988549, 0.00149 ], [0.000001, 0.000025, 0.999974], [0. , 0.016222, 0.983778], [0.00003 , 0.010307, 0.989664], [0.993286, 0.006713, 0.000001], [0.000175, 0.040597, 0.959228], [0.000004, 0.00404 , 0.995956], [0.984367, 0.012436, 0.003197], [0.000005, 0.00068 , 0.999315], [0.967662, 0.024355, 0.007983], [0. , 0.000134, 0.999866], [0.976724, 0.023247, 0.000028], [0.007089, 0.99289 , 0.000021], [0. , 0.000165, 0.999835], [0. , 0.00512 , 0.99488 ], [0.016788, 0.983198, 0.000014], [0.988196, 0.002318, 0.009486], [0.999947, 0.000021, 0.000032], [0.990009, 0.009991, 0. ], [0.016788, 0.983198, 0.000014], [0.015399, 0.984596, 0.000005], [0.022511, 0.975873, 0.001616], [0.008947, 0.990345, 0.000708], [0.007156, 0.985278, 0.007566], [0. , 0.00068 , 0.99932 ], [0.991172, 0.008807, 0.000021], [0. , 0.000106, 0.999894], [0.993286, 0.006713, 0.000001], [0. , 0.004695, 0.995305], [0.006484, 0.974177, 0.019339], [0. , 0.00001 , 0.99999 ], [0. , 0.03253 , 0.967469], [0. , 0.000014, 0.999986], [0.000011, 0.999989, 0. ], [0.991172, 0.008807, 0.000021], [0. , 0.00068 , 0.99932 ], [0.013247, 0.986728, 0.000025], [0.003605, 0.971443, 0.024952], [0.999999, 0.000001, 0. ], [0.931524, 0.068283, 0.000192], [0.00583 , 0.994147, 0.000023], [0.00583 , 0.994147, 0.000023], [0.1259 , 0.859667, 0.014433], [0.022062, 0.977641, 0.000297], [0.016788, 0.983198, 0.000014], [0.957888, 0.04211 , 0.000001], [0. , 0.016222, 0.983778], [0.000003, 0.999994, 0.000003], [0.9999 , 0.0001 , 0. ], [0. , 0.012048, 0.987952], [0.999909, 0.00009 , 0.000001], [0.928696, 0.070643, 0.000662], [0.931524, 0.068283, 0.000192], [0.000003, 0.99999 , 0.000007], [0. , 0.004695, 0.995305], [0. , 0.000014, 0.999986], [0.993283, 0.00006 , 0.006656], [0.000003, 0.000947, 0.99905 ], [0.008109, 0.979381, 0.01251 ], [0.934007, 0.009674, 0.056319], [0.999648, 0.000306, 0.000047], [0.000022, 0.003654, 0.996323], [0.013247, 0.986728, 0.000025], [0.000011, 0.999989, 0. ], [0.993286, 0.006713, 0.000001], [0.9999 , 0.0001 , 0. ], [0. , 0.000113, 0.999887], [0. , 0.000063, 0.999937], [0. , 0.00512 , 0.99488 ], [0.011601, 0.988358, 0.000041], [0.999909, 0.00009 , 0.000001], [0.000005, 0.00068 , 0.999315], [0.000175, 0.040597, 0.959228], [0.990009, 0.009991, 0. ], [0.978169, 0.021712, 0.000119], [0.015368, 0.98458 , 0.000052], [0.022062, 0.977641, 0.000297], [0. , 0.021094, 0.978906], [0.973999, 0.013972, 0.012029], [0. , 0.004695, 0.995305], [0.996352, 0.003647, 0.000001], [0.967662, 0.024355, 0.007983], [0.985581, 0.000103, 0.014316], [0. , 0.00512 , 0.99488 ], [0.984367, 0.012436, 0.003197], [0. , 0.000468, 0.999532], [0.008109, 0.979381, 0.01251 ], [0.991172, 0.008807, 0.000021], [0.013247, 0.986728, 0.000025], [0.013247, 0.986728, 0.000025], [0.999999, 0.000001, 0. ], [0.985581, 0.000103, 0.014316], [0.007156, 0.985278, 0.007566], [0.000022, 0.003654, 0.996323], [0.999947, 0.000021, 0.000032], [0.973999, 0.013972, 0.012029], [0.000022, 0.003654, 0.996323], [0.993286, 0.006713, 0.000001], [0.1259 , 0.859667, 0.014433], [0.008803, 0.991195, 0.000002], [0.008947, 0.990345, 0.000708], [0. , 0.002567, 0.997433], [0.999999, 0.000001, 0. ], [0. , 0.004695, 0.995305], [0. , 0.000013, 0.999987], [0.98178 , 0.01822 , 0. ], [0.891574, 0.108425, 0. ], [0.006622, 0.993257, 0.000122], [0.972232, 0.027758, 0.00001 ], [0.008109, 0.979381, 0.01251 ], [0.928696, 0.070643, 0.000662], [0. , 0.000196, 0.999804], [0.000003, 0.999994, 0.000003], [0.991172, 0.008807, 0.000021], [0.021549, 0.975853, 0.002598], [0.999947, 0.000021, 0.000032], [0. , 0.000134, 0.999866], [0.013247, 0.986728, 0.000025], [0. , 0.00512 , 0.99488 ], [0.991172, 0.008807, 0.000021], [0. , 0.00512 , 0.99488 ], [0.153033, 0.8185 , 0.028467], [0.008947, 0.990345, 0.000708], [0.009961, 0.988549, 0.00149 ], [0.015399, 0.984596, 0.000005], [0.000011, 0.999989, 0. ], [0.022062, 0.977641, 0.000297], [0.999999, 0.000001, 0. ], [0.967662, 0.024355, 0.007983], [0.006622, 0.993257, 0.000122], [0.015399, 0.984596, 0.000005], [0. , 0.000196, 0.999804], [0. , 0.016222, 0.983778], [0. , 0.000124, 0.999876], [0. , 0.001328, 0.998672], [0.000011, 0.999989, 0. ], [0.022511, 0.975873, 0.001616], [0.999251, 0.000216, 0.000533], [0.999999, 0.000001, 0. ], [0. , 0.03253 , 0.967469], [0. , 0.000063, 0.999937], [0.993283, 0.00006 , 0.006656], [0.007089, 0.99289 , 0.000021], [0.009961, 0.988549, 0.00149 ], [0.931524, 0.068283, 0.000192], [0.007156, 0.985278, 0.007566], [0.006622, 0.993257, 0.000122], [0.021549, 0.975853, 0.002598], [0. , 0.000003, 0.999997], [0.009393, 0.990606, 0.000001], [0. , 0.001767, 0.998233], [0.000175, 0.040597, 0.959228], [0.931524, 0.068283, 0.000192], [0.01354 , 0.98631 , 0.00015 ], [0.000011, 0.999989, 0. ], [0.07487 , 0.902033, 0.023097], [0. , 0.000113, 0.999887], [0.999999, 0.000001, 0. ], [0. , 0.000031, 0.999969], [0. , 0.0011 , 0.9989 ], [0.891574, 0.108425, 0. ], [0.957888, 0.04211 , 0.000001], [0.004898, 0.993937, 0.001166], [0. , 0. , 1. ], [0.891574, 0.108425, 0. ], [0.990009, 0.009991, 0. ], [0.976724, 0.023247, 0.000028], [0.000007, 0.999991, 0.000002], [0.891574, 0.108425, 0. ], [0. , 0.000014, 0.999986], [0.015399, 0.984596, 0.000005], [0.003605, 0.971443, 0.024952], [0.1259 , 0.859667, 0.014433], [0.991172, 0.008807, 0.000021], [0. , 0.00001 , 0.99999 ], [0.004655, 0.989938, 0.005407], [0. , 0.000063, 0.999937], [0.999999, 0.000001, 0. ], [0.007156, 0.985278, 0.007566], [0.000007, 0.999991, 0.000002], [0. , 0.001638, 0.998362], [0.006622, 0.993257, 0.000122], [0. , 0.000468, 0.999532], [0.978169, 0.021712, 0.000119], [0.891574, 0.108425, 0. ], [0. , 0.000196, 0.999804], [0.934007, 0.009674, 0.056319], [0.008109, 0.979381, 0.01251 ], [0.000004, 0.00404 , 0.995956], [0.022511, 0.975873, 0.001616], [0. , 0.03253 , 0.967469], [0.00007 , 0.024082, 0.975848], [0. , 0.000031, 0.999969], [0. , 0.012048, 0.987952], [0.999947, 0.000021, 0.000032], [0.004898, 0.993937, 0.001166]])
y_pred = classifier.predict(X_test)
y_pred
array([2, 3, 3, 1, 3, 2, 3, 2, 3, 2, 3, 2, 2, 1, 1, 2, 3, 1, 3, 1, 3, 1, 3, 3, 3, 2, 2, 1, 1, 2, 1, 2, 3, 1, 2, 3, 2, 3, 2, 1, 2, 2, 3, 2, 3, 3, 1, 3, 3, 2, 2, 2, 2, 3, 2, 3, 3, 3, 1, 3, 3, 1, 3, 1, 3, 1, 2, 3, 3, 2, 1, 1, 1, 2, 2, 2, 2, 2, 3, 1, 3, 1, 3, 2, 3, 3, 3, 2, 1, 3, 2, 2, 1, 1, 2, 2, 2, 2, 2, 1, 3, 2, 1, 3, 1, 1, 1, 2, 3, 3, 1, 3, 2, 1, 1, 3, 2, 2, 1, 1, 3, 3, 3, 2, 1, 3, 3, 1, 1, 2, 2, 3, 1, 3, 1, 1, 1, 3, 1, 3, 2, 1, 2, 2, 1, 1, 2, 3, 1, 1, 3, 1, 2, 2, 2, 3, 1, 3, 3, 1, 1, 2, 1, 2, 1, 3, 2, 1, 2, 1, 3, 2, 3, 1, 3, 2, 2, 2, 2, 2, 2, 1, 1, 2, 2, 3, 3, 3, 3, 2, 2, 1, 1, 3, 3, 1, 2, 2, 1, 2, 2, 2, 3, 2, 3, 3, 1, 2, 2, 2, 3, 1, 3, 3, 1, 1, 2, 3, 1, 1, 1, 2, 1, 3, 2, 2, 2, 1, 3, 2, 3, 1, 2, 2, 3, 2, 3, 1, 1, 3, 1, 2, 3, 2, 3, 3, 3, 3, 1, 2], dtype=int64)
pip install mlxtend
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from mlxtend.plotting import plot_confusion_matrix
import seaborn as sns
sns.set()
#cm = confusion_matrix(y_test, y_pred)
#plot_confusion_matrix(cm)
#acc = accuracy_score(y_test, y_pred)
#print('Accuracy',':', acc)
acc = accuracy_score(y_test, y_pred)
cm = confusion_matrix(y_test, y_pred)
def plot_confusion_matrix(cm):
cm = cm[::-1]
cm = pd.DataFrame(cm, columns=['pred_1', 'pred_2', 'pred_3'], index=['true_3', 'true_2', 'true_1'])
fig = ff.create_annotated_heatmap(z = cm.values, x = list(cm.columns), y = list(cm.index), colorscale = 'ice', showscale = True, reversescale = True)
fig.update_layout(width=500, height=500, title=f'Confusion Matrix - Accuracy: {acc:.4f}'. format(acc), font_size=16)
fig.show()
plot_confusion_matrix(cm)
from sklearn.metrics import classification_report
print(classification_report(y_test, y_pred, target_names=['pred_1', 'pred_2', 'pred_3']))
precision recall f1-score support pred_1 1.00 1.00 1.00 76 pred_2 1.00 1.00 1.00 89 pred_3 1.00 1.00 1.00 85 accuracy 1.00 250 macro avg 1.00 1.00 1.00 250 weighted avg 1.00 1.00 1.00 250