PythonでカスタムTensorflowを作成し、Protobufレジストリに登録して、hereのような操作をしたいとします。 Protobufの登録は、Pythonから直接このopを使用することはないため、キーとなりますが、C++のopのように登録され、Pythonランタイム環境にロードされている場合は、自分の環境で実行できます。私はコードが何かのように見えることを期待するPythonでカスタムTensorflow Opを作成して登録する
、
import tensorflow as tf
from google.protobuf import json_format
from tensorflow.python.ops.data_flow_ops import QueueBase, _as_type_list, _as_shape_list, _as_name_list
""" Missing the Python equivalent of,
class HDF5QueueOp : public ResourceOpKernel<QueueInterface> {
public:
// Implementation
};
REGISTER_OP("HDF5Queue")
.Output("handle: resource")
.Attr("filename: string")
.Attr("datasets: list(string)")
.Attr("overwrite: bool = false")
.Attr("component_types: list(type) >= 0 = []")
.Attr("shapes: list(shape) >= 0 = []")
.Attr("shared_name: string = ''")
.Attr("container: string = ''")
.Attr("capacity: int = -1")
.SetIsStateful()
.SetShapeFn(TwoElementOutput);
"""
class HDF5Queue(QueueBase):
def __init__(self, stream_id, stream_columns, dtypes=None, capacity=100,
shapes=None, names=None, name="hdf5_queue"):
if not dtypes:
dtypes = [tf.int64, tf.float32]
if not shapes:
shapes = [[1], [1]]
dtypes = _as_type_list(dtypes)
shapes = _as_shape_list(shapes, dtypes)
names = _as_name_list(names, dtypes)
queue_ref = _op_def_lib.apply_op("HDF5Queue", stream_id=stream_id,
stream_columns=stream_columns, capacity=capacity,
component_types=dtypes, shapes=shapes,
name=name, container=None, shared_name=None)
super(HDF5Queue, self).__init__(dtypes, shapes,
names, queue_ref)
上記は、TFからかなり標準です。これは、例えばFIFOQueueで見ることができます。 Python Wrapper,Protobuf Registration,C++ Implementation。コンパイル中にコンパイル時に生成されるPythonラッパーがありますが、実行する場所がわかりますgrep -A 10 -B 10 -n FIFO $(find /usr/local -name "*gen_data_flow*.py") /dev/null
以下は、TFグラフのProtobufメッセージをJSON形式でダンプします。私はこれがHDF5Queue操作のブロックをダンプすると思いますが、これはC++操作を記述した場合と同じです。
with tf.Session() as sess:
queue = HDF5Queue(stream_id=0xa)
write = queue.enqueue([[1], [1.2]])
read = queue.dequeue()
print json_format.MessageToJson(tf.train.export_meta_graph())