[RFC] Aggregation IR design

This is a proposal for a representation of aggregation in the IR. Compared to the current representation, it eliminates the need for:

  • multiple binding contexts (pre ExtractAggregators),
  • side-effects in the IR (i.e. SeqOp and CompOp, post ExtractAggregators),
  • a global ExtractAggregators pass.

We add a new branch of the type hierarchy:

case class AggType(agg: Agg, retType: Type) extends BaseType

This represents a computation which has read/write access to an aggregator state, and returns a value of type retType. Note that this is a lazy representation; it doesn’t do anything until it is passed an aggregator state. This allows us to build up a complete computation before kicking it off by initializing a new aggregator state.

The type of the aggregator state is determined by agg: Agg, where Agg is a type enumerating the primitive aggregators and the combinators which create compound aggregators:

abstract class Agg
case class PrimAgg(aggSig: AggSignature) extends Agg
case class TupleAgg(childAggs: IndexedSeq[Agg]) extends Agg
case class ArrayPerElementAgg(childAgg: Agg) extends Agg
case object IDAgg extends Agg

Here, IDAgg is the trivial empty aggregator state.

For a fixed Agg, the primitive ways of building up an aggregation are AggMap and AggPure:

// PureAgg[agg <: Agg, R](x: R): AggType[agg, R]
//
// (state) =>
//   return x;
case class PureAgg(x: IR, agg: Agg) extends AggIR {
  def inferType: AggType(agg, x.typ)
}
// AggMap[A <: Agg, R...](children...: AggType[A, R]..., body: (R...) => S): AggType[A, S]
//
// (state) =>
//   res1 = child1(&state);
//   ...
//   resN = childN(&state);
//   return body(res1, ..., resN);
case class AggMap(children: Seq[(String, AggIR)], body: IR) extends AggIR {
  val agg = children(1)._2.typ.agg
  assert(children.forall(_._2.typ.agg == agg)
  val childAggType = childAggIR.typ
  def inferType: AggType = AggType(agg, body.typ)
  def bindings(childIdx: Int): Iterable[(String, Type)] =
    if (childIdx == 1) 
      children.map { case (name, ir) => name -> ir.typ.retType } 
    else 
      Array()
}

PureAgg represents a degenerate computation which doesn’t access the agg state at all. AggMap is the primitive way of composing multiple computations in sequence. The children are executed one at a time, left to right, modifying a single copy of the agg state. Because each returns some value, we must specify how to combine all the results into a single result in body.

Some useful special cases of AggMap are

AggSequence[A <: Agg, R](child1: AggType[A, Unit], ..., child(N-1): AggType[A, Unit], childN: AggType[A, R]): AggType[A, R]

which executes a sequence of computations, returning the result of the last one, and

AggLet[A, R1, R2](child: AggType[A, R1], body: (R1) => R2): AggType[A, R2]

which executes a computation once, binds its result to a name, then executes body (which has no access to the agg state).

For the primitive aggregators, we have:

// SeqOp[aggSig <: AggSignature](args...): AggType[PrimAgg[aggSig], Unit]
//
// (state) =>
//   aggSig.seqOp(&state, args...);
//   return ();
case class SeqOp(aggSig: AggSignature, args: IndexedSeq[IR]) extends AggIR {
  def inferType: AggType = AggType(PrimAgg(aggSig), TUnit)
}

// ReturnOp[aggSig <: AggSignature]: AggType[PrimAgg[aggSig], AggOp.getType[aggSig]]
// (state) =>
//   return aggSig.returnOp(const& state);
case class ReturnOp(aggSig: AggSignature) extends AggIR {
  def inferType: AggType = AggType(PrimAgg(aggSig), AggOp.getType(aggSig))
}

which respectively modify-only and read-only.

To actually perform an aggregation and get the resulting regular hail value (not wrapped in an AggType), the primitive operation is

// RunAgg[A <: Agg, R](initArgs: AggInitArgs[A], aggIR: AggType[A, R]): R
//
// (state) =>
//   state = A.initOp(initArgs);
//   res = aggIR(&state);
//   free state;
//   return res;
case class RunAgg(initArgs: AggInitArgs, aggIR: AggIR) extends IR {
  require (initArgs.agg == aggIR.typ.agg)
  val aggType = aggIR.typ
  def inferType: Type = aggType.retType
}

Note that the arguments to the initOp have moved from ApplyAggOp to the top of the aggregation, where they are actually used. This solves another scoping complexity from the current design. We also have specialized operations for kicking-off aggregation, such as:

// ArrayScan[A <: Agg, X, R](array: Array[X], body: (X) => AggType[A, R], initArgs: AggInitArgs[A]): Array[R]
//
// state = A.initOp(initArgs);
// for i = 0 .. array.length() {
//   res[i] = body(array[i])(&state);
// }
// free state;
// return res;
}
case class ArrayScan(array: IR, eltName: String, body: AggIR, initArgs: AggInitArgs) extends IR {
  require (initArgs.agg == body.typ.agg)
  def inferType: Type = TArray(body.typ.retType)
  def bindings(childIdx: Int): Iterable[(String, Type)] =
    if (childIdx == 1)
      Array(eltName -> array.typ.asInstanceOf[TArray].elementType)
    else
      Array()
}

