2016-05-31 11 views
2

私は、2つのnumpyのアレイを有するm1m1サイズ(NX1)とm2サイズ(の1xN)であり、Iは、サイズの行列mもたらす乗算m1.dot(m2)を実行する(N×N個)マスクでnumpyドットプロダクトを高速化するにはどうすればよいですか?

あるm2私がしたいですm1およびm2の中で最も高いk個の要素を使用し、他のすべての要素を0(すべての要素が正である)にすることによって、おおよそm_approxを計算します。

私のためにサイズnが大きい(〜10k)ので、私は乗算を高速化しようとしています。私は小さなkを100と言い、実際に乗算をスピードアップしたいと考えています。ドットプロダクトロットを高速化するnumpyの疎な行列を使ってみましたが、m1とm2を疎ベクトルに変換するのは非常に遅いです。どうすればこれを達成できますか?私はマスクがこれを達成する方法かもしれないとは思いますが、どうしたらいいのでしょうか?

答えて

2

これはm1m2から選択された要素のドット積を選択して設定するための最大k要素とnp.ix_のインデックスを取得するnp.argpartitionを用いて解決することができます。したがって、次に説明するように、これを実装するために基本的に2つの段階があります。

まずオフ、インデックスがそうように、m1及びm2に最大k要素に対応する取得 -

m1_idx = np.argpartition(-m1,k,axis=0)[:k].ravel() 
m2_idx = np.argpartition(-m2,k)[:,:k].ravel() 

最後に、設定出力配列。 を使用して、出力配列内の設定する要素を選択するために、それぞれ行と列に沿ってm1およびインデックスをブロードキャストします。次は、そのように、m1_idxm2_idxでインデックスを使用してm1m2から得ることができるm1m2から最高kの要素間のドット積を計算する -

out = np.zeros((n,n)) 
out[np.ix_(m1_idx,m2_idx)] = np.dot(m1[m1_idx],m2[:,m2_idx]) 

はのは、サンプルの実行と実装を検証してみましょう明示的に低いn-k要素を設定して0m1,m2に設定してからドットプロダクトを実行する別の実装に対して実行します。ここではサンプル実行は、チェックを実行するためだ -

1)入力:

In [170]: m1 
Out[170]: 
array([[ 0.26980423], 
     [ 0.30698416], 
     [ 0.60391089], 
     [ 0.73246763], 
     [ 0.35276247]]) 

In [171]: m2 
Out[171]: array([[ 0.30523552, 0.87411242, 0.01071218, 0.81835438, 0.21693231]]) 

In [172]: k = 2 

2)が提案し、実装を実行します。

In [173]: # Proposed solution code 
    ...: m1_idx = np.argpartition(-m1,k,axis=0)[:k].ravel() 
    ...: m2_idx = np.argpartition(-m2,k)[:,:k].ravel() 
    ...: out = np.zeros((n,n)) 
    ...: out[np.ix_(m1_idx,m2_idx)] = np.dot(m1[m1_idx],m2[:,m2_idx]) 
    ...: 

3)の出力を得るために、代替の実装を使用します。

In [174]: # Explicit setting of lower n-k elements to zeros for m1 and m2 
    ...: m1[np.argpartition(-m1,k,axis=0)[k:]] = 0 
    ...: m2[:,np.argpartition(-m2,k)[:,k:].ravel()] = 0 
    ...: 

In [175]: m1 # Verify m1 and m2 have lower n-k elements set to 0s 
Out[175]: 
array([[ 0.  ], 
     [ 0.  ], 
     [ 0.60391089], 
     [ 0.73246763], 
     [ 0.  ]]) 

In [176]: m2 
Out[176]: array([[ 0.  , 0.87411242, 0.  , 0.81835438, 0.  ]]) 

In [177]: m1.dot(m2) # Use m1.dot(m2) to directly get output. This is expensive. 
Out[177]: 
array([[ 0.  , 0.  , 0.  , 0.  , 0.  ], 
     [ 0.  , 0.  , 0.  , 0.  , 0.  ], 
     [ 0.  , 0.52788601, 0.  , 0.49421312, 0.  ], 
     [ 0.  , 0.64025905, 0.  , 0.59941809, 0.  ], 
     [ 0.  , 0.  , 0.  , 0.  , 0.  ]]) 

4)提案された実装を確認してください:

In [178]: out # Print output from proposed solution obtained earlier 
Out[178]: 
array([[ 0.  , 0.  , 0.  , 0.  , 0.  ], 
     [ 0.  , 0.  , 0.  , 0.  , 0.  ], 
     [ 0.  , 0.52788601, 0.  , 0.49421312, 0.  ], 
     [ 0.  , 0.64025905, 0.  , 0.59941809, 0.  ], 
     [ 0.  , 0.  , 0.  , 0.  , 0.  ]]) 
+0

正確に私が探していたものは... np.ix_について知りませんでした! –

+0

@Adi喜んで助けてください! :) – Divakar

関連する問題