# Tail call recursion in Python

In this page, we’re going to look at tail call recursion and see how to force Python to let us eliminate tail calls by using a trampoline. We will go through two iterations of the design: first to get it to work, and second to try to make the syntax seem reasonable. I would not consider this a useful technique in itself, but I do think it’s a good example which shows off some of the power of decorators.

The first thing we should be clear about is the definition of a tail
call. The “call” part means that we are considering function
calls, and the “tail” part means that, of those, we are
considering calls which are the last thing a function does before it
returns. In the following example, the recursive call to `f`
is a tail call (the use of the variable `ret` is immaterial
because it just connects the result of the call to `f` to the
`return` statement), and the call to `g` is not a tail
call because the operation of adding one is done after `g`
returns (so it’s not in “tail position”).

def f(n) : if n > 0 : n -= 1 ret = f(n) return ret else : ret = g(n) return ret + 1

## 1. Why tail calls matter

Recursive tail calls can be replaced by jumps. This is called “tail call eliminination,” and is a transformation that can help limit the maximum stack depth used by a recursive function, with the benefit of reducing memory traffic by not having to allocate stack frames. Sometimes, recursive function which wouldn’t ordinarily be able to run due to stack overflow are transformed into function which can.

Because of the benefits, some compilers (like `gcc`) perform
tail call elimination^{[1]},
replacing recursive tail calls with jumps (and, depending on the
language and circumstances, tail calls to other functions can
sometimes be replaced with stack massaging and a jump). In the
following example, we will eliminate the tail calls in a piece of
code which does a binary search. It has two recursive tail calls.

def binary_search(x, lst, low=None, high=None) : if low == None : low = 0 if high == None : high = len(lst)-1 mid = low + (high - low) // 2 if low > high : return None elif lst[mid] == x : return mid elif lst[mid] > x : return binary_search(x, lst, low, mid-1) else : return binary_search(x, lst, mid+1, high)Supposing Python had a

`goto`statement, we could replace the tail calls with a jump to the beginning of the function, modifying the arguments at the call sites appropriately:

def binary_search(x, lst, low=None, high=None) : start: if low == None : low = 0 if high == None : high = len(lst)-1 mid = low + (high - low) // 2 if low > high : return None elif lst[mid] == x : return mid elif lst[mid] > x : (x, lst, low, high) = (x, lst, low, mid-1) goto start else : (x, lst, low, high) = (x, lst, mid+1, high) goto startwhich, one can observe, can be written in actual Python as

def binary_search(x, lst, low=None, high=None) : if low == None : low = 0 if high == None : high = len(lst)-1 while True : mid = low + (high - low) // 2 if low > high : return None elif lst[mid] == x : return mid elif lst[mid] > x : high = mid - 1 else : low = mid + 1I haven’t tested the speed difference between this iterative version and the original recursive version, but I would expect it to be quite a bit faster because of there being much, much less memory traffic.

Unfortunately, the transformation makes it harder to prove the binary search is correct in the resulting code. With the original recursive algorithm, it is almost trivial by induction.

Programming languages like Scheme depend on tail calls being
eliminated for control flow, and it’s also necessary for continuation
passing style.^{[2]}

## 2. A first attempt

Our running example is going to be the factorial function (a classic), written with an accumulator argument so that its recursive call is a tail call:

def fact(n, r=1) : if n <= 1 : return r else : return fact(n-1, n*r)

If `n` is too large, then this recursive function will overflow
the stack, despite the fact that Python can deal with really big
integers. On my machine, it can compute `fact(999)`, but
`fact(1000)` results in a sad `RuntimeError: Maximum
recursion depth exceeded`.

One solution is to modify `fact` to return objects which
represent tail calls and then to build a trampoline underneath
`fact` which executes these tail calls after `fact`
returns. This way, the stack depth will only contain two stack frames:
one for the trampoline and another for each call to `fact`.

First, we define a tail call object which reifies the concept of a tail call:

class TailCall(object) : def __init__(self, call, *args, **kwargs) : self.call = call self.args = args self.kwargs = kwargs def handle(self) : return self.call(*self.args, **self.kwargs)This is basically just the thunk

`lambda : call(*args, **kwargs)`, but we don’t use a thunk because we would like to be able to differentiate between a tail call and returning a function as a value.

The next ingredient is a function which wraps a trampoline around an arbitrary function:

def t(f) : def _f(*args, **kwargs) : ret = f(*args, **kwargs) while type(ret) is TailCall : ret = ret.handle() return ret return _f

Then, we modify `fact` to be