case class TableMapRows(child: TableIR, newRow: AggIR, initArgs: AggInitArgs) extends TableIR {
  require (initArgs.agg == newRow.typ.agg)
}

Here AggInitArgs is a simple abstraction for building the set of arguments expected by the aggregator:

abstract class AggInitArgs { val agg: Agg }
case class PrimAggInit(agg: PrimAgg, initOpArgs: IR*) extends AggInitArgs
case class TupleAggInit(children: IndexedSeq[AggInitArgs]) extends AggInitArgs {
  val agg = TupleAgg(children.map(_.agg))
}
case class ArrayPerElementAggInit(init: AggInitArgs) extends AggInitArgs {
  val agg = ArrayPerElementAgg(init.agg)
}
case object IDAggInit extends AggInitArgs {
  val agg = IDAgg
}

To aggregate over an array, we have:

// ArrayAgg[A <: Agg, X, R](array: Array[X], body: (X) => AggType[A, R]): AggType[A, R]
// build computation that returns final agg result over array
//
// (state) =>
//   for i = 0 .. array.length() {
//     res = body(array[i])(&state);
//   }
// return res;
case class ArrayAgg(array: IR, eltName: String, body: AggIR) extends AggIR {
  val childAggType = body.typ
  def inferType: AggType = childAggType
  def bindings(childIdx: Int): Iterable[(String, Type)] =
    if (childIdx == 1)
      Array(eltName -> array.typ.asInstanceOf[TArray].elementType)
    else
      Array()
}

Note that, unlike ArrayScan, this has an AggType type, and so doesn’t take initargs. To actually perform an array aggregation, we can wrap this in a RunAgg. But we can also compose multiple ArrayAggs in a single aggregation, effectively implementing AggExplode. We could also make ArrayScan have an AggType type, which would allow implementing AggExplode for scans (which I think we can’t do currently?), but this complicates the codegen/lowering, and I’m keeping things a simple as possible for a first pass.

Finally, we need some simple combinators for building up aggregations with compound agg states:

// MakeTupleAgg[A... <: Agg..., R...](childAggIRs...: AggType[A, R]...): AggType[TupleAgg[A...], Tuple[R...]]
//
// (state) =>
//   ret1 = child1(&state._1);
//   ...
//   retN = childN(&state._N);
//   return (ret1, ..., retN);
case class MakeTupleAgg(childAggIRs: IndexedSeq[AggIR]) extends AggIR {
  val childAggs = childAggIRs.map(_.typ.agg)
  val childRetTypes = childAggIRs.map(_.typ.retType)
  def inferType: AggType = AggType(TupleAgg(childAggs), TTuple(childRetTypes: _*))
}

