Post

Tree - Trim a Binary Search Tree

All diagrams presented herein are original creations, meticulously designed to enhance comprehension and recall. Crafting these aids required considerable effort, and I kindly request attribution if this content is reused elsewhere.

Difficulty : Easy

DFS

Problem

Given the root of a binary search tree and the lowest and highest boundaries as low and high, trim the tree so that all its elements lies in [low, high]. Trimming the tree should not change the relative structure of the elements that will remain in the tree (i.e., any node’s descendant should remain a descendant). It can be proven that there is a unique answer.

Return the root of the trimmed binary search tree. Note that the root may change depending on the given bounds.

Example 1:

addtwonumber1

1
2
Input: root = [1,0,2], low = 1, high = 2
Output: [1,null,2]

Example 2:

addtwonumber1

1
2
Input: root = [3,0,4,null,2,null,null,1], low = 1, high = 3
Output: [3,2,null,1]

Solution

Since the tree is BST we know that the left subtree of a root will always have values lesser than it and the right subtree of the root will always have values greater than it. We can use recursion to just trim the tree.

Since its already BST we just need to trim the tree and not remove specific nodes. This makes this problem having a lower complexity.

Start with the base case.

1
2
3
def trim_bst(root, low, high):
  if not root:
    return None

Now consider a scenario where the root of the enter tree is less than the low value. In that case we need to send back a new root. We can simply ignore the left sub-tree completely and also the root node in this case. We can call trim_bst again to explore the right sub-tree.

1
2
  if root.val < low:
    return trim_bst(root.right, low, high)

Conversely, do the same for right sub-tree (explore left sub-tree).

1
2
  if root.val > high:
    return trim_bst(root.left, low, high)

If neither of these happens, then we know that the root.val with-in the low, high range. So root is definitely going to be returned. Now lets find if the same is true for its left and right subtree.

1
2
3
4
  root.left = trim_bst(root.left, low, high)
  root.right = trim_bst(root.right, low, high)
  
  return root

Final Code

Here is the full code.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right

def trim_bst(root, low, high):
  if not root:
    return None
  
  if root.left < low:
    return trim_bst(root.right,low,high)
  if root.right > high:
    return trim_bst(root.left,low,high)
  
  root.left = trim_bst(root.left,low,high)
  root.right = trim_bst(root.right,low,high)
  
  return root
This post is licensed under CC BY 4.0 by the author.