2016-09-29 25 views
3

Caffeで最初に実装されたピクセルラベリング(FCNスタイル)のコードをTensorFlowに移植します。私はSlimのResNets(ResNet-101)の実装を16pxのストライドで使用し、それをアップコンバルションレイヤーでさらにアップサンプリングして8pxの最終的なストライドを達成しました。入力画像は任意のサイズであるため、私はbatch_size = 1を持っています。問題は、トレーニングが本当に遅いことです。オリジナルのcaffeの実装は、同じハードウェア(Tesla K40m)で30秒で実行しますが、約3.5分で100の画像を処理します。ここで私が持っているコードの縮小版だ:TF-Slimを使用した完全畳み込みResNetsが非常に遅いです。

import datetime as dt 

import tensorflow as tf 
import tensorflow.contrib.slim as slim 
from tensorflow.contrib.slim.nets import resnet_v1 

from MyDataset import MyDataset 
from TrainParams import TrainParams 

dataset = MyDataset() 
train_param = TrainParams() 

#tf.device('/gpu:0') 

num_classes = 15 

inputs = tf.placeholder(tf.float32, shape=[1, None, None, 3]) 

with slim.arg_scope(resnet_v1.resnet_arg_scope(False)): 
    mean = tf.constant([123.68, 116.779, 103.939], 
         dtype=tf.float32, shape=[1, 1, 1, 3], name='img_mean') 
    im_centered = inputs - mean 
    net, end_points = resnet_v1.resnet_v1_101(im_centered, 
               global_pool=False, output_stride=16) 

    pred_upconv = slim.conv2d_transpose(net, num_classes, 
             kernel_size = [3, 3], 
             stride = 2, 
             padding='SAME') 

    targets = tf.placeholder(tf.float32, shape=[1, None, None, num_classes]) 

    loss = slim.losses.sigmoid_cross_entropy(pred_upconv, targets) 


log_dir = 'logs/' 

variables_to_restore = slim.get_variables_to_restore(include=["resnet_v1"]) 
restorer = tf.train.Saver(variables_to_restore) 

with tf.Session() as sess: 

    sess.run(tf.initialize_all_variables()) 
    sess.run(tf.initialize_local_variables()) 

    restorer.restore(sess, '/path/to/ResNet-101.ckpt') 

    optimizer = tf.train.GradientDescentOptimizer(learning_rate=.001) 
    train_step = optimizer.minimize(loss) 
    t1 = dt.datetime.now() 
    for it in range(10000): 
     n1=dt.datetime.now() 
     batch = dataset.next_batch() # my function that prepares training batch 
     sess.run(train_step, feed_dict={inputs: batch['inputs'], 
             targets: batch['targets']}) 
     n2=dt.datetime.now() 
     time = (n2-n1).microseconds/(1000) 
     print("iteration ", it, "time", time) 

私はフレームワークを学んでいる、と私は日のカップルで一緒にこのコードを置くので、私はそれは素敵ではないかもしれない理解しています。わかりましたように、私はまた、データ準備コードとネットワークの前方 - 後方パスにかかる実際の時間を測定しようとします。そして、この時間は実際のランタイムと比較して100回の反復で合計した場合、実際にはずっと小さくなります。私はいくつかのスレッド/プロセスの同期が進行していると思われますが、それは測定されていませんが、私はそれがかなり奇妙であることがわかります。 topコマンドは、おそらくそれによって生成された主なものと同じタイトルの約10のプロセスを表示します。

私はこれをどのようにスピードアップすることができますか?

ありがとうございました。

更新。より多くの研究の結果、私は「フィード」データがキューに比べて遅くなることが分かったので、別のスレッドでキューにコードを再実装しました:https://gist.github.com/eldar/0ecc058670be340b92e5a1044dc8a089しかし、ランタイムはほぼ同じです。

UPDATE2。速度に関する問題が何であるかを考えたようです。私は完全畳み込みで練習し、私のイメージは任意のサイズとアスペクト比です。固定サイズのダミーのランダムな数のテンソルを送ると、速く動作します。 10個の定義済みサイズの入力テンソルを生成する場合、最初の10回の反復は遅くなりますが、その後は高速化されます。 TensorFlowのように見えますが、各繰り返しですべてのテンソルのサイズ変更はCaffeほど効率的ではありません。プロジェクトのGitHubでチケットを発行します。

+0

はそれがおかしくなり、巨大なモデルだということを覚えておいてください

ソリューションは、環境変数TF_CUDNN_USE_AUTOTUNE=0を設定しました。 resnet_v1_101の "101"は、それが101コンバートのレイヤーであることに由来します。 – Julius

+0

あなたが得るものが期待されるかどうかわからないのに – Julius

+0

afaik彼らはそれを訓練するためにいくつかの異なるマシンを使用します – Julius

答えて

2

問題は、任意のサイズの入力画像によるものです。 TensorFlowにはオートチューニングと呼ばれるものがあります。実行時には、それぞれの特定の入力サイズに対してさまざまなアルゴリズムをプロファイリングし、最適なものを決定します。私の場合、これは各反復で起こっていました。このGithubのチケットで

export TF_CUDNN_USE_AUTOTUNE=0 
python myscript.py 

より:https://github.com/tensorflow/tensorflow/issues/5048

+0

リンクされた問題のコード:https://gist.github.com/eldar/0ecc058670be340b92e5a1044dc8a089 –

2

一般に、TensorFlow resnetの実装は、caffeよりもはるかに遅くすべきではありません。私はちょうどcaffe/barrista(https://github.com/classner/barrista/tree/master/examples/residual-nets)の実装とTensorflow(https://github.com/tensorflow/models/tree/master/resnet)の例を比較しましたが、それらは同じ速度で完全なトレーニングに比べて無視できるほどの違いがあります。

私はこのページに私を連れて来たTensorflowの実装にまず問題がありました。理由は、私が作成したgithubのバージョンは安定しておらず、中間開発コードが非常に遅いためでした。 A git pullと再コンパイルして問題を解決しました。

ただし、自分で再実装する場合は、BatchNormのアップデート操作がどのようにトリガーされるかに注意してください。テンソルフローの例では、これはresnet_model.py、lで行われます。それらはrunオペレーションの 'フェッチ'に直接追加され、したがって、並行して実行されます。

+0

あなたの返信ありがとう! ResNetのサポートは安定版ではなかったので、夜間のビルドも使用しました。どのバージョンを使用しましたか?また、独自のデータセットを使用しましたが、この場合、どのようにデータをロードしましたか?データローディングコードが最適でない可能性があります。 – SimpleMan

+0

私は0.11.0rc0バージョンに更新しました。同時に10個の他のpythonプロセスが同時に実行されていることがわかりましたが、これは良い兆候ですが、まだ遅いです。 – SimpleMan

関連する問題