2017-06-09 15 views
1

私はtheanoに新しく、走査のまわりで私の頭を包んでいます。私は、行の重みから重み付けされた行列を計算し、重みの確率で重み付けし、対応する重み付き行列を求めたいと思います。しかし、私は行列和を追跡しながらtheanoの重みを反復することに問題があります。おもちゃの例としてスキャンしたウェイトベクトルを与えられた重み付けされた行列(theano)

def get_weighted(mtx,weights,weights_p): 
    mtx_store = np.zeros(np.shape(mtx)) 
    for idx in xrange(len(weights)): 
     mtx_store += weights_p[idx] * (mtx * [weights[idx], 1-weights[idx]]) 
    return mtx_store 

import numpy as np 

mtx = np.asarray([[1,0],[0,1],[0.5,0.5]]) 
weights = np.asarray([0.1,0.8]) #weight 1 and weight 2 
weights_p = np.asarray([0.8, 0.2]) #prob. of weight 1 and weight 2 

が例示は、より一般的にnumpyの、インデックス、およびforループの私の所望の機能を使用すると、これを行うだろう

weights_p[0] * (mtx * [weights[0],(1-weights[0])]) +\ 
weights_p[1] * (mtx * [weights[1],(1-weights[1])]) 

だろう

今私はtheanoでこれを行う必要があります。私が試した:

import theano as t 
v,w = t.tensor.vectors('v','w') 
m,n = t.tensor.matrices('m','n') 

def step(v, w, m, cum_sum): 
    return v * (m * [w,1-w]) + cum_sum 

output, updates = t.scan(fn=step, 
         sequences=[v,w], 
         non_sequences=[m], 
         outputs_info=[n]) 

get_weighted = t.function(inputs=[v,w,m,n], 
        outputs=output, 
        updates=updates) 

私の考えは、反復的に合計を格納するための空の配列を持っていることでした。

mtx_store = np.zeros(np.shape(mtx)) 
get_weighted(weights_p, weights, mtx, mtx_store) 

しかし、私は取得しています:

array([[[ 1. , 0. ], 
     [ 0. , 1. ], 
     [ 0.5 , 0.5 ]], 

     [[ 1.16, 0. ], 
     [ 0. , 1.04], 
     [ 0.58, 0.52]]]) 

代わりの

array([[ 0.24, 0. ], 
     [ 0. , 0.76], 
     [ 0.12, 0.38]]) 

私はこれが私の病気に由来していると確信していますスキャンの誤り。何が間違っていて、どのように効率的に行うことができますか?

答えて

0

問題が見つかりました。後世のために:

sequences (if any), prior result(s) (if needed), non-sequences (if any) 

私はこの順序で引数を提供していたのに対し:

sequences, non-sequences, prior-results 

次のように正しいコードは次のとおりです。

def step(v, w, cum_sum,m): 
    return v * (m * [w,1-w]) + cum_sum 

output, updates = t.scan(fn=step, 
         sequences=[v,w], 
         non_sequences=[m], 
         outputs_info=[t.tensor.zeros_like(m)]) 

final_result = output[-1] #take the final outcome of the sum 


get_weighted = t.function(inputs=[v,w,m], 
        outputs=final_result, 
        updates=updates) 
主な問題は、スキャンの構文が望んでいることでした

(引数を格納するために行列を渡すことも明らかに必要ではありませんが、これは問題ではないと思いますが、上記の 'outputs_info'で直接指定することができます)

関連する問題