2017-05-24 19 views
1

私は、のおおよそ次のタイプのtensorの単一の属性を取り入れたC++の新しいOpを定義しています。 OPコードの被覆のバージョンは以下の通りです:TensorFlowは "tensor"型のAttrに対してどのようなPython型を受け入れますか?

#include "tensorflow/core/framework/op.h" 
#include "tensorflow/core/framework/op_kernel.h" 

using namespace tensorflow; 

REGISTER_OP("DoStuff") 
    .Attr("attr: tensor = { dtype: DT_FLOAT }") 
    .Input("in: float") 
    .Output("out: float"); 

class DoStuffOp : public OpKernel { 
public: 
    explicit DoStuffOp(OpKernelConstruction *context) : OpKernel(context) { 
     OP_REQUIRES_OK(context, context->GetAttr("attr", &attr_)); 
     // ... 
    } 

    void Compute(OpKernelContext *context) override { 
     // ... 
    } 

private: 
    Tensor attr_; 
}; 

REGISTER_KERNEL_BUILDER(Name("DoStuff").Device(DEVICE_CPU), DoStuffOp); 

私は.soファイルの罰金にオペアンプをコンパイルすることができます。しかし、私はattrの値をうまく通す方法を理解できません。私はPythonで次のように実行すると:

import tensorflow as tf 
dostufflib = tf.load_op_library('build/do_stuff.so') 
sess = tf.InteractiveSession() 

A = [[1.0, 2.0, 3.0], 
    [4.0, 5.0, 6.0]] 
X = tf.Variable(tf.constant(1.0)) 

Y = dostufflib.do_stuff(X, A) 

を私はTypeError: Don't know how to convert [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] to a TensorProto for argument 'attr'を取得します。私は、タイプ変換を満たすようなものはありません:listnumpy配列、tf.Tensortf.VariableなどPython変数をどのようにしてテンソル属性としてOpに渡しますか?

答えて

2

さらに多くの狩りの後、tf.TensorProtoオブジェクトに、pythonスカラー、pythonリスト、numpy ndarray、またはnumpyスカラーを変換する関数tf.contrib.util.make_tensor_protoが見つかりました。次の作品:

A = tf.contrib.util.make_tensor_proto([[1.0, 2.0, 3.0],[4.0, 5.0, 6.0]]) 
X = tf.Variable(tf.constant(1.0)) 

Y = dostufflib.do_stuff(X, A) 
関連する問題