Add fixed bayes classification
This commit is contained in:
parent
0517754510
commit
f58a057024
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
env
|
6
README
Normal file
6
README
Normal file
@ -0,0 +1,6 @@
|
||||
Instalacja:
|
||||
|
||||
python -m venv env
|
||||
source ./env/bin/activate
|
||||
|
||||
pip install -r requirements.txt
|
116
bayes.py
Normal file
116
bayes.py
Normal file
@ -0,0 +1,116 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
#Wczytanie i normalizacja danych
|
||||
def NormalizeData(data):
|
||||
for col in data.columns:
|
||||
if data[col].dtype == object:
|
||||
data[col] = data[col].str.lower()
|
||||
if col == 'smoking_status':
|
||||
data[col] = data[col].str.replace(" ", "_")
|
||||
if col == 'work_type':
|
||||
data[col] = data[col].str.replace("-", "_")
|
||||
if col == 'bmi':
|
||||
bins = [0, 21, 28, 40]
|
||||
labels=['low','mid','high']
|
||||
data[col] = pd.cut(data[col], bins=bins, labels=labels)
|
||||
if col == 'age':
|
||||
bins = [18, 30, 40, 50, 60, 70, 120]
|
||||
labels = ['18-29', '30-39', '40-49', '50-59', '60-69', '70+']
|
||||
data[col] = pd.cut(data[col], bins, labels = labels,include_lowest = True)
|
||||
if col == 'stroke':
|
||||
data[col] = data[col].replace({1: 'yes'})
|
||||
data[col] = data[col].replace({0: 'no'})
|
||||
if col == 'hypertension':
|
||||
data[col] = data[col].replace({1: 'yes'})
|
||||
data[col] = data[col].replace({0: 'no'})
|
||||
if col == 'heart_disease':
|
||||
data[col] = data[col].replace({1: 'yes'})
|
||||
data[col] = data[col].replace({0: 'no'})
|
||||
data = data.dropna()
|
||||
return data
|
||||
|
||||
def count_a_priori_prob(dataset):
|
||||
is_stroke_amount = len(dataset[dataset.stroke == 'yes'])
|
||||
no_stroke_amount = len(dataset[dataset.stroke == 'no'])
|
||||
data_length = len(dataset.stroke)
|
||||
return {'yes': float(is_stroke_amount)/float(data_length), 'no': float(no_stroke_amount)/float(data_length)}
|
||||
|
||||
def separate_labels_from_properties(X_train):
|
||||
|
||||
labels = X_train.columns
|
||||
labels_values = {}
|
||||
for label in labels:
|
||||
labels_values[label] = set(X_train[label])
|
||||
|
||||
to_return = []
|
||||
for x in labels:
|
||||
to_return.append({x: labels_values[x]})
|
||||
|
||||
return to_return
|
||||
|
||||
data = pd.read_csv("healthcare-dataset-stroke-data.csv")
|
||||
data = NormalizeData(data)
|
||||
|
||||
#podział danych na treningowy i testowy
|
||||
data_train, data_test = train_test_split(data, random_state = 42)
|
||||
|
||||
#rozdzielenie etykiet i cech
|
||||
X_train =data_train[['gender', 'age', 'ever_married', 'Residence_type', 'bmi','smoking_status', 'work_type','hypertension','heart_disease']]
|
||||
Y_train = data_train['stroke']
|
||||
|
||||
#rozdzielenie etykiet i cech
|
||||
# Dane wejściowe - zbiór danych, wektor etykiet, wektor prawdopodobieństw a priori dla klas.
|
||||
|
||||
# Wygenerowanie wektora prawdopodobieństw a priori dla klas.
|
||||
a_priori_prob = count_a_priori_prob(data_train)
|
||||
labels = separate_labels_from_properties(X_train)
|
||||
|
||||
class NaiveBayes():
|
||||
def __init__(self, dataset, labels, a_priori_prob):
|
||||
self.dataset = dataset
|
||||
self.labels = labels
|
||||
self.a_priori_prob = a_priori_prob
|
||||
|
||||
def count_bayes(self):
|
||||
|
||||
final_probs = {'top_yes': 0.0, 'top_no': 0.0, 'total': 0.0}
|
||||
|
||||
# self.labels - Wartości etykiet które nas interesują, opcjonalnie podane sa wszystkie.
|
||||
# [{'gender': {'female', 'male', 'other'}}, {'age': {'50-59', '40-49', '60-69', '70+', '18-29', '30-39'}}, {'ever_married': {'no', 'yes'}}, {'Residence_type': {'rural', 'urban'}}, {'bmi': {'high', 'mid', 'low'}}, {'smoking_status': {'unknown', 'smokes', 'never_smoked', 'formerly_smoked'}}, {'work_type': {'self_employed', 'private', 'never_worked', 'govt_job'}}, {'hypertension': {'no', 'yes'}}, {'heart_disease': {'no', 'yes'}}]
|
||||
# Dla kazdej z klas - 'yes', 'no'
|
||||
for idx, cls in enumerate(list(set(self.dataset['stroke']))):
|
||||
label_probs = []
|
||||
for label in self.labels:
|
||||
label_name = list(label.keys())[0]
|
||||
for label_value in label[label_name]:
|
||||
# Oblicz ilość występowania danej cechy w zbiorze danych np. heart_disease.yes
|
||||
|
||||
amount_label_value_yes_class = len(self.dataset.loc[(self.dataset['stroke'] == 'yes') & (self.dataset[label_name] == label_value)])
|
||||
amount_label_value_no_class = len(self.dataset.loc[(self.dataset['stroke'] == 'no') & (self.dataset[label_name] == label_value)])
|
||||
|
||||
amount_yes_class = len(self.dataset.loc[(self.dataset['stroke'] == 'yes')])
|
||||
amount_no_class = len(self.dataset.loc[(self.dataset['stroke'] == 'no')])
|
||||
# Obliczenie P(heart_disease.yes|'stroke'|), P(heart_disease.yes|'no stroke') itd. dla kazdej cechy.
|
||||
# Zapisujemy do listy w formacie (cecha.wartość: prob stroke, cecha.wartość: prob no stroke)
|
||||
label_probs.append({str(label_name + "." + label_value):(amount_label_value_yes_class/amount_yes_class, amount_label_value_no_class/amount_no_class)})
|
||||
|
||||
# Suma prawdopodobienstw mozliwych wartosci danej cechy dla danej klasy, powinna sumować się do 1.
|
||||
print(label_probs)
|
||||
|
||||
# Obliczanie licznika wzoru Bayesa (mnozymy wartosci prob cech z prawdop apriori danej klasy):
|
||||
top = 1
|
||||
for label_prob in label_probs:
|
||||
top *= list(label_prob.values())[0][idx]
|
||||
top *= self.a_priori_prob[cls]
|
||||
|
||||
final_probs[cls] = top
|
||||
final_probs['total'] += top
|
||||
|
||||
print("Prawdopodobieństwo a posteriori dla klasy yes-stroke", final_probs['yes']/final_probs['total'])
|
||||
print("Prawdopodobieństwo a posteriori dla klasy no-stroke", final_probs['no']/final_probs['total'])
|
||||
|
||||
labels = [{'Residence_type': {'urban'}}]
|
||||
naive_bayes = NaiveBayes(data_train, labels, a_priori_prob)
|
||||
naive_bayes.count_bayes()
|
10
requirements.txt
Normal file
10
requirements.txt
Normal file
@ -0,0 +1,10 @@
|
||||
joblib==1.0.1
|
||||
numpy==1.20.3
|
||||
pandas==1.2.4
|
||||
python-dateutil==2.8.1
|
||||
pytz==2021.1
|
||||
scikit-learn==0.24.2
|
||||
scipy==1.6.3
|
||||
six==1.16.0
|
||||
sklearn==0.0
|
||||
threadpoolctl==2.1.0
|
Loading…
Reference in New Issue
Block a user