2017-08-20 4 views
0

私はテンソルフローの使い方が新しく、テンソルフローのホットエンコーディングに関するいくつかの質問があります。Tensorflow onehot encode

最後の列がラベルを表すCVSファイルを読みたいとします。ラベルは1から7の整数値です。

softmaxモデルを使用して分類したいですか?

したがって、私のラベルはonehotテンソル形式にする必要がありますか?

私のラベルをホットに変換するための簡単な方法/テンソルフローが組み込まれていますか?

ファイル解析のチュートリアルによると、私はcsvを解析するために以下のコードを持っていますが、1つのホットエンコーディングの一部がありません。

def read_from_cvs(filename_queue): 
    reader = tf.TextLineReader() 
    key, value = reader.read(filename_queue) 
    record_defaults = [[] for col in range((num_attributes))] # no defaults, all values must be given 
    attributes = tf.decode_csv(value, record_defaults=record_defaults) 
    features = tf.stack(attributes[1:-1]) 
    labels = tf.stack(attributes[-1]) 
    return features, labels 

def input_pipeline(filename = 'dataset.csv', batch_size = 30, num_epochs=None): 
    filename_queue = tf.train.string_input_producer(filename, num_epochs=num_epochs, shuffle=True) 
    features, labels = read_from_cvs(filename_queue) 

    min_after_dequeue = 10000 
    capacity = min_after_dequeue + 3 * batch_size 
    feature_batch, label_batch = tf.train.shuffle_batch(
     [features, labels], batch_size=batch_size, capacity=capacity, 
     min_after_dequeue=min_after_dequeue) 
    return feature_batch, label_batch 

答えて

2

を使用できtf.one_hot

# depth = num_clasess in general for the problem  
labels = tf.one_hot(tf.stack(attributes[-1]), depth) 
関連する問題