2017-03-12 18 views
0

私のデータサンプルはそれぞれ形状がnumpyの配列です。 (100,100,9)であり、これらのうち10個が連結されて1つの配列foo(10,100,100,9)になっています。 10個のデータサンプルにわたって、繰り返し値のインデックスを探したいと思います。たとえば、foo[0, 42, 42, 3] = 0.72foo[0, 42, 42, 7] = 0.72の場合、これを反映した出力が必要です。そうする効率的な方法は何ですか?numpy nd配列の繰り返し値を見つける

私は形状のブール値の出力配列(100,100,9)を考えていますが、各データサンプル(データサンプル数(10)の2次ランタイム)を比較するループより優れたアプローチがありますか?

+0

重複している値にフラグを付けるだけでいいですか、データ値のある辞書をキーとして使用し、インデックスを辞書値として複製しますか? – James

+0

@ジェームスは、可能な解決策を拘束しないように返される正確なデータを指定しないでジェネリックなままにしていましたが、インデックスで重複を単にフラグするブール値の配列を考えています。 – BoltzmannBrain

答えて

0

dups所望の結果です。 deltaのしきい値もあるため、値の差異は< =このしきい値は重複しています。

delta = 0. 
dups = np.zeros(foo.shape[1:], dtype=bool) 
for i in xrange(foo.shape[0]): 
    for j in xrange(foo.shape[0]): 
     if i==j: continue 
     dups += abs(foo[i] - foo[j]) <= delta 
-1

ここには、各サンプルにargsortを使用する解決策があります。かなりではなく、速くはありませんが、仕事をしています。

import numpy as np 
from timeit import timeit 

def dupl(a, axis=0, make_dict=True): 
    a = np.moveaxis(a, axis, -1) 
    i = np.argsort(a, axis=-1, kind='mergesort') 
    ai = a[tuple(np.ogrid[tuple(map(slice, a.shape))][:-1]) + (i,)] 
    same = np.zeros(a.shape[:-1] + (a.shape[-1]+1,), bool) 
    same[..., 1:-1] = np.diff(ai, axis=-1) == 0 
    uniqs = np.where((same[..., 1:] & ~same[..., :-1]).ravel())[0] 
    same = (same[...,1:]|same[...,:-1]).ravel() 
    reps = np.split(i.ravel()[same], np.cumsum(same)[uniqs[1:]-1]) 
    grps = np.searchsorted(uniqs, np.arange(0, same.size, a.shape[-1])) 
    keys = ai.ravel()[uniqs] 
    if make_dict: 
     result = np.empty(a.shape[:-1], object) 
     result.ravel()[:] = [dict(zip(*p)) for p in np.split(
       np.array([keys, reps], object), grps[1:], axis=-1)] 
     return result 
    else: 
     return keys, reps, grps 

a = np.random.randint(0,10,(10,100,100,9)) 
axis = 0 
result = dupl(a, axis) 

print('shape, axis, time (sec) for 10 trials:', 
     a.shape, axis, timeit(lambda: dupl(a, axis=axis), number=10)) 
print('same without creating dict:', 
     a.shape, axis, timeit(lambda: dupl(a, axis=axis, make_dict=False), 
          number=10)) 

#check 
print("checking result") 
am = np.moveaxis(a, axis, -1) 
for af, df in zip(am.reshape(-1, am.shape[-1]), result.ravel()): 
    assert len(set(af)) + sum(map(len, df.values())) == len(df) + am.shape[-1] 
    for k, v in df.items(): 
     assert np.all(np.where(af == k)[0] == v) 
print("no errors") 

プリント:インデックスが重複しているかを示しブール配列:以下のスニペットで

shape, axis, time (sec) for 10 trials: (10, 100, 100, 9) 0 5.328339613042772 
same without creating dict: (10, 100, 100, 9) 0 2.568383438978344 
checking result 
no errors 
+0

このコードにはすべて匂いがあり、これは非常に雑多な並べ替えがあり、これは非効率的にしようとしているようです。 – BoltzmannBrain

+0

@BoltzmannBrain bit厳しいとは思わない?あなたのものとは異なり、これはO(n k^2)ではなくO(n(log n/k + k log k))という合理的な複雑さを持っています。それは私がそれを認めているが、それはあなたを超えているためにそれをスラッシュしないで、目では簡単ではない。 –

関連する問題