Decision tree created
This commit is contained in:
parent
565a61e77a
commit
1b71852d8b
@ -1,10 +1,59 @@
|
|||||||
#!/usr/bin/python3
|
#!/usr/bin/python3
|
||||||
from sklearn import tree
|
from sklearn import tree
|
||||||
|
from pprint import PrettyPrinter
|
||||||
|
|
||||||
|
pp = PrettyPrinter(indent=2, compact=True)
|
||||||
|
def p(*args, **kwargs):
|
||||||
|
pp.pprint(*args, **kwargs)
|
||||||
|
|
||||||
|
def invoke_consume_exceptions(function, *args, **kwargs):
|
||||||
|
try:
|
||||||
|
return function(*args, **kwargs)
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
def read_tsv_from(filename):
|
def read_tsv_from(filename):
|
||||||
from csv import reader
|
from csv import reader
|
||||||
with open(filename, 'r') as f:
|
with open(filename, 'r') as f:
|
||||||
header, *rows = list(reader(f, delimiter='\t'))
|
header, *rows = list(reader(f, delimiter='\t'))
|
||||||
return [dict(zip(header, row)) for row in rows]
|
return [dict(zip(header, (el.strip() for el in row))) for row in rows]
|
||||||
|
|
||||||
print(read_tsv_from('./data.tsv'))
|
def main():
|
||||||
|
from sys import argv
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
source_file = argv[1]
|
||||||
|
invoke_consume_exceptions(os.mkdir, os.path.dirname(source_file))
|
||||||
|
data = read_tsv_from(source_file)
|
||||||
|
|
||||||
|
types = dict()
|
||||||
|
for row in data:
|
||||||
|
for (key, value) in row.items():
|
||||||
|
if key != "nazwa":
|
||||||
|
v = types.get(key, set())
|
||||||
|
v.add(value)
|
||||||
|
types[key] = v
|
||||||
|
|
||||||
|
base = dict(
|
||||||
|
(key, list(zip(sorted(values), range(1000))))
|
||||||
|
for (key, values) in types.items()
|
||||||
|
)
|
||||||
|
|
||||||
|
t2n = dict((key, dict(v)) for (key, v) in base.items())
|
||||||
|
n2t = dict((key, dict((b, a) for (a, b) in v)) for (key, v) in base.items())
|
||||||
|
|
||||||
|
X = [[
|
||||||
|
t2n[name][feature]
|
||||||
|
for (name, feature) in sample.items()
|
||||||
|
if name not in ['nazwa', 'polka']] for sample in data if 'polka' in sample]
|
||||||
|
|
||||||
|
Y = [t2n['polka'][sample['polka']] for sample in data if 'polka' in sample]
|
||||||
|
|
||||||
|
clf = tree.DecisionTreeClassifier()
|
||||||
|
clf.fit(X, Y)
|
||||||
|
l = clf.get_n_leaves()
|
||||||
|
d = clf.get_depth()
|
||||||
|
print(f'Leaves: {l}\nDepth: {d}')
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Loading…
Reference in New Issue
Block a user