Tail call recursion in Python
In this page, we’re going to look at tail call recursion and see how to force Python to let us eliminate tail calls by using a trampoline. We will go through two iterations of the design: first to get it to work, and second to try to make the syntax seem reasonable. I would not consider this a useful technique in itself, but I do think it’s a good example which shows off some of the power of decorators.
The first thing we should be clear about is the definition of a tail call. The “call” part means that we are considering function calls, and the “tail” part means that, of those, we are considering calls which are the last thing a function does before it returns. In the following example, the recursive call to f is a tail call (the use of the variable ret is immaterial because it just connects the result of the call to f to the return statement), and the call to g is not a tail call because the operation of adding one is done after g returns (so it’s not in “tail position”).
def f(n) : if n > 0 : n -= 1 ret = f(n) return ret else : ret = g(n) return ret + 1
1. Why tail calls matter
Recursive tail calls can be replaced by jumps. This is called “tail call eliminination,” and is a transformation that can help limit the maximum stack depth used by a recursive function, with the benefit of reducing memory traffic by not having to allocate stack frames. Sometimes, recursive function which wouldn’t ordinarily be able to run due to stack overflow are transformed into function which can.
Because of the benefits, some compilers (like gcc) perform tail call elimination, replacing recursive tail calls with jumps (and, depending on the language and circumstances, tail calls to other functions can sometimes be replaced with stack massaging and a jump). In the following example, we will eliminate the tail calls in a piece of code which does a binary search. It has two recursive tail calls.
def binary_search(x, lst, low=None, high=None) : if low == None : low = 0 if high == None : high = len(lst)-1 mid = low + (high - low) // 2 if low > high : return None elif lst[mid] == x : return mid elif lst[mid] > x : return binary_search(x, lst, low, mid-1) else : return binary_search(x, lst, mid+1, high)Supposing Python had a goto statement, we could replace the tail calls with a jump to the beginning of the function, modifying the arguments at the call sites appropriately:
def binary_search(x, lst, low=None, high=None) : start: if low == None : low = 0 if high == None : high = len(lst)-1 mid = low + (high - low) // 2 if low > high : return None elif lst[mid] == x : return mid elif lst[mid] > x : (x, lst, low, high) = (x, lst, low, mid-1) goto start else : (x, lst, low, high) = (x, lst, mid+1, high) goto startwhich, one can observe, can be written in actual Python as
def binary_search(x, lst, low=None, high=None) : if low == None : low = 0 if high == None : high = len(lst)-1 while True : mid = low + (high - low) // 2 if low > high : return None elif lst[mid] == x : return mid elif lst[mid] > x : high = mid - 1 else : low = mid + 1I haven’t tested the speed difference between this iterative version and the original recursive version, but I would expect it to be quite a bit faster because of there being much, much less memory traffic.
Unfortunately, the transformation makes it harder to prove the binary search is correct in the resulting code. With the original recursive algorithm, it is almost trivial by induction.
Programming languages like Scheme depend on tail calls being eliminated for control flow, and it’s also necessary for continuation passing style.
2. A first attempt
Our running example is going to be the factorial function (a classic), written with an accumulator argument so that its recursive call is a tail call:
def fact(n, r=1) : if n <= 1 : return r else : return fact(n-1, n*r)
If n is too large, then this recursive function will overflow the stack, despite the fact that Python can deal with really big integers. On my machine, it can compute fact(999), but fact(1000) results in a sad RuntimeError: Maximum recursion depth exceeded.
One solution is to modify fact to return objects which represent tail calls and then to build a trampoline underneath fact which executes these tail calls after fact returns. This way, the stack depth will only contain two stack frames: one for the trampoline and another for each call to fact.
First, we define a tail call object which reifies the concept of a tail call:
class TailCall(object) : def __init__(self, call, *args, **kwargs) : self.call = call self.args = args self.kwargs = kwargs def handle(self) : return self.call(*self.args, **self.kwargs)This is basically just the thunk lambda : call(*args, **kwargs), but we don’t use a thunk because we would like to be able to differentiate between a tail call and returning a function as a value.
The next ingredient is a function which wraps a trampoline around an arbitrary function:
def t(f) : def _f(*args, **kwargs) : ret = f(*args, **kwargs) while type(ret) is TailCall : ret = ret.handle() return ret return _f
Then, we modify fact to be
def fact(n, r=1) : if n <= 1 : return r else : return TailCall(fact, n-1, n*r)
Now, instead of calling fact(n), we must instead invoke t(fact)(n) (otherwise we’d just get a TailCall object).
This isn’t that bad: we can get tail calls of arbitrary depth, and it’s Pythonic in the sense that the user must explicitly label the tail calls, limiting the amount of unexpected magic. But, can we eliminate the need to wrap t around the initial call? I myself find it unclean to have to write that t because it makes calling fact different from calling a normal function (which is how it was before the transformation).
3. A second attempt
The basic idea is that we will redefine fact to roughly be t(fact). It’s tempting to just use t as a decorator:
@t def fact(n, r=1) : if n <= 1 : return r else : return TailCall(fact, n-1, n*r)(which, if you aren’t familiar with decorator syntax, is equivalent to writing fact = t(fact) right after the function definition). However, there is a problem with this in that the fact in the returned tail call is bound to t(fact), so the trampoline will recursively call the trampoline, completely defeating the purpose of our work. In fact, the situation is now worse than before: on my machine, fact(333) causes a RuntimeError!
For this solution, the first ingredient is the following class, which defines the trampoline as before, but wraps it in a new type so we can distinguish a trampolined function from a plain old function:
class TailCaller(object) : def __init__(self, f) : self.f = f def __call__(self, *args, **kwargs) : ret = self.f(*args, **kwargs) while type(ret) is TailCall : ret = ret.handle() return retand then we modify TailCall to be aware of TailCallers:
class TailCall(object) : def __init__(self, call, *args, **kwargs) : self.call = call self.args = args self.kwargs = kwargs def handle(self) : if type(self.call) is TailCaller : return self.call.f(*self.args, **self.kwargs) else : return self.call(*self.args, **self.kwargs)
Since classes are function-like and return their constructed object, we can just decorate our factorial function with TailCaller:
@TailCaller def fact(n, r=1) : if n <= 1 : return r else : return TailCall(fact, n-1, n*r)
And then we can call fact directly with large numbers!
Also, unlike in the first attempt, we can now have mutually recursive functions which all perform tail calls. The first-called TailCall object will handle all the trampolining.
If we wanted, we could also define the following function to make the argument lists for tail calls be more consistent with those for normal function calls:
def tailcall(f) : def _f(*args, **kwargs) : return TailCall(f, *args, **kwargs) return _fand then fact could be rewritten as
@TailCaller def fact(n, r=1) : if n <= 1 : return r else : return tailcall(fact)(n-1, n*r)
One would hope that marking the tail calls manually could just be done away with, but I can’t think of any way to detect whether a call is a tail call without inspecting the source code. Perhaps an idea for further work is to convince Guido von Rossum that Python should support tail recursion (which is quite unlikely to happen).
a = b + c d = (b + c) + eand replaces it with
a = b + c d = a + eSure, the b+c is eliminated from the second statement, but it’s not really gone...
The optimization known as “dead code elimination” actually eliminates something, but that’s because dead code has no effect, and so it can be removed (that is, be replaced with nothing).
(let next ((n 1)) (if (<= n 100) (begin (display n) (newline) (next (+ n 1)))))where next is bound to be a one-argument function which takes one argument, n, and which has the body of the let statement as its body. If that 100 were some arbitrarily large number, the tail call to next had better be handled as a jump, otherwise the stack would overflow! And there’s no other reasonable way to write such a loop!
Continuation passing style is commonly used to handle exceptions and backtracking. You write functions of the form
(define (f cont) (let ((cont2 (lambda ... (cont ...) ...))) (g cont2)))along with functions which take multiple such f’s and combines them into another function which also takes a single cont argument. I’ll probably talk about this more in another page, but for now notice how the call to g is in the tail position.