tensorflow
This commit is contained in:
parent
2358985132
commit
1ce93643ae
4
.gitignore
vendored
4
.gitignore
vendored
@ -1,2 +1,2 @@
|
|||||||
/venv
|
__pycache__
|
||||||
/.ipynb_checkpoints
|
.ipynb_checkpoints
|
119
countries_map.py
Normal file
119
countries_map.py
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
countries = {1: 'Albania',
|
||||||
|
2: 'Antigua and Barbuda',
|
||||||
|
3: 'Argentina',
|
||||||
|
4: 'Armenia',
|
||||||
|
5: 'Aruba',
|
||||||
|
6: 'Australia',
|
||||||
|
7: 'Austria',
|
||||||
|
8: 'Azerbaijan',
|
||||||
|
9: 'Bahamas',
|
||||||
|
10: 'Bahrain',
|
||||||
|
11: 'Barbados',
|
||||||
|
12: 'Belarus',
|
||||||
|
13: 'Belgium',
|
||||||
|
14: 'Belize',
|
||||||
|
15: 'Bermuda',
|
||||||
|
16: 'Bosnia and Herzegovina',
|
||||||
|
17: 'Brazil',
|
||||||
|
18: 'Brunei Darussalam',
|
||||||
|
19: 'Bulgaria',
|
||||||
|
20: 'Cabo Verde',
|
||||||
|
21: 'Canada',
|
||||||
|
22: 'Cayman Islands',
|
||||||
|
23: 'Chile',
|
||||||
|
24: 'Colombia',
|
||||||
|
25: 'Costa Rica',
|
||||||
|
26: 'Croatia',
|
||||||
|
27: 'Cyprus',
|
||||||
|
28: 'Czech Republic',
|
||||||
|
29: 'Denmark',
|
||||||
|
30: 'Dominica',
|
||||||
|
31: 'Ecuador',
|
||||||
|
32: 'Egypt',
|
||||||
|
33: 'El Salvador',
|
||||||
|
34: 'Estonia',
|
||||||
|
35: 'Fiji',
|
||||||
|
36: 'Finland',
|
||||||
|
37: 'France',
|
||||||
|
38: 'French Guiana',
|
||||||
|
39: 'Georgia',
|
||||||
|
40: 'Germany',
|
||||||
|
41: 'Greece',
|
||||||
|
42: 'Grenada',
|
||||||
|
43: 'Guadeloupe',
|
||||||
|
44: 'Guatemala',
|
||||||
|
45: 'Guyana',
|
||||||
|
46: 'Hong Kong SAR',
|
||||||
|
47: 'Hungary',
|
||||||
|
48: 'Iceland',
|
||||||
|
49: 'Iran (Islamic Rep of)',
|
||||||
|
50: 'Ireland',
|
||||||
|
51: 'Israel',
|
||||||
|
52: 'Italy',
|
||||||
|
53: 'Jamaica',
|
||||||
|
54: 'Japan',
|
||||||
|
55: 'Kazakhstan',
|
||||||
|
56: 'Kiribati',
|
||||||
|
57: 'Kuwait',
|
||||||
|
58: 'Kyrgyzstan',
|
||||||
|
59: 'Latvia',
|
||||||
|
60: 'Lithuania',
|
||||||
|
61: 'Luxembourg',
|
||||||
|
62: 'Macau',
|
||||||
|
63: 'Maldives',
|
||||||
|
64: 'Malta',
|
||||||
|
65: 'Martinique',
|
||||||
|
66: 'Mauritius',
|
||||||
|
67: 'Mayotte',
|
||||||
|
68: 'Mexico',
|
||||||
|
69: 'Mongolia',
|
||||||
|
70: 'Montenegro',
|
||||||
|
71: 'Netherlands',
|
||||||
|
72: 'New Zealand',
|
||||||
|
73: 'Nicaragua',
|
||||||
|
74: 'Norway',
|
||||||
|
75: 'Oman',
|
||||||
|
76: 'Panama',
|
||||||
|
77: 'Paraguay',
|
||||||
|
78: 'Philippines',
|
||||||
|
79: 'Poland',
|
||||||
|
80: 'Portugal',
|
||||||
|
81: 'Puerto Rico',
|
||||||
|
82: 'Qatar',
|
||||||
|
83: 'Republic of Korea',
|
||||||
|
84: 'Republic of Moldova',
|
||||||
|
85: 'Reunion',
|
||||||
|
86: 'Rodrigues',
|
||||||
|
87: 'Romania',
|
||||||
|
88: 'Russian Federation',
|
||||||
|
89: 'Saint Kitts and Nevis',
|
||||||
|
90: 'Saint Lucia',
|
||||||
|
91: 'Saint Vincent and Grenadines',
|
||||||
|
92: 'San Marino',
|
||||||
|
93: 'Sao Tome and Principe',
|
||||||
|
94: 'Serbia',
|
||||||
|
95: 'Seychelles',
|
||||||
|
96: 'Singapore',
|
||||||
|
97: 'Slovakia',
|
||||||
|
98: 'Slovenia',
|
||||||
|
99: 'South Africa',
|
||||||
|
100: 'Spain',
|
||||||
|
101: 'Sri Lanka',
|
||||||
|
102: 'Suriname',
|
||||||
|
103: 'Sweden',
|
||||||
|
104: 'Switzerland',
|
||||||
|
105: 'TFYR Macedonia',
|
||||||
|
106: 'Thailand',
|
||||||
|
107: 'Trinidad and Tobago',
|
||||||
|
108: 'Turkey',
|
||||||
|
109: 'Turkmenistan',
|
||||||
|
110: 'Ukraine',
|
||||||
|
111: 'United Arab Emirates',
|
||||||
|
112: 'United Kingdom',
|
||||||
|
113: 'United States of America',
|
||||||
|
114: 'Uruguay',
|
||||||
|
115: 'Uzbekistan',
|
||||||
|
116: 'Venezuela (Bolivarian Republic of)',
|
||||||
|
117: 'Virgin Islands (USA)',
|
||||||
|
118: 'Cuba',
|
||||||
|
}
|
62
training.py
Normal file
62
training.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
import sys
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import os
|
||||||
|
import tensorflow as tf
|
||||||
|
from countries_map import countries
|
||||||
|
|
||||||
|
|
||||||
|
def mapSet(set):
|
||||||
|
age = {"5-14 years": 0, "15-24 years": 1, "25-34 years": 2,
|
||||||
|
"35-54 years": 3, "55-74 years": 4, "75+ years": 5}
|
||||||
|
sex = {"male": 0, "female": 1}
|
||||||
|
|
||||||
|
set["age"].replace(age, inplace=True)
|
||||||
|
set["sex"].replace(sex, inplace=True)
|
||||||
|
set["country"].replace({v: k for k, v in countries.items()}, inplace=True)
|
||||||
|
|
||||||
|
return set
|
||||||
|
|
||||||
|
|
||||||
|
column_names = ["country", "year", "sex", "age", "suicides_no", "population"]
|
||||||
|
feature_names = ["country", "year", "sex", "age", "population"]
|
||||||
|
label_name = column_names[4]
|
||||||
|
|
||||||
|
sc = pd.read_csv('who_suicide_statistics.csv')
|
||||||
|
|
||||||
|
train, validate, test = np.split(sc.sample(frac=1, random_state=42),
|
||||||
|
[int(.6*len(sc)), int(.8*len(sc))])
|
||||||
|
train.dropna(inplace=True)
|
||||||
|
validate.dropna(inplace=True)
|
||||||
|
test.dropna(inplace=True)
|
||||||
|
|
||||||
|
train_n = mapSet(train)
|
||||||
|
validate_n = mapSet(validate)
|
||||||
|
test_n = mapSet(validate)
|
||||||
|
|
||||||
|
train_csv = pd.DataFrame.to_csv(train_n, index=False)
|
||||||
|
|
||||||
|
train_dataset = tf.data.experimental.make_csv_dataset(
|
||||||
|
train_csv,
|
||||||
|
1000,
|
||||||
|
column_names=column_names,
|
||||||
|
label_name=label_name,
|
||||||
|
num_epochs=1)
|
||||||
|
|
||||||
|
features, labels = next(iter(train_dataset))
|
||||||
|
print(features)
|
||||||
|
|
||||||
|
plt.scatter(features['year'],
|
||||||
|
features['age'],
|
||||||
|
c=labels,
|
||||||
|
cmap='sex')
|
||||||
|
|
||||||
|
plt.xlabel("year")
|
||||||
|
plt.ylabel("age")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
print("Features: {}".format(feature_names))
|
||||||
|
print("Label: {}".format(label_name))
|
||||||
|
|
||||||
|
# print(train)
|
Loading…
Reference in New Issue
Block a user