agabka/projekt (5) (6).ipynb
ag.gabka@gmail.com 6bdf024137 projekt.ipynb
2024-04-04 22:23:05 +02:00

648 KiB
Raw Blame History

This dataset contains information on patients with lung cancer, including their age, gender, air pollution exposure, alcohol use, dust allergy, occupational hazards, genetic risk, chronic lung disease, balanced diet, obesity, smoking, passive smoker, chest pain, coughing of blood, fatigue, weight loss ,shortness of breath ,wheezing ,swallowing difficulty ,clubbing of finger nails and snoring

https://www.kaggle.com/datasets/thedevastator/cancer-patients-and-air-pollution-a-new-link/data



import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import plotly.figure_factory as ff
import seaborn as sns
sns.set()
import plotly.express as px
import numpy as np
import sklearn 
pip install plotnine
import plotnine
dane = pd.read_csv(r'C:\Users\HP\Desktop\podyplomówka\cancer_patient_data_sets.csv', index_col = 0)
dane.head()
Patient Id Age Gender Air Pollution Alcohol use Dust Allergy OccuPational Hazards Genetic Risk chronic Lung Disease Balanced Diet ... Fatigue Weight Loss Shortness of Breath Wheezing Swallowing Difficulty Clubbing of Finger Nails Frequent Cold Dry Cough Snoring Level
index
0 P1 33 1 2 4 5 4 3 2 2 ... 3 4 2 2 3 1 2 3 4 Low
1 P10 17 1 3 1 5 3 4 2 2 ... 1 3 7 8 6 2 1 7 2 Medium
2 P100 35 1 4 5 6 5 5 4 6 ... 8 7 9 2 1 4 6 7 2 High
3 P1000 37 1 7 7 7 7 6 7 7 ... 4 2 3 1 4 5 6 7 5 High
4 P101 46 1 6 8 7 7 7 6 7 ... 3 2 4 1 4 2 4 2 3 High

5 rows × 25 columns

dane.info()
<class 'pandas.core.frame.DataFrame'>
Index: 1000 entries, 0 to 999
Data columns (total 25 columns):
 #   Column                    Non-Null Count  Dtype 
---  ------                    --------------  ----- 
 0   Patient Id                1000 non-null   object
 1   Age                       1000 non-null   int64 
 2   Gender                    1000 non-null   int64 
 3   Air Pollution             1000 non-null   int64 
 4   Alcohol use               1000 non-null   int64 
 5   Dust Allergy              1000 non-null   int64 
 6   OccuPational Hazards      1000 non-null   int64 
 7   Genetic Risk              1000 non-null   int64 
 8   chronic Lung Disease      1000 non-null   int64 
 9   Balanced Diet             1000 non-null   int64 
 10  Obesity                   1000 non-null   int64 
 11  Smoking                   1000 non-null   int64 
 12  Passive Smoker            1000 non-null   int64 
 13  Chest Pain                1000 non-null   int64 
 14  Coughing of Blood         1000 non-null   int64 
 15  Fatigue                   1000 non-null   int64 
 16  Weight Loss               1000 non-null   int64 
 17  Shortness of Breath       1000 non-null   int64 
 18  Wheezing                  1000 non-null   int64 
 19  Swallowing Difficulty     1000 non-null   int64 
 20  Clubbing of Finger Nails  1000 non-null   int64 
 21  Frequent Cold             1000 non-null   int64 
 22  Dry Cough                 1000 non-null   int64 
 23  Snoring                   1000 non-null   int64 
 24  Level                     1000 non-null   object
