tensorflow
This commit is contained in:
parent
2358985132
commit
1ce93643ae
4
.gitignore
vendored
4
.gitignore
vendored
@ -1,2 +1,2 @@
|
||||
/venv
|
||||
/.ipynb_checkpoints
|
||||
__pycache__
|
||||
.ipynb_checkpoints
|
119
countries_map.py
Normal file
119
countries_map.py
Normal file
@ -0,0 +1,119 @@
|
||||
countries = {1: 'Albania',
|
||||
2: 'Antigua and Barbuda',
|
||||
3: 'Argentina',
|
||||
4: 'Armenia',
|
||||
5: 'Aruba',
|
||||
6: 'Australia',
|
||||
7: 'Austria',
|
||||
8: 'Azerbaijan',
|
||||
9: 'Bahamas',
|
||||
10: 'Bahrain',
|
||||
11: 'Barbados',
|
||||
12: 'Belarus',
|
||||
13: 'Belgium',
|
||||
14: 'Belize',
|
||||
15: 'Bermuda',
|
||||
16: 'Bosnia and Herzegovina',
|
||||
17: 'Brazil',
|
||||
18: 'Brunei Darussalam',
|
||||
19: 'Bulgaria',
|
||||
20: 'Cabo Verde',
|
||||
21: 'Canada',
|
||||
22: 'Cayman Islands',
|
||||
23: 'Chile',
|
||||
24: 'Colombia',
|
||||
25: 'Costa Rica',
|
||||
26: 'Croatia',
|
||||
27: 'Cyprus',
|
||||
28: 'Czech Republic',
|
||||
29: 'Denmark',
|
||||
30: 'Dominica',
|
||||
31: 'Ecuador',
|
||||
32: 'Egypt',
|
||||
33: 'El Salvador',
|
||||
34: 'Estonia',
|
||||
35: 'Fiji',
|
||||
36: 'Finland',
|
||||
37: 'France',
|
||||
38: 'French Guiana',
|
||||
39: 'Georgia',
|
||||
40: 'Germany',
|
||||
41: 'Greece',
|
||||
42: 'Grenada',
|
||||
43: 'Guadeloupe',
|
||||
44: 'Guatemala',
|
||||
45: 'Guyana',
|
||||
46: 'Hong Kong SAR',
|
||||
47: 'Hungary',
|
||||
48: 'Iceland',
|
||||
49: 'Iran (Islamic Rep of)',
|
||||
50: 'Ireland',
|
||||
51: 'Israel',
|
||||
52: 'Italy',
|
||||
53: 'Jamaica',
|
||||
54: 'Japan',
|
||||
55: 'Kazakhstan',
|
||||
56: 'Kiribati',
|
||||
57: 'Kuwait',
|
||||
58: 'Kyrgyzstan',
|
||||
59: 'Latvia',
|
||||
60: 'Lithuania',
|
||||
61: 'Luxembourg',
|
||||
62: 'Macau',
|
||||
63: 'Maldives',
|
||||
64: 'Malta',
|
||||
65: 'Martinique',
|
||||
66: 'Mauritius',
|
||||
67: 'Mayotte',
|
||||
68: 'Mexico',
|
||||
69: 'Mongolia',
|
||||
70: 'Montenegro',
|
||||
71: 'Netherlands',
|
||||
72: 'New Zealand',
|
||||
73: 'Nicaragua',
|
||||
74: 'Norway',
|
||||
75: 'Oman',
|
||||
76: 'Panama',
|
||||
77: 'Paraguay',
|
||||
78: 'Philippines',
|
||||
79: 'Poland',
|
||||
80: 'Portugal',
|
||||
81: 'Puerto Rico',
|
||||
82: 'Qatar',
|
||||
83: 'Republic of Korea',
|
||||
84: 'Republic of Moldova',
|
||||
85: 'Reunion',
|
||||
86: 'Rodrigues',
|
||||
87: 'Romania',
|
||||
88: 'Russian Federation',
|
||||
89: 'Saint Kitts and Nevis',
|
||||
90: 'Saint Lucia',
|
||||
91: 'Saint Vincent and Grenadines',
|
||||
92: 'San Marino',
|
||||
93: 'Sao Tome and Principe',
|
||||
94: 'Serbia',
|
||||
95: 'Seychelles',
|
||||
96: 'Singapore',
|
||||
97: 'Slovakia',
|
||||
98: 'Slovenia',
|
||||
99: 'South Africa',
|
||||
100: 'Spain',
|
||||
101: 'Sri Lanka',
|
||||
102: 'Suriname',
|
||||
103: 'Sweden',
|
||||
104: 'Switzerland',
|
||||
105: 'TFYR Macedonia',
|
||||
106: 'Thailand',
|
||||
107: 'Trinidad and Tobago',
|
||||
108: 'Turkey',
|
||||
109: 'Turkmenistan',
|
||||
110: 'Ukraine',
|
||||
111: 'United Arab Emirates',
|
||||
112: 'United Kingdom',
|
||||
113: 'United States of America',
|
||||
114: 'Uruguay',
|
||||
115: 'Uzbekistan',
|
||||
116: 'Venezuela (Bolivarian Republic of)',
|
||||
117: 'Virgin Islands (USA)',
|
||||
118: 'Cuba',
|
||||
}
|
62
training.py
Normal file
62
training.py
Normal file
@ -0,0 +1,62 @@
|
||||
import sys
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import os
|
||||
import tensorflow as tf
|
||||
from countries_map import countries
|
||||
|
||||
|
||||
def mapSet(set):
|
||||
age = {"5-14 years": 0, "15-24 years": 1, "25-34 years": 2,
|
||||
"35-54 years": 3, "55-74 years": 4, "75+ years": 5}
|
||||
sex = {"male": 0, "female": 1}
|
||||
|
||||
set["age"].replace(age, inplace=True)
|
||||
set["sex"].replace(sex, inplace=True)
|
||||
set["country"].replace({v: k for k, v in countries.items()}, inplace=True)
|
||||
|
||||
return set
|
||||
|
||||
|
||||
column_names = ["country", "year", "sex", "age", "suicides_no", "population"]
|
||||
feature_names = ["country", "year", "sex", "age", "population"]
|
||||
label_name = column_names[4]
|
||||
|
||||
sc = pd.read_csv('who_suicide_statistics.csv')
|
||||
|
||||
train, validate, test = np.split(sc.sample(frac=1, random_state=42),
|
||||
[int(.6*len(sc)), int(.8*len(sc))])
|
||||
train.dropna(inplace=True)
|
||||
validate.dropna(inplace=True)
|
||||
test.dropna(inplace=True)
|
||||
|
||||
train_n = mapSet(train)
|
||||
validate_n = mapSet(validate)
|
||||
test_n = mapSet(validate)
|
||||
|
||||
train_csv = pd.DataFrame.to_csv(train_n, index=False)
|
||||
|
||||
train_dataset = tf.data.experimental.make_csv_dataset(
|
||||
train_csv,
|
||||
1000,
|
||||
column_names=column_names,
|
||||
label_name=label_name,
|
||||
num_epochs=1)
|
||||
|
||||
features, labels = next(iter(train_dataset))
|
||||
print(features)
|
||||
|
||||
plt.scatter(features['year'],
|
||||
features['age'],
|
||||
c=labels,
|
||||
cmap='sex')
|
||||
|
||||
plt.xlabel("year")
|
||||
plt.ylabel("age")
|
||||
plt.show()
|
||||
|
||||
print("Features: {}".format(feature_names))
|
||||
print("Label: {}".format(label_name))
|
||||
|
||||
# print(train)
|
Loading…
Reference in New Issue
Block a user