2016-11-30 13 views
1

私は2次元のNumPy ndarrayを持っています。ndarrayのすべてのargmaxを見つける方法

array([[ 0., 20., -2.], 
    [ 2., 1., 0.], 
    [ 4., 3., 20.]]) 

どのようにして最大要素のすべてのインデックスを取得できますか?だから私は出力配列([0,1]、[2,2])としてしたいと思います。

答えて

2

使用np.argwhere MAX-平等マスク -

np.argwhere(a == a.max()) 

サンプル実行 - あなたは浮動小数点数を使用している場合

In [552]: a # Input array 
Out[552]: 
array([[ 0., 20., -2.], 
     [ 2., 1., 0.], 
     [ 4., 3., 20.]]) 

In [553]: a == a.max() # Max equality mask 
Out[553]: 
array([[False, True, False], 
     [False, False, False], 
     [False, False, True]], dtype=bool) 

In [554]: np.argwhere(a == a.max()) # array of row, col indices of max-mask 
Out[554]: 
array([[0, 1], 
     [2, 2]]) 

を、あなたはそこにいくつかの許容値を使用する場合があります。だから、その考察では、デフォルト絶対値と相対許容値を持つnp.iscloseを使用することができます。これにより、以前のa == a.max()部分が置き換えられます。

In [555]: np.isclose(a, a.max()) 
Out[555]: 
array([[False, True, False], 
     [False, False, False], 
     [False, False, True]], dtype=bool) 
関連する問題