形状変更は、これらの最後の二つの軸をマージすると、次に使うnp.argmax
-
idx = x.reshape(x.shape[0],-1).argmax(-1)
out = np.unravel_index(idx, x.shape[-2:])
サンプル実行 -
In [263]: x = np.random.randint(10, size=(4,3,3))
In [264]: x
Out[264]:
array([[[0, 9, 2],
[7, 7, 8],
[2, 5, 9]],
[[1, 7, 2],
[8, 9, 0],
[2, 8, 3]],
[[7, 5, 0],
[7, 1, 6],
[5, 1, 1]],
[[0, 7, 3],
[5, 4, 1],
[9, 8, 9]]])
In [265]: idx = x.reshape(x.shape[0],-1).argmax(-1)
In [266]: np.unravel_index(idx, x.shape[-2:])
Out[266]: (array([0, 1, 0, 2]), array([1, 1, 0, 0]))
あなたがマージされたインデックスを取得意味場合は、そのシンプルな - は
x.reshape(x.shape[0],-1).argmax(1)
サンプルラン -
In [283]: x
Out[283]:
array([[[2, 3, 7],
[8, 1, 0],
[3, 6, 9]],
[[8, 0, 5],
[2, 2, 9],
[9, 0, 9]],
[[1, 9, 2],
[5, 0, 3],
[7, 2, 1]],
[[1, 6, 5],
[2, 3, 7],
[7, 4, 6]]])
In [284]: x.reshape(x.shape[0],-1).argmax(1)
Out[284]: array([8, 5, 1, 5])
私はマージされたインデックスを意味しました。ご協力いただきありがとうございます! – ted