Inzynierka/Lib/site-packages/sklearn/tree/tests/test_reingold_tilford.py
2023-06-02 12:51:02 +02:00

49 lines
1.4 KiB
Python

import numpy as np
import pytest
from sklearn.tree._reingold_tilford import buchheim, Tree
simple_tree = Tree("", 0, Tree("", 1), Tree("", 2))
bigger_tree = Tree(
"",
0,
Tree(
"",
1,
Tree("", 3),
Tree("", 4, Tree("", 7), Tree("", 8)),
),
Tree("", 2, Tree("", 5), Tree("", 6)),
)
@pytest.mark.parametrize("tree, n_nodes", [(simple_tree, 3), (bigger_tree, 9)])
def test_buchheim(tree, n_nodes):
def walk_tree(draw_tree):
res = [(draw_tree.x, draw_tree.y)]
for child in draw_tree.children:
# parents higher than children:
assert child.y == draw_tree.y + 1
res.extend(walk_tree(child))
if len(draw_tree.children):
# these trees are always binary
# parents are centered above children
assert (
draw_tree.x == (draw_tree.children[0].x + draw_tree.children[1].x) / 2
)
return res
layout = buchheim(tree)
coordinates = walk_tree(layout)
assert len(coordinates) == n_nodes
# test that x values are unique per depth / level
# we could also do it quicker using defaultdicts..
depth = 0
while True:
x_at_this_depth = [node[0] for node in coordinates if node[1] == depth]
if not x_at_this_depth:
# reached all leafs
break
assert len(np.unique(x_at_this_depth)) == len(x_at_this_depth)
depth += 1