5

TensorFlowを使用していくつかの実験を実行していますが、単純なケースのtf.train.GradientDescentOptimizerから始まって、いくつかの機能がどのように行われたかを正確に見るためにいくつかの関数の実装を見てみたいと思います。 、githubのからの完全なソースコードのzipファイルをダウンロードしたソースツリーの上にいくつかの検索を実行した、になった:勾配降下のコードはどこですか?

C:\tensorflow-master\tensorflow\python\training\gradient_descent.py 

class GradientDescentOptimizer(optimizer.Optimizer): 

    def _apply_dense(self, grad, var): 
    return training_ops.apply_gradient_descent(

わかりましたので、おそらく実際のコードがapply_gradient_descentであり、そのために検索...ではないが。ソースツリー全体で3回しか出現しません。すべてが定義ではなく、使用されています。

training_ops?暗黙の名前を持つソースファイルが存在します:

C:\tensorflow-master\tensorflow\python\training\training_ops.py 

from tensorflow.python.training import gen_training_ops 
# go/tf-wildcard-import 
# pylint: disable=wildcard-import 
from tensorflow.python.training.gen_training_ops import * 
# pylint: enable=wildcard-import 

...上記は、そのファイルの内容全体です。うーん。

私は、このファイルを見つけた:

C:\tensorflow-master\tensorflow\python\BUILD 

tf_gen_op_wrapper_private_py(
    name = "training_ops_gen", 
    out = "training/gen_training_ops.py", 
) 

などと、このような他のファイルを確認するように思われたが、ビルドプロセスで生成されたオブジェクトコード、ある - しかし、どこがから生成されたソースコードはありますか?

これは私があきらめて助けを求めるポイントです。 TensorFlowコードベースに精通している人なら、関連するソースコードの場所を教えてください。

答えて

7

さらに、ネイティブのC++コードに移行します。

template <typename T> 
struct ApplyGradientDescent<GPUDevice, T> { 
    void operator()(const GPUDevice& d, typename TTypes<T>::Flat var, 
        typename TTypes<T>::ConstScalar lr, 
        typename TTypes<T>::ConstFlat grad) { 
    Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast; 
    bcast[0] = grad.dimension(0); 
    Eigen::Sizes<1> single; 
    var.device(d) -= lr.reshape(single).broadcast(bcast) * grad; 
    } 
}; 

CPUの実装があるherecore/kernels/training_ops.cc):

template <typename T> 
struct ApplyGradientDescent<CPUDevice, T> { 
    void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, 
        typename TTypes<T>::ConstScalar lr, 
        typename TTypes<T>::ConstFlat grad) { 
    var.device(d) -= grad * lr(); 
    } 
}; 
ここ ApplyGradientDescent GPUの実装( core/kernels/training_ops_gpu.cu.cc)です
関連する問題