dtypes: int64(23), object(2)
memory usage: 203.1+ KB
dane.describe().T
count mean std min 25% 50% 75% max
Age 1000.0 37.174 12.005493 14.0 27.75 36.0 45.0 73.0
Gender 1000.0 1.402 0.490547 1.0 1.00 1.0 2.0 2.0
Air Pollution 1000.0 3.840 2.030400 1.0 2.00 3.0 6.0 8.0
Alcohol use 1000.0 4.563 2.620477 1.0 2.00 5.0 7.0 8.0
Dust Allergy 1000.0 5.165 1.980833 1.0 4.00 6.0 7.0 8.0
OccuPational Hazards 1000.0 4.840 2.107805 1.0 3.00 5.0 7.0 8.0
Genetic Risk 1000.0 4.580 2.126999 1.0 2.00 5.0 7.0 7.0
chronic Lung Disease 1000.0 4.380 1.848518 1.0 3.00 4.0 6.0 7.0
Balanced Diet 1000.0 4.491 2.135528 1.0 2.00 4.0 7.0 7.0
Obesity 1000.0 4.465 2.124921 1.0 3.00 4.0 7.0 7.0
Smoking 1000.0 3.948 2.495902 1.0 2.00 3.0 7.0 8.0
Passive Smoker 1000.0 4.195 2.311778 1.0 2.00 4.0 7.0 8.0
Chest Pain 1000.0 4.438 2.280209 1.0 2.00 4.0 7.0 9.0
Coughing of Blood 1000.0 4.859 2.427965 1.0 3.00 4.0 7.0 9.0
Fatigue 1000.0 3.856 2.244616 1.0 2.00 3.0 5.0 9.0
Weight Loss 1000.0 3.855 2.206546 1.0 2.00 3.0 6.0 8.0
Shortness of Breath 1000.0 4.240 2.285087 1.0 2.00 4.0 6.0 9.0
Wheezing 1000.0 3.777 2.041921 1.0 2.00 4.0 5.0 8.0
Swallowing Difficulty 1000.0 3.746 2.270383 1.0 2.00 4.0 5.0 8.0
Clubbing of Finger Nails 1000.0 3.923 2.388048 1.0 2.00 4.0 5.0 9.0
Frequent Cold 1000.0 3.536 1.832502 1.0 2.00 3.0 5.0 7.0
Dry Cough 1000.0 3.853 2.039007 1.0 2.00 4.0 6.0 7.0
Snoring 1000.0 2.926 1.474686 1.0 2.00 3.0 4.0 7.0
dane.columns
Index(['Patient Id', 'Age', 'Gender', 'Air Pollution', 'Alcohol use',
       'Dust Allergy', 'OccuPational Hazards', 'Genetic Risk',
       'chronic Lung Disease', 'Balanced Diet', 'Obesity', 'Smoking',
       'Passive Smoker', 'Chest Pain', 'Coughing of Blood', 'Fatigue',
       'Weight Loss', 'Shortness of Breath', 'Wheezing',
       'Swallowing Difficulty', 'Clubbing of Finger Nails', 'Frequent Cold',
       'Dry Cough', 'Snoring', 'Level'],
      dtype='object')
level_counts = dane['Level'].value_counts()
ax = level_counts.plot(kind = 'pie',  autopct='%1.1f%%', startangle=90)
ax = plt.title('Risk distribiution')

#stworzenie 'binow' dla pokazania wieku pacjentów
bins = []
for i in range (0, 101, 10):
    bins.append(i)

plt.hist(dane['Age'], bins, histtype='bar', rwidth=0.8)
for i in range(len(bins) - 1):
    count = ((dane['Age'] >= bins[i]) & (dane['Age'] < bins[i+1])).sum()
    plt.text(bins[i] + 5, count, str(count), ha='center', va='bottom')
plt.xlabel('Age')
plt.ylabel('Number of patients')
plt.title('Age of the patients')
plt.show()
mean_age = dane['Age'].mean()
mean_age
37.174
gender_counts = dane['Gender'].value_counts()
ax = gender_counts.plot(kind='bar', color=['blue', 'pink'])

# Dodawanie wartości do słupków
for i, value in enumerate(gender_counts):
    ax.text(i, value + 0.1, str(value), ha='center', va='bottom')

# Zmiana etykiet osi x
ax.set_xticks([0, 1])
ax.set_xticklabels(['Man', 'Woman'])
ax.set_title ("Distribution of patients' gender")

# Dodanie legendy
plt.legend()

# Wyświetlenie wykresu
plt.show()
# Grupowanie danych
grouped_data = dane.groupby(['Gender', 'Level']).size().unstack()

# Ustawienia kategorii i szerokości słupków
categories = grouped_data.columns
bar_width = 0.35
bar_positions_man = np.arange(len(categories))
bar_positions_woman = [pos + bar_width for pos in bar_positions_man]

# Wygenerowanie wykresu słupkowego
fig, ax = plt.subplots()

ax.bar(bar_positions_man, grouped_data.loc[1], width=bar_width, label='Man')
ax.bar(bar_positions_woman, grouped_data.loc[2], width=bar_width, label='Woman')

