2016-05-18 14 views
1

テンソルフローCIFAR-10 tutorialcifar10_inputs.py行174では、より良いデータ増強のために操作random_contrastとrandom_brightnessの順序をランダム化する必要があると言われています。TensorFlowグラフ内の条件を使用する場合

これを行うには、私が考える最初のことは、0と1の間の一様分布:p_orderから確率変数を引くことです。そして、実行します。

if p_order>0.5: 
    distorted_image=tf.image.random_contrast(image) 
    distorted_image=tf.image.random_brightness(distorted_image) 
else: 
    distorted_image=tf.image.random_brightness(image) 
    distorted_image=tf.image.random_contrast(distorted_image) 

しかしp_orderを取得するための2つの可能なオプションがあります:

1)私は、純粋なTFを望んでいたし、そのTFはnumpyのとtensorflow

を混合するために、そのユーザーを阻止するように私をdisatisfies numpyのを使用してTFを使用して

2)が、しかしp_orderように私が何をすべき場合、私は本当に知らないtf.Session() でのみ評価することができます。

with tf.Session() as sess2: 
    p_order_tensor=tf.random_uniform([1,],0.,1.) 
    p_order=float(p_order_tensor.eval()) 

これらの操作はすべて関数本体の内部にあり、別のセッション/グラフを持つ別のスクリプトから実行されます。あるいは、他のスクリプトのグラフをこの関数の引数として渡すこともできますが、私は混乱しています。 たとえテンソルがこのような関数や推論のように機能しているとしても、グラフがグローバルな方法で定義されているように見えますが、明示的に出力する必要はありません。

答えて

10

tf.cond(pred, fn1, fn2, name=None)see doc)を使用できます。 この関数を使用すると、predのブール値をTensorFlowグラフ内で使用できます(self.eval()またはsess.run()を呼び出す必要がないため、のは不要です)。

は、ここでそれを使用する方法の例です:

def fn1(): 
    distorted_image=tf.image.random_contrast(image) 
    distorted_image=tf.image.random_brightness(distorted_image) 
    return distorted_image 
def fn2(): 
    distorted_image=tf.image.random_brightness(image) 
    distorted_image=tf.image.random_contrast(distorted_image) 
    return distorted_image 

# Uniform variable in [0,1) 
p_order = tf.random_uniform(shape=[], minval=0., maxval=1., dtype=tf.float32) 
pred = tf.less(p_order, 0.5) 

distorted_image = tf.cond(pred, fn1, fn2) 
私が探していただけで何
+0

ありがとう! – jean

+0

喜び! 「TensorFlowグラフ内の条件を使用する」を使用して、質問のタイトルをより明確に変更します。 –

関連する問題