2017-10-14 3 views
0

nすべての入力に対してプレースホルダを持つネットワークがあり、それらをすべて共通の入力として別のプレースホルダ(後で作成)にリンクしたい。プレースホルダをテンソルフローの共通入力にグループ化する

class GroupOfNetworks(object): 
    def __init__(self,subtask_nets,ob_space): 
     self.x_inputs = [st_net.x for st_net in subtask_nets] #list of network inputs 

ここで、st_net.xは、次のように宣言されたプレースホルダです。

class Network(object): 
    def __init__(self, ob_space): 
      self.x = tf.placeholder(tf.float32, [None] + list(ob_space)) `#single network input 

私は唯一の私feed_dictに1つのキーと値のペアを持っている必要がありますので、これらのネットワークのすべてに単一の共通の入力を持っているしたいと思います。私はプレースホルダ(以下のコードスニペット)に代入演算を試みましたが、テンソルで変数ではないので、これはエラーをスローします。今まで

#in class GroupOfNetworks... 
common_x = tf.placeholder(tf.float32, [None] + list(ob_space),"common_input") 
set_input = tf.assign(self.x_inputs[0].x,common_x,"link_subtask_input") # DOES NOT WORK 

私は、プログラム生成feed_dictを使用した(下記参照)が、これはグラフ上ではなく、.metaファイルからグラフのロード中にインポートすることができません。

def make_common_feed_dict(self,x): 
    return {placeholder:x for placeholder in self.x_inputs} 

もっと良い解決方法を知っている人はいますか?

答えて

1

ネットワーク内のすべてのネットワークに対して1つのプレースホルダ(したがって同じ入力)が必要なので、どこでも同じプレースホルダを使用してください。

オブジェクト内にプレースホルダを作成する代わりに、__init__メソッドを作成し、作成したすべてのオブジェクトに渡します。 はこのような何かを:

# Define your network in this way 
class Network(object): 
    def __init__(self, placeholder): 
      self.x = placeholder 

を、その後、Networkオブジェクトは、プレースホルダを定義する初期化する前に、その後、Network obejectsが出力テンソルをフェッチするget方法を得たことを想定し、それに

input_placeholder = tf.placeholder(tf.float32, [None] + list(ob_space)) 

network_a = Network(input_placeholder) 
network_b = Network(input_placeholder) 

それらを使用します同じ値でそれらを与えるnetwork_anetwork_bを実行することができます:

+0

はい、これは、 'Network'オブジェクトを単独で実行する必要がなく、変更することができれば機能します。 'Network'オブジェクトを変更することができず、プレースホルダへの参照のみを使用するケースを処理する方法があれば、あなたは気付いていますか? – yokian

関連する問題