Reproducible randomness design doc

Hi all,

Here’s a draft design doc for the new reproducible randomness. I’ll update when the implementation is done to be sure it reflects the actual implementation, for future reference, but it should be pretty complete.

Overview

The IR must associate to every RNG-using node a pair of bit strings, staticBitstring and dynBitstring. staticBitstring is completely known statically, and every RNG-using node gets a unique staticBitstring. dynBitstring has statically known length, but the value is dynamic, and is used to give RNG-using nodes inside loops independent random sources on each iteration. We must ensure the encoding scheme which produces these bitstrings guarantees that each dynamic invocation of any RNG-using node gets a globally distinct (staticBitstring, dynBitstring) pair.

RNG

To generate random numbers from a (staticBitstring, dynBitstring) pair, we need a psuedorandom function from pairs of bitstrings to fixed-length words. We use a modification of the PMAC [1] construction, based on the Threefish tweakable blockcipher reduced to 20 rounds.

Threefish has a 256-bit block size and key size, and a 128-bit tweak, though we currently restrict it to just a 64-bit tweak. We denote the blockcipher as a function encrypt(key: Int256, tweak: Int64, block: Int256): Int256.

The modified PMAC construction we use is specified as follows:

The leftmost bit of the 64-bit tweak is reserved to divide the tweak space, to ensure we use distinct tweaks on static and dynamic blocks. The remaining 63-bits are used for the block counter, i.e. the index of the block within its corresponding bitstring; however, the max value 2^63 - 1 is reserved for the finalization. Concretely, we define four constants:

  • staticTweakMask = 0L
  • dynTweakMask = 1L << 63
  • finalBlockNoPadTweak = -1L >>> 1
  • finalBlockPaddedTweak = -1L

To compute the hash of the two bitstrings:

  • Pad staticBitstring by appending a single 1, followed by 0s until the length is a multiple of 256. Split into 256-bit blocks S[0], ..., S[s]
  • Split dynBitstring into 256-bit blocks D[0], ..., D[d], allowing the last block to be shorter than 256.
  • Let SE[i] = encrypt(key, i & staticTweakMask, S[i]), for i=0..s
  • Let SE be the xor of all SE[i]
  • Let DE[i] = encrypt(key, i & dynTweakMask, D[i]), for i=0..d-1, all but the last block
  • Let DE be the xor of all DE[i]
  • If the last dynamic block is full, compute the final hash as
    • encrypt(key, finalBlockNoPadTweak, SE ^ DE ^ D[d])
  • Otherwise, let B be D[d] padded by a single 1 followed by 0s, to 256 bits, and compute the final hash as
    • encrypt(key, finalBlockPaddedTweak, SE ^ DE ^ B)

[1] Rogaway, “Efficient Instantiations of Tweakable Blockciphers andRefinements to Modes OCB and PMAC”
[2] “The Skein Hash Function Family”

New types and IR nodes

To implement this design, we only need one new type and one or two new nodes.

The new type is RNGState. A value of this type contains a mix of static and dynamic data:

  • dynBlocksSum: IndexedSeq[Value[Long]]: the xor of the encrypted contents of all full blocks of dynBitstring
  • numWordsInLastDynBlock: Int: the number of words (longs), in the range [0, 4), currently contained in lastDynBlock
  • lastDynBlock: IndexedSeq[Value[Long]]: the partial contents of the last block of dynBitstring. The length of the sequence is numWordsInLastDynBlock
  • staticBlocksSum: IndexedSeq[Long]: the xor of the encrypted contents of all full blocks of staticBitstring
  • lastStaticBlock: Bitstring: the partial contents of the last block of staticBitstring
  • numStaticBlocks: Int: the number of blocks in staticBitstring
  • numDynBlocks: Int: the number of blocks in dynBitstring

Note that a value of RNGState at runtime consists of [4, 8) longs, and as we will see much of that state is shared between values.

The new node is SplitRNGState(state: IR, static: Bitstring, dyn: IR), where state has type RNGState, and dyn is any nested tuple type whose leaves are all Longs, which we interpret as a flattened array of longs with statically known length. The semantics of the node is simply to append static to the state’s staticBitstring, and append dyn to the dynBitstring, and return the resulting state. Operationally, this appends to the appropriate last block, runs encrypt if the last block is filled, and xors the result into the appropriate sum (all at compile time in the case of static). Either static or dyn may be empty (for dyn this means the empty tuple type).

