FoldConstants Rewrite

Steps:

  1. We introduce a function:
    def identifyConstantSubtrees(ir: BaseIR): Memo[Unit]
    it takes in an IR and recurs over it doing the following:
  • For any IRs that introduce constant bindings, use UsesAndDefs to find the Refs, then mark them as being constant. Then recur on children
  • For all other IRs, recur on children to check if they’re constant subtrees. For most value IRs, we can say that they are constant if their children are constant. There are exceptions, like In, ApplySeeded, UUID4. TableIRs and MatrixTableIrs cannot be constant.
  1. We have a second function:
    def fixupUnrealizables(ir: BaseIR, constantSubtrees: Memo[Unit], usesAndDefs: Memo[IR], parentPointers: Memo[IR])
    This recurs over the tree, stopping at each identified constant subtree and checking if it is of a realizable type. If it is not a realizable subtree, then we need to remove it from constantSubtrees, as well as deal with the Refs it introduces. That means using usesAndDefs to jump to the relevant Refs, removing them from memo, then using parentPointers to walk up the IR tree, removing each IR it encounters from memo until encountering one that isn’t constant. We also need to recur on the child of the original unrealizable constant subtree.

(We may also consider filtering out any degenerate constant subtrees at this step. I.e., we shouldn’t count a single constant (like just an I32 or Literal) as a constant subtree).

  1. Now we need to collect the constant subtrees to compile. constantSubtrees contains too many subtrees, since some are children of others, so we should just walk the IR again, stopping and collecting each subtree we encounter. We have to handle lets in a special way as we do this. Define a function:
    def collectConstantIRs(ir: BaseIR, constantSubtrees: Memo[Unit]): (IndexedSeq[(String, IR)], IndexedSeq[IR])

Whenever we see a Let that isn’t a constant subtree, we need to check if the value it binds is a constant subtree. If the value is a constant subtree, we need to add a pair of the name being bound and the constantIR it binds to an IndexedSeq[(String, IR)] (or a mutable list type if that’s easier). When we encounter a constant subtree, we add it to an IndexedSeq[IR]. We return both of these things.

  1. We turn the IndexedSeq[(String, IR)] into a series of let bindings. In the body of the innermost let, we use MakeTuple to create a tuple of all of constant subtrees. Then we call CompileAndEvaluate on this IR we’ve generated.

  2. One last walk through the IR, reconstructing it as we go this time. Now when we encounter a constant subtree, we figure out which position in the tuple it matches up with and replace constant subtree with a Literal or a primitive IR wrapping the value from the tuple.

This should depend on the node defining the binding, right? So in Let(value, "x", body), first recur on value; then mark Ref("x")s to be constant if and only if value was inferred to be constant; then recur on body.

1 Like

Yes, I’ll edit, thanks.

Slight complication. We need to potentially update the Refs twice during phase 1.

Consider

StreamMap('x' StreamRange(1, 10), ApplyBinaryPrimOp(Add(), Ref("x"), ApplySeeded("randNorm" ....)))

We would mark Ref('x') as constant based on the StreamRange, but then the body wouldn’t be constant because of the ApplySeeded. At that moment you need to go back and make the Ref('x') not constant anymore, as well as its parents.

Doesn’t the second pass fix this up? We mark the StreamMap node as not realizable, so the second pass will fix up its refs and propagate the non-static property upward from them.

I guess it does. I originally had the second pass only considering things that were constant, which in this case the StreamMap wouldn’t be, but it wouldn’t be a big lift to get it to include this too.