2016-06-01 7 views
4

私は球面上に以下のコード例で 'obj'と 'ps'という2つの点集合を持っています。私は 'ps'点から特定の角度距離より近いすべての 'obj'点を特定したいと思います。Numbaによるナンシー最適化

これは、各点を3D単位ベクトルで表し、そのドット積をcos(最大分離)と比較することです。これはnumpy放送では簡単に行うことができますが、私のアプリケーションではn_obj〜500,000とn_ps〜50,000があり、放送のメモリ要件が大きすぎます。以下は、numbaを使って私の現在のテイクを貼り付けたものです。これはさらに最適化できますか?

from numba import jit 
import numpy as np 
from sklearn.preprocessing import normalize 

def gen_points(n): 
    """ 
    generate random 3D unit vectors (not uniform, but irrelevant here) 
    """ 
    vec = 2*np.random.rand(n,3)-1. 
    vec_norm = normalize(vec) 
    return vec_norm 

#@jit(nopython=True) 
@jit 
def angdist_threshold_numba(vec_obj,vec_ps,cos_maxsep): 
    """ 
    finds obj that are closer than maxsep to a ps 
    """  
    nps = len(vec_ps) 
    nobj = len(vec_obj)  

    #closeobj_all = [] 
    closeobj_all = np.empty(0) 
    dotprod = np.empty(nobj) 
    a = np.arange(nobj) 
    for ps in range(nps): 
     np.sum(vec_obj*vec_ps[ps],axis=1,out=dotprod) 
     #closeobj_all.extend(a[dotprod > cos_maxsep]) 
     closeobj_all = np.append(closeobj_all, a[dotprod > cos_maxsep]) 

    return closeobj_all 


vec_obj = gen_points(50000) #in reality ~500,000 
vec_ps = gen_points(5000) #in reality ~50,000 
cos_maxsep = np.cos(0.003) 

closeobj_all = np.unique(angdist_threshold_numba(vec_obj,vec_ps,cos_maxsep)) 

これは、コード内で与えられたテストケースを使用してパフォーマンスです:私は

@jit(nopython=True) 

を使用してそれをスピードアップしようとしているが、これは

NotImplementedError: Failed at nopython (nopython frontend) 
(<class 'numba.ir.Expr'>, build_list(items=[])) 
で失敗

%timeit np.unique(angdist_threshold_numba(vec_obj,vec_ps,cos_maxsep)) 
1 loops, best of 3: 4.53 s per loop 

編集:numbaを0.26に更新した後空リストの作成はPythonモードでも失敗します。これは、np.empty(0)で置換し、np.append()で.extend()を置き換えることで修正できます。これはほとんどの場合、パフォーマンスを変更しません。 https://github.com/numba/numba/issues/858 np.empty()によると

は今nopythonモードでサポートされていますが、私はまだ(= nopython true)を@jitでこれを実行することはできません。

TypingError: Internal error at <numba.typeinfer.CallConstraint object at 0x7ff3114a9310> 
+2

はあなただけ使用することはできませんscipy.spatial.distanceからのcdist?すなわち、 distance.cdist(vec_o​​bj、vec_ps、 'cosine') –

+0

これは高速で、必要な処理を行います。しかし、私がこの例で使った数字では、cdistが返した2D配列は〜2GBのメモリを使います。つまり、実際のサイズ(両方の軸に沿って10倍大きくなります)を使用すると、メモリの問題が再び発生します。 – user4319496

+1

あなたはカスタムルートに行くのが正当だと思われます。 numbaはリストに問題があるようですので、スパース行列を代わりに使用できますか? –

答えて

4

list.appendとは異なり、あなたがnumpy.appendを呼び出すことはありませんループで!これは、1つの要素を追加する場合でも、配列全体をコピーする必要があるからです。固有のobjにのみ興味があるので、ブール値配列を使用してこれまでに見つかったマッチにフラグを立てることができます。

Numbaは、すべてのループを書き出すと最も効果的です。だから、例えば:

@jit(nopython=True) 
def numba2(vec_obj, vec_ps, cos_maxsep): 
    nps = vec_ps.shape[0] 
    nobj = vec_obj.shape[0] 
    dim = vec_obj.shape[1] 
    found = np.zeros(nobj, np.bool_) 
    for i in range(nobj): 
     for j in range(nps): 
      cos = 0.0 
      for k in range(dim): 
       cos += vec_obj[i,k] * vec_ps[j,k] 
      if cos > cos_maxsep: 
       found[i] = True 
       break 
    return found.nonzero() 

追加の利点は、我々はできるだけ早く我々は現在objに一致するものを見つけるようps配列を超えるループから抜け出すことができるということです。

3次元空間の関数を特殊化することで、さらに高速化できます。また、何らかの理由で、別の高速化にヘルパー関数の結果にすべての配列との関連寸法を渡す:私が20,000 obj 2,000 psのために得る

def numba3(vec_obj, vec_ps, cos_maxsep): 
    nps = len(vec_ps) 
    nobj = len(vec_obj) 
    out = np.zeros(nobj, bool) 
    numba3_helper(vec_obj, vec_ps, cos_maxsep, out, nps, nobj) 
    return np.flatnonzero(out) 

@jit(nopython=True) 
def numba3_helper(vec_obj, vec_ps, cos_maxsep, out, nps, nobj): 
    for i in range(nobj): 
     for j in range(nps): 
      cos = (vec_obj[i,0]*vec_ps[j,0] + 
        vec_obj[i,1]*vec_ps[j,1] + 
        vec_obj[i,2]*vec_ps[j,2]) 
      if cos > cos_maxsep: 
       out[i] = True 
       break 
    return out 

タイミング:

%timeit angdist_threshold_numba(vec_obj,vec_ps,cos_maxsep) 
1 loop, best of 3: 2.99 s per loop 
%timeit numba2(vec_obj, vec_ps, cos_maxsep) 
1 loop, best of 3: 444 ms per loop 
%timeit numba3(vec_obj, vec_ps, cos_maxsep) 
10 loops, best of 3: 134 ms per loop 
+0

私は感動していると言わなければならない。上記のコメントで提案されたscipy.spatial.distance.cdistと比較して、スピードアップの20倍、さらには10倍です。また、非常に教育的! PS:numba2の返り値に一式の編集をしようとしましたが(一貫性のために[0]を追加しています)、SOは私を許さないでしょう。 – user4319496

+0

@ user4319496音が教育的であったなら、謝罪します。私はこれが組み込みではできないことにちょうど驚いています。それができないように見えるので、ここで作られるべきscipyへの価値ある貢献があるかもしれません - あなたがしていることはそれほど普通ではありません –