2017-11-20 6 views
1

マイコード:numpyのアレイにおけるNの最高値の取得インデックス

import numpy as np 
N = 2 
a = np.array([[0.5, 0.3, 0.2], 
       [0.2, 0.6, 0.2], 
       [0.3, 0.2, 0.7], 
       [np.nan, 0.2, 0.8],      
       [np.nan, np.nan, 0.8]      
       ]) 

ind = np.argsort(np.where(np.isnan(a), -1, a), axis=1)[:, -N:] 


a 
Out[2]: 
array([[ 0.5, 0.3, 0.2], 
     [ 0.2, 0.6, 0.2], 
     [ 0.3, 0.2, 0.7], 
     [ nan, 0.2, 0.8], 
     [ nan, nan, 0.8]]) 

ind 
Out[3]: 
array([[1, 0], 
     [2, 1], 
     [0, 2], 
     [1, 2], 
     [1, 2]], dtype=int64) 

株IND [:1]最高及びINDある[:0]の場合を除いて、微細で二番目に高い

最後の行に2人がいる。 2番目に高い値を無視する方法は? 所望の出力は次のようになります。

array([[1, 0], 
     [2, 1], 
     [0, 2], 
     [1, 2], 
     [nan, 2]], dtype=int64) 

ボーナス質問:どのようにランダムに[1 ,:]の場合にはネクタイを破りますか?

答えて

1

Advanced-indexなどのように、私たちはその後、選んを行うためにnp.whereで使用することができたマスクを、与えることNaNsをチェック - NaNを持つ配列がfloat DTYPEに変換されることを

In [244]: a_ind = a[np.arange(ind.shape[0])[:,None],ind] 

In [245]: mask = np.isnan(a_ind) 

In [246]: np.where(mask, np.nan, ind) 
Out[246]: 
array([[ 1., 0.], 
     [ 2., 1.], 
     [ 0., 2.], 
     [ 1., 2.], 
     [ nan, 2.]]) 

注意したがって、最終出力もfloatdtypeになります。

関連する問題