2017-08-25 11 views
0

テンソルフローの1Dテンソルを補間しようとしています(効果的にはnp.interpに相当します)。私は同様のテンソルフローを見つけることができなかったので、私は補間を自分で実行しなければなりませんでした。テンソルフローのバイナリ検索と補間

第1のステップは、y値の対応するインデックスについてx値のソートされたリストを検索すること、すなわちバイナリ検索を実行することである。私はこれのためにwhileループを使用してみましたが、私は謎めいたランタイムエラーが発生します。ここではいくつかのコードがあります:

xaxis = tf.placeholder(tf.float32, shape=100, name='xaxis') 
query = tf.placeholder(tf.float32, name='query') 

with tf.name_scope("binsearch"): 
    up = tf.Variable(0, dtype=tf.int32, name='up') 
    mid = tf.Variable(0, dtype=tf.int32, name='mid') 
    down = tf.Variable(0, dtype=tf.int32, name='down') 
    done = tf.Variable(-1, dtype=tf.int32, name='done')   

    def cond(up, down, mid, done): 
     return tf.logical_and(done<0,up-down>1) 

    def body(up, down, mid, done): 
     val = tf.gather(xaxis, mid) 
     done = tf.cond(val>query, 
         tf.cond(tf.gather(xaxis, mid-1)<query, lambda:mid-1, lambda: -1), 
         tf.cond(tf.gather(xaxis, mid+1)>query, lambda:mid, lambda: -1)) 
     up = tf.cond(val>query, lambda: mid, lambda: up) 
     down = tf.cond(val<query, lambda: mid, lambda: down) 

     with tf.control_dependencies([done, up, down]): 
      return up, down, (up+down)//2, done 

    up, down, mid, done = tf.while_loop(cond, body, (xaxis.shape[0]-1, 0, (xaxis.shape[0]-1)//2, -1)) 

これは

AttributeError: 'int' object has no attribute 'name' 

に私は、Windows 7上のPython 3.6を使用していますし、GPUのサポートと1.1をtensorflow結果。どのようなアイデアが間違っている?おかげさまで

はここで完全なスタックトレースです:

AttributeError       Traceback (most recent call last) 
<ipython-input-185-693d3873919c> in <module>() 
    19    return up, down, (up+down)//2, done 
    20 
---> 21  up, down, mid, done = tf.while_loop(cond, body, (xaxis.shape[0]-1, 0, (xaxis.shape[0]-1)//2, -1)) 

c:\program files\python36\lib\site-packages\tensorflow\python\ops\control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name) 
    2621  context = WhileContext(parallel_iterations, back_prop, swap_memory, name) 
    2622  ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, context) 
-> 2623  result = context.BuildLoop(cond, body, loop_vars, shape_invariants) 
    2624  return result 
    2625 

c:\program files\python36\lib\site-packages\tensorflow\python\ops\control_flow_ops.py in BuildLoop(self, pred, body, loop_vars, shape_invariants) 
    2454  self.Enter() 
    2455  original_body_result, exit_vars = self._BuildLoop(
-> 2456   pred, body, original_loop_vars, loop_vars, shape_invariants) 
    2457  finally: 
    2458  self.Exit() 

c:\program files\python36\lib\site-packages\tensorflow\python\ops\control_flow_ops.py in _BuildLoop(self, pred, body, original_loop_vars, loop_vars, shape_invariants) 
    2404   structure=original_loop_vars, 
    2405   flat_sequence=vars_for_body_with_tensor_arrays) 
-> 2406  body_result = body(*packed_vars_for_body) 
    2407  if not nest.is_sequence(body_result): 
    2408  body_result = [body_result] 

<ipython-input-185-693d3873919c> in body(up, down, mid, done) 
    11   val = tf.gather(xaxis, mid) 
    12   done = tf.cond(val>query, 
---> 13      tf.cond(tf.gather(xaxis, mid-1)<query, lambda:mid-1, lambda: -1), 
    14      tf.cond(tf.gather(xaxis, mid+1)>query, lambda:mid, lambda: -1)) 
    15   up = tf.cond(val>query, lambda: mid, lambda: up) 

c:\program files\python36\lib\site-packages\tensorflow\python\ops\control_flow_ops.py in cond(pred, fn1, fn2, name) 
    1746  context_f = CondContext(pred, pivot_2, branch=0) 
    1747  context_f.Enter() 
-> 1748  _, res_f = context_f.BuildCondBranch(fn2) 
    1749  context_f.ExitResult(res_f) 
    1750  context_f.Exit() 

c:\program files\python36\lib\site-packages\tensorflow\python\ops\control_flow_ops.py in BuildCondBranch(self, fn) 
    1666    real_v = sparse_tensor.SparseTensor(indices, values, dense_shape) 
    1667   else: 
-> 1668    real_v = self._ProcessOutputTensor(v) 
    1669   result.append(real_v) 
    1670  return original_r, result 

c:\program files\python36\lib\site-packages\tensorflow\python\ops\control_flow_ops.py in _ProcessOutputTensor(self, val) 
    1624  """Process an output tensor of a conditional branch.""" 
    1625  real_val = val 
-> 1626  if val.name not in self._values: 
    1627  # Handle the special case of lambda: x 
    1628  self._values.add(val.name) 

AttributeError: 'int' object has no attribute 'name' 
+1

完全なスタックトレースを含めることができますか?これはそのエラーを生成するコードですか?あなたが持っているコードを実行すると、私は 'TypeError:true_fnは呼び出し可能でなければなりません。 – user3080953

答えて

1

私はあなたのエラーの原因を知りませんが、私はtf.while_loopは非常に遅くなる可能性が非常に高いことを伝えることができます。

import numpy as np 
import tensorflow as tf 

xaxis = tf.placeholder(tf.float32, shape=100, name='xaxis') 
yaxis = tf.placeholder(tf.float32, shape=100, name='yaxis') 
query = tf.placeholder(tf.float32, name='query') 

# Add additional elements at the beginning and end for extrapolation 
xaxis_pad = tf.concat([[tf.minimum(query - 1, xaxis[0])], xaxis, [tf.maximum(query + 1, xaxis[-1])]], axis=0) 
yaxis_pad = tf.concat([yaxis[:1], yaxis, yaxis[-1:]], axis=0) 

# Find the index of the interval containing query 
cmp = tf.cast(query >= xaxis_pad, dtype=tf.int32) 
diff = cmp[1:] - cmp[:-1] 
idx = tf.argmin(diff) 

# Interpolate 
alpha = (query - xaxis_pad[idx])/(xaxis_pad[idx + 1] - xaxis_pad[idx]) 
res = alpha * yaxis_pad[idx + 1] + (1 - alpha) * yaxis_pad[idx] 

# Test with f(x) = 2 * x 
q = 5.4 
x = np.arange(100) 
y = 2 * x 
with tf.Session() as sess: 
    q_interp = sess.run(res, feed_dict={xaxis: x, yaxis: y, query: q}) 
print(q_interp) 
>>> 10.8 

パディング部分は、あなたが範囲外の値を渡す場合は、トラブルを避けるためだけであるが、それ以外の場合は、値がに開始比較して見つけるだけの問題である:あなたはこのように、ループせずに線形補間を実装することができますqueryより大きくする必要があります。

0

問題が見つかりました - tensorflowはcondのパラメータとしてpythonの整数が好きではありません。定数で最初にラップする必要があります。このコードの動作:

with tf.name_scope("binsearch"): 
    m_one = tf.constant(-1, dtype=tf.int32, name='minus_one') 
    up = tf.Variable(0, dtype=tf.int32, name='up') 
    mid = tf.Variable(0, dtype=tf.int32, name='mid') 
    down = tf.Variable(0, dtype=tf.int32, name='down') 
    done = tf.Variable(-1, dtype=tf.int32, name='done') 

    def cond(up, down, mid, done): 
     return tf.logical_and(done<0,up-down>1) 

    def body(up, down, mid, done): 

     def fn1(): 
      return mid, down, (mid+down)//2, tf.cond(tf.gather(xaxis, mid-1)<query, lambda:mid-1, lambda: m_one) 

     def fn2(): 
      return up, mid, (up+mid)//2, tf.cond(tf.gather(xaxis, mid+1)>query, lambda:mid, lambda: m_one) 

     return tf.cond(tf.gather(xaxis, mid)>query, fn1, fn2) 

    up, down, mid, done = tf.while_loop(cond, body, (xaxis.shape[0]-1, 0, (xaxis.shape[0]-1)//2, -1))