Notes on the design and implementation of EmitStream
Motivation
A long-standing goal in the compiler backend has been to try to “lower” all table operations into array operations. Currently, tables are implemented using Apache Spark’s RDD interface, which entails creating Scala Iterator
s that produce region-allocated values. If we instead represented them in a lowered IR form, we would benefit on various fronts, such as:
- The ability to migrate away from Spark and use Dan’s distributed shuffler instead.
- Less maintainence burden for a new backend (C++ or LLVM), since there would be a smaller “surface area” of IR forms that we need to generate code for.
- The ability to take advantage of “staged programming” techniques, since the compiler becomes involved with partitioning and shuffling.
However, our previous implementation of arrays was not suited for lowering table operations. Crucially, operations that process two tables in parallel had fundamental performance problems. In particular, zip and join operations required that one of the entire partitions must be copied into RAM. Let’s explain why that is.
Previous design (“push-based iterators”)
The previous design works by compiling array operations into the ArrayIterator
interface. This interface gives us the ability to perform some callback on every element of the array. Here’s a simplified version of it:
abstract class ArrayIterator[A] {
/* 'cont' gets called for each element of the array */
def addElements(cont: Code[A] => Code[Unit]): Code[Unit]
}
This interface is attractive because it allows us to compile, for instance, ArrayRange
, directly
into a loop:
def range(len: Code[Int]) =
new ArrayIterator[Int] {
def addElements(cont: Code[Int] => Code[Unit]): Code[Unit] =
/* (Psuedo-code to avoid gritty details of JVM staged programming) */
""" var i = 0
while (i < len) {
cont(i)
i = i + 1
} """
}
However, we encounter a problem with this approach, when we attempt to compile a parallel array
operation, such as a “zip”:
def zipWith(
aitA: ArrayIterator[A],
aitB: ArrayIterator[B],
f: (Code[A], Code[B]) => Code[C]
) =
new ArrayIterator[C] {
def addElements(cont: Code[C] => Code[Unit]): Code[Unit] =
aitA.addElements { a =>
/* 'aitB.addElements(..)' would loop over all of 'aitB' */
/* There's no way to "step" through it one element at a time :( */
}
}
The previous “fix” for situations like above was to collect all of the elements of one of the arrays into memory, and then index into it. However, this is unnecessarily inefficient, and we can find better ways around it by changing the design entirely.
New design (“pull-based streams”)
What we need is some interface that allows us to ask for the next immediate element, without looping over the entire array. Take the following interface.
abstract class Stream[A] {
type S /* internal state type */
def init: Code[S]
def step(s: Code[S]): Code[Step[A, S]]
}
sealed trait Step[A, S]
case object EOS extends Step[A, S] /* "end of stream" */
case class Yield(elt: A, s: S) extends Step[A, S]
The existential type S
tells us what type of data we need to represent the current “state” of the stream. init
gives us an initial state value, and step
signals that either the stream is done (EOS
) or that there is an element, plus a new state (Yield
). To implement ArrayRange
, we don’t generate the loop itself; rather, we generate the initialization and the update portions of our loop separately.
def range(len: Code[Int]) =
new Stream[Int] {
type S = Int
def init = "0"
def step(i: Code[Int]): Code[Step[A, S]] =
""" if (i >= len)
EOS
else
Yield(i, i + 1) """
}
Aside: inspirations from category theory
Given the following functor F:
type F[X] = Step[A, X]
ArrayIterator[A]
is the initial F-algebra, and Stream[A]
is the terminal F-coalgebra; they are “duals.”
type InitAlg[F] = forAll[S] (F[S] => S) => S
type TermCoalg[F] = forSome[S] (S => F[S], S)
This is also the same functor one can use to define linked lists, using the fixpoint:
type List[A] = F[List[A]]
Iterating on the Stream
interface
The JVM-staged-programming aficionados on the Hail team would have taken issue with an aspect of the above interface: whenever a stream Yield
's a new element, it must allocate a new Yeild
object at runtime. One allocation per element is an unacceptable amount of pressure to put on the GC. We need a design that is efficient and does not allocate during runtime. How can we achieve that while also keeping the same general stream design?
We can solve this by using a clever continuation-passing-style (CPS) trick. Rather than returning values at runtime, we can pass values to a “continuation” parameter on each function.
abstract class Stream[A] {
type S /* internal state type */
def init (k: Code[S] => Code[Ctrl]): Code[Ctrl]
def step(s: Code[S])(k: Step[Code[A], Code[S]] => Code[Ctrl]): Code[Ctrl]
}
Notice how, instead of the k
's parameter being Code[Step[A, S]]
, it is Step[Code[A], Code[S]]
. This means that the Step
objects exist during compile time, and only the element (A
) and state (S
) values exist during runtime. What we have essentially done is converted data into control flow.
This technique can be utilized to further reduce allocations. Certain streams require multiple values in their state. We could combine them using e.g. a PStruct
, but that would require us to allocate and manipulate pointers, and would be considerably less efficient than storing the state as just a few variables. This is solved by taking advantage of the CPS trick again: allow the structure of the stream state (S
) to exist at compile time instead of runtime.
abstract class Stream[A] {
type S /* (staged representation of) internal state type */
def init (k: S => Code[Ctrl]): Code[Ctrl]
def step(s: S)(k: Step[Code[A], S] => Code[Ctrl]): Code[Ctrl]
}
/* for instance, if the state ought to be two integers, then
* we should define 'S' to be '(Code[Int], Code[Int])'
*/
(There are a few other details involved in the stream implementation, but they are probably not significant enough to talk about in depth.)
Progress
Currently, these techniques are being used in the latest version of Hail to implement deforested array operations. There is a bit of evaluation and testing work to be done before the old iterator implementation can be removed from the codebase entirely, but it is on its way.
Future work
In order to finish the table lowering efforts, some new array IR’s need to be designed to handle the various kinds of joins. The current ArrayLeftJoinDistinct
node is probably not sufficient to do the kinds of joins needed by lowering. Whatever joins we need should be able to be implemented using this interface.
A bigger obstacle to table lowering is correct region management at the level of array operations. We need to figure out a stream region management strategy that will be both correct and efficient before lowering real pipelines. The current implementation takes a naive approach where it does not clean up regions after individual elements.
zip
def zip[P, A, B](
left: Parameterized[P, A],
right: Parameterized[P, B]
): Parameterized[P, (A, B)] =
new Parameterized[P, (A, B)] {
implicit val leftStatePack = left.stateP
implicit val rightStatePack = right.stateP
type S = (left.S, right.S)
val stateP: ParameterPack[S] = implicitly
val emptyState: S = (left.emptyState, right.emptyState)
def length(s: S): Option[Code[Int]] =
(left.length(s._1), right.length(s._2)) match {
case (Some(n), Some(m)) => Some((n < m).mux(n, m))
case (_, _) => None
}
def init(
mb: MethodBuilder, jb: JoinPointBuilder, param: P
)(k: Init[S] => Code[Ctrl]): Code[Ctrl] = {
val j = jb.joinPoint()
j.define { _ => k(Missing) }
left.init(mb, jb, param) {
case Missing => j(())
case Start(leftS) =>
right.init(mb, jb, param) {
case Missing => j(())
case Start(rightS) => k(Start((leftS, rightS)))
}
}
}
def step(
mb: MethodBuilder, jb: JoinPointBuilder, state: S
)(k: Step[(A, B), S] => Code[Ctrl]): Code[Ctrl] = {
val (leftS, rightS) = state
val eos = jb.joinPoint()
eos.define { _ => k(EOS) }
left.step(mb, jb, leftS) {
case EOS => eos(())
case Yield(leftElt, leftS1) =>
right.step(mb, jb, rightS) {
case EOS => eos(())
case Yield(rightElt, rightS1) =>
k(Yield((leftElt, rightElt), (leftS1, rightS1)))
}
}
}
}