from stack import Stack


class Node:
    def __init__(self, value):
        self.value = value
        self.left = None
        self.right = None

    def __repr__(self):
        return f"Node({self.value})"


class BinarySearchTree:
    def __init__(self):
        self.__root = None

    def insert(self, value):
        """Add new node to the tree (unless the value is already in the tree)."""
        self.__root = self.__insert(value, self.__root)

    def __insert(self, value, node: Node) -> Node:
        if node is None:
            return Node(value)
        elif value < node.value:
            node.left = self.__insert(value, node.left)
        elif value > node.value:
            node.right = self.__insert(value, node.right)
        return node

    def find(self, value):
        """Returns `True` if `value` is in the tree."""
        return self._find(value, self._root) is not None

    def _find(self, value, node):
        if node is None:
            return None
        elif node.value == value:
            return node
        elif value < node.value:
            return self._find(value, node.left)
        else:
            return self._find(value, node.right)

    def delete(self, value):
        self._root = self._delete(value, self._root)

    def _delete(self, value, node):
        if node is None:
            return None
        elif value < node.value:
            node.left = self._delete(value, node.left)
        elif value > node.value:
            node.right = self._delete(value, node.right)
        else:  # value == node.value
            if node.left is None and node.right is None:
                return None
            elif node.left is None:
                return node.right
            elif node.right is None:
                return node.left
            else:  # both sons
                m = self._find_min(node.right)
                node.value = m.value
                node.right = self._delete(m.value, node.right)
        return node

    def find_min(self):
        return self._find_min(self._root)

    def _find_min(self, node):
        if node is None:  # empty tree
            return None
        
        while node.left is not None:
            node = node.left
        
        return node

    def traverse(self, preorder=None, inorder=None, postorder=None):
        self._traverse(self.__root, preorder, inorder, postorder)

    def _traverse(self, node, preorder=None, inorder=None, postorder=None):
        if node is None:
            return
        
        if preorder is not None:
            preorder(node)

        if node.left is not None:
            self._traverse(node.left, preorder, inorder, postorder)

        if inorder is not None:
            inorder(node)
        
        if node.right is not None:
            self._traverse(node.right, preorder, inorder, postorder)

        if postorder is not None:
            postorder(node)

    def __repr__(self):
        representation = []
        def preorder(_):
            representation.append("(")
        def inorder(node):
            representation.append(str(node.value))
        def postorder(_):
            representation.append(")")
        self.traverse(preorder, inorder, postorder)
        return "".join(representation)

    # NOTE: this is slow (it takes O(n) time), it is easier to have a __count
    # property in the class and modify it on inserts and deletes.  
    def __len__(self):
        count = 0
        def increment(_):
            nonlocal count
            count += 1
        self.traverse(increment)
        return count

    def dfs(self):
        if self.__root is None:
            return

        stack = Stack()
        stack.push(self.__root)

        while not stack.is_empty():
            node = stack.pop()
            print(node)

            if node.left is not None:
                stack.push(node.left)
            if node.right is not None:
                stack.push(node.right)


tree = BinarySearchTree()
tree.insert(5)
tree.insert(6)
tree.insert(3)
tree.insert(4)

# tree.__root = tree.__insert(10, tree.__root)

tree.traverse()
print(tree)
print(len(tree))
tree.dfs()
