2017-02-06 14 views
1

python/numpyでは、同じ形状の別の軸がある軸(たとえば最初の軸)に沿って最大の多次元配列をサブセット化するにはどうすればよいですか?`numpy.ndarray`をサブセット化するにはどうすればいいですか?

2つの3 * 2 * 4アレイ、aおよびbがあるとします。私はaが第1の軸に沿って最大値を持つ場所でbの値を含む2 * 4配列を取得したいと思います。

import numpy as np 

np.random.seed(7) 
a = np.random.rand(3*2*4).reshape((3,2,4)) 
b = np.random.rand(3*2*4).reshape((3,2,4)) 

print a 
#[[[ 0.07630829 0.77991879 0.43840923 0.72346518] 
# [ 0.97798951 0.53849587 0.50112046 0.07205113]] 
# 
# [[ 0.26843898 0.4998825 0.67923  0.80373904] 
# [ 0.38094113 0.06593635 0.2881456 0.90959353]] 
# 
# [[ 0.21338535 0.45212396 0.93120602 0.02489923] 
# [ 0.60054892 0.9501295 0.23030288 0.54848992]]] 

print a.argmax(axis=0) #(I would like b at these locations along axis0) 
#[[1 0 2 1] 
# [0 2 0 1]] 

私はこれは本当に醜いマニュアルサブセット行うことができます。

index = zip(a.argmax(axis=0).flatten(), 
      [0]*a.shape[2]+[1]*a.shape[2], # a.shape[2] = 4 here 
      range(a.shape[2])+range(a.shape[2])) 
# [(1, 0, 0), (0, 0, 1), (2, 0, 2), (1, 0, 3), 
# (0, 1, 0), (2, 1, 1), (0, 1, 2), (1, 1, 3)] 

私の望ましい結果得ることができるようになる。これを達成するために、スマート、numpy方法です何

b_where_a_is_max_along0 = np.array([b[i] for i in index]).reshape(2,4) 

# For verification: 
print a.max(axis=0) == np.array([a[i] for i in index]).reshape(2,4) 
#[[ True True True True] 
# [ True True True True]] 

を?感謝:)

答えて

1

使用advanced-indexing -

m,n = a.shape[1:] 
b_out = b[a.argmax(0),np.arange(m)[:,None],np.arange(n)] 

サンプル実行 -

セットアップ入力配列aと、第1の軸に沿ってそのargmaxます -

-

In [185]: a = np.random.randint(11,99,(3,2,4)) 

In [186]: idx = a.argmax(0) 

In [187]: idx 
Out[187]: 
array([[0, 2, 1, 2], 
     [0, 1, 2, 0]]) 

In [188]: a 
Out[188]: 
array([[[49*, 58, 13, 69], # * are the max positions 
     [94*, 28, 55, 86*]], 

     [[34, 17, 57*, 50], 
     [48, 73*, 22, 80]], 

     [[19, 89*, 42, 71*], 
     [24, 12, 66*, 82]]]) 

bで結果を確認します

In [193]: b 
Out[193]: 
array([[[18*, 72, 35, 51], # Mark * at the same positions in b 
     [74*, 57, 50, 84*]], # and verify 

     [[58, 92, 53*, 65], 
     [51, 95*, 43, 94]], 

     [[85, 23*, 13, 17*], 
     [17, 64, 35*, 91]]]) 

In [194]: b[a.argmax(0),np.arange(2)[:,None],np.arange(4)] 
Out[194]: 
array([[18, 23, 53, 17], 
     [74, 95, 35, 84]]) 
1

あなたは使用することができogrid

>>> x = np.random.random((2,3,4)) 
>>> x 
array([[[ 0.87412737, 0.11069105, 0.86951092, 0.74895912], 
     [ 0.48237622, 0.67502597, 0.11935148, 0.44133397], 
     [ 0.65169681, 0.21843482, 0.52877862, 0.72662927]], 

     [[ 0.48979028, 0.97103611, 0.36459645, 0.80723839], 
     [ 0.90467511, 0.79118429, 0.31371856, 0.99443492], 
     [ 0.96329039, 0.59534491, 0.15071331, 0.52409446]]]) 
>>> y = np.argmax(x, axis=1) 
>>> y 
array([[0, 1, 0, 0], 
     [2, 0, 0, 1]]) 
>>> i, j = np.ogrid[:2,:4] 
>>> x[i ,y, j] 
array([[ 0.87412737, 0.67502597, 0.86951092, 0.74895912], 
     [ 0.96329039, 0.97103611, 0.36459645, 0.99443492]]) 
関連する問題