# Dodanie wartości procentowych do słupków
for i, column in enumerate(categories):
    for j, value in enumerate(grouped_data.index):
        total = grouped_data[column].sum()
        percent = grouped_data.loc[value, column] / total
        height = grouped_data.loc[value, column]
        ax.text(i + j * bar_width, height + 0.2, f'{percent:.0%}', ha='center', va='bottom')  

# Ustawienia etykiet i tytułów
plt.xlabel('Gender')
plt.ylabel('Count')
plt.title('Distribution of level by gender')

# Dodanie legendy
plt.legend(title='Level')

# Zmiana etykiet osi x
ax.set_xticks([pos + bar_width / 2 for pos in bar_positions_man])
ax.set_xticklabels(categories)

# Wyświetlenie wykresu
plt.show()
dane3 = [dane.groupby('Smoking').size()]
dane3 
[Smoking
 1    181
 2    222
 3    172
 4     59
 5     10
 6     60
 7    207
 8     89
 dtype: int64]
dane['Gender'] = dane['Gender'].replace({1: 'Man', 2: 'Woman'})
smoking_counts = dane.groupby(['Smoking', 'Gender']).size()

# Zamiana liczby na procent
smoking_percentages = smoking_counts / smoking_counts.groupby('Gender').sum() * 100

# Sortowanie danych według stopnia 'Smoking'
smoking_percentages_sorted = smoking_percentages.sort_index(level='Smoking', sort_remaining=False)

plt.figure(figsize=(14, 6))

# Tworzenie wykresu słupkowego poziomego
ax = smoking_percentages_sorted.plot(kind='barh')

# Dodawanie wartości procentowych do słupków
for i, value in enumerate(smoking_percentages_sorted):
    ax.text(value + 0.1, i, f'{value:.2f}%', ha='left', va='center')

# Dodanie legendy
plt.legend()

# Ustawienia etykiet i tytułów
plt.xlabel('Percentage')
plt.ylabel('Smoking, Gender')
plt.title('Distribution of smoking by gender (%)')


plt.show()
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
dane3 = [dane.groupby('Passive Smoker').size()]
dane3 
[Passive Smoker
 1     60
 2    284
 3    140
 4    161
 5     30
 6     30
 7    187
 8    108
 dtype: int64]

dane['Gender'] = dane['Gender'].replace({1: 'Man', 2: 'Woman'})
smoking_counts = dane.groupby(['Passive Smoker', 'Gender']).size()

# Zamiana liczby na procent
smoking_percentages = smoking_counts / smoking_counts.groupby('Gender').sum() * 100

# Sortowanie danych według stopnia 'Passive smoker'
smoking_percentages_sorted = smoking_percentages.sort_index(level='Passive Smoker', sort_remaining=False)

plt.figure(figsize=(15, 6))

# Tworzenie wykresu słupkowego poziomego
ax = smoking_percentages_sorted.plot(kind='barh')

# Dodawanie wartości procentowych do słupków
for i, value in enumerate(smoking_percentages_sorted):
    ax.text(value + 0.1, i, f'{value:.2f}%', ha='left', va='center')

# Dodanie legendy
plt.legend()

# Ustawienia etykiet i tytułów
plt.xlabel('Percentage')
plt.ylabel('Passive Smoker, Gender')
plt.title('Distribution of passive smokers by gender (%)')


plt.show()
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
dane['Gender'] = dane['Gender'].replace({1: 'Man', 2: 'Woman'})
Genetic_risk_counts = dane.groupby(['Genetic Risk', 'Gender']).size()
Genetic_risk_percentages= Genetic_risk_counts / Genetic_risk_counts.groupby('Gender').sum() * 100

Genetic_risk_percentages_sorted = Genetic_risk_percentages.sort_index(level='Genetic Risk', sort_remaining=False)

plt.figure(figsize=(15, 6))

# Tworzenie wykresu słupkowego poziomego
ax = Genetic_risk_percentages_sorted.plot(kind='barh')

# Dodawanie wartości procentowych do słupków
for i, value in enumerate(Genetic_risk_percentages_sorted):
    ax.text(value + 0.1, i, f'{value:.2f}%', ha='left', va='center')

# Dodanie legendy
plt.legend()

# Ustawienia etykiet i tytułów
plt.xlabel('Percentage')
plt.ylabel('Genetic risk, Gender')
plt.title('Distribution of genetic risk by gender (%)')

# Wyświetlenie wykresu
plt.show()
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.

Genetic_risk_counts = dane.groupby(['Genetic Risk', 'Level']).size()

