2017-06-09 20 views
2

私の現在のコードを作成します。私自身のリソースタイプ(tf.resource)

// For Eigen::ThreadPoolDevice. 
#define EIGEN_USE_THREADS 1 

#include "tensorflow/core/framework/op.h" 
#include "tensorflow/core/framework/shape_inference.h" 
#include "tensorflow/core/framework/op_kernel.h" 
#include "tensorflow/core/framework/resource_mgr.h" 
#include "tensorflow/core/framework/resource_op_kernel.h" 
#include "tensorflow/core/framework/tensor.h" 
#include "tensorflow/core/framework/tensor_shape.h" 
#include "tensorflow/core/framework/types.h" 
#include "tensorflow/core/platform/macros.h" 
#include "tensorflow/core/platform/mutex.h" 
#include "tensorflow/core/platform/types.h" 

using namespace tensorflow; 

REGISTER_OP("ArrayContainerCreate") 
.Attr("T: type") 
.Attr("container: string = ''") 
.Attr("shared_name: string = ''") 
.Output("resource: resource") 
.SetIsStateful() 
.SetShapeFn(shape_inference::ScalarShape) 
.Doc(R"doc(Array container, random index access)doc"); 

REGISTER_OP("ArrayContainerGetSize") 
.Input("handle: resource") 
.Output("out: int32") 
.SetShapeFn(shape_inference::ScalarShape) 
; 

// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/resource_mgr.h 
struct ArrayContainer : public ResourceBase { 
    ArrayContainer(const DataType& dtype) : dtype_(dtype) {} 

    string DebugString() override { return "ArrayContainer"; } 
    int64 MemoryUsed() const override { return 0; }; 

    mutex mu_; 
    const DataType dtype_; 

    int32 get_size() { 
    mutex_lock l(mu_); 
    return (int32) 42; 
    } 

}; 

// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/resource_op_kernel.h 
class ArrayContainerCreateOp : public ResourceOpKernel<ArrayContainer> { 
public: 
    explicit ArrayContainerCreateOp(OpKernelConstruction* context) : ResourceOpKernel(context) { 
    OP_REQUIRES_OK(context, context->GetAttr("T", &dtype_)); 
    } 

private: 
    virtual bool IsCancellable() const { return false; } 
    virtual void Cancel() {} 

    Status CreateResource(ArrayContainer** ret) override EXCLUSIVE_LOCKS_REQUIRED(mu_) { 
    *ret = new ArrayContainer(dtype_); 
    if(*ret == nullptr) 
     return errors::ResourceExhausted("Failed to allocate"); 
    return Status::OK(); 
    } 

    Status VerifyResource(ArrayContainer* ar) override { 
    if(ar->dtype_ != dtype_) 
     return errors::InvalidArgument("Data type mismatch: expected ", DataTypeString(dtype_), 
            " but got ", DataTypeString(ar->dtype_), "."); 
    return Status::OK(); 
    } 

    DataType dtype_; 
}; 
REGISTER_KERNEL_BUILDER(Name("ArrayContainerCreate").Device(DEVICE_CPU), ArrayContainerCreateOp); 

class ArrayContainerGetSizeOp : public OpKernel { 
public: 
    using OpKernel::OpKernel; 

    void Compute(OpKernelContext* context) override { 
    ArrayContainer* ar; 
    OP_REQUIRES_OK(context, GetResourceFromContext(context, "handle", &ar)); 
    core::ScopedUnref unref(ar); 

    int32 size = ar->get_size(); 
    Tensor* tensor_size = nullptr; 
    OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), &tensor_size)); 
    tensor_size->flat<int32>().setConstant(size); 
    } 
}; 
REGISTER_KERNEL_BUILDER(Name("ArrayContainerGetSize").Device(DEVICE_CPU), ArrayContainerGetSizeOp); 

私はそれをコンパイルします。

from google.protobuf.pyext import _message as msg 
lib = msg.__file__ 

extra_compiler_flags = [ 
    "-Xlinker", "-rpath", "-Xlinker", os.path.dirname(lib), 
    "-L", os.path.dirname(lib), "-l", ":" + os.path.basename(lib)] 

私はそのhereについて読ん:私は最初のいくつかのundefined symbol: _ZN6google8protobuf8internal26fixed_address_empty_stringEエラーが発生しましたが、私はこれらの追加のコンパイラフラグを追加することで解決することに注意してください。

次に、モジュールとしてtf.load_op_library経由でロードします。

その後、私はこのPythonコードを持っている:私はsizeを評価しようとすると

handle = mod.array_container_create(T=tf.int32) 
size = mod.array_container_get_size(handle=handle) 

、私はエラーを取得する:

InvalidArgumentError (see above for traceback): Trying to access resource located in device 14ArrayContainer from device /job:localhost/replica:0/task:0/cpu:0 
     [[Node: ArrayContainerGetSize = ArrayContainerGetSize[_device="/job:localhost/replica:0/task:0/cpu:0"](array_container)]] 

は、デバイス名(14ArrayContainer)は何とか台無しにしているようです。何故ですか?コードの問題は何ですか?いくつかのより多くのテストのために

、私はArrayContainerCreateOpに、この追加のコードを追加しました:

ResourceHandle rhandle = MakeResourceHandle<ArrayContainer>(context, cinfo_.container(), cinfo_.name()); 
    printf("created. device: %s\n", rhandle.device().c_str()); 
    printf("container: %s\n", rhandle.container().c_str()); 
    printf("name: %s\n", rhandle.name().c_str()); 
    printf("actual device: %s\n", context->device()->attributes().name().c_str()); 
    printf("actual name: %s\n", cinfo_.name().c_str()); 

をこれは私に出力を提供します:

created. device: 14ArrayContainer 
container: 14ArrayContainer 
name: 14ArrayContainer 
actual device: /job:localhost/replica:0/task:0/cpu:0 
actual name: _2_array_container 

ので明確、問題のいくつかがあります。

これは、protobufで何かがうんざりしているようです。たぶん私は間違ったライブラリをリンクしていますか?しかし、私は代わりにどのlibをリンクするのか見当たりませんでした。

(私もこのhereに関する問題を掲載。)

答えて