# Fun with compiler optimizations

I just had to rewrite a non-trivial recursive algorithm as an imperative one (python sucks at recursion), and had some fun applying some compiler techniques to automate the process. I thought somebody else might think this was cool too.

I’ll abstract the problem to a standard depth-first traversal of a tree, where each node can have a variable number of children (e.g. our IR), and we may need to do some work at pre-order, post-order, and all the in-order positions. We’re also maintaining some state, which I’m hiding inside some `State` type, and to make things easier later I’m treating `State` as immutable, so most functions return the updated state.

And to throw in an extra wrinkle, in my case I can sometimes stop the traversal early. `shallow_visit` visits a node without looking at any of its children.

``````def visit(x: Node, s: State) -> State:
s = pre(x, s)
for i in range(len(x.children)):
if cond(x, i):
s = shallow_visit(x, i, s)
else:
s = pre_child(x, i, s)
child = x.children[i]
s = visit(child, s)
s = post_child(x, i, s)
return post(x, s)

def start(root: Node) -> X:
s = init_state()
s = visit(root, s)
``````

Ultimately, the way to turn recursion into loops is to write the recursion in a tail-recursive form. So most of this process is trying to get to a tail-recursive formulation.

First I replace the for loop with recursion (moving further from my goal!). This helps make all control flow explicit.

``````def visit(x: Node, s: State) -> State:
s = pre(x, s)
def loop(i: int, ls: State) -> State:
if i >= len(x.children):
return ls
if cond(x, i):
ls = shallow_visit(x, i, ls)
loop(i+1, ls)
else:
ls = pre_child(x, i, ls)
child = x.children[i]
ls = visit(child, ls)
ls = post_child(x, i, ls)
loop(i+1, ls)
s = loop(0, s)
return post(x, s)

def start(root: Node) -> X:
s = init_state()
s = visit(root, s)
``````

Then I lift the `loop` function to the top level:

``````def loop(x: Node, i: int, ls: State) -> State:
if i >= len(x.children):
return ls
if cond(x, i):
ls = shallow_visit(x, i, ls)
return loop(x, i+1, ls)
else:
ls = pre_child(x, i, ls)
child = x.children[i]
ls = visit(child, ls)
ls = post_child(x, i, ls)
return loop(x, i+1, ls)

def visit(x: Node, s: State) -> State:
s = pre(x, s)
s = loop(x, 0, s)
return post(x, s)

def start(root: Node) -> X:
s = init_state()
s = visit(root, s)
``````

Next I perform a (partial) CPS transform. CPS is continuation passing style: instead of returning values, functions are passed callbacks which they call with their result value. This makes even more control flow explicit. In this case I’m still returning some `X` value, where `X` is the type of the final result of this entire algorithm, so `k` represents the computation remaining to be done in the algorithm.

``````def loop(x: Node, i: int, ls: State, k: State -> X) -> X:
if i >= len(x.children):
return k(ls)
if cond(x, i):
ls = shallow_visit(x, i, ls)
return loop(x, i+1, ls, k)
else:
ls = pre_child(x, i, ls)
child = x.children[i]
return visit(child, ls, lambda ls1:
ls1 = post_child(x, i, ls1)
return loop(x, i+1, ls1, k)
)

def visit(x: Node, s: State, k: State -> X) -> X:
s = pre(x, s)
return loop(x, 0, s, lambda s1:
s1 = post(x, s1)
return k(s1)
)

def start(root: Node) -> X:
s = init_state()
return visit(root, s, lambda s1:
)
``````

Next I inline `visit`, so there is just one recursive function, instead of a pair of mutually recursive functions.

``````def loop(x: Node, i: int, ls: State, k: State -> X) -> X:
if i >= len(x.children):
return k(ls)
if cond(x, i):
ls = shallow_visit(x, i, ls)
return loop(x, i+1, ls, k)
else:
ls = pre_child(x, i, ls)
child = x.children[i]
s = pre(child, ls)
return loop(child, 0, s, lambda s1:
s1 = post(child, s1)
ls1 = post_child(x, i, s1)
return loop(x, i+1, ls1, k)
)

def start(root: Node) -> X:
s = init_state()
s = pre(root, s)
return loop(root, 0, s, lambda s1:
s1 = post(root, s1)
)
``````

Now the real meat (besides CPS): Defunctionalization. This transforms a higher-order language (with lambdas and first-class function values) into a first-order language (all functions are defined at the top level), by replacing lambdas with objects that contain their closures. Hopefully the basic idea is clear from the example:

