diff --git a/.gitignore b/.gitignore index edbe8ca..82da6c8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,2 @@ -/venv -/.ipynb_checkpoints \ No newline at end of file +__pycache__ +.ipynb_checkpoints \ No newline at end of file diff --git a/countries_map.py b/countries_map.py new file mode 100644 index 0000000..610f0be --- /dev/null +++ b/countries_map.py @@ -0,0 +1,119 @@ +countries = {1: 'Albania', + 2: 'Antigua and Barbuda', + 3: 'Argentina', + 4: 'Armenia', + 5: 'Aruba', + 6: 'Australia', + 7: 'Austria', + 8: 'Azerbaijan', + 9: 'Bahamas', + 10: 'Bahrain', + 11: 'Barbados', + 12: 'Belarus', + 13: 'Belgium', + 14: 'Belize', + 15: 'Bermuda', + 16: 'Bosnia and Herzegovina', + 17: 'Brazil', + 18: 'Brunei Darussalam', + 19: 'Bulgaria', + 20: 'Cabo Verde', + 21: 'Canada', + 22: 'Cayman Islands', + 23: 'Chile', + 24: 'Colombia', + 25: 'Costa Rica', + 26: 'Croatia', + 27: 'Cyprus', + 28: 'Czech Republic', + 29: 'Denmark', + 30: 'Dominica', + 31: 'Ecuador', + 32: 'Egypt', + 33: 'El Salvador', + 34: 'Estonia', + 35: 'Fiji', + 36: 'Finland', + 37: 'France', + 38: 'French Guiana', + 39: 'Georgia', + 40: 'Germany', + 41: 'Greece', + 42: 'Grenada', + 43: 'Guadeloupe', + 44: 'Guatemala', + 45: 'Guyana', + 46: 'Hong Kong SAR', + 47: 'Hungary', + 48: 'Iceland', + 49: 'Iran (Islamic Rep of)', + 50: 'Ireland', + 51: 'Israel', + 52: 'Italy', + 53: 'Jamaica', + 54: 'Japan', + 55: 'Kazakhstan', + 56: 'Kiribati', + 57: 'Kuwait', + 58: 'Kyrgyzstan', + 59: 'Latvia', + 60: 'Lithuania', + 61: 'Luxembourg', + 62: 'Macau', + 63: 'Maldives', + 64: 'Malta', + 65: 'Martinique', + 66: 'Mauritius', + 67: 'Mayotte', + 68: 'Mexico', + 69: 'Mongolia', + 70: 'Montenegro', + 71: 'Netherlands', + 72: 'New Zealand', + 73: 'Nicaragua', + 74: 'Norway', + 75: 'Oman', + 76: 'Panama', + 77: 'Paraguay', + 78: 'Philippines', + 79: 'Poland', + 80: 'Portugal', + 81: 'Puerto Rico', + 82: 'Qatar', + 83: 'Republic of Korea', + 84: 'Republic of Moldova', + 85: 'Reunion', + 86: 'Rodrigues', + 87: 'Romania', + 88: 'Russian Federation', + 89: 'Saint Kitts and Nevis', + 90: 'Saint Lucia', + 91: 'Saint Vincent and Grenadines', + 92: 'San Marino', + 93: 'Sao Tome and Principe', + 94: 'Serbia', + 95: 'Seychelles', + 96: 'Singapore', + 97: 'Slovakia', + 98: 'Slovenia', + 99: 'South Africa', + 100: 'Spain', + 101: 'Sri Lanka', + 102: 'Suriname', + 103: 'Sweden', + 104: 'Switzerland', + 105: 'TFYR Macedonia', + 106: 'Thailand', + 107: 'Trinidad and Tobago', + 108: 'Turkey', + 109: 'Turkmenistan', + 110: 'Ukraine', + 111: 'United Arab Emirates', + 112: 'United Kingdom', + 113: 'United States of America', + 114: 'Uruguay', + 115: 'Uzbekistan', + 116: 'Venezuela (Bolivarian Republic of)', + 117: 'Virgin Islands (USA)', + 118: 'Cuba', + } diff --git a/training.py b/training.py new file mode 100644 index 0000000..a4c7c63 --- /dev/null +++ b/training.py @@ -0,0 +1,62 @@ +import sys +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import os +import tensorflow as tf +from countries_map import countries + + +def mapSet(set): + age = {"5-14 years": 0, "15-24 years": 1, "25-34 years": 2, + "35-54 years": 3, "55-74 years": 4, "75+ years": 5} + sex = {"male": 0, "female": 1} + + set["age"].replace(age, inplace=True) + set["sex"].replace(sex, inplace=True) + set["country"].replace({v: k for k, v in countries.items()}, inplace=True) + + return set + + +column_names = ["country", "year", "sex", "age", "suicides_no", "population"] +feature_names = ["country", "year", "sex", "age", "population"] +label_name = column_names[4] + +sc = pd.read_csv('who_suicide_statistics.csv') + +train, validate, test = np.split(sc.sample(frac=1, random_state=42), + [int(.6*len(sc)), int(.8*len(sc))]) +train.dropna(inplace=True) +validate.dropna(inplace=True) +test.dropna(inplace=True) + +train_n = mapSet(train) +validate_n = mapSet(validate) +test_n = mapSet(validate) + +train_csv = pd.DataFrame.to_csv(train_n, index=False) + +train_dataset = tf.data.experimental.make_csv_dataset( + train_csv, + 1000, + column_names=column_names, + label_name=label_name, + num_epochs=1) + +features, labels = next(iter(train_dataset)) +print(features) + +plt.scatter(features['year'], + features['age'], + c=labels, + cmap='sex') + +plt.xlabel("year") +plt.ylabel("age") +plt.show() + +print("Features: {}".format(feature_names)) +print("Label: {}".format(label_name)) + +# print(train)