This post serves to document the process of substantially speeding up some NDArray operations. Thought it might be interesting to others, and I want it to remind myself how to do this process in the future. Thanks to Dan for helping me out with this.
I ran the following test to get a sense of how fast NDArray multiplying was compared to Breeze. For reference, multiplying a 4096x4096 matrix in Breeze by itself takes ~670ms.
ones = hl.nd.ones((4096, 4096))
hl.eval((ones @ ones).shape)
The above took 10 seconds, which is ridiculous.
Here is the IR that gets compiled:
(Let __iruid_35
(ArrayRange
(I32 0)
(I32 16777216)
(I32 1))
(Let __iruid_36
(NDArrayReshape
(NDArrayMap __iruid_37
(MakeNDArray
(Ref __iruid_35)
(MakeTuple (0)
(Cast Int64
(ArrayLen
(Ref __iruid_35))))
(True))
(F64 1.0))
(Literal Tuple[Int64,Int64] <literal value>))
(NDArrayShape
(NDArrayMatMul
(Ref __iruid_36)
(Ref __iruid_36)))))
I inserted some timing statements to get a sense of where it was spending it’s time. The breakdown was:
4.5 seconds doing all the pre multiply work. That’s creating the array from the range and mapping over it to create the array of ones, then reshaping it to be square.
2.5 seconds to copy the two input NDArrays from row major to column major (required for BLAS).
711 millseconds to actually do the matrix multiply.
1.5 seconds to copy the NDArray back to row major.
Somehow, the most complicated part of the code is the fastest, and the easy stuff is really slow.
To debug this, used the following arguments to start an ipython interpreter that dumps the JVM compilation information to the terminal
PYSPARK_SUBMIT_ARGS='--conf spark.executor.extraJavaOptions=-XX:+PrintCompilation --conf spark.driver.extraJavaOptions=-XX:+PrintCompilation pyspark-shell' ipython
In the large printout, the interesting bit was the following three lines:
4202 3542 % 3 is.hail.codegen.generated.C0::m1_method @ 380 (3081 bytes)
compilation bailout: stack not empty at OSR entry point
4205 3542 % 3 is.hail.codegen.generated.C0::m1_method @ 380 (3081 bytes) COMPILE SKIPPED: stack not empty at OSR entry point (retry at different tier)
4205 3543 % 4 is.hail.codegen.generated.C0::m1_method @ 380 (3081 bytes)
4205 3543 % 4 is.hail.codegen.generated.C0::m1_method @ 380 (3081 bytes) COMPILE SKIPPED: OSR starts with non-empty stack (retry at different tier)
The JVM is bailing out of compiling some of my emitted code, and as a result it’s performing very badly. The OSR bit stands for “On Stack Replacement”, which as I understand it is the process by which the JVM swaps out some byte code with more optimized, JIT compiled byte code as needed. By failing to meet some JVM internal requirement, I’m not getting JIT benefits.
If we could figure out what exactly causes this situation to pop up, I think that could be very useful across the board. In the mean time though, a fix to this is to emit a new method and dump the hot loop code in that method.
So this code that generates loops:
idxVars.zipWithIndex.foldRight(body) { case((dimVar, dimIdx), innerLoops) =>
Code(
dimVar := 0L,
Code.whileLoop(dimVar < outputShapeVariables(dimIdx),
innerLoops,
dimVar := dimVar + 1L
)
)
}
got bound to a variable and wrapped in a method:
val loops = idxVars.zipWithIndex.foldRight(body) { case((dimVar, dimIdx), innerLoops) =>
Code(
dimVar := 0L,
Code.whileLoop(dimVar < outputShapeVariables(dimIdx),
innerLoops,
dimVar := dimVar + 1L
)
)
}
val eVti = typeToTypeInfo(TVoid)
val innerMethod = mb.fb.newMethod(eVti)
innerMethod.emit(loops)
innerMethod.invoke()
The above was enough to bring the mapping over the ndarray down from 4 seconds to 200 milliseconds, and similar changes improved the copying from row to column major. I’m going to just change PNDArray to store ndarrays in column major anyway (Breeze does this), so those copies will be removed entirely in the future.