私の現在のコードを作成します。私自身のリソースタイプ(tf.resource)
// For Eigen::ThreadPoolDevice.
#define EIGEN_USE_THREADS 1
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
using namespace tensorflow;
REGISTER_OP("ArrayContainerCreate")
.Attr("T: type")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.Output("resource: resource")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(Array container, random index access)doc");
REGISTER_OP("ArrayContainerGetSize")
.Input("handle: resource")
.Output("out: int32")
.SetShapeFn(shape_inference::ScalarShape)
;
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/resource_mgr.h
struct ArrayContainer : public ResourceBase {
ArrayContainer(const DataType& dtype) : dtype_(dtype) {}
string DebugString() override { return "ArrayContainer"; }
int64 MemoryUsed() const override { return 0; };
mutex mu_;
const DataType dtype_;
int32 get_size() {
mutex_lock l(mu_);
return (int32) 42;
}
};
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/resource_op_kernel.h
class ArrayContainerCreateOp : public ResourceOpKernel<ArrayContainer> {
public:
explicit ArrayContainerCreateOp(OpKernelConstruction* context) : ResourceOpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("T", &dtype_));
}
private:
virtual bool IsCancellable() const { return false; }
virtual void Cancel() {}
Status CreateResource(ArrayContainer** ret) override EXCLUSIVE_LOCKS_REQUIRED(mu_) {
*ret = new ArrayContainer(dtype_);
if(*ret == nullptr)
return errors::ResourceExhausted("Failed to allocate");
return Status::OK();
}
Status VerifyResource(ArrayContainer* ar) override {
if(ar->dtype_ != dtype_)
return errors::InvalidArgument("Data type mismatch: expected ", DataTypeString(dtype_),
" but got ", DataTypeString(ar->dtype_), ".");
return Status::OK();
}
DataType dtype_;
};
REGISTER_KERNEL_BUILDER(Name("ArrayContainerCreate").Device(DEVICE_CPU), ArrayContainerCreateOp);
class ArrayContainerGetSizeOp : public OpKernel {
public:
using OpKernel::OpKernel;
void Compute(OpKernelContext* context) override {
ArrayContainer* ar;
OP_REQUIRES_OK(context, GetResourceFromContext(context, "handle", &ar));
core::ScopedUnref unref(ar);
int32 size = ar->get_size();
Tensor* tensor_size = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), &tensor_size));
tensor_size->flat<int32>().setConstant(size);
}
};
REGISTER_KERNEL_BUILDER(Name("ArrayContainerGetSize").Device(DEVICE_CPU), ArrayContainerGetSizeOp);
私はそれをコンパイルします。
from google.protobuf.pyext import _message as msg
lib = msg.__file__
extra_compiler_flags = [
"-Xlinker", "-rpath", "-Xlinker", os.path.dirname(lib),
"-L", os.path.dirname(lib), "-l", ":" + os.path.basename(lib)]
私はそのhereについて読ん:私は最初のいくつかのundefined symbol: _ZN6google8protobuf8internal26fixed_address_empty_stringE
エラーが発生しましたが、私はこれらの追加のコンパイラフラグを追加することで解決することに注意してください。
次に、モジュールとしてtf.load_op_library
経由でロードします。
その後、私はこのPythonコードを持っている:私はsize
を評価しようとすると
handle = mod.array_container_create(T=tf.int32)
size = mod.array_container_get_size(handle=handle)
、私はエラーを取得する:
InvalidArgumentError (see above for traceback): Trying to access resource located in device 14ArrayContainer from device /job:localhost/replica:0/task:0/cpu:0
[[Node: ArrayContainerGetSize = ArrayContainerGetSize[_device="/job:localhost/replica:0/task:0/cpu:0"](array_container)]]
は、デバイス名(14ArrayContainer
)は何とか台無しにしているようです。何故ですか?コードの問題は何ですか?いくつかのより多くのテストのために
、私はArrayContainerCreateOp
に、この追加のコードを追加しました:
ResourceHandle rhandle = MakeResourceHandle<ArrayContainer>(context, cinfo_.container(), cinfo_.name());
printf("created. device: %s\n", rhandle.device().c_str());
printf("container: %s\n", rhandle.container().c_str());
printf("name: %s\n", rhandle.name().c_str());
printf("actual device: %s\n", context->device()->attributes().name().c_str());
printf("actual name: %s\n", cinfo_.name().c_str());
をこれは私に出力を提供します:
created. device: 14ArrayContainer
container: 14ArrayContainer
name: 14ArrayContainer
actual device: /job:localhost/replica:0/task:0/cpu:0
actual name: _2_array_container
ので明確、問題のいくつかがあります。
これは、protobufで何かがうんざりしているようです。たぶん私は間違ったライブラリをリンクしていますか?しかし、私は代わりにどのlibをリンクするのか見当たりませんでした。
(私もこのhereに関する問題を掲載。)