2017-06-20 12 views
0

2つの行列AとBがあります。Aのサイズは2M * 50、Bは20k * 50です。私は各行のA%*%t(B)の上位10個の値を計算したいと思います。この1R - 行列積の高速ソート

library(parallel) 
library(pbapply) 

set.seed(1) 

A <- matrix(runif(2e6*50), nrow=2e6) 
B <- matrix(runif(2e4*50), nrow=2e4) 
n <- 10 

cl = makeCluster(detectCores()) 

clusterExport(cl, c("A","B", "n")) 

Z <- pbsapply(1:nrow(A), function(x){ 
    score = A[x,] %*% t(B) 
    nth_score = -sort(-score, partial=n)[n] 
    top_scores_1 = which(score > nth_score) 
    top_scores_2 = which(score == nth_score) 
    if (!length(top_scores_2) == 1) top_scores_2 = sample(top_scores_2, n - length(top_scores_1)) 
    top_scores = c(top_scores_1, top_scores_2) 
    top_ix = sort(score[top_scores], decreasing = T, index.return=T)$ix 
    return(top_scores[top_ix]) 
}, cl = cl) 

stopCluster(cl) 

答えて

1

一つの迅速な改善よりも高速な実装がある場合 私は思っていた、tcrossprod(A,B)

n <- 10 

func1 <- function(x){ 
    score = A[x,] %*% t(B) 
    nth_score = -sort(-score, partial=n)[n] 
    top_scores_1 = which(score > nth_score) 
    top_scores_2 = which(score == nth_score) 
    if (!length(top_scores_2) == 1) top_scores_2 = sample(top_scores_2, n - length(top_scores_1)) 
    top_scores = c(top_scores_1, top_scores_2) 
    top_ix = sort(score[top_scores], decreasing = T, index.return=T)$ix 
    return(top_scores[top_ix]) 
} 

func2 <- function(x){ 
    score = tcrossprod(A[x,],B) 
    nth_score = -sort(-score, partial=n)[n] 
    top_scores_1 = which(score > nth_score) 
    top_scores_2 = which(score == nth_score) 
    if (!length(top_scores_2) == 1) top_scores_2 = sample(top_scores_2, n - length(top_scores_1)) 
    top_scores = c(top_scores_1, top_scores_2) 
    top_ix = sort(score[top_scores], decreasing = T, index.return=T)$ix 
    return(top_scores[top_ix]) 
} 

all.equal(func1(1),func2(1)) 
# TRUE 

microbenchmark(func1(1),func2(1)) 
# Unit: milliseconds 
# expr  min  lq  mean median  uq  max neval 
# func1(1) 6.527077 9.254476 9.757431 9.726585 10.311310 11.932170 100 
# func2(1) 3.365654 3.721711 4.036532 3.998387 4.246175 5.405226 100 
A %*% t(B)を置き換えます