## Recursion: introduction and binary trees

A recursive function is one that calls itself. We can illustrate this with the usual example: the factorial function. Before presenting the example, however, it’s advisable to begin with a warning. Recursion is definitely not the best way to implement calculations that can be done much more easily and more efficiently using loops. I’m giving the factorial example partly because it’s one of the simplest to write using recursion and because it illustrates a common problem with recursion. More on that after we give the example.

The factorial of a non-negative integer is defined as

The case is defined to have a factorial of (to make it consistent the mathematical gamma function, but we don’t need to know that).

From its definition, we can see a recursive algorithm that can be used to calculate . This is

One important thing to note about this formula is that it is valid only if . If we tried to use it to calculate , the formula would ask us for which, at least for our purposes, doesn’t exist. Thus we must first check that before applying the recursive formula.

This is a special case of a more general principle. In any recursive algorithm there must be an anchor step, that is, a stopping point for the algorithm beyond which no further recursive calls are made. If we don’t have an anchor step, the algorithm could fall into an infinite recursion which, on a computer with a finite memory, will cause a stack overflow (more on this in a minute).

The code for implementing the recursive factorial algorithm is as follows:

def factorial(n):
if n == 0: return 1
return n * factorial(n-1)

while True:
num = input('Enter number for factorial: ')
if num == 'quit': break

num = int(num)
print(f'{num}! = {factorial(num)}')

We see that the factorial() function implements the above algorithm. It checks first to see if is zero and, if so, returns 1 with no further recursive call. If is not 0, we use the recursive algorithm to calculate as multiplied by a recursive call to factorial(n-1). [We haven’t included any error checks here to make sure that is a non-negative integer, but you can add these using exceptions if you like.]

The while loop just asks the user for a number and then calls factorial() to calculate its factorial.

If you run this program it seems to work well for numbers up to 998 (on my system, at least). For or higher, we get the error: RecursionError: maximum recursion depth exceeded in comparison. To understand what has gone wrong here, we need to understand a bit about how recursive functions work.

Whenever any function (recursive or otherwise) is called, the state of the program at the point where the function is called is pushed onto a stack (basically, this means the state of the program is stored in memory). When the function finishes, the stored state is popped off the stack and restored, so the stored state is no longer stored in memory. If we have one function nested inside another, then two program states must be stored in memory while the inner function is running.

In a recursive function, each recursive call stores another program state until the anchor step is reached, at which point the program backs out of all the function calls, restoring one program state with each recursive call that finishes. When we attempt to use the factorial() function to calculate, say, 1000!, then 1000 program states must be stored in memory before we get to the anchor step. Depending on the complexity of the program, storing this number of program states could cause the computer to run out of memory.

In practice, however, Python imposes a limit of 1000 on the recursion depth, after which it generates a RecursionError exception. This is a safeguard against a program falling into an infinite recursion (which could happen in our factorial example if we forgot to put in the anchor step). It’s possible to increase the allowable recursion depth by adding the code:

import sys
sys.setrecursionlimit(5000)

This sets the limit at 5000 recursive calls, but of course, the larger the depth you request, the larger the danger of running out of memory. If you find that you are using a recursion depth greater than the default value of 1000, it’s probably a sign that you should be looking for another way to do your calculation.

For the factorial, it’s much better to use a simple for loop. Using a loop is non-recursive so it doesn’t involve any function calls. As Python has no upper limit (apart from the computer’s memory) on the size of integers it can calculate, a for loop will calculate a factorial for as large a number as you like.

## Binary trees

A binary tree is a data structure that is most easily understood by drawing a diagram.

The tree consists of a number of nodes. The node at the top, with no arrows leading into it, is called the root, with the nodes at the bottom (with no arrows leading out of them) known as leaves. [Yes, I know, the tree is pictured upside down relative to a botanical tree, but that’s just the way the convention has evolved.] It’s called a binary tree because each node can have up to 2 arrows leading out of it. Other types of tree allow more branches, but of course they get much more complex.

A common use of a binary tree is the sorting of data. In the diagram above, we see that, for any node, the node that is its left child always has a number less than its parent, and the node that is its right child has a number greater than the parent. We can devise an algorithm for inserting a new node into the tree that preserves this structure.

## Inserting nodes into a binary sort tree

Suppose we want to insert a node with the value of 27. We start at the root and compare 27 with the root’s value of 12. Since 27 > 12, we move down the right branch from the root and encounter the node with the value 33. Comparing 27 with 33, we see that 27 < 33, so we go down the left branch from 33, and encounter 24. Since 27 > 24, and there is no right child for 24, we add a new node as the right child of 24 and place 27 there.

A little thought shows that this algorithm is inherently recursive. We start at the root and do a comparison to decide which branch to follow from the root. When we arrive at one of the children of the root, we apply the same algorithm again at the next node, and continue until we run out of nodes to compare with. At this point, we add a new node into the tree.

We can illustrate this in Python code. The TreeNode class that follows stores a number and its square at each node.

class TreeNode(object):
"""a node in a binary tree"""

def __init__(self, data):
self.left = None
self.right = None
self.data = data
self.square = data ** 2

