agabka/projekt.ipynb
2024-01-01 18:07:42 +01:00

241 KiB
Raw Blame History

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import plotly.figure_factory as ff
import seaborn as sns
sns.set()
dane = pd.read_csv(r'C:\Users\HP\Desktop\podyplomówka\cancer_patient_data_sets.csv', index_col = 0)
dane.head()
Patient Id Age Gender Air Pollution Alcohol use Dust Allergy OccuPational Hazards Genetic Risk chronic Lung Disease Balanced Diet ... Fatigue Weight Loss Shortness of Breath Wheezing Swallowing Difficulty Clubbing of Finger Nails Frequent Cold Dry Cough Snoring Level
index
0 P1 33 1 2 4 5 4 3 2 2 ... 3 4 2 2 3 1 2 3 4 Low
1 P10 17 1 3 1 5 3 4 2 2 ... 1 3 7 8 6 2 1 7 2 Medium
2 P100 35 1 4 5 6 5 5 4 6 ... 8 7 9 2 1 4 6 7 2 High
3 P1000 37 1 7 7 7 7 6 7 7 ... 4 2 3 1 4 5 6 7 5 High
4 P101 46 1 6 8 7 7 7 6 7 ... 3 2 4 1 4 2 4 2 3 High

5 rows × 25 columns

dane.info()
<class 'pandas.core.frame.DataFrame'>
Index: 1000 entries, 0 to 999
Data columns (total 25 columns):
 #   Column                    Non-Null Count  Dtype 
---  ------                    --------------  ----- 
 0   Patient Id                1000 non-null   object
 1   Age                       1000 non-null   int64 
 2   Gender                    1000 non-null   int64 
 3   Air Pollution             1000 non-null   int64 
 4   Alcohol use               1000 non-null   int64 
 5   Dust Allergy              1000 non-null   int64 
 6   OccuPational Hazards      1000 non-null   int64 
 7   Genetic Risk              1000 non-null   int64 
 8   chronic Lung Disease      1000 non-null   int64 
 9   Balanced Diet             1000 non-null   int64 
 10  Obesity                   1000 non-null   int64 
 11  Smoking                   1000 non-null   int64 
 12  Passive Smoker            1000 non-null   int64 
 13  Chest Pain                1000 non-null   int64 
 14  Coughing of Blood         1000 non-null   int64 
 15  Fatigue                   1000 non-null   int64 
 16  Weight Loss               1000 non-null   int64 
 17  Shortness of Breath       1000 non-null   int64 
 18  Wheezing                  1000 non-null   int64 
 19  Swallowing Difficulty     1000 non-null   int64 
 20  Clubbing of Finger Nails  1000 non-null   int64 
 21  Frequent Cold             1000 non-null   int64 
 22  Dry Cough                 1000 non-null   int64 
 23  Snoring                   1000 non-null   int64 
 24  Level                     1000 non-null   object
dtypes: int64(23), object(2)
memory usage: 203.1+ KB
dane0 = dane[dane['Level'] == 'High'][['Air Pollution', 'Smoking', 'Passive Smoker']]
dane0
Air Pollution Smoking Passive Smoker
index
2 4 2 3
3 7 7 7
4 6 8 7
5 4 2 3
10 6 7 8
... ... ... ...
995 6 7 8
996 6 7 8
997 4 2 3
998 6 8 7
999 6 2 3

365 rows × 3 columns

