Seeking some input on current implementation of LD Prune method

Overview of Method

Given an r^2 threshold and a window (specifying distance between variants), the method prunes variants in linkage disequilibrium, ultimately returning a set of independent variants.

The method starts with a local prune that removes variants over the r^2 threshold on each partition (Jackie’s method). Next, the genotypes are mean-imputed and standardized, so that a correlation matrix can be computed via a matrix multiplication of block matrices.

Currently, this entire correlation matrix is computed, squared, and then filtered to just those entries that we care about (variants with same contig and positions close to each other, as specified by the window).

The entries table of this correlation matrix has the schema [variant i, variant j, r2]. We filter the entries table to just those entries over the r2 threshold, then pass the edges (i, j) to the maximal_independent_set method, which will return a list of variants to prune out. We filter out those variants, leaving a table of independent variants to return to the user.

Performance / Areas for Improvement

On a dataset of 25,956 variants (profile.vcf), Jackie’s method takes 200-300 seconds. My implementation takes around 800-900 seconds.

Around 85% of the time for my implementation is spent on the maximal independent set (MIS) method. This method uses a greedy algorithm with a binary heap for the edge set.

I have been thinking about ways to speed up MIS. The MIS method could possibly be sped up if I could split the graph into several connected components, put all the edges from each component onto the same partition, then run MIS in parallel. But splitting the edge set into the connected components might be slow, so I’m not sure if the time saved from running MIS in parallel would be worth the time spent getting the components.

A banded block matrix, as mentioned in the original proposal, would also offer some improvement, but doesn’t seem to be the main bottleneck right now.

I have experimented a bit with using persist() to speed up my code, which helps somewhat, so perhaps I could be placing my persist() calls in better places to speed things up. This wouldn’t really help with the time spent on MIS, though.

Around 85% of the time for my implementation is spent on the maximal independent set (MIS) method.

This is good information. How many variants are input to and output from MIS in your test dataset?

I spent a little time looking at approximate MIS methods, some of which look like they’d be reasonably straightforward to implement (though not as straightforward as the priority queue method). I also think they would lend themselves to parallelization, in the multi-threaded shared-memory sense. If we think that’s the bottleneck, I’d be happy to collaborate on a faster MIS method.

One more thing, remember RDDs are lazy. The collect is done in MIS, so it is possible the matrix multiple for example is happening when collect is called and the time is incorrectly being attributed to MIS. You might time collect separately from the MIS call that takes an Array (using printTime from the utils package, for example).

Due to lazyness, I’d time the steps up to exporting the result of the matrix multiply, and then separately time the steps starting from reading the result back in.

Let me ask another question: how big is the matrix you’re multiplying? What block size? Are you multiplying/filtering the entire matrix? I wouldn’t be surprised if the matrix multiply is dominating here. One potential improvement is to only compute the blocks needed according to the variant window. This may make a big difference.

Hm, okay, yes the matrix multiply is what’s dominating. I thought filtering the entries table that I get after the matrix multiply would force execution of the matrix multiply, but I guess not.

The matrix multiply is with a 12,595*2,535 matrix. I’m just using the default block size, which is 4096.

So, I suppose there’s not much need to optimize MIS in that case…

Hmm. Here’s that multiply in numpy:

In [1]: import numpy as np

In [2]: m = np.random.rand(12595, 2535)

In [3]: %%time
   ...: c = m.dot(m.T)
   ...: 
CPU times: user 8.29 s, sys: 387 ms, total: 8.68 s
Wall time: 2.7 s

In [4]: c.shape
Out[4]: (12595, 12595)

That’s 8.3s total compute (I think it is using 4 cores).

@jbloom found our distributed matrix multiply to be 2-3x slower than numpy (on the cloud working out of Google Storage, which I would expect to be slower than your laptop). If it is much worse than that, I’d be suspicious.

I’d try to isolate the timing of the matrix multiple separate from the filter to start if possible.

Can you post the code that filters the matrix multiply result? Do you use the expression language there? It can be quite slow. MIS uses the expression language for the edge coordinates. Also, it uses the legacy Table.rdd, that might also explain some overhead.

You might also profile to see what’s showing up overall.

So if I write to disk in between each operation, the matrix multiply takes about 200 seconds (including writing to disk) and then squaring the matrix (element-wise) takes about 400 seconds (again, including writing to disk). Not sure how much of that time is taken up by the writing-to-disk part.

This is the code with the block matrix operations and the filtering:

block_matrix = BlockMatrix.from_matrix_table(normalized_mean_imputed_genotype_expr)
correlation_matrix = block_matrix.dot(block_matrix.T)
r2_matrix = BlockMatrix.multiply(correlation_matrix, correlation_matrix)

edges = r2_matrix.entries()

similar_filter = (edges['entry'] >= r2) & (edges['i'] != edges['j']) & (edges['i'] < edges['j'])
edges = edges.filter(similar_filter)

index_table = locally_pruned_ds.rows().select('locus', 'alleles', 'variant_idx').key_by('variant_idx')

edges = edges.annotate(locus_i=index_table[edges.i].locus, allele_i=index_table[edges.i].alleles,
                       locus_j=index_table[edges.j].locus, allele_j=index_table[edges.j].alleles)

contig_filter = edges.locus_i.contig == edges.locus_j.contig
window_filter = (hl.abs(edges.locus_i.position - edges.locus_j.position)) <= window
edges = edges.filter(contig_filter & window_filter)
# TODO: handle case where edges is empty

related_nodes_to_remove = maximal_independent_set(edges.i, edges.j, keep=False)

