2016-07-22 20 views
0

私はTensorflowで遊んでおり、k means clusteringアルゴリズムを実装しています。すべてうまくいきますが、listでセッションを実行したい場合は、listTensorまたはOperationに変換できないというエラーが表示されます。フェッチのリストでテンソルフローを実行できません。

documentationは、Session.run()にリストを明示的に記載しています。私は間違って何かしていますか?ここ

は、ソースコードである:ここ

import tensorflow as tf 
import numpy as np 

def tf_k_means(k, data, eps_=0.1): 
    eps = tf.constant(eps_) 

    cluster_means = tf.placeholder(tf.float32, [None, 2]) 
    tf_data = tf.placeholder(tf.float32, [None, 2], name='data') 

    model = tf.initialize_all_variables() 

    expanded_data = tf.expand_dims(tf_data, 0) 
    expanded_means = tf.expand_dims(cluster_means, 1) 
    distances = tf.reduce_sum(tf.square(tf.sub(expanded_means, expanded_data)), 2) 
    mins = tf.to_int32(tf.argmin(distances, 0)) 

    clusters = tf.dynamic_partition(tf_data, mins, k) 
    old_cluster_means = tf.identity(cluster_means) 
    new_means = tf.concat(0, [tf.expand_dims(tf.reduce_mean(cluster, 0), 0) for cluster in clusters]) 

    clusters_moved = tf.reduce_sum(tf.square(tf.sub(old_cluster_means, new_means)), 1) 
    converged = tf.reduce_all(tf.less(clusters_moved, eps)) 

    cms = data[np.random.randint(data.shape[0],size=k), :] 

    with tf.Session() as sess: 
     sess.run(model) 
     conv = False 
     while not conv: 
      ##################################### 
      # THE FOLLOWING LINE DOES NOT WORK: # 
      ##################################### 
      (cs, cms, conv) = sess.run([clusters, new_means, converged], 
             feed_dict={tf_data: data, cluster_means: cms})  

    return cs, cms 

はエラーメッセージである:

TypeError: Fetch argument [<tf.Tensor 'DynamicPartition_25:0' shape=(?, 2) dtype=float32>, 
<tf.Tensor 'DynamicPartition_25:1' shape=(?, 2) dtype=float32>, 
<tf.Tensor 'DynamicPartition_25:2' shape=(?, 2) dtype=float32>, 
<tf.Tensor 'DynamicPartition_25:3' shape=(?, 2) dtype=float32>] of 
[<tf.Tensor 'DynamicPartition_25:0' shape=(?, 2) dtype=float32>, 
<tf.Tensor 'DynamicPartition_25:1' shape=(?, 2) dtype=float32>, 
<tf.Tensor 'DynamicPartition_25:2' shape=(?, 2) dtype=float32>, 
<tf.Tensor 'DynamicPartition_25:3' shape=(?, 2) dtype=float32>] has 
invalid type <class 'list'>, must be a string or Tensor. (Can not 
convert a list into a Tensor or Operation.) 

答えて

2

tf.dynamic_partitionlist of Tensorsを返し、そうclustersリストそのものです。

clusters = tf.dynamic_partition(tf_data, mins, k) 

リストを別のリスト内のsess.runにフィードすると、問題が発生していると思います。試してみてください:

sess.run(clusters + [new_means, converged], ... 
関連する問題