2016-02-02 13 views

答えて

9

Theanoのdimshuffle TensorFlow中を実装するための3つの関連するOPSがあります。

  • tf.transpose()はテンソルの次元を置換するために使用されています。 dimshuffleの引数で指定されたパターンが入力テンソルの次元の順列である場合(すなわち、'x'または次元がない場合)tf.transpose()を使用してdimshuffle()を実装できます。

  • tf.expand_dims()は、テンソルに1つ以上のサイズ-1の次元を追加するために使用されます。これは、'x'dimshuffle()パターンの一部として指定されているが、既存の寸法の順序を変更しない場合を処理します。

  • tf.squeeze()は、テンソルから1つ以上のサイズ-1の次元を削除するために使用されます。これは、パターンがdimshuffle()パターンから省略されているが、既存の次元の順序を変更しない場合を処理します。

あなたの例は、(dimshuffle(0, 'x')が)のみtf.expand_dims()を用いて表現することができ、入力がベクトルであると仮定すると:

input = tf.placeholder(tf.float32, [None]) # Defines an arbitrary-sized vector. 
result = tf.expand_dims(input, 1) 

print result.get_shape() # ==> TensorShape([Dimension(None), Dimension(1)]) 

は、より複雑な例をとると、マトリックスに適用dimshuffle(1, 'x', 0)は次のようになります

input = tf.placeholder(tf.float32, [128, 32]) # Defines a matrix. 
output = tf.expand_dims(tf.transpose(input, [1, 0]), 1) 

print output.get_shape() 
# ==> TensorShape([Dimension(32), Dimension(1), Dimension(128)]) 
0

TensorFlowのdimshuffleour framework Returnn()に実装しました。)。コードはこれです:あなたのバックエンドtensorflowある場合

def expand_multiple_dims(x, axes, name="expand_multiple_dims"): 
    """ 
    :param tf.Tensor x: 
    :param list[int]|tuple[int] axes: after completion, tf.shape(y)[axis] == 1 for axis in axes 
    :param str name: scope name 
    :return: y where we have a new broadcast axis for each axis in axes 
    :rtype: tf.Tensor 
    """ 
    with tf.name_scope(name): 
    for i in sorted(axes): 
     x = tf.expand_dims(x, axis=i, name="expand_axis_%i" % i) 
    return x 


def dimshuffle(x, axes, name="dimshuffle"): 
    """ 
    Like Theanos dimshuffle. 
    Combines tf.transpose, tf.expand_dims and tf.squeeze. 

    :param tf.Tensor x: 
    :param list[int|str]|tuple[int|str] axes: 
    :param str name: scope name 
    :rtype: tf.Tensor 
    """ 
    with tf.name_scope(name): 
    assert all([i == "x" or isinstance(i, int) for i in axes]) 
    real_axes = [i for i in axes if isinstance(i, int)] 
    bc_axes = [i for (i, j) in enumerate(axes) if j == "x"] 
    if x.get_shape().ndims is None: 
     x_shape = tf.shape(x) 
     x = tf.reshape(x, [x_shape[i] for i in range(max(real_axes) + 1)]) # will have static ndims 
    assert x.get_shape().ndims is not None 

    # First squeeze missing axes. 
    i = 0 
    while i < x.get_shape().ndims: 
     if i not in real_axes: 
     x = tf.squeeze(x, axis=i) 
     real_axes = [(j if (j < i) else (j - 1)) for j in real_axes] 
     else: 
     i += 1 

    # Now permute. 
    assert list(sorted(real_axes)) == list(range(x.get_shape().ndims)) 
    if real_axes != list(range(x.get_shape().ndims)): 
     x = tf.transpose(x, real_axes) 

    # Now add broadcast dimensions. 
    if bc_axes: 
     x = expand_multiple_dims(x, bc_axes) 
    assert len(axes) == x.get_shape().ndims 
    return x 
0

from keras import baskend as K 
K.permute_dimension should do 
関連する問題