2017-10-03 1 views
3

オブジェクト検出のためのオブジェクト座標を訓練するコードがあります。 CNNネットワークを使用しました。出力レイヤーは、イメージ内のオブジェクトの(x0、y0、height、width)を返す回帰レイヤー(bound_box_output)です。この層の後、私は損失のステップの前に直接画像を保存しようとします。python-tf.write_fileはテンソルフローでは機能しません

i = 0 
    image_decoded = tf.image.decode_jpeg(tf.read_file('3.jpg'), channels=3) 
    cropped = tf.image.crop_to_bounding_box(image = image_decoded, 
              offset_height = tf.cast(bound_box_output[i,0], tf.int32), 
              offset_width = tf.cast(bound_box_output[i,1], tf.int32), 
              target_height = tf.cast(bound_box_output[i,2], tf.int32), 
              target_width = tf.cast(bound_box_output[i,3], tf.int32)) 

    enc = tf.image.encode_jpeg(cropped) 
    fname = tf.constant('4.jpeg') 
    fwrite = tf.write_file(fname, enc) 

でtf.train.SessionRunHook私はこの問題は、それが特定のフォルダ

に '4.jpeg' 画像を保存していないということです

def begin(self): 
     self._step = -1 
     self._start_time = time.time() 

def before_run(self, run_context): 
     self._step += 1 
     return tf.train.SessionRunArgs(loss) 
def after_run(self, run_context, run_values): 
     if self._step % LOG_FREQUENCY == 0: 
      current_time = time.time() 
      duration = current_time - self._start_time 
      self._start_time = current_time 

      loss_value = run_values.results 
      examples_per_sec = LOG_FREQUENCY * BATCH_SIZE/duration 
      sec_per_batch = float(duration/LOG_FREQUENCY) 


      format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 
         'sec/batch)') 
      print (format_str % (datetime.now(), self._step, loss_value, 
           examples_per_sec, sec_per_batch)) 

     if self._step == MAX_STEPS-1: 
      loss_value = run_values.results 
      print("The final value of loss is:: ") 
      print(loss_value) 
      print(fwrite) 
      tf.train.SessionRunArgs(fwrite) 

それを実行します

注:私はテンソルフローを使用します。1.1.3python3.5

+1

最後の印刷文を印刷しましたか(「損失の最終価値は::」など)? –

+0

はい、それは練習の最後のステップです。 – CCCC

+1

それは知っておいてよかったです。コードが正しいと思われるので、何が間違っているのか分かりません。明示的に出力ファイルのパスを記述しようとしましたか? (例えば、 '/ tmp/4.jpeg'を' fname'として使用しています。ファイルの作者が実際に何かを保存していることを確認するだけです) –

答えて

2

TLDR; tf.train.SessionRunArgs(fwrite)run_context.session.run(fwrite)と置き換えてください。

SessionRunArgsは、実際には提供された操作を実行しません。 SessionRunArgsbefore_run()コールから返されます。彼らの役割は、次のsession.run()コールに引数を追加することです。

if self._step == MAX_STEPS-1: 
    loss_value = run_values.results 
    print("The final value of loss is:: ") 
    print(loss_value) 
    print(fwrite) 
    tf.train.SessionRunArgs(fwrite) # problematic line 

あなたはafter_run()の終わりにfwrite操作を実行しようとしています。ただし、単にオブジェクトSessionRunArgsをインスタンス化します。

run_context引数がafter_run()に供給されていることを利用して、望ましい動作を実現することができます。 run_contextのタイプはSessionRunContextで、タイプはsessionです。

run_context.session.run(fwrite)はあなたのためのトリックを行う必要があります。

+0

これは今、 – CCCC

関連する問題