2017-08-01 11 views
0

としましょう。c = a + bですが、abndarrayです。その形状は必ずしも同じではありません。つまり、general broadcasting rulesに続く2つの配列である可能性があります。numpyでのブロードキャスト追加の微分を計算するためのもっとpythonicな方法?

私はいくつかの出力dl/dcの派生語を持っています。dl/daを計算したいと思います。 abが同じ形状の場合、dl/da = dl/db = dl/dcです。しかし、私はa.shape == (3,)b.shape == (2,3)のように、このようないくつかの追加があるかもしれないので、c[i][j] = a[j] + b[i][j]。つまり、dl/da[j] = sum_i c[i][j]を意味します。一般に、dl/daは、aでブロードキャストされたすべての軸上のdl/dcの合計です。

は、一般的に abのチェーンルール誘導体を計算するために、私は次の関数を書いたが、私はそれが非常にニシキヘビはないと感じ、そしておそらくより効率的に行うことができます:私は実際に終わった

def addition_derivatives(x, y, d): 
    flip = False 
    if x.ndim < y.ndim: # x should have higher ndim 
     flip = True 
     x, y = y, x 

    S = x.shape # shape of array with higher ndim 
    s = y.shape # shape of array with lower ndim 

    # figure out which axes will be broadcast in which arrays 
    n = len(S) 
    # impute missing ones in the shape of the smaller array as per: 
    # https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules 
    s = tuple(1 if i < len(S) - len(s) else s[i - (len(S) - len(s))] for i in range(n)) 
    axis_x = [] 
    axis_y = [] 
    for i in range(n): 
     assert s[i] == S[i] or s[i] == 1 or S[i] == 1 
     if S[i] == 1 and s[i] != 1: 
      axis_x.append(i) 
     if s[i] == 1 and S[i] != 1: 
      axis_y.append(i) 
    axis_x, axis_y = map(tuple, (axis_x, axis_y)) 

    # compute the derivatives 
    dx = np.sum(d, axis=axis_x).reshape(x.shape) 
    dy = np.sum(d, axis=axis_y).reshape(y.shape) 
    if flip: 
     dx, dy = dy, dx 

    return dx, dy 

答えて

0

np.broadcast_arraysnp.stridesを使用してこれを行うための一種のハッキングを見つけます。これがすべての場合に機能するかどうかはわかりませんが、np.stridesは次元1のすべての軸に対して0を返すため、これまで働いていました。

def addition_derivatives(x, y, d): 
    bx, by = np.broadcast_arrays(x, y) 
    ax = tuple(i for i, (dx, dy) in enumerate(zip(bx.strides, by.strides)) if dx == 0 and dy != 0) 
    ay = tuple(i for i, (dx, dy) in enumerate(zip(bx.strides, by.strides)) if dx != 0 and dy == 0) 
    dx = np.sum(d, ax).reshape(x.shape) 
    dy = np.sum(d, ay).reshape(y.shape) 
    return dx, dy 
関連する問題