2017-01-26 12 views
1

私はちょうどcntkを学び始めています。しかし、私は基本的な質問を持っています。なぜ矛盾した図形がnumpyかcntkですか?

import numpy as np 
from cntk import input_variable, plus 

def test_simple(self): 

    x_input = np.asarray([[1, 2, 2]], dtype=np.int64) 
    assert (1, 3) == x_input.shape 

    y_input = np.asarray([[5, 3, 3]], dtype=np.int64) 
    assert (1, 3) == y_input.shape 

    x = input_variable(x_input.shape[1]) 
    assert (3,) == x.shape 

    y = input_variable(y_input.shape[1]) 
    assert (3,) == y.shape 

    x_plus_y = plus(x, y) 
    assert (3,) == x_plus_y.shape 

    res = x_plus_y.eval({x: x_input, y: y_input}) 

    assert 6 == res[0, 0, 0] 
    assert 5 == res[0, 0, 1] 
    assert 5 == res[0, 0, 2] 

をIは、第1および第2の軸は、それぞれのバッチとデフォルト動的軸であるように、出力の形状が(1、1、3)であることを理解する:私は通過し、次の試験を持っています。

しかし、なぜ入力変数の形状を(1,3)ではなく(3)として設定する必要がありますか。 (1,3)を使用すると失敗します。

グラフの入力ノードの形状とそのノードへの入力として使用されるnumpyデータの間に矛盾があるのはなぜですか?

は、 水田

答えて

2

これをありがとうFunction.forwardための「引数」の説明に少し説明されています。もう1つの説明はhereです。あなたの混乱の理由はおそらく、CNTKがいくつかの「役に立つ」変換をしているからです。

入力を(1,3)として指定する場合、配列軸のないミニバッチの場合は(1,3)配列のリストを指定する必要がありますシーケンス軸を持つminibatchの場合(xはミニバッチ内の各シーケンスごとに異なる可能性があります)。同様に、入力を(3、)として指定する場合は、(3、)ベクトルまたは(x、3)ベクトルのリストのいずれかを指定する必要があります。

おそらく、リストが提供されていない場合に混乱が生じる可能性があります。その場合、CNTKは、与えられたテンソルの先頭の軸を反復し、これらの要素のリストを作成する。 (5,1,3)テンソルは(1,3)の形をした5つの要素のバッチになります。

関連する問題