This is a proposal for a representation of aggregation in the IR. Compared to the current representation, it eliminates the need for:
- multiple binding contexts (pre
ExtractAggregators
), - side-effects in the IR (i.e.
SeqOp
andCompOp
, postExtractAggregators
), - a global
ExtractAggregators
pass.
We add a new branch of the type hierarchy:
case class AggType(agg: Agg, retType: Type) extends BaseType
This represents a computation which has read/write access to an aggregator state, and returns a value of type retType
. Note that this is a lazy representation; it doesn’t do anything until it is passed an aggregator state. This allows us to build up a complete computation before kicking it off by initializing a new aggregator state.
The type of the aggregator state is determined by agg: Agg
, where Agg
is a type enumerating the primitive aggregators and the combinators which create compound aggregators:
abstract class Agg
case class PrimAgg(aggSig: AggSignature) extends Agg
case class TupleAgg(childAggs: IndexedSeq[Agg]) extends Agg
case class ArrayPerElementAgg(childAgg: Agg) extends Agg
case object IDAgg extends Agg
Here, IDAgg
is the trivial empty aggregator state.
For a fixed Agg
, the primitive ways of building up an aggregation are AggMap
and AggPure
:
// PureAgg[agg <: Agg, R](x: R): AggType[agg, R]
//
// (state) =>
// return x;
case class PureAgg(x: IR, agg: Agg) extends AggIR {
def inferType: AggType(agg, x.typ)
}
// AggMap[A <: Agg, R...](children...: AggType[A, R]..., body: (R...) => S): AggType[A, S]
//
// (state) =>
// res1 = child1(&state);
// ...
// resN = childN(&state);
// return body(res1, ..., resN);
case class AggMap(children: Seq[(String, AggIR)], body: IR) extends AggIR {
val agg = children(1)._2.typ.agg
assert(children.forall(_._2.typ.agg == agg)
val childAggType = childAggIR.typ
def inferType: AggType = AggType(agg, body.typ)
def bindings(childIdx: Int): Iterable[(String, Type)] =
if (childIdx == 1)
children.map { case (name, ir) => name -> ir.typ.retType }
else
Array()
}
PureAgg
represents a degenerate computation which doesn’t access the agg state at all. AggMap
is the primitive way of composing multiple computations in sequence. The children
are executed one at a time, left to right, modifying a single copy of the agg state. Because each returns some value, we must specify how to combine all the results into a single result in body
.
Some useful special cases of AggMap
are
AggSequence[A <: Agg, R](child1: AggType[A, Unit], ..., child(N-1): AggType[A, Unit], childN: AggType[A, R]): AggType[A, R]
which executes a sequence of computations, returning the result of the last one, and
AggLet[A, R1, R2](child: AggType[A, R1], body: (R1) => R2): AggType[A, R2]
which executes a computation once, binds its result to a name, then executes body
(which has no access to the agg state).
For the primitive aggregators, we have:
// SeqOp[aggSig <: AggSignature](args...): AggType[PrimAgg[aggSig], Unit]
//
// (state) =>
// aggSig.seqOp(&state, args...);
// return ();
case class SeqOp(aggSig: AggSignature, args: IndexedSeq[IR]) extends AggIR {
def inferType: AggType = AggType(PrimAgg(aggSig), TUnit)
}
// ReturnOp[aggSig <: AggSignature]: AggType[PrimAgg[aggSig], AggOp.getType[aggSig]]
// (state) =>
// return aggSig.returnOp(const& state);
case class ReturnOp(aggSig: AggSignature) extends AggIR {
def inferType: AggType = AggType(PrimAgg(aggSig), AggOp.getType(aggSig))
}
which respectively modify-only and read-only.
To actually perform an aggregation and get the resulting regular hail value (not wrapped in an AggType
), the primitive operation is
// RunAgg[A <: Agg, R](initArgs: AggInitArgs[A], aggIR: AggType[A, R]): R
//
// (state) =>
// state = A.initOp(initArgs);
// res = aggIR(&state);
// free state;
// return res;
case class RunAgg(initArgs: AggInitArgs, aggIR: AggIR) extends IR {
require (initArgs.agg == aggIR.typ.agg)
val aggType = aggIR.typ
def inferType: Type = aggType.retType
}
Note that the arguments to the initOp have moved from ApplyAggOp
to the top of the aggregation, where they are actually used. This solves another scoping complexity from the current design. We also have specialized operations for kicking-off aggregation, such as:
// ArrayScan[A <: Agg, X, R](array: Array[X], body: (X) => AggType[A, R], initArgs: AggInitArgs[A]): Array[R]
//
// state = A.initOp(initArgs);
// for i = 0 .. array.length() {
// res[i] = body(array[i])(&state);
// }
// free state;
// return res;
}
case class ArrayScan(array: IR, eltName: String, body: AggIR, initArgs: AggInitArgs) extends IR {
require (initArgs.agg == body.typ.agg)
def inferType: Type = TArray(body.typ.retType)
def bindings(childIdx: Int): Iterable[(String, Type)] =
if (childIdx == 1)
Array(eltName -> array.typ.asInstanceOf[TArray].elementType)
else
Array()
}
case class TableMapRows(child: TableIR, newRow: AggIR, initArgs: AggInitArgs) extends TableIR {
require (initArgs.agg == newRow.typ.agg)
}
Here AggInitArgs
is a simple abstraction for building the set of arguments expected by the aggregator:
abstract class AggInitArgs { val agg: Agg }
case class PrimAggInit(agg: PrimAgg, initOpArgs: IR*) extends AggInitArgs
case class TupleAggInit(children: IndexedSeq[AggInitArgs]) extends AggInitArgs {
val agg = TupleAgg(children.map(_.agg))
}
case class ArrayPerElementAggInit(init: AggInitArgs) extends AggInitArgs {
val agg = ArrayPerElementAgg(init.agg)
}
case object IDAggInit extends AggInitArgs {
val agg = IDAgg
}
To aggregate over an array, we have:
// ArrayAgg[A <: Agg, X, R](array: Array[X], body: (X) => AggType[A, R]): AggType[A, R]
// build computation that returns final agg result over array
//
// (state) =>
// for i = 0 .. array.length() {
// res = body(array[i])(&state);
// }
// return res;
case class ArrayAgg(array: IR, eltName: String, body: AggIR) extends AggIR {
val childAggType = body.typ
def inferType: AggType = childAggType
def bindings(childIdx: Int): Iterable[(String, Type)] =
if (childIdx == 1)
Array(eltName -> array.typ.asInstanceOf[TArray].elementType)
else
Array()
}
Note that, unlike ArrayScan
, this has an AggType
type, and so doesn’t take initargs. To actually perform an array aggregation, we can wrap this in a RunAgg
. But we can also compose multiple ArrayAgg
s in a single aggregation, effectively implementing AggExplode
. We could also make ArrayScan
have an AggType
type, which would allow implementing AggExplode
for scans (which I think we can’t do currently?), but this complicates the codegen/lowering, and I’m keeping things a simple as possible for a first pass.
Finally, we need some simple combinators for building up aggregations with compound agg states:
// MakeTupleAgg[A... <: Agg..., R...](childAggIRs...: AggType[A, R]...): AggType[TupleAgg[A...], Tuple[R...]]
//
// (state) =>
// ret1 = child1(&state._1);
// ...
// retN = childN(&state._N);
// return (ret1, ..., retN);
case class MakeTupleAgg(childAggIRs: IndexedSeq[AggIR]) extends AggIR {
val childAggs = childAggIRs.map(_.typ.agg)
val childRetTypes = childAggIRs.map(_.typ.retType)
def inferType: AggType = AggType(TupleAgg(childAggs), TTuple(childRetTypes: _*))
}
// MakeArrayPerElementAgg[A <: Agg, X, R](
// array: Array[X],
// body: (eltName: X) => AggType[A, R]
// ): AggType[ArrayPerElementAgg[A], Array[R]]
//
// for i = 0 .. array.length() {
// res[i] = body(array[i])(&state[i]);
// }
// return res;
}
case class MakeArrayPerElementAgg(array: IR, eltName: String, body: AggIR) extends AggIR {
val childAggType = body.typ
def inferType: AggType = AggType(ArrayPerElementAgg(childAggType.agg), TArray(childAggType.retType))
def bindings(childIdx: Int): Iterable[(String, Type)] =
if (childIdx == 1)
Array(eltName -> array.typ.asInstanceOf[TArray].elementType)
else
Array()
}
Besides just making the types work out, these should simplify the codegen, by marking where we need to load a component of a compound state, allowing the following code to be unaware of the compound state.
I’ll add more to this, elaborating on codegen/lowering and optimization, and proposing some possible first steps to migrate to the new design. I could also give some more detail on the semantics of the AggType
abstraction, and why it’s guaranteed to be parallelizable.