# Sortowanie danych według ryzyka genetycznego i liczby w odwrotnej kolejności
Genetic_risk_counts_sorted = Genetic_risk_counts.sort_index(level=['Genetic Risk', 'Level'], key=lambda x: x.map({'High': 1, 'Medium': 2, 'Low': 3}))

plt.figure(figsize=(10, 6))

# Tworzenie wykresu słupkowego horyzontalnego
ax = Genetic_risk_counts_sorted.plot(kind='barh')

# Dodawanie wartości do słupków
for i, value in enumerate(Genetic_risk_counts_sorted):
    ax.text(value + 0.1, i, str(value), ha='left', va='center')

# Ustawienia etykiet i tytułów
plt.xlabel('Count')
plt.ylabel('Genetic Risk, Level')
plt.title('Distribution of level by genetic risk')

# Wyświetlenie wykresu
plt.show()
x = dane[(dane['Smoking'] >= 7) & (dane['Alcohol use'] >= 7)]
x_sorted = x[['Smoking', 'Alcohol use', 'Level']].sort_values(by=['Smoking', 'Alcohol use'], ascending=False)
x_sorted
Smoking Alcohol use Level
index
4 8 8 High
20 8 8 High
22 8 8 High
46 8 8 High
68 8 8 High
... ... ... ...
989 7 7 High
992 7 7 High
993 7 7 High
994 7 7 High
995 7 7 High

256 rows × 3 columns

dane7 = dane['Air Pollution'].value_counts()
dane7.sort_values()
Air Pollution
8     19
5     20
7     30
4     90
1    141
3    173
2    201
6    326
Name: count, dtype: int64
air_pollution = dane.groupby(['Air Pollution', 'Level']).size()

# Sortowanie danych według ryzyka genetycznego i liczby w odwrotnej kolejności
air_pollution_sorted = air_pollution.sort_index(level=['Air Pollution', 'Level'], key=lambda x: x.map({'High': 1, 'Medium': 2, 'Low': 3}))

plt.figure(figsize=(10, 6))

# Tworzenie wykresu słupkowego horyzontalnego
ax = air_pollution_sorted.plot(kind='barh')

# Dodawanie wartości do słupków
for i, value in enumerate(air_pollution_sorted):
    ax.text(value + 0.1, i, str(value), ha='left', va='center')

# Ustawienia etykiet i tytułów
plt.xlabel('Count')
plt.ylabel('Air Pollution, Level')
plt.title('Distribution of level by air pollution')

# Wyświetlenie wykresu
plt.show()
data = dane.replace({'Level':{'High' : 3, 'Medium' : 2, 'Low' : 1}})
data = data.drop(['Patient Id', 'Gender', 'Age'], axis=1)


corr_matrix = data.corr()


corr_df = corr_matrix.stack().reset_index()
corr_df.columns = ['x', 'y', 'value']


fig = px.imshow(corr_matrix, x=corr_matrix.columns.tolist(), y=corr_matrix.columns.tolist(), zmin=-1, zmax=1, color_continuous_scale=['blue', 'white', 'red'])



fig.update_layout(
    width=800,
    height=800, 
    coloraxis_colorbar=dict(
        title='Correlation', 
        tickvals=[-1, -0.5, 0, 0.5, 1],  
        ticktext=['-1', '-0.5', '0', '0.5', '1'], 
    ),
    coloraxis=dict(
        colorscale=[
            [0, 'rgb(58, 89, 156)'],  # Niebieski (cool)
            [0.5, 'rgb(255, 255, 255)'],  # Biały
            [1, 'rgb(179, 35, 26)']  # Czerwony (warm)
        ],  
    ),
    font=dict(
        family="Arial",  
        size=12, 
        color="black", 
    ))
    

# Wyświetlenie wykresu
fig.show()
import numpy as np
import sklearn 
data = dane.replace({'Level':{'High' : 3, 'Medium' : 2, 'Low' : 1}})
data['Gender'] = data['Gender'].replace({'Man' : 1, 'Woman' : 2})
data.head()
Patient Id Age Gender Air Pollution Alcohol use Dust Allergy OccuPational Hazards Genetic Risk chronic Lung Disease Balanced Diet ... Fatigue Weight Loss Shortness of Breath Wheezing Swallowing Difficulty Clubbing of Finger Nails Frequent Cold Dry Cough Snoring Level
index
0 P1 33 1 2 4 5 4 3 2 2 ... 3 4 2 2 3 1 2 3 4 1
1 P10 17 1 3 1 5 3 4 2 2 ... 1 3 7 8 6 2 1 7 2 2
2 P100 35 1 4 5 6 5 5 4 6 ... 8 7 9 2 1 4 6 7 2 3
3 P1000 37 1 7 7 7 7 6 7 7 ... 4 2 3 1 4 5 6 7 5 3
4 P101 46 1 6 8 7 7 7 6 7 ... 3 2 4 1 4 2 4 2 3 3

