Notes on EmitStream

Notes on the design and implementation of EmitStream

Motivation

A long-standing goal in the compiler backend has been to try to “lower” all table operations into array operations. Currently, tables are implemented using Apache Spark’s RDD interface, which entails creating Scala Iterators that produce region-allocated values. If we instead represented them in a lowered IR form, we would benefit on various fronts, such as:

  • The ability to migrate away from Spark and use Dan’s distributed shuffler instead.
  • Less maintainence burden for a new backend (C++ or LLVM), since there would be a smaller “surface area” of IR forms that we need to generate code for.
  • The ability to take advantage of “staged programming” techniques, since the compiler becomes involved with partitioning and shuffling.

However, our previous implementation of arrays was not suited for lowering table operations. Crucially, operations that process two tables in parallel had fundamental performance problems. In particular, zip and join operations required that one of the entire partitions must be copied into RAM. Let’s explain why that is.

Previous design (“push-based iterators”)

The previous design works by compiling array operations into the ArrayIterator interface. This interface gives us the ability to perform some callback on every element of the array. Here’s a simplified version of it:

abstract class ArrayIterator[A] {
  /* 'cont' gets called for each element of the array */
  def addElements(cont: Code[A] => Code[Unit]): Code[Unit]
}

This interface is attractive because it allows us to compile, for instance, ArrayRange, directly
into a loop:

def range(len: Code[Int]) =
  new ArrayIterator[Int] {
    def addElements(cont: Code[Int] => Code[Unit]): Code[Unit] =
      /* (Psuedo-code to avoid gritty details of JVM staged programming) */
      """ var i = 0
          while (i < len) {
            cont(i)
            i = i + 1
          } """
  }

However, we encounter a problem with this approach, when we attempt to compile a parallel array
operation, such as a “zip”:

def zipWith(
  aitA: ArrayIterator[A],
  aitB: ArrayIterator[B],
  f: (Code[A], Code[B]) => Code[C]
) =
  new ArrayIterator[C] {
    def addElements(cont: Code[C] => Code[Unit]): Code[Unit] =
      aitA.addElements { a =>
        /* 'aitB.addElements(..)' would loop over all of 'aitB' */
        /* There's no way to "step" through it one element at a time :( */
      }
  }

The previous “fix” for situations like above was to collect all of the elements of one of the arrays into memory, and then index into it. However, this is unnecessarily inefficient, and we can find better ways around it by changing the design entirely.

New design (“pull-based streams”)

What we need is some interface that allows us to ask for the next immediate element, without looping over the entire array. Take the following interface.

abstract class Stream[A] {
  type S /* internal state type */
  def init: Code[S]
  def step(s: Code[S]): Code[Step[A, S]]
}

sealed trait Step[A, S]
case object EOS                extends Step[A, S] /* "end of stream" */
case class Yield(elt: A, s: S) extends Step[A, S]

The existential type S tells us what type of data we need to represent the current “state” of the stream. init gives us an initial state value, and step signals that either the stream is done (EOS) or that there is an element, plus a new state (Yield). To implement ArrayRange, we don’t generate the loop itself; rather, we generate the initialization and the update portions of our loop separately.

def range(len: Code[Int]) =
  new Stream[Int] {
    type S = Int
    def init = "0"
    def step(i: Code[Int]): Code[Step[A, S]] =
      """ if (i >= len)
            EOS
          else
            Yield(i, i + 1) """
  }

Aside: inspirations from category theory

Given the following functor F:

type F[X] = Step[A, X]

ArrayIterator[A] is the initial F-algebra, and Stream[A] is the terminal F-coalgebra; they are “duals.”

type InitAlg[F]   = forAll[S]  (F[S] => S) => S
type TermCoalg[F] = forSome[S] (S => F[S],    S)

This is also the same functor one can use to define linked lists, using the fixpoint:

type List[A] = F[List[A]]

Iterating on the Stream interface

The JVM-staged-programming aficionados on the Hail team would have taken issue with an aspect of the above interface: whenever a stream Yield's a new element, it must allocate a new Yeild object at runtime. One allocation per element is an unacceptable amount of pressure to put on the GC. We need a design that is efficient and does not allocate during runtime. How can we achieve that while also keeping the same general stream design?