def insert(self, data):
if data > self.data:
if self.right: self.right.insert(data)
else: self.right = TreeNode(data)
else:
if self.left: self.left.insert(data)
else: self.left = TreeNode(data)

The constructor initializes a new node by setting both its children (self.left and self.right) to None. It then stores the number (data) and the square of that number within the node.

The insert() function performs the algorithm above. The data argument to insert() is the new number to be entered into the tree, so it’s compared with the self.data in the existing node. If the new number is greater than the existing number, we check on line 12 to see if there is a right child to the existing node. If so, we recursively call insert() to insert the new data into the subtree that starts with self.right. If self.right doesn’t exist (its value is None), then we create a new TreeNode and attach it to self.right. Lines 14 to 16 do the same thing if data <= self.data, so the new node gets inserted into the left child.

The anchor step here is the check that either the left or right child exists. If it doesn’t, then a new TreeNode is attached to the tree and no further recursive call is made.

## Traversing a binary sort tree

Once we have stored our data in the tree, we need some way of retrieving it. One way is to traverse the tree, which means to list its contents in sorted order. This process, too, is a naturally recursive one. Consider again the diagram above, and suppose we want to generate a list of the numbers in sorted order. As usual, we have to start at the root (in practice, the root node of the tree is the only one we have direct access to via a named variable). If we want a list in ascending order, we want to find the smallest node first. Thus from the root, we travel down to its left child. If that node, in turn, has a left child, we travel down it as well until we find a node with no left child (the 4 node in the diagram). We then return a list with that node as its only element. At this point, we need to check if the current node has a right child, since any nodes in the right child’s subtree will come between the 4 and the parent of 4. In our example, 4 has no right child, so we back up to the parent node of 6.

The 6 is appended to the right end of the list. Since 6 has no right child, we back up to 12 and add 12 to the list. The 12 node does have a right child, so we travel down to 33. 33 has a left child, so we visit it, then back up to 33 then down 33’s right child to 42.

The recursive algorithm therefore has the following form.

At each node:

1. Find the traversal of the left child (if any).
2. Append the value from the current node.
3. Append the traversal of the right child (if any).

Steps 1 and 3 are the recursive steps. Their anchor steps are the cases where the left or right child doesn’t exist.

Here’s the expanded TreeNode class with a traverse() function added:

class TreeNode(object):
"""a node in a binary tree"""

def __init__(self, data):
self.left = None
self.right = None
self.data = data
self.square = data ** 2

def insert(self, data):
if data > self.data:
if self.right: self.right.insert(data)
else: self.right = TreeNode(data)
else:
if self.left: self.left.insert(data)
else: self.left = TreeNode(data)

def traverse(self):
travList = []
if self.left: travList = self.left.traverse()
travList += [self.data]
if self.right: travList += self.right.traverse()
return travList

We initialize travList as an empty list. If there is a left child, we set travList to the traversal from self.left, then append self.data, then the traversal of self.right. For both self.left and self.right, we implement the anchor step by checking if they exist before attempting the recursive call.

We can test this code with the following main program, in which we generate 20 random integers between 1 and 1000, insert these into a binary tree and then print out a traversal of the tree.

from random import *
from TreeNode import *

sortTree = TreeNode(randint(1, 1000))
for i in range(20):
sortTree.insert(randint(1,1000))

print(sortTree.traverse())

## Exercise

Extend the TreeNode class above by writing a find() function which searches for a particular number in an existing binary tree. It might be helpful to use the diagram above to work out the recursive algorithm for searching for a number. Consider first the case where the number is present in the tree, and then the other case where the number is not in the tree.

Add some code to the main program above which allows the user to enter a number, then searches the tree and prints out either the square of the number (if found) or the message ‘Not found’ if not.

The algorithm is similar to that for traversing the tree. If we’re looking for a given number , we start at the root and compare with self.data. If we have a match, we return self.square and stop, with no recursive call.

If self.data > , we need to search the left subtree from the root. If there is no left child, then is not in the tree, so we can return None and stop. If there is a left child, we call find() recursively on the left child.

If self.data < , we do the same thing except we go down the right subtree from the root.

The TreeNode class with a find() function is as follows:

class TreeNode(object):
"""a node in a binary tree"""

def __init__(self, data):
self.left = None
self.right = None
self.data = data
self.square = data ** 2

def insert(self, data):
if data > self.data:
if self.right: self.right.insert(data)
else: self.right = TreeNode(data)
else:
if self.left: self.left.insert(data)
else: self.left = TreeNode(data)

def traverse(self):
travList = []
if self.left: travList = self.left.traverse()
travList += [self.data]
if self.right: travList += self.right.traverse()
return travList

def find(self, data):
if self.data == data: return self.square
elif self.data > data:
if not self.left: return None
else: return self.left.find(data)
else:
if not self.right: return None
else: return self.right.find(data)



The code on line 25 to 32 implements the above algorithm directly.

We can add some code to the main program to allow the user to search for some numbers:

from random import *
from TreeNode import *

sortTree = TreeNode(randint(1, 1000))
for i in range(20):
sortTree.insert(randint(1,1000))

print(sortTree.traverse())

while True:
search = input('Enter number to find: ')
if search == 'quit': break

found = sortTree.find(int(search))
else: print(f'Square = {found}')