Fun with compiler optimizations

I just had to rewrite a non-trivial recursive algorithm as an imperative one (python sucks at recursion), and had some fun applying some compiler techniques to automate the process. I thought somebody else might think this was cool too.

I’ll abstract the problem to a standard depth-first traversal of a tree, where each node can have a variable number of children (e.g. our IR), and we may need to do some work at pre-order, post-order, and all the in-order positions. We’re also maintaining some state, which I’m hiding inside some State type, and to make things easier later I’m treating State as immutable, so most functions return the updated state.

And to throw in an extra wrinkle, in my case I can sometimes stop the traversal early. shallow_visit visits a node without looking at any of its children.

def visit(x: Node, s: State) -> State:
    s = pre(x, s)
    for i in range(len(x.children)):
        if cond(x, i):
            s = shallow_visit(x, i, s)
        else:
            s = pre_child(x, i, s)
            child = x.children[i]
            s = visit(child, s)
            s = post_child(x, i, s)
    return post(x, s)

def start(root: Node) -> X:
    s = init_state()
    s = visit(root, s)
    return extract_answer(s)

Ultimately, the way to turn recursion into loops is to write the recursion in a tail-recursive form. So most of this process is trying to get to a tail-recursive formulation.

First I replace the for loop with recursion (moving further from my goal!). This helps make all control flow explicit.

def visit(x: Node, s: State) -> State:
    s = pre(x, s)
    def loop(i: int, ls: State) -> State:
        if i >= len(x.children):
            return ls
        if cond(x, i):
            ls = shallow_visit(x, i, ls)
            loop(i+1, ls)
        else:
            ls = pre_child(x, i, ls)
            child = x.children[i]
            ls = visit(child, ls)
            ls = post_child(x, i, ls)
            loop(i+1, ls)
    s = loop(0, s)
    return post(x, s)

def start(root: Node) -> X:
    s = init_state()
    s = visit(root, s)
    return extract_answer(s)

Then I lift the loop function to the top level:

def loop(x: Node, i: int, ls: State) -> State:
    if i >= len(x.children):
        return ls
    if cond(x, i):
        ls = shallow_visit(x, i, ls)
        return loop(x, i+1, ls)
    else:
        ls = pre_child(x, i, ls)
        child = x.children[i]
        ls = visit(child, ls)
        ls = post_child(x, i, ls)
        return loop(x, i+1, ls)

def visit(x: Node, s: State) -> State:
    s = pre(x, s)
    s = loop(x, 0, s)
    return post(x, s)

def start(root: Node) -> X:
    s = init_state()
    s = visit(root, s)
    return extract_answer(s)

Next I perform a (partial) CPS transform. CPS is continuation passing style: instead of returning values, functions are passed callbacks which they call with their result value. This makes even more control flow explicit. In this case I’m still returning some X value, where X is the type of the final result of this entire algorithm, so k represents the computation remaining to be done in the algorithm.

def loop(x: Node, i: int, ls: State, k: State -> X) -> X:
    if i >= len(x.children):
        return k(ls)
    if cond(x, i):
        ls = shallow_visit(x, i, ls)
        return loop(x, i+1, ls, k)
    else:
        ls = pre_child(x, i, ls)
        child = x.children[i]
        return visit(child, ls, lambda ls1:
            ls1 = post_child(x, i, ls1)
            return loop(x, i+1, ls1, k)
        )

def visit(x: Node, s: State, k: State -> X) -> X:
    s = pre(x, s)
    return loop(x, 0, s, lambda s1:
        s1 = post(x, s1)
        return k(s1)
    )

def start(root: Node) -> X:
    s = init_state()
    return visit(root, s, lambda s1:
        return extract_answer(s1)
    )

Next I inline visit, so there is just one recursive function, instead of a pair of mutually recursive functions.

def loop(x: Node, i: int, ls: State, k: State -> X) -> X:
    if i >= len(x.children):
        return k(ls)
    if cond(x, i):
        ls = shallow_visit(x, i, ls)
        return loop(x, i+1, ls, k)
    else:
        ls = pre_child(x, i, ls)
        child = x.children[i]
        s = pre(child, ls)
        return loop(child, 0, s, lambda s1:
            s1 = post(child, s1)
            ls1 = post_child(x, i, s1)
            return loop(x, i+1, ls1, k)
        )

def start(root: Node) -> X:
    s = init_state()
    s = pre(root, s)
    return loop(root, 0, s, lambda s1:
        s1 = post(root, s1)
        return extract_answer(s1)
    )

Now the real meat (besides CPS): Defunctionalization. This transforms a higher-order language (with lambdas and first-class function values) into a first-order language (all functions are defined at the top level), by replacing lambdas with objects that contain their closures. Hopefully the basic idea is clear from the example:

