ここでは、任意の大きさの多次元配列を処理するためにreshaping
とlinear indexing
を使用して一つの手法だ -
shp = x.shape[:-1]
n_ele = np.prod(shp)
y_out = x.reshape(n_ele,-1)[np.arange(n_ele),m.ravel()].reshape(shp)
のは6 dimensions
のndarray
とサンプルケースを取り、我々は最後にインデックスにm = x.argmax(axis=-1)
を使用しているとしましょうしましょう寸法。したがって、出力はx.max(-1)
になります。さんが提案された解決策のためにこれを確認してみましょう -
In [121]: x = np.random.randint(0,9,(4,5,3,3,2,4))
In [122]: m = x.argmax(axis=-1)
In [123]: shp = x.shape[:-1]
...: n_ele = np.prod(shp)
...: y_out = x.reshape(n_ele,-1)[np.arange(n_ele),m.ravel()].reshape(shp)
...:
In [124]: np.allclose(x.max(-1),y_out)
Out[124]: True
私はその優雅さのために@B. M.'s
solutionが好きでした。
firstdims=np.indices(x.shape[:-1])
とあなたを追加します:だから、ここではランタイムテストは、これら2つのベンチマークにだ - -
def reshape_based(x,m):
shp = x.shape[:-1]
n_ele = np.prod(shp)
return x.reshape(n_ele,-1)[np.arange(n_ele),m.ravel()].reshape(shp)
def indices_based(x,m): ## @B. M.'s solution
firstdims=np.indices(x.shape[:-1])
ind=tuple(firstdims)+(m,)
return x[ind]
タイミング
In [152]: x = np.random.randint(0,9,(4,5,3,3,4,3,6,2,4,2,5))
...: m = x.argmax(axis=-1)
...:
In [153]: %timeit indices_based(x,m)
10 loops, best of 3: 30.2 ms per loop
In [154]: %timeit reshape_based(x,m)
100 loops, best of 3: 5.14 ms per loop
ありがとう!これは機能します。私は最後の行に '.reshape(shp)'を追加するだけです。 – BlindDriver
@BlindDriverうん、それはタイプミスで、それを編集しました。 – Divakar