We can solve this by using a clever continuation-passing-style (CPS) trick. Rather than returning values at runtime, we can pass values to a “continuation” parameter on each function.

abstract class Stream[A] {
  type S /* internal state type */
  def init            (k: Code[S]                => Code[Ctrl]): Code[Ctrl]
  def step(s: Code[S])(k: Step[Code[A], Code[S]] => Code[Ctrl]): Code[Ctrl]
}

Notice how, instead of the k's parameter being Code[Step[A, S]], it is Step[Code[A], Code[S]]. This means that the Step objects exist during compile time, and only the element (A) and state (S) values exist during runtime. What we have essentially done is converted data into control flow.

This technique can be utilized to further reduce allocations. Certain streams require multiple values in their state. We could combine them using e.g. a PStruct, but that would require us to allocate and manipulate pointers, and would be considerably less efficient than storing the state as just a few variables. This is solved by taking advantage of the CPS trick again: allow the structure of the stream state (S) to exist at compile time instead of runtime.

abstract class Stream[A] {
  type S /* (staged representation of) internal state type */
  def init      (k: S                => Code[Ctrl]): Code[Ctrl]
  def step(s: S)(k: Step[Code[A], S] => Code[Ctrl]): Code[Ctrl]
}

/* for instance, if the state ought to be two integers, then 
 * we should define 'S' to be '(Code[Int], Code[Int])'
 */

(There are a few other details involved in the stream implementation, but they are probably not significant enough to talk about in depth.)

Progress

Currently, these techniques are being used in the latest version of Hail to implement deforested array operations. There is a bit of evaluation and testing work to be done before the old iterator implementation can be removed from the codebase entirely, but it is on its way.

Future work

In order to finish the table lowering efforts, some new array IR’s need to be designed to handle the various kinds of joins. The current ArrayLeftJoinDistinct node is probably not sufficient to do the kinds of joins needed by lowering. Whatever joins we need should be able to be implemented using this interface.

A bigger obstacle to table lowering is correct region management at the level of array operations. We need to figure out a stream region management strategy that will be both correct and efficient before lowering real pipelines. The current implementation takes a naive approach where it does not clean up regions after individual elements.

zip


  def zip[P, A, B](
    left: Parameterized[P, A],
    right: Parameterized[P, B]
  ): Parameterized[P, (A, B)] =
    new Parameterized[P, (A, B)] {
      implicit val leftStatePack = left.stateP
      implicit val rightStatePack = right.stateP

      type S = (left.S, right.S)
      val stateP: ParameterPack[S] = implicitly

      val emptyState: S = (left.emptyState, right.emptyState)

      def length(s: S): Option[Code[Int]] =
        (left.length(s._1), right.length(s._2)) match {
          case (Some(n), Some(m)) => Some((n < m).mux(n, m))
          case (_, _) => None
        }

      def init(
        mb: MethodBuilder, jb: JoinPointBuilder, param: P
      )(k: Init[S] => Code[Ctrl]): Code[Ctrl] = {
        val j = jb.joinPoint()
        j.define { _ => k(Missing) }
        left.init(mb, jb, param) {
          case Missing => j(())
          case Start(leftS) =>
            right.init(mb, jb, param) {
              case Missing => j(())
              case Start(rightS) => k(Start((leftS, rightS)))
            }
        }
      }

      def step(
        mb: MethodBuilder, jb: JoinPointBuilder, state: S
      )(k: Step[(A, B), S] => Code[Ctrl]): Code[Ctrl] = {
        val (leftS, rightS) = state
        val eos = jb.joinPoint()
        eos.define { _ => k(EOS) }
        left.step(mb, jb, leftS) {
          case EOS => eos(())
          case Yield(leftElt, leftS1) =>
            right.step(mb, jb, rightS) {
              case EOS => eos(())
              case Yield(rightElt, rightS1) =>
                k(Yield((leftElt, rightElt), (leftS1, rightS1)))
            }
        }
      }
    }