tensorflow

This commit is contained in:
Maciej Sobkowiak 2021-04-26 02:14:45 +02:00
parent 2358985132
commit 1ce93643ae
3 changed files with 183 additions and 2 deletions

4
.gitignore vendored
View File

@ -1,2 +1,2 @@
/venv
/.ipynb_checkpoints
__pycache__
.ipynb_checkpoints

119
countries_map.py Normal file
View File

@ -0,0 +1,119 @@
countries = {1: 'Albania',
2: 'Antigua and Barbuda',
3: 'Argentina',
4: 'Armenia',
5: 'Aruba',
6: 'Australia',
7: 'Austria',
8: 'Azerbaijan',
9: 'Bahamas',
10: 'Bahrain',
11: 'Barbados',
12: 'Belarus',
13: 'Belgium',
14: 'Belize',
15: 'Bermuda',
16: 'Bosnia and Herzegovina',
17: 'Brazil',
18: 'Brunei Darussalam',
19: 'Bulgaria',
20: 'Cabo Verde',
21: 'Canada',
22: 'Cayman Islands',
23: 'Chile',
24: 'Colombia',
25: 'Costa Rica',
26: 'Croatia',
27: 'Cyprus',
28: 'Czech Republic',
29: 'Denmark',
30: 'Dominica',
31: 'Ecuador',
32: 'Egypt',
33: 'El Salvador',
34: 'Estonia',
35: 'Fiji',
36: 'Finland',
37: 'France',
38: 'French Guiana',
39: 'Georgia',
40: 'Germany',
41: 'Greece',
42: 'Grenada',
43: 'Guadeloupe',
44: 'Guatemala',
45: 'Guyana',
46: 'Hong Kong SAR',
47: 'Hungary',
48: 'Iceland',
49: 'Iran (Islamic Rep of)',
50: 'Ireland',
51: 'Israel',
52: 'Italy',
53: 'Jamaica',
54: 'Japan',
55: 'Kazakhstan',
56: 'Kiribati',
57: 'Kuwait',
58: 'Kyrgyzstan',
59: 'Latvia',
60: 'Lithuania',
61: 'Luxembourg',
62: 'Macau',
63: 'Maldives',
64: 'Malta',
65: 'Martinique',
66: 'Mauritius',
67: 'Mayotte',
68: 'Mexico',
69: 'Mongolia',
70: 'Montenegro',
71: 'Netherlands',
72: 'New Zealand',
73: 'Nicaragua',
74: 'Norway',
75: 'Oman',
76: 'Panama',
77: 'Paraguay',
78: 'Philippines',
79: 'Poland',
80: 'Portugal',
81: 'Puerto Rico',
82: 'Qatar',
83: 'Republic of Korea',
84: 'Republic of Moldova',
85: 'Reunion',
86: 'Rodrigues',
87: 'Romania',
88: 'Russian Federation',
89: 'Saint Kitts and Nevis',
90: 'Saint Lucia',
91: 'Saint Vincent and Grenadines',
92: 'San Marino',
93: 'Sao Tome and Principe',
94: 'Serbia',
95: 'Seychelles',
96: 'Singapore',
97: 'Slovakia',
98: 'Slovenia',
99: 'South Africa',
100: 'Spain',
101: 'Sri Lanka',
102: 'Suriname',
103: 'Sweden',
104: 'Switzerland',
105: 'TFYR Macedonia',
106: 'Thailand',
107: 'Trinidad and Tobago',
108: 'Turkey',
109: 'Turkmenistan',
110: 'Ukraine',
111: 'United Arab Emirates',
112: 'United Kingdom',
113: 'United States of America',
114: 'Uruguay',
115: 'Uzbekistan',
116: 'Venezuela (Bolivarian Republic of)',
117: 'Virgin Islands (USA)',
118: 'Cuba',
}

62
training.py Normal file
View File

@ -0,0 +1,62 @@
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import tensorflow as tf
from countries_map import countries
def mapSet(set):
age = {"5-14 years": 0, "15-24 years": 1, "25-34 years": 2,
"35-54 years": 3, "55-74 years": 4, "75+ years": 5}
sex = {"male": 0, "female": 1}
set["age"].replace(age, inplace=True)
set["sex"].replace(sex, inplace=True)
set["country"].replace({v: k for k, v in countries.items()}, inplace=True)
return set
column_names = ["country", "year", "sex", "age", "suicides_no", "population"]
feature_names = ["country", "year", "sex", "age", "population"]
label_name = column_names[4]
sc = pd.read_csv('who_suicide_statistics.csv')
train, validate, test = np.split(sc.sample(frac=1, random_state=42),
[int(.6*len(sc)), int(.8*len(sc))])
train.dropna(inplace=True)
validate.dropna(inplace=True)
test.dropna(inplace=True)
train_n = mapSet(train)
validate_n = mapSet(validate)
test_n = mapSet(validate)
train_csv = pd.DataFrame.to_csv(train_n, index=False)
train_dataset = tf.data.experimental.make_csv_dataset(
train_csv,
1000,
column_names=column_names,
label_name=label_name,
num_epochs=1)
features, labels = next(iter(train_dataset))
print(features)
plt.scatter(features['year'],
features['age'],
c=labels,
cmap='sex')
plt.xlabel("year")
plt.ylabel("age")
plt.show()
print("Features: {}".format(feature_names))
print("Label: {}".format(label_name))
# print(train)