241 KiB
241 KiB
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import plotly.figure_factory as ff
import seaborn as sns
sns.set()
dane = pd.read_csv(r'C:\Users\HP\Desktop\podyplomówka\cancer_patient_data_sets.csv', index_col = 0)
dane.head()
Patient Id | Age | Gender | Air Pollution | Alcohol use | Dust Allergy | OccuPational Hazards | Genetic Risk | chronic Lung Disease | Balanced Diet | ... | Fatigue | Weight Loss | Shortness of Breath | Wheezing | Swallowing Difficulty | Clubbing of Finger Nails | Frequent Cold | Dry Cough | Snoring | Level | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
index | |||||||||||||||||||||
0 | P1 | 33 | 1 | 2 | 4 | 5 | 4 | 3 | 2 | 2 | ... | 3 | 4 | 2 | 2 | 3 | 1 | 2 | 3 | 4 | Low |
1 | P10 | 17 | 1 | 3 | 1 | 5 | 3 | 4 | 2 | 2 | ... | 1 | 3 | 7 | 8 | 6 | 2 | 1 | 7 | 2 | Medium |
2 | P100 | 35 | 1 | 4 | 5 | 6 | 5 | 5 | 4 | 6 | ... | 8 | 7 | 9 | 2 | 1 | 4 | 6 | 7 | 2 | High |
3 | P1000 | 37 | 1 | 7 | 7 | 7 | 7 | 6 | 7 | 7 | ... | 4 | 2 | 3 | 1 | 4 | 5 | 6 | 7 | 5 | High |
4 | P101 | 46 | 1 | 6 | 8 | 7 | 7 | 7 | 6 | 7 | ... | 3 | 2 | 4 | 1 | 4 | 2 | 4 | 2 | 3 | High |
5 rows × 25 columns
dane.info()
<class 'pandas.core.frame.DataFrame'> Index: 1000 entries, 0 to 999 Data columns (total 25 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 Patient Id 1000 non-null object 1 Age 1000 non-null int64 2 Gender 1000 non-null int64 3 Air Pollution 1000 non-null int64 4 Alcohol use 1000 non-null int64 5 Dust Allergy 1000 non-null int64 6 OccuPational Hazards 1000 non-null int64 7 Genetic Risk 1000 non-null int64 8 chronic Lung Disease 1000 non-null int64 9 Balanced Diet 1000 non-null int64 10 Obesity 1000 non-null int64 11 Smoking 1000 non-null int64 12 Passive Smoker 1000 non-null int64 13 Chest Pain 1000 non-null int64 14 Coughing of Blood 1000 non-null int64 15 Fatigue 1000 non-null int64 16 Weight Loss 1000 non-null int64 17 Shortness of Breath 1000 non-null int64 18 Wheezing 1000 non-null int64 19 Swallowing Difficulty 1000 non-null int64 20 Clubbing of Finger Nails 1000 non-null int64 21 Frequent Cold 1000 non-null int64 22 Dry Cough 1000 non-null int64 23 Snoring 1000 non-null int64 24 Level 1000 non-null object dtypes: int64(23), object(2) memory usage: 203.1+ KB
dane0 = dane[dane['Level'] == 'High'][['Air Pollution', 'Smoking', 'Passive Smoker']]
dane0
Air Pollution | Smoking | Passive Smoker | |
---|---|---|---|
index | |||
2 | 4 | 2 | 3 |
3 | 7 | 7 | 7 |
4 | 6 | 8 | 7 |
5 | 4 | 2 | 3 |
10 | 6 | 7 | 8 |
... | ... | ... | ... |
995 | 6 | 7 | 8 |
996 | 6 | 7 | 8 |
997 | 4 | 2 | 3 |
998 | 6 | 8 | 7 |
999 | 6 | 2 | 3 |
365 rows × 3 columns
dane.describe().T
count | mean | std | min | 25% | 50% | 75% | max | |
---|---|---|---|---|---|---|---|---|
Age | 1000.0 | 37.174 | 12.005493 | 14.0 | 27.75 | 36.0 | 45.0 | 73.0 |
Gender | 1000.0 | 1.402 | 0.490547 | 1.0 | 1.00 | 1.0 | 2.0 | 2.0 |
Air Pollution | 1000.0 | 3.840 | 2.030400 | 1.0 | 2.00 | 3.0 | 6.0 | 8.0 |
Alcohol use | 1000.0 | 4.563 | 2.620477 | 1.0 | 2.00 | 5.0 | 7.0 | 8.0 |
Dust Allergy | 1000.0 | 5.165 | 1.980833 | 1.0 | 4.00 | 6.0 | 7.0 | 8.0 |
OccuPational Hazards | 1000.0 | 4.840 | 2.107805 | 1.0 | 3.00 | 5.0 | 7.0 | 8.0 |
Genetic Risk | 1000.0 | 4.580 | 2.126999 | 1.0 | 2.00 | 5.0 | 7.0 | 7.0 |
chronic Lung Disease | 1000.0 | 4.380 | 1.848518 | 1.0 | 3.00 | 4.0 | 6.0 | 7.0 |
Balanced Diet | 1000.0 | 4.491 | 2.135528 | 1.0 | 2.00 | 4.0 | 7.0 | 7.0 |
Obesity | 1000.0 | 4.465 | 2.124921 | 1.0 | 3.00 | 4.0 | 7.0 | 7.0 |
Smoking | 1000.0 | 3.948 | 2.495902 | 1.0 | 2.00 | 3.0 | 7.0 | 8.0 |
Passive Smoker | 1000.0 | 4.195 | 2.311778 | 1.0 | 2.00 | 4.0 | 7.0 | 8.0 |
Chest Pain | 1000.0 | 4.438 | 2.280209 | 1.0 | 2.00 | 4.0 | 7.0 | 9.0 |
Coughing of Blood | 1000.0 | 4.859 | 2.427965 | 1.0 | 3.00 | 4.0 | 7.0 | 9.0 |
Fatigue | 1000.0 | 3.856 | 2.244616 | 1.0 | 2.00 | 3.0 | 5.0 | 9.0 |
Weight Loss | 1000.0 | 3.855 | 2.206546 | 1.0 | 2.00 | 3.0 | 6.0 | 8.0 |
Shortness of Breath | 1000.0 | 4.240 | 2.285087 | 1.0 | 2.00 | 4.0 | 6.0 | 9.0 |
Wheezing | 1000.0 | 3.777 | 2.041921 | 1.0 | 2.00 | 4.0 | 5.0 | 8.0 |
Swallowing Difficulty | 1000.0 | 3.746 | 2.270383 | 1.0 | 2.00 | 4.0 | 5.0 | 8.0 |
Clubbing of Finger Nails | 1000.0 | 3.923 | 2.388048 | 1.0 | 2.00 | 4.0 | 5.0 | 9.0 |
Frequent Cold | 1000.0 | 3.536 | 1.832502 | 1.0 | 2.00 | 3.0 | 5.0 | 7.0 |
Dry Cough | 1000.0 | 3.853 | 2.039007 | 1.0 | 2.00 | 4.0 | 6.0 | 7.0 |
Snoring | 1000.0 | 2.926 | 1.474686 | 1.0 | 2.00 | 3.0 | 4.0 | 7.0 |
dane[['Genetic Risk']].median()
Genetic Risk 5.0 dtype: float64
dane.columns
Index(['Patient Id', 'Age', 'Gender', 'Air Pollution', 'Alcohol use', 'Dust Allergy', 'OccuPational Hazards', 'Genetic Risk', 'chronic Lung Disease', 'Balanced Diet', 'Obesity', 'Smoking', 'Passive Smoker', 'Chest Pain', 'Coughing of Blood', 'Fatigue', 'Weight Loss', 'Shortness of Breath', 'Wheezing', 'Swallowing Difficulty', 'Clubbing of Finger Nails', 'Frequent Cold', 'Dry Cough', 'Snoring', 'Level'], dtype='object')
dane2 = dane.groupby('Gender').size()
dane2
Gender 1 598 2 402 dtype: int64
_ = dane['Gender'].value_counts().plot(kind = 'bar')
_ = plt.legend()
dane3 = [dane.groupby('Smoking').size()]
dane3
[Smoking 1 181 2 222 3 172 4 59 5 10 6 60 7 207 8 89 dtype: int64]
_ = dane['Smoking'].value_counts().plot(kind = 'pie')
dane4 = [dane.groupby('Passive Smoker').size()]
dane4
[Passive Smoker 1 60 2 284 3 140 4 161 5 30 6 30 7 187 8 108 dtype: int64]
_ = dane['Passive Smoker'].value_counts().plot(kind = 'pie')
dane.groupby(['Smoking','Gender']).size()
Smoking Gender 1 1 102 2 79 2 1 102 2 120 3 1 79 2 93 4 1 49 2 10 5 1 10 6 1 28 2 32 7 1 167 2 40 8 1 61 2 28 dtype: int64
dane6 = dane.groupby(['Smoking','Gender'])
_ = dane6[['Smoking', 'Gender']].value_counts().plot(kind = 'bar')
_ = plt.legend()
dane[['Smoking']].median()
Smoking 3.0 dtype: float64
x = dane[['Genetic Risk', 'Smoking','Alcohol use']]
x.sort_values('Genetic Risk')
Genetic Risk | Smoking | Alcohol use | |
---|---|---|---|
index | |||
725 | 1 | 4 | 3 |
59 | 1 | 4 | 3 |
727 | 1 | 3 | 3 |
940 | 1 | 4 | 2 |
819 | 1 | 1 | 2 |
... | ... | ... | ... |
537 | 7 | 8 | 8 |
538 | 7 | 7 | 7 |
755 | 7 | 4 | 7 |
533 | 7 | 4 | 7 |
812 | 7 | 7 | 7 |
1000 rows × 3 columns
dane7 = dane['Air Pollution'].value_counts()
dane7.sort_values()
Air Pollution 8 19 5 20 7 30 4 90 1 141 3 173 2 201 6 326 Name: count, dtype: int64
_ = dane7 = dane['Air Pollution'].value_counts().plot(kind = 'bar')
_ = plt.legend()
import numpy as np
import sklearn
_ = dane8 = dane['Genetic Risk'].value_counts()
dane8.sort_values()
Genetic Risk 4 40 1 40 5 100 6 108 3 173 2 212 7 327 Name: count, dtype: int64
dane
Patient Id | Age | Gender | Air Pollution | Alcohol use | Dust Allergy | OccuPational Hazards | Genetic Risk | chronic Lung Disease | Balanced Diet | ... | Fatigue | Weight Loss | Shortness of Breath | Wheezing | Swallowing Difficulty | Clubbing of Finger Nails | Frequent Cold | Dry Cough | Snoring | Level | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
index | |||||||||||||||||||||
0 | P1 | 33 | 1 | 2 | 4 | 5 | 4 | 3 | 2 | 2 | ... | 3 | 4 | 2 | 2 | 3 | 1 | 2 | 3 | 4 | Low |
1 | P10 | 17 | 1 | 3 | 1 | 5 | 3 | 4 | 2 | 2 | ... | 1 | 3 | 7 | 8 | 6 | 2 | 1 | 7 | 2 | Medium |
2 | P100 | 35 | 1 | 4 | 5 | 6 | 5 | 5 | 4 | 6 | ... | 8 | 7 | 9 | 2 | 1 | 4 | 6 | 7 | 2 | High |
3 | P1000 | 37 | 1 | 7 | 7 | 7 | 7 | 6 | 7 | 7 | ... | 4 | 2 | 3 | 1 | 4 | 5 | 6 | 7 | 5 | High |
4 | P101 | 46 | 1 | 6 | 8 | 7 | 7 | 7 | 6 | 7 | ... | 3 | 2 | 4 | 1 | 4 | 2 | 4 | 2 | 3 | High |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
995 | P995 | 44 | 1 | 6 | 7 | 7 | 7 | 7 | 6 | 7 | ... | 5 | 3 | 2 | 7 | 8 | 2 | 4 | 5 | 3 | High |
996 | P996 | 37 | 2 | 6 | 8 | 7 | 7 | 7 | 6 | 7 | ... | 9 | 6 | 5 | 7 | 2 | 4 | 3 | 1 | 4 | High |
997 | P997 | 25 | 2 | 4 | 5 | 6 | 5 | 5 | 4 | 6 | ... | 8 | 7 | 9 | 2 | 1 | 4 | 6 | 7 | 2 | High |
998 | P998 | 18 | 2 | 6 | 8 | 7 | 7 | 7 | 6 | 7 | ... | 3 | 2 | 4 | 1 | 4 | 2 | 4 | 2 | 3 | High |
999 | P999 | 47 | 1 | 6 | 5 | 6 | 5 | 5 | 4 | 6 | ... | 8 | 7 | 9 | 2 | 1 | 4 | 6 | 7 | 2 | High |
1000 rows × 25 columns
dane['Level'].value_counts()
Level High 365 Medium 332 Low 303 Name: count, dtype: int64
data = dane.replace({'Level':{'High' : 3, 'Medium' : 2, 'Low' : 1}})
data.head()
Patient Id | Age | Gender | Air Pollution | Alcohol use | Dust Allergy | OccuPational Hazards | Genetic Risk | chronic Lung Disease | Balanced Diet | ... | Fatigue | Weight Loss | Shortness of Breath | Wheezing | Swallowing Difficulty | Clubbing of Finger Nails | Frequent Cold | Dry Cough | Snoring | Level | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
index | |||||||||||||||||||||
0 | P1 | 33 | 1 | 2 | 4 | 5 | 4 | 3 | 2 | 2 | ... | 3 | 4 | 2 | 2 | 3 | 1 | 2 | 3 | 4 | 1 |
1 | P10 | 17 | 1 | 3 | 1 | 5 | 3 | 4 | 2 | 2 | ... | 1 | 3 | 7 | 8 | 6 | 2 | 1 | 7 | 2 | 2 |
2 | P100 | 35 | 1 | 4 | 5 | 6 | 5 | 5 | 4 | 6 | ... | 8 | 7 | 9 | 2 | 1 | 4 | 6 | 7 | 2 | 3 |
3 | P1000 | 37 | 1 | 7 | 7 | 7 | 7 | 6 | 7 | 7 | ... | 4 | 2 | 3 | 1 | 4 | 5 | 6 | 7 | 5 | 3 |
4 | P101 | 46 | 1 | 6 | 8 | 7 | 7 | 7 | 6 | 7 | ... | 3 | 2 | 4 | 1 | 4 | 2 | 4 | 2 | 3 | 3 |
5 rows × 25 columns
import sklearn
np.random.seed(10)
np.set_printoptions(precision=6, suppress=True)
X = data.drop(['Level', 'Patient Id'], axis=1)
y = data['Level']
print("Y shape:", y.shape)
print("X shape:", X.shape)
Y shape: (1000,) X shape: (1000, 23)
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split (X, y)
print("X_train shape:", X_train.shape)
print("y_train shape:", y_train.shape)
print("X_test shape:", X_test.shape)
print("y_test shape:", y_test.shape)
X_train shape: (750, 23) y_train shape: (750,) X_test shape: (250, 23) y_test shape: (250,)
from sklearn.linear_model import LogisticRegression
classifier = LogisticRegression()
classifier.fit(X_train, y_train)
c:\Users\HP\anaconda3\lib\site-packages\sklearn\linear_model\_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1): STOP: TOTAL NO. of ITERATIONS REACHED LIMIT. Increase the number of iterations (max_iter) or scale the data as shown in: https://scikit-learn.org/stable/modules/preprocessing.html Please also refer to the documentation for alternative solver options: https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
y_prob = classifier.predict_proba(X_test)
y_prob
array([[0.019763, 0.980237, 0. ], [0. , 0. , 1. ], [0. , 0.000002, 0.999998], [0.999979, 0.000021, 0. ], [0. , 0.000038, 0.999962], [0.000022, 0.983401, 0.016577], [0. , 0.023981, 0.976019], [0.025065, 0.943631, 0.031305], [0. , 0.011278, 0.988722], [0.077079, 0.922921, 0. ], [0.000003, 0.000326, 0.999672], [0.000473, 0.999527, 0. ], [0.16753 , 0.83247 , 0. ], [0.995731, 0.004269, 0. ], [0.949387, 0.050613, 0. ], [0.21037 , 0.78963 , 0. ], [0. , 0. , 1. ], [0.91181 , 0.045917, 0.042272], [0. , 0.002178, 0.997822], [0.984437, 0.015558, 0.000005], [0. , 0.002066, 0.997934], [0.99922 , 0.00078 , 0. ], [0. , 0.000298, 0.999702], [0. , 0.004236, 0.995764], [0. , 0.000004, 0.999996], [0.238042, 0.760375, 0.001583], [0.025065, 0.943631, 0.031305], [0.99808 , 0.001917, 0.000003], [0.997777, 0.002223, 0. ], [0.238042, 0.760375, 0.001583], [0.913746, 0.086254, 0. ], [0.002141, 0.997859, 0. ], [0. , 0.004236, 0.995764], [0.99922 , 0.00078 , 0. ], [0.00192 , 0.99808 , 0. ], [0. , 0.012453, 0.987547], [0.00145 , 0.99855 , 0. ], [0. , 0. , 1. ], [0.00248 , 0.99752 , 0. ], [0.998712, 0.000937, 0.000351], [0.238042, 0.760375, 0.001583], [0.001213, 0.97811 , 0.020677], [0. , 0. , 1. ], [0.015839, 0.984161, 0. ], [0. , 0.023981, 0.976019], [0.000003, 0.000326, 0.999672], [0.930462, 0.069418, 0.000121], [0. , 0.002178, 0.997822], [0. , 0.000003, 0.999997], [0.001321, 0.998655, 0.000024], [0.001178, 0.998807, 0.000014], [0.00035 , 0.99965 , 0. ], [0.05124 , 0.948155, 0.000606], [0. , 0. , 1. ], [0.001321, 0.998655, 0.000024], [0. , 0. , 1. ], [0. , 0.004285, 0.995715], [0.000004, 0.02071 , 0.979285], [0.969283, 0.030717, 0. ], [0.000063, 0.040843, 0.959093], [0. , 0.002066, 0.997934], [0.942577, 0.018654, 0.038768], [0. , 0.000002, 0.999998], [0.870813, 0.129142, 0.000044], [0. , 0. , 1. ], [0.913746, 0.086254, 0. ], [0.000637, 0.999363, 0. ], [0. , 0.000003, 0.999997], [0. , 0.000298, 0.999702], [0.029989, 0.970011, 0. ], [0.994782, 0.000846, 0.004371], [0.999889, 0.000075, 0.000035], [0.947557, 0.052443, 0. ], [0.029989, 0.970011, 0. ], [0.129826, 0.870174, 0. ], [0.055579, 0.94417 , 0.000251], [0.002141, 0.997859, 0. ], [0.001398, 0.9986 , 0.000002], [0. , 0.000062, 0.999938], [0.99808 , 0.001917, 0.000003], [0. , 0.000004, 0.999996], [0.969283, 0.030717, 0. ], [0. , 0.004236, 0.995764], [0.001178, 0.998807, 0.000014], [0. , 0. , 1. ], [0. , 0.014023, 0.985977], [0. , 0. , 1. ], [0.033577, 0.966423, 0. ], [0.99808 , 0.001917, 0.000003], [0. , 0.000062, 0.999938], [0.099592, 0.900408, 0. ], [0.000177, 0.999787, 0.000036], [0.997699, 0.002301, 0. ], [0.930462, 0.069418, 0.000121], [0.00248 , 0.99752 , 0. ], [0.00248 , 0.99752 , 0. ], [0.238042, 0.760375, 0.001583], [0.101768, 0.898232, 0. ], [0.029989, 0.970011, 0. ], [0.98255 , 0.01745 , 0. ], [0. , 0.004285, 0.995715], [0.000332, 0.999668, 0. ], [0.999328, 0.000672, 0. ], [0. , 0.023981, 0.976019], [0.997777, 0.002223, 0. ], [0.949387, 0.050613, 0. ], [0.930462, 0.069418, 0.000121], [0.00028 , 0.99972 , 0. ], [0. , 0.004236, 0.995764], [0. , 0. , 1. ], [0.998712, 0.000937, 0.000351], [0. , 0.000038, 0.999962], [0.000898, 0.99788 , 0.001222], [0.91181 , 0.045917, 0.042272], [0.984437, 0.015558, 0.000005], [0.000003, 0.000326, 0.999672], [0.099592, 0.900408, 0. ], [0.033577, 0.966423, 0. ], [0.969283, 0.030717, 0. ], [0.999328, 0.000672, 0. ], [0. , 0. , 1. ], [0. , 0.008068, 0.991932], [0. , 0.000298, 0.999702], [0.003572, 0.996428, 0. ], [0.997777, 0.002223, 0. ], [0. , 0.000002, 0.999998], [0.000063, 0.040843, 0.959093], [0.947557, 0.052443, 0. ], [0.907855, 0.092061, 0.000084], [0.00145 , 0.99855 , 0. ], [0.101768, 0.898232, 0. ], [0. , 0.009143, 0.990857], [0.861344, 0.134685, 0.00397 ], [0. , 0.004236, 0.995764], [0.990975, 0.009025, 0. ], [0.870813, 0.129142, 0.000044], [0.995587, 0.000049, 0.004364], [0. , 0.000298, 0.999702], [0.942577, 0.018654, 0.038768], [0. , 0.000005, 0.999995], [0.000898, 0.99788 , 0.001222], [0.99808 , 0.001917, 0.000003], [0.099592, 0.900408, 0. ], [0.099592, 0.900408, 0. ], [0.999979, 0.000021, 0. ], [0.995587, 0.000049, 0.004364], [0.001398, 0.9986 , 0.000002], [0.000003, 0.000326, 0.999672], [0.999889, 0.000075, 0.000035], [0.861344, 0.134685, 0.00397 ], [0.000003, 0.000326, 0.999672], [0.969283, 0.030717, 0. ], [0.238042, 0.760375, 0.001583], [0.000977, 0.999023, 0. ], [0.002141, 0.997859, 0. ], [0.000005, 0.000077, 0.999918], [0.997699, 0.002301, 0. ], [0. , 0.004236, 0.995764], [0. , 0. , 1. ], [0.999425, 0.000575, 0. ], [0.974593, 0.025407, 0. ], [0.002453, 0.997528, 0.000019], [0.84426 , 0.15574 , 0. ], [0.000898, 0.99788 , 0.001222], [0.949387, 0.050613, 0. ], [0. , 0.000032, 0.999968], [0.000332, 0.999668, 0. ], [0.99808 , 0.001917, 0.000003], [0.05124 , 0.948155, 0.000606], [0.999889, 0.000075, 0.000035], [0. , 0. , 1. ], [0.099592, 0.900408, 0. ], [0. , 0.000298, 0.999702], [0.99808 , 0.001917, 0.000003], [0. , 0.000298, 0.999702], [0.598003, 0.391491, 0.010506], [0.002141, 0.997859, 0. ], [0.001321, 0.998655, 0.000024], [0.129826, 0.870174, 0. ], [0.033577, 0.966423, 0. ], [0.101768, 0.898232, 0. ], [0.999979, 0.000021, 0. ], [0.870813, 0.129142, 0.000044], [0.002453, 0.997528, 0.000019], [0.129826, 0.870174, 0. ], [0. , 0.000032, 0.999968], [0. , 0.004285, 0.995715], [0. , 0.000004, 0.999996], [0. , 0.000345, 0.999655], [0.033577, 0.966423, 0. ], [0.055579, 0.94417 , 0.000251], [0.99922 , 0.00078 , 0. ], [0.997699, 0.002301, 0. ], [0. , 0.014023, 0.985977], [0. , 0.008068, 0.991932], [0.998712, 0.000937, 0.000351], [0.000637, 0.999363, 0. ], [0.001321, 0.998655, 0.000024], [0.930462, 0.069418, 0.000121], [0.001398, 0.9986 , 0.000002], [0.002453, 0.997528, 0.000019], [0.05124 , 0.948155, 0.000606], [0. , 0. , 1. ], [0.001111, 0.998889, 0. ], [0. , 0.000203, 0.999797], [0.000063, 0.040843, 0.959093], [0.930462, 0.069418, 0.000121], [0.019763, 0.980237, 0. ], [0.033577, 0.966423, 0. ], [0.659404, 0.319765, 0.020831], [0. , 0. , 1. ], [0.999979, 0.000021, 0. ], [0. , 0.002178, 0.997822], [0. , 0.000085, 0.999915], [0.974593, 0.025407, 0. ], [0.98255 , 0.01745 , 0. ], [0.000308, 0.999692, 0. ], [0. , 0. , 1. ], [0.974593, 0.025407, 0. ], [0.947557, 0.052443, 0. ], [0.913746, 0.086254, 0. ], [0.015839, 0.984161, 0. ], [0.974593, 0.025407, 0. ], [0. , 0. , 1. ], [0.129826, 0.870174, 0. ], [0.000177, 0.999787, 0.000036], [0.238042, 0.760375, 0.001583], [0.99808 , 0.001917, 0.000003], [0. , 0. , 1. ], [0.001213, 0.97811 , 0.020677], [0. , 0.008068, 0.991932], [0.999979, 0.000021, 0. ], [0.001398, 0.9986 , 0.000002], [0.015839, 0.984161, 0. ], [0. , 0.021595, 0.978405], [0.002453, 0.997528, 0.000019], [0. , 0.000005, 0.999995], [0.907855, 0.092061, 0.000084], [0.974593, 0.025407, 0. ], [0. , 0.000032, 0.999968], [0.91181 , 0.045917, 0.042272], [0.000898, 0.99788 , 0.001222], [0. , 0.002066, 0.997934], [0.055579, 0.94417 , 0.000251], [0. , 0.014023, 0.985977], [0.000029, 0.006112, 0.993859], [0. , 0.002178, 0.997822], [0. , 0.023981, 0.976019], [0.999889, 0.000075, 0.000035], [0.000308, 0.999692, 0. ]])
y_pred = classifier.predict(X_test)
y_pred
array([2, 3, 3, 1, 3, 2, 3, 2, 3, 2, 3, 2, 2, 1, 1, 2, 3, 1, 3, 1, 3, 1, 3, 3, 3, 2, 2, 1, 1, 2, 1, 2, 3, 1, 2, 3, 2, 3, 2, 1, 2, 2, 3, 2, 3, 3, 1, 3, 3, 2, 2, 2, 2, 3, 2, 3, 3, 3, 1, 3, 3, 1, 3, 1, 3, 1, 2, 3, 3, 2, 1, 1, 1, 2, 2, 2, 2, 2, 3, 1, 3, 1, 3, 2, 3, 3, 3, 2, 1, 3, 2, 2, 1, 1, 2, 2, 2, 2, 2, 1, 3, 2, 1, 3, 1, 1, 1, 2, 3, 3, 1, 3, 2, 1, 1, 3, 2, 2, 1, 1, 3, 3, 3, 2, 1, 3, 3, 1, 1, 2, 2, 3, 1, 3, 1, 1, 1, 3, 1, 3, 2, 1, 2, 2, 1, 1, 2, 3, 1, 1, 3, 1, 2, 2, 2, 3, 1, 3, 3, 1, 1, 2, 1, 2, 1, 3, 2, 1, 2, 1, 3, 2, 3, 1, 3, 1, 2, 2, 2, 2, 2, 1, 1, 2, 2, 3, 3, 3, 3, 2, 2, 1, 1, 3, 3, 1, 2, 2, 1, 2, 2, 2, 3, 2, 3, 3, 1, 2, 2, 1, 3, 1, 3, 3, 1, 1, 2, 3, 1, 1, 1, 2, 1, 3, 2, 2, 2, 1, 3, 2, 3, 1, 2, 2, 3, 2, 3, 1, 1, 3, 1, 2, 3, 2, 3, 3, 3, 3, 1, 2], dtype=int64)
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from mlxtend.plotting import plot_confusion_matrix
import seaborn as sns
sns.set()
cm = confusion_matrix(y_test, y_pred)
plot_confusion_matrix(cm)
acc = accuracy_score(y_test, y_pred)
print('Accuracy',':', acc)
Accuracy : 0.992
def plot_confusion_matrix(cm):
cm = cm[::-1]
cm = pd.DataFrame(cm, columns=['pred_1', 'pred_2', 'pred_3'], index=['true_1', 'true_2', 'true_3'])
fig = ff.create_annotated_heatmap(z = cm.values, x = list(cm.columns), y = list(cm.index), colorscale = 'ice', showscale = True, reversescale = True)
fig.update_layout(width=400, height=400, title='Confusion Matrix - Accuracy: {:.4f}'.format(acc), font_size=14)
fig.show()
plot_confusion_matrix(cm)
from sklearn.metrics import classification_report
print(classification_report(y_test, y_pred, target_names=['pred_1', 'pred_2', 'pred_3']))
precision recall f1-score support pred_1 0.97 1.00 0.99 76 pred_2 1.00 0.98 0.99 89 pred_3 1.00 1.00 1.00 85 accuracy 0.99 250 macro avg 0.99 0.99 0.99 250 weighted avg 0.99 0.99 0.99 250