The stack of iterators pattern

,

Depth-first search is a straightforward algorithm for visiting the nodes of a tree or tree-like data structure.

Here’s how you might implement it in Python:

def search(node):
    for child in children(node):
        search(child)

This works well in many cases, but it has a couple of problems:1

  1. It descends the tree by recursing, and so it uses as many levels of stack as the depth of the deepest node in the tree. But Python’s call stack is limited in size (see sys.getrecursionlimit) and so deep enough trees will run out of call stack and fail with “RuntimeError: maximum recursion depth exceeded.”

  2. If you’d like to be able to stop the search part way through and return a result (for example returning a target node as soon as you find it), then this requires a slightly awkward change to the code. You could add logic for exiting the recursion:

    def search(node):
        if target(node):
            return node
        for child in children(node):
            result = search(child)
            if result is not None:
                return result
    

    or you could return the result non-locally using an exception:

    class Found(Exception):
        pass
    
    def search(node):
        if target(node):
            raise Found(node)
        for child in children(node):
            search(child)
    

    or you could rewrite the search to use generators:2

    def search(node):
        if target(node):
            yield node
        for child in children(node):
            yield from search(child)
    

    and then the caller can call next(search(root)).

The problems can be avoided using the stack of iterators design pattern.3

def search(root):
    stack = [iter([root])]
    while stack:
        for node in stack[-1]:
            stack.append(iter(children(node)))
            break
        else:
            stack.pop()

This avoids the problems above:

  1. Lists can grow without limit, unlike the function call stack.

  2. Since there’s no recursion, you can just return when you have a result:

    def search(root):
        stack = [iter([root])]
        while stack:
            for node in stack[-1]:
                if target(node):
                    return node
                stack.append(iter(children(node)))
                break
            else:
                stack.pop()
    

The control flow here might seem a bit tricky if you’re not used to the way that Python’s for ... else construct interacts with break. The pattern works by maintaining a stack of iterators that remember the position reached in the iteration over the children of the node at each level of the search. After pushing a new iterator on the stack, the break exits the for loop, bypasses the else: clause, and goes round the while loop again, so that it picks up the new iterator from stack[-1]. When there are no more children in the current iteration, the for loop exits via the else: clause and pops the stack. Then the next iteration of the while loop picks up the iteration at the previous level from where it left off.

Three examples. First, finding a key in a possibly nested dictionary:4

def search(d, key, default=None):
    """Return a value corresponding to the specified key in the (possibly
    nested) dictionary d. If there is no item with that key, return
    default.
 
    """
    stack = [iter(d.items())]
    while stack:
        for k, v in stack[-1]:
            if isinstance(v, dict):
                stack.append(iter(v.items()))
                break
            elif k == key:
                return v
        else:
            stack.pop()
    return default

Second, finding a simple path visiting a set of positions on a grid:5

def hamilton_path(start, positions, directions=((0, 1), (1, 0), (0, -1), (-1, 0))):
    """Find a simple path that visits all positions.
 
    start: tuple(int, int)
        Starting position for the path.
    positions: iterable of tuple(int, int)
        Iterable of positions to be visited by the path.
    directions: iterable of tuple(int, int)
        Iterable of directions to take at each step.
 
    Return the path as a list of tuple(int, int) giving the order in
    which the positions are visited. Raise ValueError if there are no
    paths visiting all positions.
 
    """
    positions = set(positions)
    directions = list(directions)
    path = [start]
    stack = [iter(directions)]
    while path:
        x, y = path[-1]
        for dx, dy in stack[-1]:
            pos = x + dx, y + dy
            if pos in positions:
                path.append(pos)
                positions.remove(pos)
                stack.append(iter(directions))
                if not positions:
                    return path
                break
        else:
            positions.add(path.pop())
            stack.pop()
    raise ValueError("no path")

Third, finding the minimum sum of values along path from the root to a leaf in a tree:6

def min_path_sum(root):
    """Return minimum sum of values of path from root to a leaf."""
    min_sum = float('inf')
    current_path = [0]
    current_sum = 0
    stack = [iter([root])]
    while stack:
        for node in stack[-1]:
            current_path.append(node.value)
            current_sum += node.value
            children = [n for n in (node.left, node.right) if n]
            stack.append(iter(children))
            if not children:
                min_sum = min(min_path, current_sum)
            break
        else:
            current_sum -= current_path.pop()
            stack.pop()
    return min_path

I’m not the first to discover this pattern. After writing this article I found the 2009 post Detecting Cycles in a Directed Graph by Guido van Rossum. He leaves the full details, including the stack of iterators itself, as an exercise for the reader, but it’s clear that his technique is essentially the same as in the presentation above, including the combination of for ... else with break.

A close cousin of the pattern appears in the 2013 post Stack-based graph traversal ≠ depth first search by David Eppstein. His version uses try ... except instead of for ... else and break but the core idea is the same.7


  1.  Originally I mentioned a third problem and corresponding benefit, but even though it was the least important and interesting of the three, it led to an influx of comments all disputing this point and ignoring the rest of the article. Accordingly, I’ve rewritten this to remove the contentious point, and deleted all the argumentative comments. Hopefully the interesting material will be clearer without the distraction.

  2.  This requires Python 3.3 or later in order to be able to use the yield from statement.

  3.  In software, a design pattern is a technique for working around a defect or omission in a particular programming language. For example, the well-known singleton pattern is a work-around for the lack of global variables in the Java language. In the case of the stack of iterators pattern, the defects in Python that we are working around are (1) stack size is bounded; and (2) there is no mechanism for returning from an enclosing scope: that is, nothing like return-from in Common Lisp.

  4.  Example adapted from this answer on Code Review Stack Exchange.

  5.  Example adapted from this answer on Code Review Stack Exchange.

  6.  Example adapted from this answer on Code Review Stack Exchange.

  7.  Eppstein's version of the pattern is:

    def search(root):
        stack = [iter([root])]
        while stack:
            try:
                node = next(stack[-1])
                stack.append(iter(children(node)))
            except StopIteration:
                stack.pop()
    

    which is fine if you know that children(node) cannot raise StopIteration, but in the general case I would prefer to write:

    def search(root):
        stack = [iter([root])]
        while stack:
            try:
                node = next(stack[-1])
            except StopIteration:
                stack.pop()
            else:
                stack.append(iter(children(node)))