theano

2017-08-11 12 views
1

でscipy.convolveを再現するにはどうすればtheano conv1d問題に関する以前の回答を読んだが、私はそれを動作させるように見えることはできません。theano

x = np.arange(50) * 1. 
y = np.random.normal((x+0.1)/5, 1, 50) 

def tophat(x, centre, width, amplitude): 
    return tt.switch((x < centre + (width/2)) & (x >= centre - (width/2)), np.float64(amplitude)/width, np.float64(0.)) 

import theano.tensor.signal.conv 
def theano_convolve(x, y, filt_range, centre, width, amplitude): 
    a = tt.matrix('a', dtype='float64') 
    b = tt.matrix('b', dtype='float64') 

    filt = tophat(b, centre, width, amplitude) 

    func = tt.signal.conv.conv2d(a, filt, (1, y.shape[0]), (1, filt_range.shape[0]), border_mode='full')/filt.sum() 

    return theano.function([a, b], func)(y[None, :], filt_range[None, :]) 

from scipy.signal import convolve 

def scipy_convolve(x, y, filt_range, centre, width, amplitude): 
    a = tt.vector('a') 
    filt = theano.function([a], tophat(a, centre, width, amplitude))(filt_range) 
    return convolve(y, filt, mode='same')/sum(filt) 

convolved_theano = theano_convolve(x, y, np.linspace(-10, 10, len(x)), 0, 3, 1) 

convolved_scipy = scipy_convolve(x, y, np.linspace(-10, 10, len(x)), 0, 3, 1) 

plt.plot(x, y, '.', label='data') 
plt.plot(r[0], label='theano') 
plt.plot(convolved_scipy, label='scipy'); 
plt.legend(); 

enter image description here

これは、その結果theanoとのゼロパディング畳み込み。私はゼロを取り除くことができますが、私はむしろ何が起こっているのか知っています!

どのようにしてtoano関数のデータを1次元で畳み込みますか?あなたが見る

おかげ

答えて

1

動作は、次の2つの畳み込みのために使用するさまざまmodeによって引き起こされます。

mode='same'で、theano.tensor.signal.conv.conv2dではmode='full'を使用します。

mode='full'を使用するようにscipy.signal.convolveを変更すると、全く同じ ベクターが得られます。画像の場合、をtheanoベクトルに追加して、線を表示し、sicpy.convolveと重ならないようにしました。

scipy vs theano convolution

import numpy as np 
import theano.tensor as tt 
import seaborn as sns 

plt = sns.plt 

x = np.arange(50) * 1. 
y = np.random.normal((x+0.1)/5, 1, 50) 

def tophat(x, centre, width, amplitude): 
    return tt.switch((x < centre + (width/2)) & (x >= centre - (width/2)), np.float64(amplitude)/width, np.float64(0.)) 

import theano.tensor.signal.conv 
def theano_convolve(x, y, filt_range, centre, width, amplitude): 
    a = tt.matrix('a', dtype='float64') 
    b = tt.matrix('b', dtype='float64') 

    filt = tophat(b, centre, width, amplitude) 

    func = tt.signal.conv.conv2d(a, filt, (1, y.shape[0]), (1, filt_range.shape[0]), border_mode='full')/filt.sum() 

    return theano.function([a, b], func)(y[None, :], filt_range[None, :]) 

from scipy.signal import convolve 

def scipy_convolve(x, y, filt_range, centre, width, amplitude): 
    a = tt.vector('a') 
    filt = theano.function([a], tophat(a, centre, width, amplitude))(filt_range) 
    return convolve(y, filt, mode='full')/sum(filt) 

convolved_theano = theano_convolve(x, y, np.linspace(-10, 10, len(x)), 0, 3, 1) 

convolved_scipy = scipy_convolve(x, y, np.linspace(-10, 10, len(x)), 0, 3, 1) 

plt.plot(x, y, '.', label='data') 
plt.plot(convolved_theano[0]+0.1, label='theano') 
plt.plot(convolved_scipy, label='scipy') 
plt.legend() 
plt.show(block=True) 

は残念ながら border_mode='same'をサポートしていませんtheanoため theano documentation conv2dを見ています。

+0

はおそらくデータを畳み込みに送信する前に境界線を修正できます。 –