Numbaを使用してコードを高速化しようとしています。コードはシンプルですが、基本的にループは単純な数値計算配列で計算されます。Numbaを使用して次のコードを高速化する
import numpy as np
import time
from numba import jit, double
def MinimizeSquareDiffBudget(x, budget):
if (budget > np.sum(x)):
return x
n = np.size(x,0)
j = 1
i = 0
y = np.zeros((n, 1))
while (budget > 0):
while (x[i] == x[j]) and (j < n-1):
j += 1
i = j - 1
if (np.std(x)<1e-10):
to_give = budget/n
y += to_give
x= x- to_give
break
to_give = min(budget, (x[0] - x[j])*j)
y[0:j] += to_give/j
x[0:j]=x[0:j]-to_give/j
budget = budget - to_give
j = 1
return y
は今、私は@jitを使用して最適化し、定義してみました。私はNumbaははるかに高速であると予想しながら
fastMinimizeSquareDiffBudget = jit(double[:,:](double[:,:], double[:,:]))(MinimizeSquareDiffBudget)
しかし、時間は、ほぼ同じです。
テストコード:
budget = 335.0
x = np.random.uniform(0,1,(1000,1))
x.sort(axis=0)
x = x[::-1]
t = time.process_time()
y = MinimizeSquareDiffBudget(x, budget)
print(time.process_time()-t)
x = np.random.uniform(0,1,(1000,1))
x.sort(axis=0)
x = x[::-1]
t = time.process_time()
y = fastMinimizeSquareDiffBudget(x, budget)
print(time.process_time()-t)
直接実装するための0.28秒とNumbaと最適化されたコードのために0.45秒を要します。 Cで書かれた同じコードは0.001秒未満です。
アイデア?