2017-03-23 5 views
3

チュートリアルを使用してカスタムTensorflowを作成していますが、Tensorsの読み書き方法の理解に問題があります。Tensorflowカスタムオペレーション - Tensorsからの読み書きはどのようにしますか?

方法Iインデックスにすることができます、のはそのテンソルの形状、たとえば、[2、10、20]を持っている場合、私は、私は const Tensor& values_tensor = context->input(0);(ここで、コンテキスト= OpKernelConstruction*

から取得する私のOpKernelでテンソルを持っているとしましょうそれは(例えば、auto x = values_tensor[1, 4, 12]など)?同等

、私はoutput_tensor[1, 2, 3] = 11、などのように、output_tensorに割り当てることができる方法

Tensor *output_tensor = NULL; 
OP_REQUIRES_OK(context, context->allocate_output(
    0, 
    {batch_size, value_len - window_size, window_size}, 
    &output_tensor 
)); 

を持っている場合は?

:)

ばかな質問して申し訳ありませんが、ドキュメントは本当にここに私をトリップされており、内蔵のOPS用Tensorflowカーネルコードの例では、何らかの形で私は非常に混乱してしまいポイントにこれを難読化感謝君は!

答えて

1

tensorflow::Tensorオブジェクトを読み書きする最も簡単な方法は、tensorflow::Tensor::tensor<T, NDIMS>()メソッドを使用してEigen tensorに変換することです。テンプレートパラメータTとしてテンソルの要素の(C++)タイプを指定する必要があることに注意してください。例えば

DT_FLOAT32テンソルから特定の値を読むために:DT_FLOAT32テンソルに特定の値を書き込むには

const Tensor& values_tensor = context->input(0); 
auto x = value_tensor.tensor<float, 3>()(1, 4, 12); 

Tensor* output_tensor = ...; 
output_tensor->tensor<float, 3>()(1, 2, 3) = 11.0; 

scalarにアクセスするための便利な方法もあります。 、vector、またはmatrixである。

関連する問題