Tree Recursion¶
Tree recursion occurs when a function calls itself multiple times per invocation, creating a tree-like call pattern. This is common in problems with overlapping subproblems.
Mental Model
Tree recursion branches at every call -- each invocation spawns two or more recursive calls, forming a tree of computation. Without memoization, the same subproblems are recomputed many times, leading to exponential time. Recognizing tree recursion is the first step toward applying memoization or converting to bottom-up dynamic programming.
Fibonacci: Classic Tree Recursion¶
```python def fibonacci(n): '''Fibonacci sequence using naive recursion''' if n <= 1: return n return fibonacci(n - 1) + fibonacci(n - 2)
Example calls¶
print(fibonacci(5)) # Output: 5 print(fibonacci(6)) # Output: 8 ```
The call tree for fibonacci(5):
fibonacci(5)
/ \
fibonacci(4) fibonacci(3)
/ \ / \
fib(3) fib(2) fib(2) fib(1)
/ \ / \ / \
fib(2) fib(1) fib(1) fib(0) ...
Notice fibonacci(3) and fibonacci(2) are computed multiple times!
Counting Recursive Calls¶
```python def fibonacci_counted(n, call_count=None): '''Fibonacci with call counter''' if call_count is None: call_count = {'count': 0}
call_count['count'] += 1
if n <= 1:
return n
return fibonacci_counted(n - 1, call_count) + fibonacci_counted(n - 2, call_count)
Count calls for different inputs¶
for n in [5, 10, 15, 20]: call_count = {'count': 0} fibonacci_counted(n, call_count) print(f"fib({n:2d}): {call_count['count']:6d} calls") ```
Output:
fib( 5): 15 calls
fib(10): 177 calls
fib(15): 1973 calls
fib(20): 21891 calls
Binary Search Tree Traversal¶
```python class TreeNode: def init(self, value): self.value = value self.left = None self.right = None
def traverse_tree(node, result=None): '''In-order tree traversal''' if result is None: result = []
if node is None:
return result
traverse_tree(node.left, result)
result.append(node.value)
traverse_tree(node.right, result)
return result
Build and traverse¶
root = TreeNode(4) root.left = TreeNode(2) root.right = TreeNode(6) root.left.left = TreeNode(1) root.left.right = TreeNode(3)
print(traverse_tree(root)) # [1, 2, 3, 4, 6] ```
Performance Comparison¶
The exponential nature of tree recursion means small input size changes cause huge performance differences. Use memoization (chapter on memoization) to fix this.
Exercises¶
Exercise 1.
Write a tree-recursive function count_partitions(n, max_part) that counts the number of ways to partition the integer n using parts up to max_part. For example, count_partitions(6, 4) returns 9. Identify the two recursive branches (use the largest part vs. exclude it).
Solution to Exercise 1
def count_partitions(n, max_part):
if n == 0:
return 1
if n < 0 or max_part == 0:
return 0
# Use max_part + don't use max_part
return (count_partitions(n - max_part, max_part)
+ count_partitions(n, max_part - 1))
print(count_partitions(6, 4)) # 9
print(count_partitions(5, 5)) # 7
print(count_partitions(10, 5)) # 30
Exercise 2.
Add a call counter to the naive recursive fibonacci(n) function. Compute fibonacci(25) and print the total number of function calls. Then add @lru_cache and repeat, printing the new call count to demonstrate the dramatic difference.
Solution to Exercise 2
from functools import lru_cache
# Without memoization
naive_calls = 0
def fib_naive(n):
global naive_calls
naive_calls += 1
if n < 2:
return n
return fib_naive(n - 1) + fib_naive(n - 2)
fib_naive(25)
print(f"Naive calls for fib(25): {naive_calls}") # 242785
# With memoization
memo_calls = 0
@lru_cache(maxsize=None)
def fib_cached(n):
global memo_calls
memo_calls += 1
if n < 2:
return n
return fib_cached(n - 1) + fib_cached(n - 2)
fib_cached(25)
print(f"Cached calls for fib(25): {memo_calls}") # 26
Exercise 3.
Write a tree-recursive function paths_in_grid(rows, cols) that counts the number of unique paths from the top-left to the bottom-right of a rows x cols grid, where you can only move right or down. Draw the call tree for paths_in_grid(3, 3) in a comment and verify the result is 6.
Solution to Exercise 3
def paths_in_grid(rows, cols):
# Base cases
if rows == 1 or cols == 1:
return 1
# Move down + move right
return paths_in_grid(rows - 1, cols) + paths_in_grid(rows, cols - 1)
print(paths_in_grid(3, 3)) # 6
print(paths_in_grid(4, 4)) # 20
# Call tree for paths_in_grid(3, 3):
# (3,3)
# / \
# (2,3) (3,2)
# / \ / \
# (1,3) (2,2) (2,2) (3,1)
# / \ / \
# (1,2)(2,1)(1,2)(2,1)