2016-11-12 14 views
1

np.uint8のnumpy配列Xを持っています(したがって、値は[0, 255]です)。私は[0, 255]から[0, 3]へのマッピングfも持っています。numpyでテンソルを効率的に塗りつぶす

y_{k, i, j} = 1 if k == f(x_{i, j})のような形状(4, n, m)の配列Yを作成したいと思います。今のところ、私は次のようにしています:

しかし、これは非常に遅いので、私はそれを行うより効率的な方法を見つけることができません。何か案は?

Yout = (f(X) == np.arange(4)[:,None,None]).astype(int) 

ランタイムテストと検証 - - あなたはbroadcastingを使用することができ、fは一行くにすべての反復値を操作したと仮定すると

+0

func 'f'の実装を共有できますか? – Divakar

+0

これは非常に問題の特定の機能だから分かりにくいですが、実行に時間がかからず、プロセス全体を高速化するために使用できるものはありません。これは次のようなものです: def f(x):v [x] ここでv = np.random.randint(4、size =(256、)) – dhokas

答えて

1

In [35]: def original_app(X,n,m): 
    ...:  Y = np.zeros((4, n, m)) 
    ...:  for i in range(256): 
    ...:   Y[f(i), X == i] = 1 
    ...:  return Y 
    ...: 

In [36]: # Setup Inputs 
    ...: n,m = 2000,2000 
    ...: X = np.random.randint(0,255,(n,m)).astype('uint8') 
    ...: v = np.random.randint(4, size=(256,)) 
    ...: def f(x): 
    ...:  return v[x] 
    ...: 

In [37]: Y = original_app(X,n,m) 
    ...: Yout = (f(X) == np.arange(4)[:,None,None]).astype(int) 
    ...: 

In [38]: np.allclose(Yout,Y) # Verify 
Out[38]: True 

In [39]: %timeit original_app(X,n,m) 
1 loops, best of 3: 3.77 s per loop 

In [40]: %timeit (f(X) == np.arange(4)[:,None,None]).astype(int) 
10 loops, best of 3: 74.5 ms per loop 
1

スカラーインデックスとブールのミックスに表示されますあなたのスピードを傷つけるでしょう:

In [706]: %%timeit 
    ...: Y=np.zeros((4,3,4)) 
    ...: for i in range(256): 
    ...: Y[f(i), X==i]+=1 
    ...: 

100 loops, best of 3: 12.5 ms per loop 

In [722]: %%timeit 
    ...: Y=np.zeros((4,3,4)) 
    ...: for i in range(256): 
    ...:  I,J=np.where(X==i) 
    ...:  Y[f(i),I,J] = 1 
    ...: 
100 loops, best of 3: 8.55 ms per loop 

これはです

In [718]: timeit K=[f(i) for i in range(256)] 
10000 loops, best of 3: 120 µs per loop 

が、X==iインデックスを得ることが遅いですが

In [720]: timeit K=[X==i for i in range(256)] 
1000 loops, best of 3: 1.29 ms per loop 
In [721]: timeit K=[np.where(X==i) for i in range(256)] 
100 loops, best of 3: 2.73 ms per loop 

我々はのX==i一部を再考する必要があります。この場合

X=np.arange(12,dtype=np.uint8).reshape(3,4) 
def f(i): 
    return i%4 
は、 f(i)は主要な時間の消費者ではありませんマッピングではなく、 f(i)部分です。助け、最後の2次元のフラット化

=====================

In [780]: %%timeit 
    ...: X1=X.ravel() 
    ...: Y=np.zeros((4,12)) 
    ...: for i in range(256): 
    ...:  Y[f(i),X1==i]=1 
    ...: Y.shape=(4,3,4) 
    ...: 
100 loops, best of 3: 3.16 ms per loop 
+0

私は驚きです。ブールインデックスは 'nonzeroとにかくボンネットの下で?あなたの2番目の例は、 '+ =' – Eric

+0

私のPC上でテストされました - 最初の部分で観察しているほとんどの時間差は '=' vs '+ =' typo、悲しく – Eric

+0

です。私は '' = 'を一度にテストしていました。なぜなら、バッファされていない 'add.at'が必要となる(' X'で値が重複する)場合を考えていたからです。 – hpaulj

関連する問題