dane.describe().T
count mean std min 25% 50% 75% max
Age 1000.0 37.174 12.005493 14.0 27.75 36.0 45.0 73.0
Gender 1000.0 1.402 0.490547 1.0 1.00 1.0 2.0 2.0
Air Pollution 1000.0 3.840 2.030400 1.0 2.00 3.0 6.0 8.0
Alcohol use 1000.0 4.563 2.620477 1.0 2.00 5.0 7.0 8.0
Dust Allergy 1000.0 5.165 1.980833 1.0 4.00 6.0 7.0 8.0
OccuPational Hazards 1000.0 4.840 2.107805 1.0 3.00 5.0 7.0 8.0
Genetic Risk 1000.0 4.580 2.126999 1.0 2.00 5.0 7.0 7.0
chronic Lung Disease 1000.0 4.380 1.848518 1.0 3.00 4.0 6.0 7.0
Balanced Diet 1000.0 4.491 2.135528 1.0 2.00 4.0 7.0 7.0
Obesity 1000.0 4.465 2.124921 1.0 3.00 4.0 7.0 7.0
Smoking 1000.0 3.948 2.495902 1.0 2.00 3.0 7.0 8.0
Passive Smoker 1000.0 4.195 2.311778 1.0 2.00 4.0 7.0 8.0
Chest Pain 1000.0 4.438 2.280209 1.0 2.00 4.0 7.0 9.0
Coughing of Blood 1000.0 4.859 2.427965 1.0 3.00 4.0 7.0 9.0
Fatigue 1000.0 3.856 2.244616 1.0 2.00 3.0 5.0 9.0
Weight Loss 1000.0 3.855 2.206546 1.0 2.00 3.0 6.0 8.0
Shortness of Breath 1000.0 4.240 2.285087 1.0 2.00 4.0 6.0 9.0
Wheezing 1000.0 3.777 2.041921 1.0 2.00 4.0 5.0 8.0
Swallowing Difficulty 1000.0 3.746 2.270383 1.0 2.00 4.0 5.0 8.0
Clubbing of Finger Nails 1000.0 3.923 2.388048 1.0 2.00 4.0 5.0 9.0
Frequent Cold 1000.0 3.536 1.832502 1.0 2.00 3.0 5.0 7.0
Dry Cough 1000.0 3.853 2.039007 1.0 2.00 4.0 6.0 7.0
Snoring 1000.0 2.926 1.474686 1.0 2.00 3.0 4.0 7.0
dane[['Genetic Risk']].median()
Genetic Risk    5.0
dtype: float64
dane.columns
Index(['Patient Id', 'Age', 'Gender', 'Air Pollution', 'Alcohol use',
       'Dust Allergy', 'OccuPational Hazards', 'Genetic Risk',
       'chronic Lung Disease', 'Balanced Diet', 'Obesity', 'Smoking',
       'Passive Smoker', 'Chest Pain', 'Coughing of Blood', 'Fatigue',
       'Weight Loss', 'Shortness of Breath', 'Wheezing',
       'Swallowing Difficulty', 'Clubbing of Finger Nails', 'Frequent Cold',
       'Dry Cough', 'Snoring', 'Level'],
      dtype='object')
dane2 = dane.groupby('Gender').size()
dane2
Gender
1    598
2    402
dtype: int64

_ = dane['Gender'].value_counts().plot(kind = 'bar')
_ = plt.legend()
dane3 = [dane.groupby('Smoking').size()]
dane3 
[Smoking
 1    181
 2    222
 3    172
 4     59
 5     10
 6     60
 7    207
 8     89
 dtype: int64]
_ = dane['Smoking'].value_counts().plot(kind = 'pie')
dane4 = [dane.groupby('Passive Smoker').size()]
dane4
[Passive Smoker
 1     60
 2    284
 3    140
 4    161
 5     30
 6     30
 7    187
 8    108
 dtype: int64]
_ = dane['Passive Smoker'].value_counts().plot(kind = 'pie')
dane.groupby(['Smoking','Gender']).size()
Smoking  Gender
1        1         102
         2          79
2        1         102
         2         120
3        1          79
         2          93
4        1          49
         2          10
5        1          10
6        1          28
         2          32
7        1         167
         2          40
8        1          61
         2          28
dtype: int64
dane6 = dane.groupby(['Smoking','Gender'])
_ = dane6[['Smoking', 'Gender']].value_counts().plot(kind = 'bar')
_ = plt.legend()
dane[['Smoking']].median()
Smoking    3.0
dtype: float64
x = dane[['Genetic Risk', 'Smoking','Alcohol use']]
x.sort_values('Genetic Risk')
Genetic Risk Smoking Alcohol use
index
725 1 4 3
59 1 4 3
727 1 3 3
940 1 4 2
819 1 1 2
... ... ... ...
537 7 8 8
538 7 7 7
755 7 4 7
533 7 4 7
812 7 7 7

