1
n
の値を保持し、それ以外はすべてゼロにする必要のあるデータ配列があります。numpy配列の各行のn個の最高値を保持し、他のすべてをゼロにします。
私の現在のソリューション:上記のコードで
import numpy as np
np.random.seed(30)
# keep only the n highest values
n = 3
# Simple 2x5 data field for this example, real life application will be exteremely large
data = np.random.random((2,5))
#[[ 0.64414354 0.38074849 0.66304791 0.16365073 0.96260781]
# [ 0.34666184 0.99175099 0.2350579 0.58569427 0.4066901 ]]
# find indices of the n highest values per row
idx = np.argsort(data)[:,-n:]
#[[0 2 4]
# [4 3 1]]
# put those values back in a blank array
data_ = np.zeros(data.shape) # blank slate
for i in xrange(data.shape[0]):
data_[i,idx[i]] = data[i,idx[i]]
# Each row contains only the 3 highest values per row or the original data
#[[ 0.64414354 0. 0.66304791 0. 0.96260781]
# [ 0. 0.99175099 0. 0.58569427 0.4066901 ]]
は、data_
はn
最高値を持っており、他のすべてはゼロにされます。 data.shape[1]
がn
よりも小さくてもうまくいきます。しかし、唯一の問題はfor loop
です。私の実際のユースケースは非常に大きな配列上にあるため、遅いです。
forループを取り除くことは可能ですか?
わかりやすさのために私の解決策を編集しました。上記のデータを使用して、1行につきn個の最高値を得ることはできません。同じデータを使用してソリューションを使用して、違いを確認してください。 – Fnord
@Fnord:oops、余分な 'argsort'を忘れました。ランクデータのように振る舞うためには別の引数が必要です(私はSeriesやDataFrame上で.rank( 'dense')を使うだけでした)。 – DSM