2017-11-04 7 views
1

nの値を保持し、それ以外はすべてゼロにする必要のあるデータ配列があります。numpy配列の各行のn個の最高値を保持し、他のすべてをゼロにします。

私の現在のソリューション:上記のコードで

import numpy as np 
np.random.seed(30) 

# keep only the n highest values 
n = 3 

# Simple 2x5 data field for this example, real life application will be exteremely large 
data = np.random.random((2,5)) 
#[[ 0.64414354 0.38074849 0.66304791 0.16365073 0.96260781] 
# [ 0.34666184 0.99175099 0.2350579 0.58569427 0.4066901 ]] 


# find indices of the n highest values per row 
idx = np.argsort(data)[:,-n:] 
#[[0 2 4] 
# [4 3 1]] 


# put those values back in a blank array 
data_ = np.zeros(data.shape) # blank slate 
for i in xrange(data.shape[0]): 
    data_[i,idx[i]] = data[i,idx[i]] 

# Each row contains only the 3 highest values per row or the original data 
#[[ 0.64414354 0.   0.66304791 0.   0.96260781] 
# [ 0.   0.99175099 0.   0.58569427 0.4066901 ]] 

は、data_n最高値を持っており、他のすべてはゼロにされます。 data.shape[1]nよりも小さくてもうまくいきます。しかし、唯一の問題はfor loopです。私の実際のユースケースは非常に大きな配列上にあるため、遅いです。

forループを取り除くことは可能ですか?

答えて

3

あなたはnp.argsortの結果に作用することができる - np.argsort 二回、インデックス順と第二のランクを取得するために取得する最初の - ベクトル化の方法で、その後、np.whereまたは単に乗算のいずれかを使用他のすべてをゼロにする:

In [116]: np.argsort(data) 
Out[116]: 
array([[3, 1, 0, 2, 4], 
     [2, 0, 4, 3, 1]]) 

In [117]: np.argsort(np.argsort(data)) # these are the ranks 
Out[117]: 
array([[2, 1, 3, 0, 4], 
     [1, 4, 0, 3, 2]]) 

In [118]: np.argsort(np.argsort(data)) >= data.shape[1] - 3 
Out[118]: 
array([[ True, False, True, False, True], 
     [False, True, False, True, True]], dtype=bool) 

In [119]: data * (np.argsort(np.argsort(data)) >= data.shape[1] - 3) 
Out[119]: 
array([[ 0.64414354, 0.  , 0.66304791, 0.  , 0.96260781], 
     [ 0.  , 0.99175099, 0.  , 0.58569427, 0.4066901 ]]) 

In [120]: np.where(np.argsort(np.argsort(data)) >= data.shape[1]-3, data, 0) 
Out[120]: 
array([[ 0.64414354, 0.  , 0.66304791, 0.  , 0.96260781], 
     [ 0.  , 0.99175099, 0.  , 0.58569427, 0.4066901 ]]) 
+0

わかりやすさのために私の解決策を編集しました。上記のデータを使用して、1行につきn個の最高値を得ることはできません。同じデータを使用してソリューションを使用して、違いを確認してください。 – Fnord

+0

@Fnord:oops、余分な 'argsort'を忘れました。ランクデータのように振る舞うためには別の引数が必要です(私はSeriesやDataFrame上で.rank( 'dense')を使うだけでした)。 – DSM

関連する問題