def fact(n, r=1) : if n <= 1 : return r else : return TailCall(fact, n-1, n*r)

Now, instead of calling `fact(n)`, we must instead invoke
`t(fact)(n)` (otherwise we’d just get a `TailCall`
object).

This isn’t that bad: we can get tail calls of arbitrary depth, and
it’s Pythonic in the sense that the user must explicitly label the
tail calls, limiting the amount of unexpected magic. But, can we
eliminate the need to wrap `t` around the initial call? I
myself find it unclean to have to write that `t` because it
makes calling `fact` different from calling a normal function
(which is how it was before the transformation).

## 3. A second attempt

The basic idea is that we will redefine `fact` to roughly be
`t(fact)`. It’s tempting to just use `t` as a
decorator:

@t def fact(n, r=1) : if n <= 1 : return r else : return TailCall(fact, n-1, n*r)(which, if you aren’t familiar with decorator syntax, is equivalent to writing

`fact = t(fact)`right after the function definition). However, there is a problem with this in that the

`fact`in the returned tail call is bound to

`t(fact)`, so the trampoline will recursively call the trampoline, completely defeating the purpose of our work. In fact, the situation is now worse than before: on my machine,

`fact(333)`causes a

`RuntimeError`!

For this solution, the first ingredient is the following class, which defines the trampoline as before, but wraps it in a new type so we can distinguish a trampolined function from a plain old function:

class TailCaller(object) : def __init__(self, f) : self.f = f def __call__(self, *args, **kwargs) : ret = self.f(*args, **kwargs) while type(ret) is TailCall : ret = ret.handle() return retand then we modify

`TailCall`to be aware of

`TailCaller`s:

class TailCall(object) : def __init__(self, call, *args, **kwargs) : self.call = call self.args = args self.kwargs = kwargs def handle(self) : if type(self.call) is TailCaller : return self.call.f(*self.args, **self.kwargs) else : return self.call(*self.args, **self.kwargs)

Since classes are function-like and return their constructed object,
we can just decorate our factorial function with `TailCaller`:

@TailCaller def fact(n, r=1) : if n <= 1 : return r else : return TailCall(fact, n-1, n*r)

And then we can call `fact` directly with large numbers!

Also, unlike in the first attempt, we can now have mutually recursive
functions which all perform tail calls. The first-called
`TailCall` object will handle all the trampolining.

If we wanted, we could also define the following function to make the
argument lists for tail calls be more consistent with those for normal
function calls:^{[3]}

def tailcall(f) : def _f(*args, **kwargs) : return TailCall(f, *args, **kwargs) return _fand then

`fact`could be rewritten as

@TailCaller def fact(n, r=1) : if n <= 1 : return r else : return tailcall(fact)(n-1, n*r)

One would hope that marking the tail calls manually could just be done away with, but I can’t think of any way to detect whether a call is a tail call without inspecting the source code. Perhaps an idea for further work is to convince Guido von Rossum that Python should support tail recursion (which is quite unlikely to happen).

^{[1]}This is compiler-writer speak. For some reason, “elimination” is what you do when you replace a computation with something equivalent. In this case, it’s true that the call is being eliminated, but in its place there’s a jump. The same is true for “common subexpression elimination” (known as CSE), which takes, for instance,

a = b + c d = (b + c) + eand replaces it with

a = b + c d = a + eSure, the

`b+c`is eliminated from the second statement, but it’s not

*really*gone...

The optimization known as “dead code elimination” actually eliminates something, but that’s because dead code has no effect, and so it can be removed (that is, be replaced with nothing).

^{[2]}In Scheme, all loops are written as recursive functions since tail calls are the pure way of redefining variables (this is the same technique Haskell uses). For instance, to print the numbers from 1 to 100, you’d write

(let next ((n 1)) (if (<= n 100) (begin (display n) (newline) (next (+ n 1)))))where

`next`is bound to be a one-argument function which takes one argument,

`n`, and which has the body of the

`let`statement as its body. If that

`100`were some arbitrarily large number, the tail call to

`next`had better be handled as a jump, otherwise the stack would overflow! And there’s no other reasonable way to write such a loop!

Continuation passing style is commonly used to handle exceptions and backtracking. You write functions of the form

(define (f cont) (let ((cont2 (lambda ... (cont ...) ...))) (g cont2)))along with functions which take multiple such

`f`’s and combines them into another function which also takes a single

`cont`argument. I’ll probably talk about this more in another page, but for now notice how the call to

`g`is in the tail position.

^{[4]}That is,

*Schönfinkelized*.