``````def loop(x: Node, i: int, ls: State, k: Kont) -> X:
if i >= len(x.children):
return k(ls)
if cond(x, i):
ls = shallow_visit(x, i, ls)
return loop(x, i+1, ls, k)
else:
ls = pre_child(x, i, ls)
child = x.children[i]
s = pre(child, ls)
return loop(child, 0, s, PostChild(x, i, k))

def start(root: Node) -> X:
s = init_state()
s = pre(root, s)
return loop(root, 0, s, Init(root))

class Kont:
def __call__(s: State) -> X:
...

class PostChild(Kont):
def __init__(x: Node, i: int, k: Kont):
self.x = x
self.i = i
self.k = k

def __call__(s: State) -> X:
s = post(self.x.children[self.i], s)
s = post_child(self.x, self.i, s)
return loop(self.x, self.i + 1, s, self.k)

class Init(Kont):
def __init__(x: Node):
self.x = x

def __call__(s: State):
s = post(self.x, s)
``````

and then I inline `Kont.__call__`:

``````def loop(x: Node, i: int, ls: State, k: Kont) -> X:
if i >= len(x.children):
if isinstance(k, PostChild):
s = post(k.x.children[k.i], ls)
s = post_child(k.x, k.i, s)
return loop(k.x, k.i + 1, s, k.k)
else:
assert(isinstance(k, Init))
s = post(k.x, ls)
if cond(x, i):
ls = shallow_visit(x, i, ls)
return loop(x, i+1, ls, k)
else:
ls = pre_child(x, i, ls)
child = x.children[i]
s = pre(child, ls)
return loop(child, 0, s, PostChild(x, i, child, k))

def start(root: Node) -> X:
s = init_state()
s = pre(root, s)
return loop(root, 0, s, Init(root))

class Kont:
...

class PostChild(Kont):
def __init__(x: Node, i: int, k: Kont):
self.x = x
self.i = i
self.k = k

class Init(Kont):
def __init__(x: Node):
self.x = x
``````

Now I’ve gotten to a state where `loop` is tail-recursive, so finally tail-call optimization can remove all recursion. I also inline `loop`.

``````def start(root: Node) -> X:
s = init_state()
s = pre(root, s)
x = root
i = 0
k = Init(x)
while True:
if i >= len(x.children):
if isinstance(k, PostChild):
s = post(k.x.children[k.i], s)
s = post_child(k.x, k.i, s)
x = k.x
i = k.i + 1
k = k.k
continue
else:
assert(isinstance(k, Init))
s = post(k.x, s)
if cond(x, i):
s = shallow_visit(x, i, s)
i = i+1
continue
else:
s = pre_child(x, i, s)
child = x.children[i]
s = pre(child, s)
k = PostChild(x, i, k)
x = child
i = 0
continue

class Kont:
...

class PostChild(Kont):
def __init__(x: Node, i: int, k: Kont):
self.x = x
self.i = i
self.k = k

class Init(Kont):
def __init__(x: Node):
self.x = x
``````

Finally, notice that `Kont` is really a linked list: there are two concrete subclasses, one containing pointer to another `Kont` (a cons cell), and one that doesn’t (the terminator of the list). Since the terminator holds a `Node`, this is equivalent to a list of `(Node, int)` pairs, plus a `Node` on the side that was in the terminator, in this case that is always `root`:

``````def start(root: Node) -> X:
s = init_state()
s = pre(root, s)
x = root
i = 0
k: List[StackFrame] = []
while True:
if i >= len(x.children):
if len(k) > 0:
frame = k[-1]
s = post(frame.x.children[frame.i], s)
s = post_child(frame.x, frame.i, s)
x = frame.x
i = frame.i + 1
k.pop()
continue
else:
s = post(root, s)
if cond(x, i):
s = shallow_visit(x, i, s)
i = i+1
continue
else:
s = pre_child(x, i, s)
child = x.children[i]
s = pre(child, s)
k.append(StackFrame(x, i))
x = child
i = 0
continue

class StackFrame
def __init__(x: Node, i: int):
self.x = x
self.i = i
``````

I think I would have had a hard time coming up with quite as clean an imperative algorithm without the mechanical transformation (though it looks so obvious after the fact). I also think it’s cool how knowledge of compiler optimizations sometimes comes in handy for writing good code by hand!

For a more leisurely intro to this technique, there is a nice talk with very readable transcript here, and the original papers (one and two) are also very readable and have a bunch more applications.