// MakeArrayPerElementAgg[A <: Agg, X, R](
//   array: Array[X],
//   body: (eltName: X) => AggType[A, R]
// ): AggType[ArrayPerElementAgg[A], Array[R]]
//
// for i = 0 .. array.length() {
//   res[i] = body(array[i])(&state[i]);
// }
// return res;
}
case class MakeArrayPerElementAgg(array: IR, eltName: String, body: AggIR) extends AggIR {
  val childAggType = body.typ
  def inferType: AggType = AggType(ArrayPerElementAgg(childAggType.agg), TArray(childAggType.retType))
  def bindings(childIdx: Int): Iterable[(String, Type)] =
    if (childIdx == 1)
      Array(eltName -> array.typ.asInstanceOf[TArray].elementType)
    else
      Array()
}

Besides just making the types work out, these should simplify the codegen, by marking where we need to load a component of a compound state, allowing the following code to be unaware of the compound state.

I’ll add more to this, elaborating on codegen/lowering and optimization, and proposing some possible first steps to migrate to the new design. I could also give some more detail on the semantics of the AggType abstraction, and why it’s guaranteed to be parallelizable.

The first post didn’t mention filter. It is slightly tricky.

We can add an AggFilter node in the obvious way:

// AggFilter[A <: Agg, R](cond: Bool, child: AggType[A, R]): AggType[A, R]
case class AggFilter(cond: IR, child: AggIR): AggIR {
  def inferType: AggType = child.typ
}

What makes this tricky is that child represents a computation which both modifies the current state and computes some result based on the new state. If the filter condition is false, we want to leave the state as is, but we still need to produce a value from the state.

In this simple version which leaves out dependent aggregations (one aggregation depends on the result of another), and requires AggScan to be top-level (not producing an AggType), it turns out every aggregation can be split into a modify step followed by a read step. This makes filter easy to implement.

The following simple pass eliminates the AggFilter node, implementing it in terms of the remaining nodes:

def lowerFilter(ir: IR): IR = ir.children.map(lowerFilter)
def lowerFilter(ir: AggIR): AggIR = {
  (write, read) = split(ir)
  AggSequence(write, read)
}
// split ir into two steps, a write only step with return type TUnit,
// and a read only step with original return type
def split(ir: AggIR): (AggIR, AggIR) = ir match {
  case AggFilter(cond, child) =>
    (write, read) = split(child)
    (If(cond, write, AggPure(Void(), child.typ.agg)),
     read)
  case ir@ReturnOp(_) =>
    (AggPure(Void(), ir.typ.agg),
     ir)
  case ir@SeqOp(_) =>
    (ir, 
     AggPure(Void(), ir.typ.agg))
  case ArrayAgg(array, eltName, body) =>
    (write, read) = split(body)
    (ArrayAgg(array, eltName, write), 
     read)
  case AggMap(children, body) =>
    names = children.map(_._1)
    splitChildren = children.map(split(_._2))
    writes = splitChildren.map(_._1)
    reads = splitChildren.map(_._2)
    (AggSequence(writes), 
     AggMap(names zip reads, body))
  case MakeTupleAgg(children) =>
    splitChildren = children.map(split(_._2))
    writes = splitChildren.map(_._1)
    reads = splitChildren.map(_._2)
    // the type of the writes piece a tuple of units instead of unit,
    // but that's easy to fix
    (MakeTupleAgg(writes), 
     MakeTupleAgg(reads))
  etc...
}

It is clear from inspection that there are no AggFilters in the result, and that there is no code duplication being introduced—aside from the duplicated MakeTupleAgg nodes, and the similar (unshown) duplicated MakeArrayPerElementAgg nodes, which will correspond to loading the appropriate agg states multiple times. But this is unavoidable in general. Avoiding this duplication when possible would be an optimization, either integrated into this pass or cleaned up in a later pass.