5 rows × 25 columns

np.random.seed(10)
np.set_printoptions(precision=6, suppress=True)
X = data.drop(['Level', 'Patient Id'], axis=1)
y = data['Level']


print("Y shape:", y.shape)
print("X shape:", X.shape)
Y shape: (1000,)
X shape: (1000, 23)
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split (X, y)




print("X_train shape:", X_train.shape)
print("y_train shape:", y_train.shape)
print("X_test shape:", X_test.shape)
print("y_test shape:", y_test.shape)
X_train shape: (750, 23)
y_train shape: (750,)
X_test shape: (250, 23)
y_test shape: (250,)

from sklearn.linear_model import LogisticRegression
#classifier = LogisticRegression(multi_class='multinomial', solver='lbfgs')
classifier = LogisticRegression(max_iter = 200)


classifier.fit(X_train, y_train)
c:\Users\HP\anaconda3\lib\site-packages\sklearn\linear_model\_logistic.py:460: ConvergenceWarning:

lbfgs failed to converge (status=1):
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression

LogisticRegression(max_iter=200)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression(max_iter=200)
y_prob = classifier.predict_proba(X_test)
y_prob
array([[0.01354 , 0.98631 , 0.00015 ],
       [0.      , 0.000005, 0.999995],
       [0.000005, 0.00068 , 0.999315],
       [0.999999, 0.000001, 0.      ],
       [0.000003, 0.000947, 0.99905 ],
       [0.001809, 0.982182, 0.016009],
       [0.      , 0.012048, 0.987952],
       [0.059547, 0.831086, 0.109367],
       [0.      , 0.075388, 0.924612],
       [0.039136, 0.960851, 0.000013],
       [0.000022, 0.003654, 0.996323],
       [0.006091, 0.993801, 0.000108],
       [0.017895, 0.982104, 0.000001],
       [0.998827, 0.001173, 0.      ],
       [0.928696, 0.070643, 0.000662],
       [0.033735, 0.966263, 0.000003],
       [0.      , 0.000113, 0.999887],
       [0.934007, 0.009674, 0.056319],
       [0.      , 0.000031, 0.999969],
       [0.999648, 0.000306, 0.000047],
       [0.000004, 0.00404 , 0.995956],
       [0.999251, 0.000216, 0.000533],
       [0.      , 0.00512 , 0.99488 ],
       [0.      , 0.004695, 0.995305],
       [0.      , 0.000124, 0.999876],
       [0.1259  , 0.859667, 0.014433],
       [0.059547, 0.831086, 0.109367],
       [0.991172, 0.008807, 0.000021],
       [0.999909, 0.00009 , 0.000001],
       [0.1259  , 0.859667, 0.014433],
       [0.976724, 0.023247, 0.000028],
       [0.008947, 0.990345, 0.000708],
       [0.      , 0.004695, 0.995305],
       [0.999251, 0.000216, 0.000533],
       [0.005118, 0.994787, 0.000095],
       [0.      , 0.00008 , 0.99992 ],
       [0.015368, 0.98458 , 0.000052],
       [0.      , 0.000003, 0.999997],
       [0.00583 , 0.994147, 0.000023],
       [0.993283, 0.00006 , 0.006656],
       [0.1259  , 0.859667, 0.014433],
       [0.004655, 0.989938, 0.005407],
       [0.      , 0.000134, 0.999866],
       [0.000007, 0.999991, 0.000002],
       [0.      , 0.012048, 0.987952],
       [0.000022, 0.003654, 0.996323],
       [0.931524, 0.068283, 0.000192],
       [0.      , 0.000031, 0.999969],
       [0.      , 0.000165, 0.999835],
       [0.009961, 0.988549, 0.00149 ],
       [0.006484, 0.974177, 0.019339],
       [0.00523 , 0.994198, 0.000572],
       [0.021549, 0.975853, 0.002598],
       [0.      , 0.000012, 0.999988],
       [0.009961, 0.988549, 0.00149 ],
       [0.000001, 0.000025, 0.999974],
       [0.      , 0.016222, 0.983778],
       [0.00003 , 0.010307, 0.989664],
       [0.993286, 0.006713, 0.000001],
       [0.000175, 0.040597, 0.959228],
       [0.000004, 0.00404 , 0.995956],
       [0.984367, 0.012436, 0.003197],
       [0.000005, 0.00068 , 0.999315],
       [0.967662, 0.024355, 0.007983],
       [0.      , 0.000134, 0.999866],
       [0.976724, 0.023247, 0.000028],
       [0.007089, 0.99289 , 0.000021],
       [0.      , 0.000165, 0.999835],
       [0.      , 0.00512 , 0.99488 ],
       [0.016788, 0.983198, 0.000014],
       [0.988196, 0.002318, 0.009486],
       [0.999947, 0.000021, 0.000032],
       [0.990009, 0.009991, 0.      ],
       [0.016788, 0.983198, 0.000014],
       [0.015399, 0.984596, 0.000005],
       [0.022511, 0.975873, 0.001616],
       [0.008947, 0.990345, 0.000708],
       [0.007156, 0.985278, 0.007566],
       [0.      , 0.00068 , 0.99932 ],
       [0.991172, 0.008807, 0.000021],
       [0.      , 0.000106, 0.999894],
       [0.993286, 0.006713, 0.000001],
       [0.      , 0.004695, 0.995305],
       [0.006484, 0.974177, 0.019339],
       [0.      , 0.00001 , 0.99999 ],
       [0.      , 0.03253 , 0.967469],
       [0.      , 0.000014, 0.999986],
       [0.000011, 0.999989, 0.      ],
       [0.991172, 0.008807, 0.000021],
       [0.      , 0.00068 , 0.99932 ],
       [0.013247, 0.986728, 0.000025],
       [0.003605, 0.971443, 0.024952],
       [0.999999, 0.000001, 0.      ],
       [0.931524, 0.068283, 0.000192],
       [0.00583 , 0.994147, 0.000023],
       [0.00583 , 0.994147, 0.000023],
       [0.1259  , 0.859667, 0.014433],
       [0.022062, 0.977641, 0.000297],
       [0.016788, 0.983198, 0.000014],
       [0.957888, 0.04211 , 0.000001],
       [0.      , 0.016222, 0.983778],
       [0.000003, 0.999994, 0.000003],
       [0.9999  , 0.0001  , 0.      ],
       [0.      , 0.012048, 0.987952],
       [0.999909, 0.00009 , 0.000001],
       [0.928696, 0.070643, 0.000662],
       [0.931524, 0.068283, 0.000192],
       [0.000003, 0.99999 , 0.000007],
       [0.      , 0.004695, 0.995305],
       [0.      , 0.000014, 0.999986],
       [0.993283, 0.00006 , 0.006656],
       [0.000003, 0.000947, 0.99905 ],
       [0.008109, 0.979381, 0.01251 ],
       [0.934007, 0.009674, 0.056319],
       [0.999648, 0.000306, 0.000047],
       [0.000022, 0.003654, 0.996323],
       [0.013247, 0.986728, 0.000025],
       [0.000011, 0.999989, 0.      ],
       [0.993286, 0.006713, 0.000001],
       [0.9999  , 0.0001  , 0.      ],
       [0.      , 0.000113, 0.999887],
       [0.      , 0.000063, 0.999937],
       [0.      , 0.00512 , 0.99488 ],
       [0.011601, 0.988358, 0.000041],
       [0.999909, 0.00009 , 0.000001],
       [0.000005, 0.00068 , 0.999315],
       [0.000175, 0.040597, 0.959228],
       [0.990009, 0.009991, 0.      ],
       [0.978169, 0.021712, 0.000119],
       [0.015368, 0.98458 , 0.000052],
       [0.022062, 0.977641, 0.000297],
       [0.      , 0.021094, 0.978906],
       [0.973999, 0.013972, 0.012029],
       [0.      , 0.004695, 0.995305],
       [0.996352, 0.003647, 0.000001],
       [0.967662, 0.024355, 0.007983],
       [0.985581, 0.000103, 0.014316],
       [0.      , 0.00512 , 0.99488 ],
       [0.984367, 0.012436, 0.003197],
       [0.      , 0.000468, 0.999532],
       [0.008109, 0.979381, 0.01251 ],
       [0.991172, 0.008807, 0.000021],
       [0.013247, 0.986728, 0.000025],
       [0.013247, 0.986728, 0.000025],
       [0.999999, 0.000001, 0.      ],
       [0.985581, 0.000103, 0.014316],
       [0.007156, 0.985278, 0.007566],
       [0.000022, 0.003654, 0.996323],
       [0.999947, 0.000021, 0.000032],
       [0.973999, 0.013972, 0.012029],
       [0.000022, 0.003654, 0.996323],
       [0.993286, 0.006713, 0.000001],
       [0.1259  , 0.859667, 0.014433],
       [0.008803, 0.991195, 0.000002],
       [0.008947, 0.990345, 0.000708],
       [0.      , 0.002567, 0.997433],
       [0.999999, 0.000001, 0.      ],
       [0.      , 0.004695, 0.995305],
       [0.      , 0.000013, 0.999987],
       [0.98178 , 0.01822 , 0.      ],
       [0.891574, 0.108425, 0.      ],
       [0.006622, 0.993257, 0.000122],
       [0.972232, 0.027758, 0.00001 ],
       [0.008109, 0.979381, 0.01251 ],
       [0.928696, 0.070643, 0.000662],
       [0.      , 0.000196, 0.999804],
       [0.000003, 0.999994, 0.000003],
       [0.991172, 0.008807, 0.000021],
       [0.021549, 0.975853, 0.002598],
       [0.999947, 0.000021, 0.000032],
       [0.      , 0.000134, 0.999866],
       [0.013247, 0.986728, 0.000025],
       [0.      , 0.00512 , 0.99488 ],
       [0.991172, 0.008807, 0.000021],
       [0.      , 0.00512 , 0.99488 ],
       [0.153033, 0.8185  , 0.028467],
       [0.008947, 0.990345, 0.000708],
       [0.009961, 0.988549, 0.00149 ],
       [0.015399, 0.984596, 0.000005],
       [0.000011, 0.999989, 0.      ],
       [0.022062, 0.977641, 0.000297],
       [0.999999, 0.000001, 0.      ],
       [0.967662, 0.024355, 0.007983],
       [0.006622, 0.993257, 0.000122],
       [0.015399, 0.984596, 0.000005],
       [0.      , 0.000196, 0.999804],
       [0.      , 0.016222, 0.983778],
       [0.      , 0.000124, 0.999876],
       [0.      , 0.001328, 0.998672],
       [0.000011, 0.999989, 0.      ],
       [0.022511, 0.975873, 0.001616],
       [0.999251, 0.000216, 0.000533],
       [0.999999, 0.000001, 0.      ],
       [0.      , 0.03253 , 0.967469],
       [0.      , 0.000063, 0.999937],
       [0.993283, 0.00006 , 0.006656],
       [0.007089, 0.99289 , 0.000021],
       [0.009961, 0.988549, 0.00149 ],
       [0.931524, 0.068283, 0.000192],
       [0.007156, 0.985278, 0.007566],
       [0.006622, 0.993257, 0.000122],
       [0.021549, 0.975853, 0.002598],
       [0.      , 0.000003, 0.999997],
       [0.009393, 0.990606, 0.000001],
       [0.      , 0.001767, 0.998233],
       [0.000175, 0.040597, 0.959228],
       [0.931524, 0.068283, 0.000192],
       [0.01354 , 0.98631 , 0.00015 ],
       [0.000011, 0.999989, 0.      ],
       [0.07487 , 0.902033, 0.023097],
       [0.      , 0.000113, 0.999887],
       [0.999999, 0.000001, 0.      ],
       [0.      , 0.000031, 0.999969],
       [0.      , 0.0011  , 0.9989  ],
       [0.891574, 0.108425, 0.      ],
       [0.957888, 0.04211 , 0.000001],
       [0.004898, 0.993937, 0.001166],
       [0.      , 0.      , 1.      ],
       [0.891574, 0.108425, 0.      ],
       [0.990009, 0.009991, 0.      ],
       [0.976724, 0.023247, 0.000028],
       [0.000007, 0.999991, 0.000002],
       [0.891574, 0.108425, 0.      ],
       [0.      , 0.000014, 0.999986],
       [0.015399, 0.984596, 0.000005],
       [0.003605, 0.971443, 0.024952],
       [0.1259  , 0.859667, 0.014433],
       [0.991172, 0.008807, 0.000021],
       [0.      , 0.00001 , 0.99999 ],
       [0.004655, 0.989938, 0.005407],
       [0.      , 0.000063, 0.999937],
       [0.999999, 0.000001, 0.      ],
       [0.007156, 0.985278, 0.007566],
       [0.000007, 0.999991, 0.000002],
       [0.      , 0.001638, 0.998362],
       [0.006622, 0.993257, 0.000122],
       [0.      , 0.000468, 0.999532],
       [0.978169, 0.021712, 0.000119],
       [0.891574, 0.108425, 0.      ],
       [0.      , 0.000196, 0.999804],
       [0.934007, 0.009674, 0.056319],
       [0.008109, 0.979381, 0.01251 ],
       [0.000004, 0.00404 , 0.995956],
       [0.022511, 0.975873, 0.001616],
       [0.      , 0.03253 , 0.967469],
       [0.00007 , 0.024082, 0.975848],
       [0.      , 0.000031, 0.999969],
       [0.      , 0.012048, 0.987952],
       [0.999947, 0.000021, 0.000032],
       [0.004898, 0.993937, 0.001166]])
