2017-06-03 11 views
0

私はnumbaでコンパイルしたい関数を持っていますが、その関数内で階乗を計算する必要があります。残念ながらnumbamath.factorialをサポートしていません:numba nopython関数の階乗を計算する最速の方法

import math 
import numba as nb 

@nb.njit 
def factorial1(x): 
    return math.factorial(x) 

factorial1(10) 
# UntypedAttributeError: Failed at nopython (nopython frontend) 

私はそれはそれは「整数値を表す浮動小数点数を返さない本当のmath.gamma機能にしかし反対に、(階乗を計算するのに使用することができる)math.gammaをサポートすることを見ました「:

@nb.njit 
def factorial2(x): 
    return math.gamma(x+1) 

factorial2(10) 
# 3628799.9999999995 <-- not exact 

math.gamma(11) 
# 3628800.0 <-- exact 

、それはmath.factorialに比べて遅いです:

%timeit factorial2(10) 
# 1.12 µs ± 11.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) 
%timeit math.factorial(10) 
# 321 ns ± 6.12 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) 

は、だから私は自分自身の関数を定義することにしました:

@nb.njit 
def factorial3(x): 
    n = 1 
    for i in range(2, x+1): 
     n *= i 
    return n 

factorial3(10) 
# 3628800 

%timeit factorial3(10) 
# 821 ns ± 12.2 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) 

それはまだmath.factorialより遅いのですが、それはmath.gammaベースnumba機能よりも高速だし、値が「正確」です。

したがって、nopython numba関数の内部で、正の整数(< = 20;オーバーフローを避けるため)のfactorialを計算する最速の方法を探しています。

+2

整数「0..20」の階乗だけに関係する場合、ルックアップテーブルはおそらく速度のチェックに値するでしょう。 –

+0

Arrrgggh、前回のコメントであなたの書いた*あなたが*書いたはずです*。または*あなたの唯一の関心事は... * –

+0

numbaでpythonのアプローチを再実装することができます - それは特別な方法で乗算を注文するいくつかの追加ステップを経ます -​​ https://github.com/python/ cbathon/blob/3.6/Modules/mathmodule.c#L1275 – chrisb

答えて

1

値が< = 20の場合、コメントに示唆されているように、pythonはルックアップテーブルを使用しています。 pythonから呼び出され https://github.com/python/cpython/blob/3.6/Modules/mathmodule.c#L1452

LOOKUP_TABLE = np.array([ 
    1, 1, 2, 6, 24, 120, 720, 5040, 40320, 
    362880, 3628800, 39916800, 479001600, 
    6227020800, 87178291200, 1307674368000, 
    20922789888000, 355687428096000, 6402373705728000, 
    121645100408832000, 2432902008176640000], dtype='int64') 

@nb.jit 
def fast_factorial(n): 
    if n > 20: 
     raise ValueError 
    return LOOKUP_TABLE[n] 

それはnumbaディスパッチのオーバーヘッドのためにPythonのバージョンよりもわずかに遅いです。別のnumba関数内で呼ば

In [58]: %timeit math.factorial(10) 
10000000 loops, best of 3: 79.4 ns per loop 

In [59]: %timeit fast_factorial(10) 
10000000 loops, best of 3: 173 ns per loop 

しかしがはるかに高速になります。

def loop_python(): 
    for i in range(10000): 
     for n in range(21): 
      math.factorial(n) 

@nb.njit 
def loop_numba(): 
    for i in range(10000): 
     for n in range(21): 
      fast_factorial(n) 

In [65]: %timeit loop_python() 
10 loops, best of 3: 36.7 ms per loop 

In [66]: %timeit loop_numba() 
10000000 loops, best of 3: 73.6 ns per loop 
+0

Numbaは積極的なループ最適化を行いますので、 'fast_factorial'の結果を保存しないとループしません。 – MSeifert

関連する問題