2017-07-13 3 views
4

最も近い例は、この問題にあります。この最小再現可能なコードでhttps://github.com/tensorflow/tensorflow/issues/899TensorFlow:モデルのFLOPSを測定する方法はありますか?私は得ることができます

import tensorflow as tf 
import tensorflow.python.framework.ops as ops 
g = tf.Graph() 
with g.as_default(): 
    A = tf.Variable(tf.random_normal([25,16])) 
    B = tf.Variable(tf.random_normal([16,9])) 
    C = tf.matmul(A,B) # shape=[25,9] 
for op in g.get_operations(): 
    flops = ops.get_stats_for_node_def(g, op.node_def, 'flops').value 
    if flops is not None: 
    print 'Flops should be ~',2*25*16*9 
    print '25 x 25 x 9 would be',2*25*25*9 # ignores internal dim, repeats first 
    print 'TF stats gives',flops 

しかし、FLOPSが返さ常にNoneです。特にPBファイルでFLOPSを具体的に測定する方法はありますか?

答えて

4

少し遅れていますが、将来的には一部の訪問者を助けるかもしれません。あなたの例のために私は正常に次のコードをテストした:

g = tf.Graph() 
run_meta = tf.RunMetadata() 
with g.as_default(): 
    A = tf.Variable(tf.random_normal([25,16])) 
    B = tf.Variable(tf.random_normal([16,9])) 
    C = tf.matmul(A,B) # shape=[25,9] 

    opts = tf.profiler.ProfileOptionBuilder.float_operation()  
    flops = tf.profiler.profile(g, run_meta=run_meta, cmd='op', options=opts) 
    if flops is not None: 
     print('Flops should be ~',2*25*16*9) 
     print('25 x 25 x 9 would be',2*25*25*9) # ignores internal dim, repeats first 
     print('TF stats gives',flops.total_float_ops) 

それは、次のスニペットのようなKerasとの組み合わせでプロファイラを使用することも可能です:

import tensorflow as tf 
import keras.backend as K 
from keras.applications.mobilenet import MobileNet 

run_meta = tf.RunMetadata() 
with tf.Session(graph=tf.Graph()) as sess: 
    K.set_session(sess) 
    net = MobileNet(alpha=.75, input_tensor=tf.placeholder('float32', shape=(1,32,32,3))) 

    opts = tf.profiler.ProfileOptionBuilder.float_operation()  
    flops = tf.profiler.profile(sess.graph, run_meta=run_meta, cmd='op', options=opts) 

    opts = tf.profiler.ProfileOptionBuilder.trainable_variables_parameter()  
    params = tf.profiler.profile(sess.graph, run_meta=run_meta, cmd='op', options=opts) 

    print("{:,} --- {:,}".format(flops.total_float_ops, params.total_parameters)) 

私が助けることができる願っています!

関連する問題