ここには、各サンプルにargsort
を使用する解決策があります。かなりではなく、速くはありませんが、仕事をしています。
import numpy as np
from timeit import timeit
def dupl(a, axis=0, make_dict=True):
a = np.moveaxis(a, axis, -1)
i = np.argsort(a, axis=-1, kind='mergesort')
ai = a[tuple(np.ogrid[tuple(map(slice, a.shape))][:-1]) + (i,)]
same = np.zeros(a.shape[:-1] + (a.shape[-1]+1,), bool)
same[..., 1:-1] = np.diff(ai, axis=-1) == 0
uniqs = np.where((same[..., 1:] & ~same[..., :-1]).ravel())[0]
same = (same[...,1:]|same[...,:-1]).ravel()
reps = np.split(i.ravel()[same], np.cumsum(same)[uniqs[1:]-1])
grps = np.searchsorted(uniqs, np.arange(0, same.size, a.shape[-1]))
keys = ai.ravel()[uniqs]
if make_dict:
result = np.empty(a.shape[:-1], object)
result.ravel()[:] = [dict(zip(*p)) for p in np.split(
np.array([keys, reps], object), grps[1:], axis=-1)]
return result
else:
return keys, reps, grps
a = np.random.randint(0,10,(10,100,100,9))
axis = 0
result = dupl(a, axis)
print('shape, axis, time (sec) for 10 trials:',
a.shape, axis, timeit(lambda: dupl(a, axis=axis), number=10))
print('same without creating dict:',
a.shape, axis, timeit(lambda: dupl(a, axis=axis, make_dict=False),
number=10))
#check
print("checking result")
am = np.moveaxis(a, axis, -1)
for af, df in zip(am.reshape(-1, am.shape[-1]), result.ravel()):
assert len(set(af)) + sum(map(len, df.values())) == len(df) + am.shape[-1]
for k, v in df.items():
assert np.all(np.where(af == k)[0] == v)
print("no errors")
プリント:インデックスが重複しているかを示しブール配列:以下のスニペットで
shape, axis, time (sec) for 10 trials: (10, 100, 100, 9) 0 5.328339613042772
same without creating dict: (10, 100, 100, 9) 0 2.568383438978344
checking result
no errors
重複している値にフラグを付けるだけでいいですか、データ値のある辞書をキーとして使用し、インデックスを辞書値として複製しますか? – James
@ジェームスは、可能な解決策を拘束しないように返される正確なデータを指定しないでジェネリックなままにしていましたが、インデックスで重複を単にフラグするブール値の配列を考えています。 – BoltzmannBrain