2017-12-06 12 views
11

FullyConnectedのようなビルド前の高水準関数を使用してモデルを作成できます。たとえば、このようにして行列と行列演算を明示的に使ってMXNetでモデルを構築する方法は?

X = mx.sym.Variable('data') 
P = mx.sym.FullyConnected(data = X, name = 'fc1', num_hidden = 2) 

私はシンボリック変数Xに依存しているシンボリック変数Pを取得します。言い換えれば、モデルを定義し、fitpredictのような演算を実行するために使用できる計算グラフがあります。

PからXを別の方法で表現したいと思います。より詳細には、高レベルの機能(FullyConnectedなど)を使用する代わりに、低レベルのテンソル演算(行列乗算など)とモデルパラメータを表す記号変数を使用して、PXの間の関係を明示的に指定したいとします重み行列)。

は上記と同じを達成するために、例えば、私はfollowigを試みた:

W = mx.sym.Variable('W') 
B = mx.sym.Variable('B') 
P = mx.sym.broadcast_plus(mx.sym.dot(X, W), B) 

しかし、P得られるこの方法は、以前得Pと等価ではありません。私はそれを同じように使うことはできません。特に、私が理解する限り、MXNetはWBに値がないことを訴えています(これは意味があります)。

私はまた、(彼らは値を持っているように)別の方法でWBを宣言しようとしている:それは同様に動作しません

w = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) 
b = np.array([7.0, 8.0]) 

W = mx.nd.array(w) 
B = mx.nd.array(b) 

。 MXNetはシンボリック変数を期待しているので不満を持ちますが、代わりにnd-arrayを取得すると思います。

私の質問は、低レベルのテンソル演算(行列乗算など)とモデルパラメータ(重み行列など)を表す明示的なオブジェクトを使用してモデルを構築する方法です。

答えて

5

Gluon APIをご検討ください。

def net(X): 
    ####################### 
    # Compute the first hidden layer 
    ####################### 
    h1_linear = nd.dot(X, W1) + b1 
    ... 

と実行自動勾配

for param in params: 
    param.attach_grad() 

モデルを定義するためにそれらをアタッチ

####################### 
# Allocate parameters for the first hidden layer 
####################### 
W1 = nd.random_normal(shape=(num_inputs, num_hidden), scale=weight_scale, ctx=model_ctx) 
b1 = nd.random_normal(shape=num_hidden, scale=weight_scale, ctx=model_ctx) 

params = [W1, b1, ...] 

たとえばここにパラメータを割り当てる、スクラッチからMLPを構築するなどのためのガイドでありますそれ

epochs = 10 
learning_rate = .001 
smoothing_constant = .01 

for e in range(epochs): 
    ... 
    for i, (data, label) in enumerate(train_data): 
     data = data.as_in_context(model_ctx).reshape((-1, 784)) 
     label = label.as_in_context(model_ctx) 
     ... 
     with autograd.record(): 
      output = net(data) 
      loss = softmax_cross_entropy(output, label_one_hot) 
     loss.backward() 
     SGD(params, learning_rate) 

ストレートドープの完全な例を見ることができます:

http://gluon.mxnet.io/chapter03_deep-neural-networks/mlp-scratch.html

関連する問題