NDArray Matmul Performance Improvements

This post serves to document the process of substantially speeding up some NDArray operations. Thought it might be interesting to others, and I want it to remind myself how to do this process in the future. Thanks to Dan for helping me out with this.

I ran the following test to get a sense of how fast NDArray multiplying was compared to Breeze. For reference, multiplying a 4096x4096 matrix in Breeze by itself takes ~670ms.

ones = hl.nd.ones((4096, 4096))
hl.eval((ones @ ones).shape)

The above took 10 seconds, which is ridiculous.

Here is the IR that gets compiled:

(Let __iruid_35
  (ArrayRange
    (I32 0)
    (I32 16777216)
    (I32 1))
  (Let __iruid_36
    (NDArrayReshape
      (NDArrayMap __iruid_37
        (MakeNDArray
          (Ref __iruid_35)
          (MakeTuple (0)
            (Cast Int64
              (ArrayLen
                (Ref __iruid_35))))
          (True))
        (F64 1.0))
      (Literal Tuple[Int64,Int64] <literal value>))
    (NDArrayShape
      (NDArrayMatMul
        (Ref __iruid_36)
        (Ref __iruid_36)))))

I inserted some timing statements to get a sense of where it was spending it’s time. The breakdown was:

4.5 seconds doing all the pre multiply work. That’s creating the array from the range and mapping over it to create the array of ones, then reshaping it to be square.

2.5 seconds to copy the two input NDArrays from row major to column major (required for BLAS).

711 millseconds to actually do the matrix multiply.

1.5 seconds to copy the NDArray back to row major.

Somehow, the most complicated part of the code is the fastest, and the easy stuff is really slow.

To debug this, used the following arguments to start an ipython interpreter that dumps the JVM compilation information to the terminal

PYSPARK_SUBMIT_ARGS='--conf spark.executor.extraJavaOptions=-XX:+PrintCompilation --conf spark.driver.extraJavaOptions=-XX:+PrintCompilation pyspark-shell' ipython

In the large printout, the interesting bit was the following three lines:

4202 3542 %     3       is.hail.codegen.generated.C0::m1_method @ 380 (3081 bytes)
compilation bailout: stack not empty at OSR entry point
   4205 3542 %     3       is.hail.codegen.generated.C0::m1_method @ 380 (3081 bytes)   COMPILE SKIPPED: stack not empty at OSR entry point (retry at different tier)
   4205 3543 %     4       is.hail.codegen.generated.C0::m1_method @ 380 (3081 bytes)
   4205 3543 %     4       is.hail.codegen.generated.C0::m1_method @ 380 (3081 bytes)   COMPILE SKIPPED: OSR starts with non-empty stack (retry at different tier)

The JVM is bailing out of compiling some of my emitted code, and as a result it’s performing very badly. The OSR bit stands for “On Stack Replacement”, which as I understand it is the process by which the JVM swaps out some byte code with more optimized, JIT compiled byte code as needed. By failing to meet some JVM internal requirement, I’m not getting JIT benefits.

If we could figure out what exactly causes this situation to pop up, I think that could be very useful across the board. In the mean time though, a fix to this is to emit a new method and dump the hot loop code in that method.

So this code that generates loops:

idxVars.zipWithIndex.foldRight(body) { case((dimVar, dimIdx), innerLoops) =>
      Code(
        dimVar := 0L,
        Code.whileLoop(dimVar < outputShapeVariables(dimIdx),
          innerLoops,
          dimVar := dimVar + 1L
        )
      )
    }

got bound to a variable and wrapped in a method:

val loops = idxVars.zipWithIndex.foldRight(body) { case((dimVar, dimIdx), innerLoops) =>
      Code(
        dimVar := 0L,
        Code.whileLoop(dimVar < outputShapeVariables(dimIdx),
          innerLoops,
          dimVar := dimVar + 1L
        )
      )
    }
val eVti = typeToTypeInfo(TVoid)
val innerMethod = mb.fb.newMethod(eVti)
innerMethod.emit(loops)
innerMethod.invoke()

The above was enough to bring the mapping over the ndarray down from 4 seconds to 200 milliseconds, and similar changes improved the copying from row to column major. I’m going to just change PNDArray to store ndarrays in column major anyway (Breeze does this), so those copies will be removed entirely in the future.

2 Likes

Sorry to hijack this topic, but I’m curious to know the best way to get the IR for a particular Hail statement or expression. Are there flags to an executable, parameters to a function, or standalone tools that make this easy? Or are you just adding print/log statements or running within a debugger and examining intermediate state?

No worries, good question. To see the initial IR that’s getting built for a value expression, you can just check the _ir field of that python object. If it’s a table expression, there’s a _tir field. BolckMatrixExpressions have ._bmir. I don’t have a python interpreter open right now but my guess is that the matrix table one is _mtir.

The other easy way to see the IR is to look at the hail log file, though that shows the entire pipeline as opposed to any specific piece. The log file also shows the successive IR’s that are generated as we do optimizations.

1 Like

Perhaps it’s worth having physical types for both row-major (C style) and column-major (Fortran style) memory layouts? numpy uses the order parameter to allow both, for example.

Haven’t explored it all the way yet, but I think what I’ll probably do is just have PNDArray continue to have the strides parameter that it does now, which explains how it’s laid out, change the default to be column major, and only flip it if necessary.