2016-04-03 17 views
5

私はテンソルBxWxHxDを持っていると言います。私はテンソルを処理して、各WxHスライスの最大要素のみが保持され、他のすべての値がゼロである新しいBxWxHxDテンソルを持つようにしたい。Tensorflow多次元argmax

つまり、これを達成する最善の方法は、何とかWxHスライスを横切る2D argmaxを取ることで、行と列のBxDインデックステンソルが1ホットBxWxHxDテンソルに変換できることですマスクとして使用する。この仕事をするにはどうすればいいですか?

答えて

1

以下の機能を出発点として使用できます。各バッチおよび各チャネルの最大要素のインデックスを計算します。結果の配列は、フォーマット(バッチサイズ、2、チャンネル数)の形式です。

def argmax_2d(tensor): 

    # input format: BxHxWxD 
    assert rank(tensor) == 4 

    # flatten the Tensor along the height and width axes 
    flat_tensor = tf.reshape(tensor, (tf.shape(tensor)[0], -1, tf.shape(tensor)[3])) 

    # argmax of the flat tensor 
    argmax = tf.cast(tf.argmax(flat_tensor, axis=1), tf.int32) 

    # convert indexes into 2D coordinates 
    argmax_x = argmax // tf.shape(tensor)[2] 
    argmax_y = argmax % tf.shape(tensor)[2] 

    # stack and return 2D coordinates 
    return tf.stack((argmax_x, argmax_y), axis=1) 

def rank(tensor): 

    # return the rank of a Tensor 
    return len(tensor.get_shape())