2017-05-01 6 views
1

PythonでカスタムTensorflowを作成し、Protobufレジストリに登録して、hereのような操作をしたいとします。 Protobufの登録は、Pythonから直接このopを使用することはないため、キーとなりますが、C++のopのように登録され、Pythonランタイム環境にロードされている場合は、自分の環境で実行できます。私はコードが何かのように見えることを期待するPythonでカスタムTensorflow Opを作成して登録する

import tensorflow as tf 
from google.protobuf import json_format 
from tensorflow.python.ops.data_flow_ops import QueueBase, _as_type_list, _as_shape_list, _as_name_list 

""" Missing the Python equivalent of,                                           

    class HDF5QueueOp : public ResourceOpKernel<QueueInterface> {                                    
    public:                                                  
     // Implementation                                              
    };                                                   

    REGISTER_OP("HDF5Queue")                                             
    .Output("handle: resource")                                             
    .Attr("filename: string")                                             
    .Attr("datasets: list(string)")                                            
    .Attr("overwrite: bool = false")                                           
    .Attr("component_types: list(type) >= 0 = []")                                        
    .Attr("shapes: list(shape) >= 0 = []")                                          
    .Attr("shared_name: string = ''")                                           
    .Attr("container: string = ''")                                            
    .Attr("capacity: int = -1")                                             
    .SetIsStateful()                                               
    .SetShapeFn(TwoElementOutput);                                            

""" 

class HDF5Queue(QueueBase): 
    def __init__(self, stream_id, stream_columns, dtypes=None, capacity=100, 
       shapes=None, names=None, name="hdf5_queue"): 
    if not dtypes: 
     dtypes = [tf.int64, tf.float32] 

    if not shapes: 
     shapes = [[1], [1]] 

    dtypes = _as_type_list(dtypes) 
    shapes = _as_shape_list(shapes, dtypes) 
    names = _as_name_list(names, dtypes) 
    queue_ref = _op_def_lib.apply_op("HDF5Queue", stream_id=stream_id, 
            stream_columns=stream_columns, capacity=capacity, 
            component_types=dtypes, shapes=shapes, 
            name=name, container=None, shared_name=None) 
    super(HDF5Queue, self).__init__(dtypes, shapes, 
            names, queue_ref) 

上記は、TFからかなり標準です。これは、例えばFIFOQueueで見ることができます。 Python Wrapper,Protobuf Registration,C++ Implementation。コンパイル中にコンパイル時に生成されるPythonラッパーがありますが、実行する場所がわかりますgrep -A 10 -B 10 -n FIFO $(find /usr/local -name "*gen_data_flow*.py") /dev/null

以下は、TFグラフのProtobufメッセージをJSON形式でダンプします。私はこれがHDF5Queue操作のブロックをダンプすると思いますが、これはC++操作を記述した場合と同じです。

with tf.Session() as sess: 
    queue = HDF5Queue(stream_id=0xa) 
    write = queue.enqueue([[1], [1.2]]) 
    read = queue.dequeue() 
    print json_format.MessageToJson(tf.train.export_meta_graph()) 

答えて

0

これは、py_funcを使用して並べ替えることができます。ここに例があります。

import tensorflow as tf 
from google.protobuf import json_format 
import sys, json, base64, numpy 
from tensorflow.python.ops.script_ops import _py_funcs as py_func_registry 
from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef 

graph = tf.Graph() 
graph2 = tf.Graph() 

def f(x): 
    return x 

def g(x): 
    return 2*x 

with graph.as_default(): 
    x = tf.placeholder(tf.float32, shape=(3,), name='x') 
    y = tf.py_func(f, [x], tf.float32, name='y') 

    # py_func_registry._funcs.clear() # Optional line to clear the Python function registry 
    msg = json.loads(json_format.MessageToJson(tf.train.export_meta_graph())) 

# Change the function being used by py_func 
msg['graphDef']['node'][1]['attr']['token']['s'] = base64.b64encode(py_func_registry.insert(g)) 

with graph2.as_default():  
    # Load graph 
    meta_graph_def = MetaGraphDef() 
    json_format.Parse(json.dumps(msg), meta_graph_def) 
    tf.train.import_meta_graph(meta_graph_def) 

    sess = tf.Session(graph=graph2) 
    print sess.run('y:0', feed_dict={'x:0':numpy.array([1, 2, 3])}) 
    print g(numpy.array([1, 2, 3])) 
関連する問題