3

私はthe Dataset API in Tensorflow v1.3で遊んでいます。それは素晴らしい。 hereのように関数を使用してデータセットをマップすることは可能です。私はたとえばarg1のために、追加の引数を持つ関数を渡すことができる方法を知って興味があります:もちろんTF1.3の新しいDataset APIを使用して、追加のパラメータで関数をどのようにマップするのですか?

def _parse_function(example_proto, arg1): 
    features = {"image": tf.FixedLenFeature((), tf.string, default_value=""), 
       "label": tf.FixedLenFeature((), tf.int32, default_value=0)} 
    parsed_features = tf.parse_single_example(example_proto, features) 
    return parsed_features["image"], parsed_features["label"] 

arg1に合格する方法がないので、

dataset = dataset.map(_parse_function) 

が動作しません。 。ここで

+0

ジャストアイデア、多分私たちすることができますに渡すことによって、これを偽Pythonクラスのクラスメンバとしてarg1を定義し、__call__メソッドを定義しています。 –

+0

'arg1'はどんな種類の引数ですか?通常のPython変数(TensorFlowではなく)であれば、 'arg1'が知られている別の関数内で' _parse_function'関数を定義するだけで、もうそれを渡す必要はありません。 – CNugteren

答えて

2

たちは引数を渡したい先の機能ラップするラムダ式を使用した例である:それは動作することを確認するには

import tensorflow as tf 
def fun(x, arg): 
    return x * arg 

my_arg = tf.constant(2, dtype=tf.int64) 
ds = tf.data.Dataset.range(5) 
ds = ds.map(lambda x: fun(x, my_arg)) 

を、私たちは、マッピングが実際にして、各データセット要素を乗算することを観察することができます2:

iterator = ds.make_initializable_iterator() 
next_x = iterator.get_next() 
with tf.Session() as sess: 
    sess.run(iterator.initializer) 

    while True: 
     try: 
     print(sess.run(next_x)) 
     except tf.errors.OutOfRangeError: 
     break 

出力:

0 
2 
4 
6 
8 
関連する問題