1000 rows × 3 columns

dane7 = dane['Air Pollution'].value_counts()
dane7.sort_values()
Air Pollution
8     19
5     20
7     30
4     90
1    141
3    173
2    201
6    326
Name: count, dtype: int64
_ = dane7 = dane['Air Pollution'].value_counts().plot(kind = 'bar')
_ = plt.legend()
import numpy as np
import sklearn 
_ = dane8 = dane['Genetic Risk'].value_counts()
dane8.sort_values()

Genetic Risk
4     40
1     40
5    100
6    108
3    173
2    212
7    327
Name: count, dtype: int64
dane
Patient Id Age Gender Air Pollution Alcohol use Dust Allergy OccuPational Hazards Genetic Risk chronic Lung Disease Balanced Diet ... Fatigue Weight Loss Shortness of Breath Wheezing Swallowing Difficulty Clubbing of Finger Nails Frequent Cold Dry Cough Snoring Level
index
0 P1 33 1 2 4 5 4 3 2 2 ... 3 4 2 2 3 1 2 3 4 Low
1 P10 17 1 3 1 5 3 4 2 2 ... 1 3 7 8 6 2 1 7 2 Medium
2 P100 35 1 4 5 6 5 5 4 6 ... 8 7 9 2 1 4 6 7 2 High
3 P1000 37 1 7 7 7 7 6 7 7 ... 4 2 3 1 4 5 6 7 5 High
4 P101 46 1 6 8 7 7 7 6 7 ... 3 2 4 1 4 2 4 2 3 High
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
995 P995 44 1 6 7 7 7 7 6 7 ... 5 3 2 7 8 2 4 5 3 High
996 P996 37 2 6 8 7 7 7 6 7 ... 9 6 5 7 2 4 3 1 4 High
997 P997 25 2 4 5 6 5 5 4 6 ... 8 7 9 2 1 4 6 7 2 High
998 P998 18 2 6 8 7 7 7 6 7 ... 3 2 4 1 4 2 4 2 3 High
999 P999 47 1 6 5 6 5 5 4 6 ... 8 7 9 2 1 4 6 7 2 High

1000 rows × 25 columns

dane['Level'].value_counts()
Level
High      365
Medium    332
Low       303
Name: count, dtype: int64
data = dane.replace({'Level':{'High' : 3, 'Medium' : 2, 'Low' : 1}})
data.head()
Patient Id Age Gender Air Pollution Alcohol use Dust Allergy OccuPational Hazards Genetic Risk chronic Lung Disease Balanced Diet ... Fatigue Weight Loss Shortness of Breath Wheezing Swallowing Difficulty Clubbing of Finger Nails Frequent Cold Dry Cough Snoring Level
index
0 P1 33 1 2 4 5 4 3 2 2 ... 3 4 2 2 3 1 2 3 4 1
1 P10 17 1 3 1 5 3 4 2 2 ... 1 3 7 8 6 2 1 7 2 2
2 P100 35 1 4 5 6 5 5 4 6 ... 8 7 9 2 1 4 6 7 2 3
3 P1000 37 1 7 7 7 7 6 7 7 ... 4 2 3 1 4 5 6 7 5 3
4 P101 46 1 6 8 7 7 7 6 7 ... 3 2 4 1 4 2 4 2 3 3

5 rows × 25 columns

import sklearn
np.random.seed(10)
np.set_printoptions(precision=6, suppress=True)
X = data.drop(['Level', 'Patient Id'], axis=1)
y = data['Level']


print("Y shape:", y.shape)
print("X shape:", X.shape)
Y shape: (1000,)
X shape: (1000, 23)
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split (X, y)




print("X_train shape:", X_train.shape)
print("y_train shape:", y_train.shape)
print("X_test shape:", X_test.shape)
print("y_test shape:", y_test.shape)
X_train shape: (750, 23)
y_train shape: (750,)
X_test shape: (250, 23)
y_test shape: (250,)
from sklearn.linear_model import LogisticRegression
classifier = LogisticRegression()
classifier.fit(X_train, y_train)
c:\Users\HP\anaconda3\lib\site-packages\sklearn\linear_model\_logistic.py:460: ConvergenceWarning:

