The stack of iterators pattern

,

Depth-first search is a straightforward algorithm for visiting the nodes of a tree or tree-like data structure.

Here’s how you might implement it in Python:

def search(node):
    for child in children(node):
        search(child)

This works well in many cases, but it has a few problems:

  1. It descends the tree using a function call search(child), and function calls are quite slow in Python.

  2. It descends the tree by recursing, and so it uses as many levels of stack as the depth of the deepest node in the tree. But Python’s call stack is limited in size (see sys.getrecursionlimit) and so deep enough trees will run out of call stack and fail with “RuntimeError: maximum recursion depth exceeded.”

  3. If you’d like to be able to stop the search part way through and return a result (for example returning a target node as soon as you find it), then this requires a slightly awkward change to the code. You could add logic for exiting the recursion:

    def search(node):
        if target(node):
            return node
        for child in children(node):
            result = search(child)
            if result is not None:
                return result
    

    or you could return the result non-locally using an exception:

    class Found(Exception):
        pass
    
    def search(node):
        if target(node):
            raise Found(node)
        for child in children(node):
            search(child)
    

    or you could rewrite the search to use generators:1

    def search(node):
        if target(node):
            yield node
        for child in children(node):
            yield from search(child)
    

    and then the caller can call next(search(root)).

These problems can all be avoided using the stack of iterators design pattern.2

def search(root):
    stack = [iter([root])]
    while stack:
        for node in stack[-1]:
            stack.append(iter(children(node)))
            break
        else:
            stack.pop()

This avoids the three problems above:

  1. Pushing and popping a list is faster than calling a function in Python.

  2. Lists can grow without limit, unlike the function call stack.

  3. Since there’s no recursion, you can just return when you have a result:

    def search(root):
        stack = [iter([root])]
        while stack:
            for node in stack[-1]:
                if target(node):
                    return node
                stack.append(iter(children(node)))
                break
            else:
                stack.pop()
    

The control flow here might seem a bit tricky if you’re not used to the way that Python’s for ... else construct interacts with break. The pattern works by maintaining a stack of iterators that remember the position reached in the iteration over the children of the node at each level of the search. After pushing a new iterator on the stack, the break exits the for loop, bypasses the else: clause, and goes round the while loop again, so that it picks up the new iterator from stack[-1]. When there are no more children in the current iteration, the for loop exits via the else: clause and pops the stack. Then the next iteration of the while loop picks up the iteration at the previous level from where it left off.

Two examples. First, finding a key in a possibly nested dictionary:3

def search(d, key, default=None):
    """Return a value corresponding to the specified key in the (possibly
    nested) dictionary d. If there is no item with that key, return
    default.
 
    """
    stack = [iter(d.items())]
    while stack:
        for k, v in stack[-1]:
            if isinstance(v, dict):
                stack.append(iter(v.items()))
                break
            elif k == key:
                return v
        else:
            stack.pop()
    return default

Second, finding a simple path visiting a set of positions on a grid:4

def hamilton_path(start, positions, directions=((0, 1), (1, 0), (0, -1), (-1, 0))):
    """Find a simple path that visits all positions.
 
    start: tuple(int, int)
        Starting position for the path.
    positions: iterable of tuple(int, int)
        Iterable of positions to be visited by the path.
    directions: iterable of tuple(int, int)
        Iterable of directions to take at each step.
 
    Return the path as a list of tuple(int, int) giving the order in
    which the positions are visited. Raise ValueError if there are no
    paths visiting all positions.
 
    """
    positions = set(positions)
    directions = list(directions)
    path = [start]
    stack = [iter(directions)]
    while path:
        x, y = path[-1]
        for dx, dy in stack[-1]:
            pos = x + dx, y + dy
            if pos in positions:
                path.append(pos)
                positions.remove(pos)
                stack.append(iter(directions))
                if not positions:
                    return path
                break
        else:
            positions.add(path.pop())
            stack.pop()
    raise ValueError("no path")

Update 2016-10-19. There’s some discussion of this article at Reddit. Here are answers to some of the points raised.5

  1. Is it faster to push and pop a stack than it is to call a function? Aren’t stack.append() and stack.pop() themselves function calls?

    In CPython, functions implemented in Python (like search in the example above) are called differently from functions implemented in C (like stack.append), and are indeed slightly slower to call:

    >>> from timeit import timeit
    >>> def f(): pass
    >>> timeit(f) # Call function implemented in Python, 10**6 times.
    0.17691538999497425
    >>> timeit(int) # Call function implemented in C, 10**6 times.
    0.12964718499279115
    

    But the more important difference between calling a function and pushing a stack is that the former has to allocate memory for a call stack frame whereas the latter only has to allocate memory for one slot in a list, and the deeper the function recurses, the bigger the difference.

  2. But is it really faster to push and pop a stack than it is to call a function? In my timing tests it didn’t seem to be the case.

    When timing differences are fairly small (as they are here) you have to be careful about how you interpret them. Timing comparisons can be skewed by things like the time taken to look up attribute names or global variables. Here’s a comparison of the two approaches that I think is fair:

    def recurse(n):
        if n > 0:
           recurse(n - 1)
    
    def pushpop(n):
        stack = [iter([n])]
        while stack:
            for n in stack[-1]:
                if n > 0:
                   stack.append(iter((n - 1,)))
                   break
            else:
                stack.pop()
    

    Let’s increase the recursion limit to 10,007 and compare the two functions:

    >>> from sys import setrecursionlimit
    >>> setrecursionlimit(10**4 + 7)
    >>> from timeit import timeit
    >>> timeit(lambda:recurse(10**4), number=100)
    0.9759416219894774
    >>> timeit(lambda:pushpop(10**4), number=100)
    0.8555211240018252
    

    The stack of iterators pattern takes about 12% less time on this test, despite paying the cost of looking up the methods stack.append and stack.pop, and the built-in function iter, on each loop. What if we apply a standard optimization and cache these methods in local variables?

    def pushpop2(n):
        _push = stack.append
        _pop = stack.pop
        _iter = iter
        stack = [_iter([n])]
        while stack:
            for n in stack[-1]:
                if n > 0:
                   _push(_iter((n - 1,)))
                   break
            else:
                _pop()
    

    Now the stack of iterators pattern takes 25% less time than recursion:

    >>> timeit(lambda:pushpop2(10**4), number=100)
    0.7366151169990189
    
  3. Who cares about such a small performance improvement? If you really cared about performance, you wouldn’t be using Python in the first place.

    Sure, the small performance improvement in CPython probably isn’t convincing enough, by itself, to make you adopt the stack of iterators pattern. But I think it’s important to mention it because it pre-empts a likely objection to the pattern. (Perhaps it was a mistake to make it my first point.) It’s really the second point that is likely to be most convincing: sometimes you’d like to be able to traverse deeply nested data structures without having to figure out in advance what value you need to pass to sys.setrecursionlimit.

  4. Isn’t it kind of strange to enter the body of a for loop and then immediately break?

    Isn’t it kind of strange to enter the body of a function and then immediately call another function?


  1.  This requires Python 3.3 or later in order to be able to use the yield from statement.

  2.  In software, a design pattern is a technique for working around a defect or omission in a particular programming language. For example, the well-known singleton pattern is a work-around for the lack of global variables in the Java language.

  3.  Example adapted from this answer on Code Review.

  4.  Example adapted from this answer on Code Review.

  5.  The points here are my summary of the whole discussion and not intended to reflect the opinion of any particular participant.