Is it true that the dimensions are:

block_matrix : n_snps x n_samples
correlation_matrix : n_snps x n_snps
r2_matrix : n_snaps x n_snaps

I’m vaguely horrified that element-wise squaring is twice as slow as dot product.

Can you share the code with the reads/writes?

Yes, those are the dimensions.

Here’s the code for the whole method, with the writes:

import timeit
start_time = timeit.default_timer()

sites_only_jvds = Env.hail().methods.LDPrune.apply(
    require_biallelic(dataset, 'ld_prune')._jvds, num_cores, r2, window, memory_per_core)

time_local_prune = timeit.default_timer() - start_time

sites_only_table = Table(sites_only_jvds)

locally_pruned_ds = dataset.filter_rows(hl.is_defined(sites_only_table[(dataset.locus, dataset.alleles)])).persist()

locally_pruned_ds = (locally_pruned_ds
    .annotate_rows(mean=sites_only_table[(locally_pruned_ds.locus, locally_pruned_ds.alleles)].mean,
                   sd_reciprocal=sites_only_table[
                       (locally_pruned_ds.locus, locally_pruned_ds.alleles)].sd_reciprocal)
    .index_rows('variant_idx'))

locally_pruned_ds.write("/Users/maccum/test_data/locallyPrunedDS.vds", overwrite=True)
time_filter_mt = timeit.default_timer() - start_time

normalized_mean_imputed_genotype_expr = (
    hl.cond(hl.is_defined(locally_pruned_ds['GT']),
            (locally_pruned_ds['GT'].num_alt_alleles() - locally_pruned_ds['mean'])
            * locally_pruned_ds['sd_reciprocal'],
            0))

block_matrix = BlockMatrix.from_matrix_table(normalized_mean_imputed_genotype_expr)
block_matrix.write("/Users/maccum/test_data/block_matrix")
time_block_matrix = timeit.default_timer() - start_time
correlation_matrix = block_matrix.dot(block_matrix.T)
correlation_matrix.write("/Users/maccum/test_data/correlation_matrix")
time_correlation_matrix = timeit.default_timer() - start_time
r2_matrix = BlockMatrix.multiply(correlation_matrix, correlation_matrix)

r2_matrix.write("/Users/maccum/test_data/r2_matrix")
time_r2_matrix = timeit.default_timer() - start_time

edges = r2_matrix.entries().persist()
edges.write("/Users/maccum/test_data/edges.kt",overwrite=True)
time_edges_entries = timeit.default_timer() - start_time

similar_filter = (edges['entry'] >= r2) & (edges['i'] != edges['j']) & (edges['i'] < edges['j'])
edges = edges.filter(similar_filter)

index_table = locally_pruned_ds.rows().select('locus', 'alleles', 'variant_idx').key_by('variant_idx')

edges = edges.annotate(locus_i=index_table[edges.i].locus, allele_i=index_table[edges.i].alleles,
                       locus_j=index_table[edges.j].locus, allele_j=index_table[edges.j].alleles)

contig_filter = edges.locus_i.contig == edges.locus_j.contig
window_filter = (hl.abs(edges.locus_i.position - edges.locus_j.position)) <= window
edges = edges.filter(contig_filter & window_filter)
# MERTODO: handle case where edges is empty

edges.write("/Users/maccum/test_data/edges.kt", overwrite=True)
time_filter_edges = timeit.default_timer() - start_time

related_nodes_to_remove = maximal_independent_set(edges.i, edges.j, keep=False)

time_mis = timeit.default_timer() - start_time

pruned_ds = locally_pruned_ds.filter_rows(hl.is_defined(related_nodes_to_remove[locally_pruned_ds.variant_idx]),
                                          keep=False)

info("LD prune step 3 of 3: nVariantsKept={}".format(pruned_ds.count_rows()))

time_filter_mis = timeit.default_timer() - start_time

print("TIME:\n\tLocal prune: {}\n\tFilter matrix table: {}\n\tStandardized GT Block Matrix: {}"
      "\n\tCorrelation Matrix: {}\n\tr2 matrix: {}\n\tEdges entries table: {}\n\tFilter edges: {}\n\t"
      "MIS: {}\n\tFilter after MIS: {}\n".format(
    time_local_prune,
    time_filter_mt-time_local_prune,
    time_block_matrix-time_filter_mt,
    time_correlation_matrix-time_block_matrix,
    time_r2_matrix-time_correlation_matrix,
    time_edges_entries-time_r2_matrix,
    time_filter_edges-time_edges_entries,
    time_mis-time_filter_edges,
    time_filter_mis-time_mis
))

return pruned_ds.rows().select('locus', 'alleles')

So I think the time for r2_matrix will include recomputing the correlation matrix here. I’m curious if we run this:

correlation_matrix = block_matrix.dot(block_matrix.T)
correlation_matrix.write("/Users/maccum/test_data/correlation_matrix")
time_correlation_matrix = timeit.default_timer() - start_time
# I think this is called read_matrix, I'm not sure
correlation_matrix = hl.read_matrix("/Users/maccum/test_data/correlation_matrix")
r2_matrix = BlockMatrix.multiply(correlation_matrix, correlation_matrix)

Now we’re measuring the read time instead of the recompute time. If we also knew the time it took to do a read_matrix("...").count() then we could getter a better estimate of the pointwise-mulitply + write time.

How about de-obfuscating the timing by doing take(1) element (or count()) to force
evaluation of a previously-lazy RDD ?

Maybe count() is better because it (possibly) ensures that all elements of the
result have been evaluated.