lbfgs failed to converge (status=1):
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression

LogisticRegression()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
y_prob = classifier.predict_proba(X_test)
y_prob
array([[0.019763, 0.980237, 0.      ],
       [0.      , 0.      , 1.      ],
       [0.      , 0.000002, 0.999998],
       [0.999979, 0.000021, 0.      ],
       [0.      , 0.000038, 0.999962],
       [0.000022, 0.983401, 0.016577],
       [0.      , 0.023981, 0.976019],
       [0.025065, 0.943631, 0.031305],
       [0.      , 0.011278, 0.988722],
       [0.077079, 0.922921, 0.      ],
       [0.000003, 0.000326, 0.999672],
       [0.000473, 0.999527, 0.      ],
       [0.16753 , 0.83247 , 0.      ],
       [0.995731, 0.004269, 0.      ],
       [0.949387, 0.050613, 0.      ],
       [0.21037 , 0.78963 , 0.      ],
       [0.      , 0.      , 1.      ],
       [0.91181 , 0.045917, 0.042272],
       [0.      , 0.002178, 0.997822],
       [0.984437, 0.015558, 0.000005],
       [0.      , 0.002066, 0.997934],
       [0.99922 , 0.00078 , 0.      ],
       [0.      , 0.000298, 0.999702],
       [0.      , 0.004236, 0.995764],
       [0.      , 0.000004, 0.999996],
       [0.238042, 0.760375, 0.001583],
       [0.025065, 0.943631, 0.031305],
       [0.99808 , 0.001917, 0.000003],
       [0.997777, 0.002223, 0.      ],
       [0.238042, 0.760375, 0.001583],
       [0.913746, 0.086254, 0.      ],
       [0.002141, 0.997859, 0.      ],
       [0.      , 0.004236, 0.995764],
       [0.99922 , 0.00078 , 0.      ],
       [0.00192 , 0.99808 , 0.      ],
       [0.      , 0.012453, 0.987547],
       [0.00145 , 0.99855 , 0.      ],
       [0.      , 0.      , 1.      ],
       [0.00248 , 0.99752 , 0.      ],
       [0.998712, 0.000937, 0.000351],
       [0.238042, 0.760375, 0.001583],
       [0.001213, 0.97811 , 0.020677],
       [0.      , 0.      , 1.      ],
       [0.015839, 0.984161, 0.      ],
       [0.      , 0.023981, 0.976019],
       [0.000003, 0.000326, 0.999672],
       [0.930462, 0.069418, 0.000121],
       [0.      , 0.002178, 0.997822],
       [0.      , 0.000003, 0.999997],
       [0.001321, 0.998655, 0.000024],
       [0.001178, 0.998807, 0.000014],
       [0.00035 , 0.99965 , 0.      ],
       [0.05124 , 0.948155, 0.000606],
       [0.      , 0.      , 1.      ],
       [0.001321, 0.998655, 0.000024],
       [0.      , 0.      , 1.      ],
       [0.      , 0.004285, 0.995715],
       [0.000004, 0.02071 , 0.979285],
       [0.969283, 0.030717, 0.      ],
       [0.000063, 0.040843, 0.959093],
       [0.      , 0.002066, 0.997934],
       [0.942577, 0.018654, 0.038768],
       [0.      , 0.000002, 0.999998],
       [0.870813, 0.129142, 0.000044],
       [0.      , 0.      , 1.      ],
       [0.913746, 0.086254, 0.      ],
       [0.000637, 0.999363, 0.      ],
       [0.      , 0.000003, 0.999997],
       [0.      , 0.000298, 0.999702],
       [0.029989, 0.970011, 0.      ],
       [0.994782, 0.000846, 0.004371],
       [0.999889, 0.000075, 0.000035],
       [0.947557, 0.052443, 0.      ],
       [0.029989, 0.970011, 0.      ],
       [0.129826, 0.870174, 0.      ],
       [0.055579, 0.94417 , 0.000251],
       [0.002141, 0.997859, 0.      ],
       [0.001398, 0.9986  , 0.000002],
       [0.      , 0.000062, 0.999938],
       [0.99808 , 0.001917, 0.000003],
       [0.      , 0.000004, 0.999996],
       [0.969283, 0.030717, 0.      ],
       [0.      , 0.004236, 0.995764],
       [0.001178, 0.998807, 0.000014],
       [0.      , 0.      , 1.      ],
       [0.      , 0.014023, 0.985977],
       [0.      , 0.      , 1.      ],
       [0.033577, 0.966423, 0.      ],
       [0.99808 , 0.001917, 0.000003],
       [0.      , 0.000062, 0.999938],
       [0.099592, 0.900408, 0.      ],
       [0.000177, 0.999787, 0.000036],
       [0.997699, 0.002301, 0.      ],
       [0.930462, 0.069418, 0.000121],
       [0.00248 , 0.99752 , 0.      ],
       [0.00248 , 0.99752 , 0.      ],
       [0.238042, 0.760375, 0.001583],
       [0.101768, 0.898232, 0.      ],
       [0.029989, 0.970011, 0.      ],
       [0.98255 , 0.01745 , 0.      ],
       [0.      , 0.004285, 0.995715],
       [0.000332, 0.999668, 0.      ],
       [0.999328, 0.000672, 0.      ],
       [0.      , 0.023981, 0.976019],
       [0.997777, 0.002223, 0.      ],
       [0.949387, 0.050613, 0.      ],
       [0.930462, 0.069418, 0.000121],
       [0.00028 , 0.99972 , 0.      ],
       [0.      , 0.004236, 0.995764],
       [0.      , 0.      , 1.      ],
       [0.998712, 0.000937, 0.000351],
       [0.      , 0.000038, 0.999962],
       [0.000898, 0.99788 , 0.001222],
       [0.91181 , 0.045917, 0.042272],
       [0.984437, 0.015558, 0.000005],
       [0.000003, 0.000326, 0.999672],
       [0.099592, 0.900408, 0.      ],
       [0.033577, 0.966423, 0.      ],
       [0.969283, 0.030717, 0.      ],
       [0.999328, 0.000672, 0.      ],
       [0.      , 0.      , 1.      ],
       [0.      , 0.008068, 0.991932],
       [0.      , 0.000298, 0.999702],
       [0.003572, 0.996428, 0.      ],
       [0.997777, 0.002223, 0.      ],
       [0.      , 0.000002, 0.999998],
       [0.000063, 0.040843, 0.959093],
       [0.947557, 0.052443, 0.      ],
       [0.907855, 0.092061, 0.000084],
       [0.00145 , 0.99855 , 0.      ],
       [0.101768, 0.898232, 0.      ],
       [0.      , 0.009143, 0.990857],
       [0.861344, 0.134685, 0.00397 ],
       [0.      , 0.004236, 0.995764],
       [0.990975, 0.009025, 0.      ],
       [0.870813, 0.129142, 0.000044],
       [0.995587, 0.000049, 0.004364],
       [0.      , 0.000298, 0.999702],
       [0.942577, 0.018654, 0.038768],
       [0.      , 0.000005, 0.999995],
       [0.000898, 0.99788 , 0.001222],
       [0.99808 , 0.001917, 0.000003],
       [0.099592, 0.900408, 0.      ],
       [0.099592, 0.900408, 0.      ],
       [0.999979, 0.000021, 0.      ],
       [0.995587, 0.000049, 0.004364],
       [0.001398, 0.9986  , 0.000002],
       [0.000003, 0.000326, 0.999672],
       [0.999889, 0.000075, 0.000035],
       [0.861344, 0.134685, 0.00397 ],
       [0.000003, 0.000326, 0.999672],
       [0.969283, 0.030717, 0.      ],
       [0.238042, 0.760375, 0.001583],
       [0.000977, 0.999023, 0.      ],
       [0.002141, 0.997859, 0.      ],
       [0.000005, 0.000077, 0.999918],
       [0.997699, 0.002301, 0.      ],
       [0.      , 0.004236, 0.995764],
       [0.      , 0.      , 1.      ],
       [0.999425, 0.000575, 0.      ],
       [0.974593, 0.025407, 0.      ],
       [0.002453, 0.997528, 0.000019],
       [0.84426 , 0.15574 , 0.      ],
       [0.000898, 0.99788 , 0.001222],
       [0.949387, 0.050613, 0.      ],
       [0.      , 0.000032, 0.999968],
       [0.000332, 0.999668, 0.      ],
       [0.99808 , 0.001917, 0.000003],
       [0.05124 , 0.948155, 0.000606],
       [0.999889, 0.000075, 0.000035],
       [0.      , 0.      , 1.      ],
       [0.099592, 0.900408, 0.      ],
       [0.      , 0.000298, 0.999702],
       [0.99808 , 0.001917, 0.000003],
       [0.      , 0.000298, 0.999702],
       [0.598003, 0.391491, 0.010506],
       [0.002141, 0.997859, 0.      ],
       [0.001321, 0.998655, 0.000024],
       [0.129826, 0.870174, 0.      ],
       [0.033577, 0.966423, 0.      ],
       [0.101768, 0.898232, 0.      ],
       [0.999979, 0.000021, 0.      ],
       [0.870813, 0.129142, 0.000044],
       [0.002453, 0.997528, 0.000019],
       [0.129826, 0.870174, 0.      ],
       [0.      , 0.000032, 0.999968],
       [0.      , 0.004285, 0.995715],
       [0.      , 0.000004, 0.999996],
       [0.      , 0.000345, 0.999655],
       [0.033577, 0.966423, 0.      ],
       [0.055579, 0.94417 , 0.000251],
       [0.99922 , 0.00078 , 0.      ],
       [0.997699, 0.002301, 0.      ],
       [0.      , 0.014023, 0.985977],
       [0.      , 0.008068, 0.991932],
       [0.998712, 0.000937, 0.000351],
       [0.000637, 0.999363, 0.      ],
       [0.001321, 0.998655, 0.000024],
       [0.930462, 0.069418, 0.000121],
       [0.001398, 0.9986  , 0.000002],
       [0.002453, 0.997528, 0.000019],
       [0.05124 , 0.948155, 0.000606],
       [0.      , 0.      , 1.      ],
       [0.001111, 0.998889, 0.      ],
       [0.      , 0.000203, 0.999797],
       [0.000063, 0.040843, 0.959093],
       [0.930462, 0.069418, 0.000121],
       [0.019763, 0.980237, 0.      ],
       [0.033577, 0.966423, 0.      ],
       [0.659404, 0.319765, 0.020831],
       [0.      , 0.      , 1.      ],
       [0.999979, 0.000021, 0.      ],
       [0.      , 0.002178, 0.997822],
       [0.      , 0.000085, 0.999915],
       [0.974593, 0.025407, 0.      ],
       [0.98255 , 0.01745 , 0.      ],
       [0.000308, 0.999692, 0.      ],
       [0.      , 0.      , 1.      ],
       [0.974593, 0.025407, 0.      ],
       [0.947557, 0.052443, 0.      ],
       [0.913746, 0.086254, 0.      ],
       [0.015839, 0.984161, 0.      ],
       [0.974593, 0.025407, 0.      ],
       [0.      , 0.      , 1.      ],
       [0.129826, 0.870174, 0.      ],
       [0.000177, 0.999787, 0.000036],
       [0.238042, 0.760375, 0.001583],
       [0.99808 , 0.001917, 0.000003],
       [0.      , 0.      , 1.      ],
       [0.001213, 0.97811 , 0.020677],
       [0.      , 0.008068, 0.991932],
       [0.999979, 0.000021, 0.      ],
       [0.001398, 0.9986  , 0.000002],
       [0.015839, 0.984161, 0.      ],
       [0.      , 0.021595, 0.978405],
       [0.002453, 0.997528, 0.000019],
       [0.      , 0.000005, 0.999995],
       [0.907855, 0.092061, 0.000084],
       [0.974593, 0.025407, 0.      ],
       [0.      , 0.000032, 0.999968],
       [0.91181 , 0.045917, 0.042272],
       [0.000898, 0.99788 , 0.001222],
       [0.      , 0.002066, 0.997934],
       [0.055579, 0.94417 , 0.000251],
       [0.      , 0.014023, 0.985977],
       [0.000029, 0.006112, 0.993859],
       [0.      , 0.002178, 0.997822],
       [0.      , 0.023981, 0.976019],
       [0.999889, 0.000075, 0.000035],
       [0.000308, 0.999692, 0.      ]])
