Decision tree created

This commit is contained in:
Robert Bendun 2021-05-24 07:50:59 +02:00
parent 565a61e77a
commit 1b71852d8b

View File

@ -1,10 +1,59 @@
#!/usr/bin/python3
from sklearn import tree
from pprint import PrettyPrinter
pp = PrettyPrinter(indent=2, compact=True)
def p(*args, **kwargs):
pp.pprint(*args, **kwargs)
def invoke_consume_exceptions(function, *args, **kwargs):
try:
return function(*args, **kwargs)
except:
return None
def read_tsv_from(filename):
from csv import reader
with open(filename, 'r') as f:
header, *rows = list(reader(f, delimiter='\t'))
return [dict(zip(header, row)) for row in rows]
return [dict(zip(header, (el.strip() for el in row))) for row in rows]
print(read_tsv_from('./data.tsv'))
def main():
from sys import argv
import os
import pathlib
source_file = argv[1]
invoke_consume_exceptions(os.mkdir, os.path.dirname(source_file))
data = read_tsv_from(source_file)
types = dict()
for row in data:
for (key, value) in row.items():
if key != "nazwa":
v = types.get(key, set())
v.add(value)
types[key] = v
base = dict(
(key, list(zip(sorted(values), range(1000))))
for (key, values) in types.items()
)
t2n = dict((key, dict(v)) for (key, v) in base.items())
n2t = dict((key, dict((b, a) for (a, b) in v)) for (key, v) in base.items())
X = [[
t2n[name][feature]
for (name, feature) in sample.items()
if name not in ['nazwa', 'polka']] for sample in data if 'polka' in sample]
Y = [t2n['polka'][sample['polka']] for sample in data if 'polka' in sample]
clf = tree.DecisionTreeClassifier()
clf.fit(X, Y)
l = clf.get_n_leaves()
d = clf.get_depth()
print(f'Leaves: {l}\nDepth: {d}')
if __name__ == '__main__':
main()