Hi all,
Here’s a draft design doc for the new reproducible randomness. I’ll update when the implementation is done to be sure it reflects the actual implementation, for future reference, but it should be pretty complete.
Overview
The IR must associate to every RNG-using node a pair of bit strings, staticBitstring
and dynBitstring
. staticBitstring
is completely known statically, and every RNG-using node gets a unique staticBitstring
. dynBitstring
has statically known length, but the value is dynamic, and is used to give RNG-using nodes inside loops independent random sources on each iteration. We must ensure the encoding scheme which produces these bitstrings guarantees that each dynamic invocation of any RNG-using node gets a globally distinct (staticBitstring, dynBitstring)
pair.
RNG
To generate random numbers from a (staticBitstring, dynBitstring)
pair, we need a psuedorandom function from pairs of bitstrings to fixed-length words. We use a modification of the PMAC [1] construction, based on the Threefish tweakable blockcipher reduced to 20 rounds.
Threefish has a 256-bit block size and key size, and a 128-bit tweak, though we currently restrict it to just a 64-bit tweak. We denote the blockcipher as a function encrypt(key: Int256, tweak: Int64, block: Int256): Int256
.
The modified PMAC construction we use is specified as follows:
The leftmost bit of the 64-bit tweak is reserved to divide the tweak space, to ensure we use distinct tweaks on static and dynamic blocks. The remaining 63-bits are used for the block counter, i.e. the index of the block within its corresponding bitstring; however, the max value 2^63 - 1
is reserved for the finalization. Concretely, we define four constants:
staticTweakMask = 0L
dynTweakMask = 1L << 63
finalBlockNoPadTweak = -1L >>> 1
finalBlockPaddedTweak = -1L
To compute the hash of the two bitstrings:
- Pad
staticBitstring
by appending a single1
, followed by0
s until the length is a multiple of 256. Split into 256-bit blocksS[0], ..., S[s]
- Split
dynBitstring
into 256-bit blocksD[0], ..., D[d]
, allowing the last block to be shorter than 256. - Let
SE[i] = encrypt(key, i & staticTweakMask, S[i])
, fori=0..s
- Let
SE
be the xor of allSE[i]
- Let
DE[i] = encrypt(key, i & dynTweakMask, D[i])
, fori=0..d-1
, all but the last block - Let
DE
be the xor of allDE[i]
- If the last dynamic block is full, compute the final hash as
encrypt(key, finalBlockNoPadTweak, SE ^ DE ^ D[d])
- Otherwise, let
B
beD[d]
padded by a single1
followed by0
s, to 256 bits, and compute the final hash asencrypt(key, finalBlockPaddedTweak, SE ^ DE ^ B)
[1] Rogaway, “Efficient Instantiations of Tweakable Blockciphers andRefinements to Modes OCB and PMAC”
[2] “The Skein Hash Function Family”
New types and IR nodes
To implement this design, we only need one new type and one or two new nodes.
The new type is RNGState
. A value of this type contains a mix of static and dynamic data:
-
dynBlocksSum: IndexedSeq[Value[Long]]
: the xor of the encrypted contents of all full blocks ofdynBitstring
-
numWordsInLastDynBlock: Int
: the number of words (longs), in the range[0, 4)
, currently contained inlastDynBlock
-
lastDynBlock: IndexedSeq[Value[Long]]
: the partial contents of the last block ofdynBitstring
. The length of the sequence isnumWordsInLastDynBlock
-
staticBlocksSum: IndexedSeq[Long]
: the xor of the encrypted contents of all full blocks ofstaticBitstring
-
lastStaticBlock: Bitstring
: the partial contents of the last block ofstaticBitstring
-
numStaticBlocks: Int
: the number of blocks instaticBitstring
-
numDynBlocks: Int
: the number of blocks indynBitstring
Note that a value of RNGState
at runtime consists of [4, 8)
longs, and as we will see much of that state is shared between values.
The new node is SplitRNGState(state: IR, static: Bitstring, dyn: IR)
, where state
has type RNGState
, and dyn
is any nested tuple type whose leaves are all Long
s, which we interpret as a flattened array of longs with statically known length. The semantics of the node is simply to append static
to the state’s staticBitstring
, and append dyn
to the dynBitstring
, and return the resulting state. Operationally, this appends to the appropriate last block, runs encrypt
if the last block is filled, and xors the result into the appropriate sum (all at compile time in the case of static
). Either static
or dyn
may be empty (for dyn
this means the empty tuple type).
We may also need a node like RNGStateLiteral
which can create an RNGState
from scratch, at least for testing purposes, and possibly to create the initial state at the root of an IR before compiling.
Elaborating an IR
When we construct an IR in python, we must elaborate it with RNGSplit
s to fix the RNG semantics before anything else can modify the IR. This should be done in a bottom-up way, so that an IR’s RNG semantics doesn’t depend on context. Since the RNG semantics is completely encoded in the IR before compilation, using regular types/nodes whose semantics are pure functions without side-effects, most of the compiler doesn’t need to be aware of randomness at all.
We introduce a reserved identifier __rng_state
. If a node uses an RNG, this must be bound in the environment.
Static splitting
Let us consider static splitting first. Conceptually, this works as follows:
- in a bottom-up pass, compute a boolean flag
uses_randomness
for every IR node, where nodes likeApplySeeded
use randomness, and otherwise a node uses randomness iff one of its children does. - in a top-down pass, each node splits the
__rng_state
in its environmentn
ways, wheren
is the number of its children which use randomness, and rebinds__rng_state
to one of then
split states in each of then
children.
How we split a state n
ways is arbitrary, but must be fixed. We choose a recursive method which splits the children in half, creates the splits for each half, then prepends 0
to one half and 1
to the other. This minimizes the average length of the splitting bitstrings. So for example, with 5 children using randomness, we create 5 static bitstrings [00, 01, 10, 110, 111]
, as
split(5) = ['0' + split(2), '1' + split(3)]
= ['0' + ['0' + split(1), '1' + split(1)], '1' + ['0' + split(1), '1' + split(2)]]
= ['0' + ['0' + '', '1' + ''], '1' + ['0' + '', '1' + ['0', '1']]
= ['00', '01', '10', '110', '111']
This could be implemented by wrapping each child in e.g. Let(__rng_state, RNGSplit(Ref(__rng_state), Bitstring(110), ()))
, but this would pollute the IR with a lot of Let
s. Instead, we can accumulate static splitting bitstrings along each path in the IR, so that for each leaf RNG use we have the entire bitstring of static splits from the root of the IR.
Dynamic splitting
For dynamic splitting, we must modify the IR so that in any loop context, i.e. a body of IR which may execute more than once at run-time, there is some UID bound in the environment which is guaranteed to have a different value on every execution.
As a simple example,
TailLoop("loop", [("var", init)], ..Recur(next_var)..)
is modified to (if the loop body uses randomness)
TailLoop(
"loop",
[("var", init), ("__uid", 0)],
Let(__rng_state, RNGSplit(Ref(__rng_state), Bitstring(), Ref("__uid")),
..Recur(next_var, Ref("__uid") + 1)..))
For tables, we require every table producer to create a new row field “__uid” which is guaranteed to be distinct across rows. Streams are similar: if the element type is a struct, we add a “__uid” field, otherwise if the element type is T
we replace it with (UID, T)
(where UID
is any valid type for dynamic splitting as described earlier). This special handling of struct element types is needed for stream join nodes, so that the join fields are still top-level fields.
We implement this with a single recursive traversal over stream/table pipelines. At a high level, starting at the final consumer of a stream/table pipeline:
- recursively process stream/table children which use randomness
- create uids if requested by the parent
- if the body IR (if any) uses randomness, split
__rng_state
in the body in the body using the row UID
In the last two bullets, we must create uids for the resulting stream/table, to pass to the parent node and/or to split the rng state in the body. Do this by requesting uids from one or more stream/table children, and uses them to construct a new uid (how to do this is specific to each node). For the base case, for streams we can simply zip with index. For tables, we must add a flag to all table producer nodes like TableRead
and TableParallelize
to optionally create a uid field. Each node is free to do this any way it likes, but a good default uid which most nodes should be able to implement efficiently is a tuple (partitionNumber, indexWithinPartition)
.
Rebuilding an IR in python like this can interfere with the python object-based CSE, but in this case it’s safe since stream/table nodes can’t be CSE’d anyways.