We may also need a node like RNGStateLiteral which can create an RNGState from scratch, at least for testing purposes, and possibly to create the initial state at the root of an IR before compiling.

Elaborating an IR

When we construct an IR in python, we must elaborate it with RNGSplits to fix the RNG semantics before anything else can modify the IR. This should be done in a bottom-up way, so that an IR’s RNG semantics doesn’t depend on context. Since the RNG semantics is completely encoded in the IR before compilation, using regular types/nodes whose semantics are pure functions without side-effects, most of the compiler doesn’t need to be aware of randomness at all.

We introduce a reserved identifier __rng_state. If a node uses an RNG, this must be bound in the environment.

Static splitting

Let us consider static splitting first. Conceptually, this works as follows:

  • in a bottom-up pass, compute a boolean flag uses_randomness for every IR node, where nodes like ApplySeeded use randomness, and otherwise a node uses randomness iff one of its children does.
  • in a top-down pass, each node splits the __rng_state in its environment n ways, where n is the number of its children which use randomness, and rebinds __rng_state to one of the n split states in each of the n children.

How we split a state n ways is arbitrary, but must be fixed. We choose a recursive method which splits the children in half, creates the splits for each half, then prepends 0 to one half and 1 to the other. This minimizes the average length of the splitting bitstrings. So for example, with 5 children using randomness, we create 5 static bitstrings [00, 01, 10, 110, 111], as

split(5) = ['0' + split(2), '1' + split(3)]
         = ['0' + ['0' + split(1), '1' + split(1)], '1' + ['0' + split(1), '1' + split(2)]]
         = ['0' + ['0' + '', '1' + ''], '1' + ['0' + '', '1' + ['0', '1']]
         = ['00', '01', '10', '110', '111']

This could be implemented by wrapping each child in e.g. Let(__rng_state, RNGSplit(Ref(__rng_state), Bitstring(110), ())), but this would pollute the IR with a lot of Lets. Instead, we can accumulate static splitting bitstrings along each path in the IR, so that for each leaf RNG use we have the entire bitstring of static splits from the root of the IR.

Dynamic splitting

For dynamic splitting, we must modify the IR so that in any loop context, i.e. a body of IR which may execute more than once at run-time, there is some UID bound in the environment which is guaranteed to have a different value on every execution.

As a simple example,

TailLoop("loop", [("var", init)], ..Recur(next_var)..)

is modified to (if the loop body uses randomness)

TailLoop(
  "loop",
  [("var", init), ("__uid", 0)],
  Let(__rng_state, RNGSplit(Ref(__rng_state), Bitstring(), Ref("__uid")),
      ..Recur(next_var, Ref("__uid") + 1)..))

For tables, we require every table producer to create a new row field “__uid” which is guaranteed to be distinct across rows. Streams are similar: if the element type is a struct, we add a “__uid” field, otherwise if the element type is T we replace it with (UID, T) (where UID is any valid type for dynamic splitting as described earlier). This special handling of struct element types is needed for stream join nodes, so that the join fields are still top-level fields.

We implement this with a single recursive traversal over stream/table pipelines. At a high level, starting at the final consumer of a stream/table pipeline:

  • recursively process stream/table children which use randomness
  • create uids if requested by the parent
  • if the body IR (if any) uses randomness, split __rng_state in the body in the body using the row UID

In the last two bullets, we must create uids for the resulting stream/table, to pass to the parent node and/or to split the rng state in the body. Do this by requesting uids from one or more stream/table children, and uses them to construct a new uid (how to do this is specific to each node). For the base case, for streams we can simply zip with index. For tables, we must add a flag to all table producer nodes like TableRead and TableParallelize to optionally create a uid field. Each node is free to do this any way it likes, but a good default uid which most nodes should be able to implement efficiently is a tuple (partitionNumber, indexWithinPartition).

Rebuilding an IR in python like this can interfere with the python object-based CSE, but in this case it’s safe since stream/table nodes can’t be CSE’d anyways.