y_pred = classifier.predict(X_test)
y_pred
array([2, 3, 3, 1, 3, 2, 3, 2, 3, 2, 3, 2, 2, 1, 1, 2, 3, 1, 3, 1, 3, 1,
       3, 3, 3, 2, 2, 1, 1, 2, 1, 2, 3, 1, 2, 3, 2, 3, 2, 1, 2, 2, 3, 2,
       3, 3, 1, 3, 3, 2, 2, 2, 2, 3, 2, 3, 3, 3, 1, 3, 3, 1, 3, 1, 3, 1,
       2, 3, 3, 2, 1, 1, 1, 2, 2, 2, 2, 2, 3, 1, 3, 1, 3, 2, 3, 3, 3, 2,
       1, 3, 2, 2, 1, 1, 2, 2, 2, 2, 2, 1, 3, 2, 1, 3, 1, 1, 1, 2, 3, 3,
       1, 3, 2, 1, 1, 3, 2, 2, 1, 1, 3, 3, 3, 2, 1, 3, 3, 1, 1, 2, 2, 3,
       1, 3, 1, 1, 1, 3, 1, 3, 2, 1, 2, 2, 1, 1, 2, 3, 1, 1, 3, 1, 2, 2,
       2, 3, 1, 3, 3, 1, 1, 2, 1, 2, 1, 3, 2, 1, 2, 1, 3, 2, 3, 1, 3, 1,
       2, 2, 2, 2, 2, 1, 1, 2, 2, 3, 3, 3, 3, 2, 2, 1, 1, 3, 3, 1, 2, 2,
       1, 2, 2, 2, 3, 2, 3, 3, 1, 2, 2, 1, 3, 1, 3, 3, 1, 1, 2, 3, 1, 1,
       1, 2, 1, 3, 2, 2, 2, 1, 3, 2, 3, 1, 2, 2, 3, 2, 3, 1, 1, 3, 1, 2,
       3, 2, 3, 3, 3, 3, 1, 2], dtype=int64)
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from mlxtend.plotting import plot_confusion_matrix
import seaborn as sns
sns.set()
cm = confusion_matrix(y_test, y_pred)
plot_confusion_matrix(cm)

acc = accuracy_score(y_test, y_pred)
print('Accuracy',':', acc)
Accuracy : 0.992
def plot_confusion_matrix(cm):
    cm = cm[::-1]
    cm = pd.DataFrame(cm, columns=['pred_1', 'pred_2', 'pred_3'], index=['true_1', 'true_2', 'true_3'])
    fig = ff.create_annotated_heatmap(z = cm.values, x = list(cm.columns), y = list(cm.index), colorscale = 'ice', showscale = True, reversescale = True)
    fig.update_layout(width=400, height=400, title='Confusion Matrix - Accuracy: {:.4f}'.format(acc), font_size=14)
    fig.show()
plot_confusion_matrix(cm)
from sklearn.metrics import classification_report
print(classification_report(y_test, y_pred, target_names=['pred_1', 'pred_2', 'pred_3']))
              precision    recall  f1-score   support

      pred_1       0.97      1.00      0.99        76
      pred_2       1.00      0.98      0.99        89
      pred_3       1.00      1.00      1.00        85

    accuracy                           0.99       250
   macro avg       0.99      0.99      0.99       250
weighted avg       0.99      0.99      0.99       250