2011-08-14 15 views
6

私はKaratsubaの乗算アルゴリズムをC++で実装しようとしていますが、今はただPythonで動作させようとしています。Karatsubaアルゴリズムが多すぎる再帰

def mult(x, y, b, m): 
    if max(x, y) < b: 
     return x * y 

    bm = pow(b, m) 
    x0 = x/bm 
    x1 = x % bm 
    y0 = y/bm 
    y1 = y % bm 

    z2 = mult(x1, y1, b, m) 
    z0 = mult(x0, y0, b, m) 
    z1 = mult(x1 + x0, y1 + y0, b, m) - z2 - z0 

    return mult(z2, bm ** 2, b, m) + mult(z1, bm, b, m) + z0 

私は何を得ることはありませんがある::z2は、z1、およびz0はどのように作成する必要があります。ここ

は私のコードですか? mult関数を再帰的に使用していますか?そうであれば、再帰が止まっていないので、どこかでうんざりしています。

どこの人がエラーを指摘できますか?

+1

もちろん、再帰は停止しません。再帰を停止させる条件はどこですか? – neurino

+2

私はおそらく、 'x0、x1 = divmod(x、bm)'を使う方が速いでしょう。 – utdemir

+0

@neurino、最初の行にif文が返されます。 – utdemir

答えて

5

NB:直接アドレス以下の応答OPの質問について 過度の再帰が、それは正しい カラツバ法を提供しようとしません。その他の回答は、 の方がはるかに有益です。

は、このバージョンをお試しください:

def mult(x, y, b, m): 
    bm = pow(b, m) 

    if min(x, y) <= bm: 
     return x * y 

    # NOTE the following 4 lines 
    x0 = x % bm 
    x1 = x/bm 
    y0 = y % bm 
    y1 = y/bm 

    z0 = mult(x0, y0, b, m) 
    z2 = mult(x1, y1, b, m) 
    z1 = mult(x1 + x0, y1 + y0, b, m) - z2 - z0 

    retval = mult(mult(z2, bm, b, m) + z1, bm, b, m) + z0 
    assert retval == x * y, "%d * %d == %d != %d" % (x, y, x * y, retval) 
    return retval 

バージョンと最も深刻な問題はX0とし、Y0およびY1のX1のあなたの計算が反転していることです。また、x1y1が0の場合、アルゴリズムの派生は成立しません。この場合、分解ステップが無効になるからです。したがって、xとyの両方がb ** mよりも大きくなるようにすることで、この可能性を回避する必要があります。

EDIT:コードの入力ミスを修正しました。追加の明確化

EDIT2:

があなたの元のバージョンに直接コメントを、より明確にするには、次の

def mult(x, y, b, m): 
    # The termination condition will never be true when the recursive 
    # call is either 
    # mult(z2, bm ** 2, b, m) 
    # or mult(z1, bm, b, m) 
    # 
    # Since every recursive call leads to one of the above, you have an 
    # infinite recursion condition. 
    if max(x, y) < b: 
     return x * y 

    bm = pow(b, m) 

    # Even without the recursion problem, the next four lines are wrong 
    x0 = x/bm # RHS should be x % bm 
    x1 = x % bm # RHS should be x/bm 
    y0 = y/bm # RHS should be y % bm 
    y1 = y % bm # RHS should be y/bm 

    z2 = mult(x1, y1, b, m) 
    z0 = mult(x0, y0, b, m) 
    z1 = mult(x1 + x0, y1 + y0, b, m) - z2 - z0 

    return mult(z2, bm ** 2, b, m) + mult(z1, bm, b, m) + z0 
+0

このバージョンと上記のバージョンとの違いを説明できますか?これは無限再帰にどのように対処していますか? – templatetypedef

+2

