Depth-first search is a straightforward algorithm for visiting the nodes of a tree or tree-like data structure.
Here’s how you might implement it in Python:
def search(node): for child in children(node): search(child)
This works well in many cases, but it has a few problems:
It descends the tree using a function call
search(child), and function calls are quite slow in Python.
It descends the tree by recursing, and so it uses as many levels of stack as the depth of the deepest node in the tree. But Python’s call stack is limited in size (see
sys.getrecursionlimit) and so deep enough trees will run out of call stack and fail with “RuntimeError: maximum recursion depth exceeded.”
If you’d like to be able to stop the search part way through and return a result (for example returning a target node as soon as you find it), then this requires a slightly awkward change to the code. You could add logic for exiting the recursion:
def search(node): if target(node): return node for child in children(node): result = search(child) if result is not None: return result
or you could return the result non-locally using an exception:
class Found(Exception): pass def search(node): if target(node): raise Found(node) for child in children(node): search(child)
or you could rewrite the search to use generators:1
def search(node): if target(node): yield node for child in children(node): yield from search(child)
and then the caller can call
These problems can all be avoided using the stack of iterators design pattern.2
def search(root): stack = [iter([root])] while stack: for node in stack[-1]: stack.append(iter(children(node))) break else: stack.pop()
This avoids the three problems above:
Pushing and popping a list is faster than calling a function in Python.
Lists can grow without limit, unlike the function call stack.
Since there’s no recursion, you can just return when you have a result:
def search(root): stack = [iter([root])] while stack: for node in stack[-1]: if target(node): return node stack.append(iter(children(node))) break else: stack.pop()
The control flow here might seem a bit tricky if you’re not used to the way that Python’s
for ... else construct interacts with
break. The pattern works by maintaining a stack of iterators that remember the position reached in the iteration over the children of the node at each level of the search. After pushing a new iterator on the stack, the
break exits the
for loop, bypasses the
else: clause, and goes round the
while loop again, so that it picks up the new iterator from
stack[-1]. When there are no more children in the current iteration, the
for loop exits via the
else: clause and pops the stack. Then the next iteration of the
while loop picks up the iteration at the previous level from where it left off.
Two examples. First, finding a key in a possibly nested dictionary:3
def search(d, key, default=None): """Return a value corresponding to the specified key in the (possibly nested) dictionary d. If there is no item with that key, return default. """ stack = [iter(d.items())] while stack: for k, v in stack[-1]: if isinstance(v, dict): stack.append(iter(v.items())) break elif k == key: return v else: stack.pop() return default
Second, finding a simple path visiting a set of positions on a grid:4
def hamilton_path(start, positions, directions=((0, 1), (1, 0), (0, -1), (-1, 0))): """Find a simple path that visits all positions. start: tuple(int, int) Starting position for the path. positions: iterable of tuple(int, int) Iterable of positions to be visited by the path. directions: iterable of tuple(int, int) Iterable of directions to take at each step. Return the path as a list of tuple(int, int) giving the order in which the positions are visited. Raise ValueError if there are no paths visiting all positions. """ positions = set(positions) directions = list(directions) path = [start] stack = [iter(directions)] while path: x, y = path[-1] for dx, dy in stack[-1]: pos = x + dx, y + dy if pos in positions: path.append(pos) positions.remove(pos) stack.append(iter(directions)) if not positions: return path break else: positions.add(path.pop()) stack.pop() raise ValueError("no path")
Update 2016-10-19. There’s some discussion of this article at Reddit. Here are answers to some of the points raised.5
Is it faster to push and pop a stack than it is to call a function? Aren’t
stack.pop() themselves function calls?
In CPython, functions implemented in Python (like
search in the example above) are called differently from functions implemented in C (like
stack.append), and are indeed slightly slower to call:
>>> from timeit import timeit >>> def f(): pass >>> timeit(f) # Call function implemented in Python, 10**6 times. 0.17691538999497425 >>> timeit(int) # Call function implemented in C, 10**6 times. 0.12964718499279115
But the more important difference between calling a function and pushing a stack is that the former has to allocate memory for a call stack frame whereas the latter only has to allocate memory for one slot in a list, and the deeper the function recurses, the bigger the difference.
But is it really faster to push and pop a stack than it is to call a function? In my timing tests it didn’t seem to be the case.
When timing differences are fairly small (as they are here) you have to be careful about how you interpret them. Timing comparisons can be skewed by things like the time taken to look up attribute names or global variables. Here’s a comparison of the two approaches that I think is fair:
def recurse(n): if n > 0: recurse(n - 1) def pushpop(n): stack = [iter([n])] while stack: for n in stack[-1]: if n > 0: stack.append(iter((n - 1,))) break else: stack.pop()
Let’s increase the recursion limit to 10,007 and compare the two functions:
>>> from sys import setrecursionlimit >>> setrecursionlimit(10**4 + 7) >>> from timeit import timeit >>> timeit(lambda:recurse(10**4), number=100) 0.9759416219894774 >>> timeit(lambda:pushpop(10**4), number=100) 0.8555211240018252
The stack of iterators pattern takes about 12% less time on this test, despite paying the cost of looking up the methods
stack.pop, and the built-in function
iter, on each loop. What if we apply a standard optimization and cache these methods in local variables?
def pushpop2(n): _push = stack.append _pop = stack.pop _iter = iter stack = [_iter([n])] while stack: for n in stack[-1]: if n > 0: _push(_iter((n - 1,))) break else: _pop()
Now the stack of iterators pattern takes 25% less time than recursion:
>>> timeit(lambda:pushpop2(10**4), number=100) 0.7366151169990189
Who cares about such a small performance improvement? If you really cared about performance, you wouldn’t be using Python in the first place.
Sure, the small performance improvement in CPython probably isn’t convincing enough, by itself, to make you adopt the stack of iterators pattern. But I think it’s important to mention it because it pre-empts a likely objection to the pattern. (Perhaps it was a mistake to make it my first point.) It’s really the second point that is likely to be most convincing: sometimes you’d like to be able to traverse deeply nested data structures without having to figure out in advance what value you need to pass to
Isn’t it kind of strange to enter the body of a
for loop and then immediately
Isn’t it kind of strange to enter the body of a function and then immediately call another function?
↩ This requires Python 3.3 or later in order to be able to use the
yield from statement.
↩ In software, a design pattern is a technique for working around a defect or omission in a particular programming language. For example, the well-known singleton pattern is a work-around for the lack of global variables in the Java language.
↩ Example adapted from this answer on Code Review.
↩ Example adapted from this answer on Code Review.
↩ The points here are my summary of the whole discussion and not intended to reflect the opinion of any particular participant.