y_pred = classifier.predict(X_test)
y_pred
array([2, 3, 3, 1, 3, 2, 3, 2, 3, 2, 3, 2, 2, 1, 1, 2, 3, 1, 3, 1, 3, 1,
       3, 3, 3, 2, 2, 1, 1, 2, 1, 2, 3, 1, 2, 3, 2, 3, 2, 1, 2, 2, 3, 2,
       3, 3, 1, 3, 3, 2, 2, 2, 2, 3, 2, 3, 3, 3, 1, 3, 3, 1, 3, 1, 3, 1,
       2, 3, 3, 2, 1, 1, 1, 2, 2, 2, 2, 2, 3, 1, 3, 1, 3, 2, 3, 3, 3, 2,
       1, 3, 2, 2, 1, 1, 2, 2, 2, 2, 2, 1, 3, 2, 1, 3, 1, 1, 1, 2, 3, 3,
       1, 3, 2, 1, 1, 3, 2, 2, 1, 1, 3, 3, 3, 2, 1, 3, 3, 1, 1, 2, 2, 3,
       1, 3, 1, 1, 1, 3, 1, 3, 2, 1, 2, 2, 1, 1, 2, 3, 1, 1, 3, 1, 2, 2,
       2, 3, 1, 3, 3, 1, 1, 2, 1, 2, 1, 3, 2, 1, 2, 1, 3, 2, 3, 1, 3, 2,
       2, 2, 2, 2, 2, 1, 1, 2, 2, 3, 3, 3, 3, 2, 2, 1, 1, 3, 3, 1, 2, 2,
       1, 2, 2, 2, 3, 2, 3, 3, 1, 2, 2, 2, 3, 1, 3, 3, 1, 1, 2, 3, 1, 1,
       1, 2, 1, 3, 2, 2, 2, 1, 3, 2, 3, 1, 2, 2, 3, 2, 3, 1, 1, 3, 1, 2,
       3, 2, 3, 3, 3, 3, 1, 2], dtype=int64)
