2017-12-15 25 views
1

私はcov(X * B)の対角線を必要とするのでRcppEigen速い共分散

cov(X * B) = X * cov(B) * X.transpose()

私はちょうど各行X_i * Bの共分散を取得し、

#include <RcppEigen.h> 
// [[Rcpp::depends(RcppEigen)]] 

using Eigen::Map; 
using Eigen::MatrixXd; 
using Eigen::VectorXd; 
using Eigen::SparseMatrix; 
using Eigen::MappedSparseMatrix; 
using namespace Rcpp; 
using namespace Eigen; 

double foo(const Eigen::MappedSparseMatrix<double>& mm, 
      const Eigen::MappedSparseMatrix<double>& vcov) { 

    int n = mm.rows(); 
    double out = 0; 
    SparseMatrix<double> mm_t = mm.adjoint(); 

    SparseMatrix<double> var(1, 1); 
    var.setZero(); 

    for (int i = 0; i < n; i++) { 
    var = mm.row(i) * vcov * mm_t.col(i); 
    out += var.coeff(0, 0); 
    } 

    return out; 
} 
それらを合計することができ、完全な行列の乗算を行う必要はありません。何らかの理由で、この機能は1M行では非常に遅いです。私は行単位で操作する代わりに "ブロック"を使用しようとしました。ブロックの値を操作することで、vcovによる行列乗算を高速化できると考えていました。これは関数をより速くすることはありませんでした。ここに再現可能な例があります

require(Matrix) 

set.seed(100) 
N = 2.5e5 
p = 100 

mm = rsparsematrix(N, p, .01) 
vcov = rsparsematrix(p, p, .5) 

system.time(foo(mm, vcov)) 

この機能を高速化する方法はありますか?

+0

対角要素の合計または完全なベクトルを返しますか? – ekstroem

+0

対角要素の合計 – JCWong

答えて

3

共分散行列が実数で対称である(そしてあなたの場合には共分散行列である)なら、簡単な数学的な "トリック"を使うことができます。

x %*% b %*% t(b) %*% t(x)の対角要素の和は、超高速で

sum((x %*% b)^2) 

として計算することができます。上記の数式はの平方根を計算し、数式を使用する必要があるので、 "サンドイッチ"の "ハム"部分としてb %*% t(b)があることに注意してください。

また、おそらく最適化することができ、以下のように、あなたはちょうど私がそれもRcppEigenに精通し、そこにスパース行列ではないよ

sum((mm %*% vcov) * mm) 

Rに直接、次の要素毎の製品を使用することができますが、それはそうですかなり

// [[Rcpp::export]]                               
double foo2(const Eigen::MappedSparseMatrix<double>& mm, 
      const Eigen::MappedSparseMatrix<double>& vcov) { 

    double out = 0; 
    SparseMatrix<double> mat; 

    mat = mm.cwiseProduct(mm*vcov); 


    for (int k=0; k<mat.outerSize(); ++k) { 
    for (SparseMatrix<double>::InnerIterator it(mat,k); it; ++it) 
     { 
     out +=it.value(); 
     } 
    } 

    return out; 
} 

ここでは簡単な速度比較は高速です

> microbenchmark::microbenchmark(foo(mm, vcov), foo2(mm, vcov), sum((mm %*% vcov) * mm), times=2) 
Unit: milliseconds 
        expr  min   lq  mean  median   uq 
      foo(mm, vcov) 32575.5488 32575.5488 33587.4147 33587.4147 34599.2806 
      foo2(mm, vcov) 463.9440 463.9440 492.4232 492.4232 520.9023 
sum((mm %*% vcov) * mm) 953.7902 953.7902 981.4750 981.4750 1009.1598 
     max neval cld 
34599.2806  2 b 
    520.9023  2 a 
    1009.1598  2 a 

かなり改善されています。 Rだけを使うから。

+0

これは最適なソリューションです!実際のデータセットで約60倍高速 – JCWong