Blocked MT Woes

Copied from https://hail.zulipchat.com/#narrow/stream/123011-Hail-Dev/topic/Blocked.20MT.20Conclusion.

OK, I can’t push this forward without addressing the optimizer.

My suggestion for a path forward / request for comments is under the third and
final header.

But first, let me make an argument for enabling this functionality that is so
close
to working in Hail as is.

An Apology for Blocked MT

First, I think this is a transformational datatype: blocked matrix table with
support for matrix multiply. This datatype addresses many of the issues that
Finucane lab has with BlockMatrix by providing them with the full power of the
expression language. They can filter rows and columns based on relevant
information like locus, and sample id. They can modify the entries with the full
power of the expression language. No more “I just need hl.floor on each
entry”. With this object, we can finally realize pc-relate in Python! With
this object, KING is a few lines of Python!

Description

Second, a description. I created a table with the following type. Each row of
the table is a square block of the matrix. The row information and column
information are stored as separate, blocked tables. If they’re needed, they’re
joined in. This model easily generalizes to an arbitrary number of dimensions.

In [5]: da.m.describe()
----------------------------------------
Global fields:
    'n_rows': int32
    'n_cols': int32
    'n_block_rows': int32
    'n_block_cols': int32
    'block_size': int32
----------------------------------------
Row fields:
    'x': int32
    'y': int32
    'block': ndarray<float64, 2>
----------------------------------------
Key: ['x', 'y']
----------------------------------------

Constructing one of these objects is easy to write, but requires a unavoidable
shuffle. I localize an MT. Then I group the entry array into an array of arrays
of size block_size. I explode that array. Then I group by the row block id,
mt.row_index // block_size, and the column block id, collecting a block. That
looks like this:

mt = mt.group_by(mt.x, mt.y).aggregate(
    entries=hl.nd.array(hl.sorted(
        hl.agg.collect(hl.struct(row_index=mt[row_index], entries=mt.entries)),
        key=lambda x: x.row_index
    ).map(lambda x: x.entries)))

This is great, but its not enough. BlockMatrix multiply works because we have a
1:1 mapping from blocks to partitions. Luckily, Hail has the facilities to
achieve exactly this!

t = hl.read_table(fname, _intervals=[
    hl.Interval(hl.Struct(x=x, y=y),
                hl.Struct(x=x, y=y+1))
    for x in range(n_block_rows)
    for y in range(n_block_cols)])

Multiplication is also straightforward! I arrange for a
1-partition-per-block-multiply table with keys x, y, and z which
correspond to the product’s rows, the product’s cols, and the summed inner
dimension. Then I just do this:

o = o.key_by('x', 'z', 'y')
o = o.annotate(left=left.m[o.x, o.z].block)
o = o.key_by('z', 'y', 'x')
o = o.annotate(right=right.m[o.z, o.y].block)
o = o.annotate(product=o.left @ o.right)
o = o.key_by('x', 'y', 'z')
o = o.group_by('x', 'y').aggregate(
    block=nd_array_sum_placeholder(o.product))

That’s actually incredibly clear code! I really like this. Much easier to read
than the current BlockMatrix index manipulation code!

Issues

Third, and finally, the issues.

The heart of the issue is that Hail is dropping some of my keys and converting
“out-of-order partitions” cases into “full shuffle”.

There’s a separate issue that Hail scans the keys to ensure they’re
ordered. This is a non-issue. I already wrote code that notices 1:1 partitioners
(this is easy to do for keys consisting only of ints or longs) and skips the
sortedness check. For this to work I need to preserve keys (otherwise I’ll
forget they’re 1:1 partitioners).

On the subject of eliding keys and full shuffles, I thought I could work around
this at the RVD layer, but that now seems like a bad idea (e.g. the RVD may need
keys that the Table level removes or modifies).

Two things I hope we can address:

  1. If I read a table that’s partitioned by x, y, and z, and Hail weakens my key
    to x and y, Hail will shuffle because it does not tolerate keys split across
    partitions. This plainly seems like a bug, we should never weaken a key if
    that would trigger a shuffle.
  2. More generally, with 1:1 partitions, dropping keys breaks the 1:1ness. Can we
    teach the optimizer to know about 1:1 partitions and not drop their keys?

Can we add a TableReallyKeyBy node whose key is never weakened? The
combination of that node with my fast path for partitioners would enable dnd
arrays right now by preventing the following transformation.

(TableKeyByAndAggregate None 50
  (TableKeyBy (x y z) False
    (TableMapRows
      (TableMapRows
        (TableMapRows
          (TableLeftJoinRightDistinct __uid_31
            (TableKeyBy (z y x) False
              (TableMapRows
                (TableLeftJoinRightDistinct __uid_30
                  (TableKeyBy (x z y) False
(TableKeyByAndAggregate None 50
  (TableMapRows
    (TableKeyBy () False
      (TableLeftJoinRightDistinct __uid_31
        (TableKeyBy (z y) False
          (TableKeyBy () False
            (TableMapRows
              (TableLeftJoinRightDistinct __uid_30
                (TableKeyBy (x z) False
                  (TableKeyBy () False
                    (TableRead ....)