pip install mlxtend
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score

from mlxtend.plotting import plot_confusion_matrix
import seaborn as sns
sns.set()
#cm = confusion_matrix(y_test, y_pred)
#plot_confusion_matrix(cm)

#acc = accuracy_score(y_test, y_pred)
#print('Accuracy',':', acc)
acc = accuracy_score(y_test, y_pred)
cm = confusion_matrix(y_test, y_pred)
def plot_confusion_matrix(cm):
    cm = cm[::-1]
    cm = pd.DataFrame(cm, columns=['pred_1', 'pred_2', 'pred_3'], index=['true_3', 'true_2', 'true_1'])
    fig = ff.create_annotated_heatmap(z = cm.values, x = list(cm.columns), y = list(cm.index), colorscale = 'ice', showscale = True, reversescale = True)
    fig.update_layout(width=500, height=500, title=f'Confusion Matrix - Accuracy: {acc:.4f}'. format(acc), font_size=16)
    fig.show()

plot_confusion_matrix(cm)
from sklearn.metrics import classification_report
print(classification_report(y_test, y_pred, target_names=['pred_1', 'pred_2', 'pred_3']))
              precision    recall  f1-score   support

      pred_1       1.00      1.00      1.00        76
      pred_2       1.00      1.00      1.00        89
      pred_3       1.00      1.00      1.00        85

    accuracy                           1.00       250
   macro avg       1.00      1.00      1.00       250
weighted avg       1.00      1.00      1.00       250