def loop(x: Node, i: int, ls: State, k: Kont) -> X:
    if i >= len(x.children):
        return k(ls)
    if cond(x, i):
        ls = shallow_visit(x, i, ls)
        return loop(x, i+1, ls, k)
    else:
        ls = pre_child(x, i, ls)
        child = x.children[i]
        s = pre(child, ls)
        return loop(child, 0, s, PostChild(x, i, k))

def start(root: Node) -> X:
    s = init_state()
    s = pre(root, s)
    return loop(root, 0, s, Init(root))

class Kont:
    def __call__(s: State) -> X:
        ...

class PostChild(Kont):
    def __init__(x: Node, i: int, k: Kont):
        self.x = x
        self.i = i
        self.k = k

    def __call__(s: State) -> X:
        s = post(self.x.children[self.i], s)
        s = post_child(self.x, self.i, s)
        return loop(self.x, self.i + 1, s, self.k)

class Init(Kont):
    def __init__(x: Node):
        self.x = x

    def __call__(s: State):
        s = post(self.x, s)
        return extract_answer(s)

and then I inline Kont.__call__:

def loop(x: Node, i: int, ls: State, k: Kont) -> X:
    if i >= len(x.children):
        if isinstance(k, PostChild):
            s = post(k.x.children[k.i], ls)
            s = post_child(k.x, k.i, s)
            return loop(k.x, k.i + 1, s, k.k)
        else:
            assert(isinstance(k, Init))
            s = post(k.x, ls)
            return extract_answer(s)
    if cond(x, i):
        ls = shallow_visit(x, i, ls)
        return loop(x, i+1, ls, k)
    else:
        ls = pre_child(x, i, ls)
        child = x.children[i]
        s = pre(child, ls)
        return loop(child, 0, s, PostChild(x, i, child, k))

def start(root: Node) -> X:
    s = init_state()
    s = pre(root, s)
    return loop(root, 0, s, Init(root))

class Kont:
    ...

class PostChild(Kont):
    def __init__(x: Node, i: int, k: Kont):
        self.x = x
        self.i = i
        self.k = k

class Init(Kont):
    def __init__(x: Node):
        self.x = x

Now I’ve gotten to a state where loop is tail-recursive, so finally tail-call optimization can remove all recursion. I also inline loop.

def start(root: Node) -> X:
    s = init_state()
    s = pre(root, s)
    x = root
    i = 0
    k = Init(x)
    while True:
        if i >= len(x.children):
            if isinstance(k, PostChild):
                s = post(k.x.children[k.i], s)
                s = post_child(k.x, k.i, s)
                x = k.x
                i = k.i + 1
                k = k.k
                continue
            else:
                assert(isinstance(k, Init))
                s = post(k.x, s)
                return extract_answer(s)
        if cond(x, i):
            s = shallow_visit(x, i, s)
            i = i+1
            continue
        else:
            s = pre_child(x, i, s)
            child = x.children[i]
            s = pre(child, s)
            k = PostChild(x, i, k)
            x = child
            i = 0
            continue

class Kont:
    ...

class PostChild(Kont):
    def __init__(x: Node, i: int, k: Kont):
        self.x = x
        self.i = i
        self.k = k

class Init(Kont):
    def __init__(x: Node):
        self.x = x

Finally, notice that Kont is really a linked list: there are two concrete subclasses, one containing pointer to another Kont (a cons cell), and one that doesn’t (the terminator of the list). Since the terminator holds a Node, this is equivalent to a list of (Node, int) pairs, plus a Node on the side that was in the terminator, in this case that is always root:

def start(root: Node) -> X:
    s = init_state()
    s = pre(root, s)
    x = root
    i = 0
    k: List[StackFrame] = []
    while True:
        if i >= len(x.children):
            if len(k) > 0:
                frame = k[-1]
                s = post(frame.x.children[frame.i], s)
                s = post_child(frame.x, frame.i, s)
                x = frame.x
                i = frame.i + 1
                k.pop()
                continue
            else:
                s = post(root, s)
                return extract_answer(s)
        if cond(x, i):
            s = shallow_visit(x, i, s)
            i = i+1
            continue
        else:
            s = pre_child(x, i, s)
            child = x.children[i]
            s = pre(child, s)
            k.append(StackFrame(x, i))
            x = child
            i = 0
            continue

class StackFrame
    def __init__(x: Node, i: int):
        self.x = x
        self.i = i

I think I would have had a hard time coming up with quite as clean an imperative algorithm without the mechanical transformation (though it looks so obvious after the fact). I also think it’s cool how knowledge of compiler optimizations sometimes comes in handy for writing good code by hand!

For a more leisurely intro to this technique, there is a nice talk with very readable transcript here, and the original papers (one and two) are also very readable and have a bunch more applications.