これは終了問題を解決しますが、 'retval'の代入における最後の2回の再帰呼び出しは間違っており、bmの計算も間違っています:これはKaratsubaのO(n ^(log_2(3))の実行時間を与えません。 – huitseeker

+0

@huitseekerによれば、OPは、最適な実行時間ではなく、基本的なロジックを動作させたいと説明しました。これは、すべての再帰で 'bm'の計算だけでなく、'% 'と'/'。最後の2つの再帰呼び出しに関して、私はそれらが元のアルゴリズムの一部であるかどうかわかりません(私は簡単に完全なバージョンを見つけることができませんでした)間違っていますが、示されているアルゴリズムはすべてのテスト実行で正しい結果を出しました。反例がある場合は投稿してください。 – kjo

1

この技術の背後にあるアイデアは、再帰アルゴリズムを使用してz iという用語が計算されますが、結果はそのようにまとめられていないと考えています。あなたが望む最終結果はあなたがBの適切な値を選択すると仮定すると

z0 B^2m + z1 B^m + z2 

あるので(例えば、2)あなたは、任意の乗算を行うことなく、B^Mを計算することができます。たとえば、B = 2を使用する場合、乗算ではなくビットシフトを使用してB^mを計算できます。つまり、最後のステップは何も乗算を行わずに行うことができます。

もう1つ - アルゴリズム全体で固定値mを選択したことに気付きました。通常、このアルゴリズムを実装するには、mを常にB^mがベースBに書き込まれるときのxとyの桁数の半分になるような値にすることです。もし2のべき乗を使用しているならば、 m = ceil((log x)/ 2)を選ぶ。

希望すると便利です。

+0

ありがとう。最終的にそれを変更して、bに2の累乗を使います。今のところ、これを稼働させようとしているだけです – calccrypto

4

通常大きな数字は、整数の配列として格納されています。各整数は1桁を表します。このアプローチでは、配列の単純な左シフトで任意の数に基数の累乗を掛けることができます。ここ

は私のリストベースの実装(バグを含んでいてもよい)である。

def normalize(l,b): 
    over = 0 
    for i,x in enumerate(l): 
     over,l[i] = divmod(x+over,b) 
    if over: l.append(over) 
    return l 
def sum_lists(x,y,b): 
    l = min(len(x),len(y)) 
    res = map(operator.add,x[:l],y[:l]) 
    if len(x) > l: res.extend(x[l:]) 
    else: res.extend(y[l:]) 
    return normalize(res,b) 
def sub_lists(x,y,b): 
    res = map(operator.sub,x[:len(y)],y) 
    res.extend(x[len(y):]) 
    return normalize(res,b) 
def lshift(x,n): 
    if len(x) > 1 or len(x) == 1 and x[0] != 0: 
     return [0 for i in range(n)] + x 
    else: return x 
def mult_lists(x,y,b): 
    if min(len(x),len(y)) == 0: return [0] 
    m = max(len(x),len(y)) 
    if (m == 1): return normalize([x[0]*y[0]],b) 
    else: m >>= 1 
    x0,x1 = x[:m],x[m:] 
    y0,y1 = y[:m],y[m:] 
    z0 = mult_lists(x0,y0,b) 
    z1 = mult_lists(x1,y1,b) 
    z2 = mult_lists(sum_lists(x0,x1,b),sum_lists(y0,y1,b),b) 
    t1 = lshift(sub_lists(z2,sum_lists(z1,z0,b),b),m) 
    t2 = lshift(z1,m*2) 
    return sum_lists(sum_lists(z0,t1,b),t2,b) 

sum_listssub_lists戻る非正規化結果 - 一桁が基準値よりも大きくすることができます。 normalize関数がこの問題を解決しました。

すべての関数は、逆の順序で桁のリストを取得することを想定しています。たとえば、基数10の12は[2,1]と書かれます。上の改善分割し、代わりに4の3つの再帰呼び出しを行うことにより、乗算アルゴリズムを征服するKarastuba乗算の目標がある9987654321.

» a = [1,2,3,4,5,6,7,8,9] 
» res = mult_lists(a,a,10) 
» res.reverse() 
» res 
[9, 7, 5, 4, 6, 1, 0, 5, 7, 7, 8, 9, 9, 7, 1, 0, 4, 1] 
4

の二乗を取ることができます。したがって、スクリプト内の乗算への再帰呼び出しを含む行は、z0,z1およびz2を割り当てる行だけです。他のものはあなたに悪化させます。乗算をまだ定義していないときにはpowを使用してb^mを計算することはできません(および2乗べき乗)。

このため、アルゴリズムでは、位置記法システムを使用しているという事実が決定的に使用されています。 bの数字の表現xがある場合は、x*b^mは、単純にその表現のビットを左に移動することによって得られます。mそのシフト操作は、任意の位置記法システムでは本質的に「フリー」である。つまり、実装したい場合は、この位置表記法と "フリー"シフトを再現する必要があります。どちらかをベースb=2で計算し、Pythonのビット演算子(または、テストプラットフォームにある場合は与えられた10進数、16進数、...のベースのビット演算子)を使用するか、を使用するか、教育目的のために任意のbとし、文字列、配列、リストなどでこの位置の算術を再現します

あなたはすでにsolution with listsです。 int(s,base)sという文字列に対応する整数を与えるため、Pythonで文字列を扱うのが好きです。baseの数値表現で見ればテストが簡単です。 ひどくコメントされた文字列ベースの実装を、文字列と数字から文字列へのプリミティブを含めて、hereとして投稿しました。

あなたはmultにベースと引数としての(等しい)長さで埋められた文字列を提供することにより、それをテストすることができます

In [169]: mult("987654321","987654321",10,9) 

Out[169]: '966551847789971041' 

あなたは、文字列の長さをパディングを把握またはカウントしたくない場合は、パディング機能は、あなたのためにそれを行うことができます。

In [170]: padding("987654321","2") 

Out[170]: ('987654321', '000000002', 9) 

そしてもちろん、それはb>10で動作します:

In [171]: mult('987654321', '000000002', 16,9) 

Out[171]: '130eca8642' 
を0

wolfram alphaで確認してください)

関連する問題