2016-10-15 9 views
1
image_size = 28 
num_labels = 10 

def reformat(dataset, labels): 
    dataset = dataset.reshape((-1, image_size * image_size)).astype(np.float32) 
    # Map 0 to [1.0, 0.0, 0.0 ...], 1 to [0.0, 1.0, 0.0 ...] 
    labels = (np.arange(num_labels) == labels[:,None]).astype(np.float32) 
    return dataset, labels 
train_dataset, train_labels = reformat(train_dataset, train_labels) 
valid_dataset, valid_labels = reformat(valid_dataset, valid_labels) 
test_dataset, test_labels = reformat(test_dataset, test_labels) 
print('Training set', train_dataset.shape, train_labels.shape) 
print('Validation set', valid_dataset.shape, valid_labels.shape) 
print('Test set', test_dataset.shape, test_labels.shape) 

この行は何を意味していますか?numpyのオーバーライドは==演算子でした。私はPythonコードを理解できません。

labels = (np.arange(num_labels) == labels[:,None]).astype(np.float32) 

コードはそうです、それはその意味でオーバーロードされ、(音符のラインで行われているように)は、2つのnumpyの配列を比較するときにnumpyのでhttps://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/udacity/2_fullyconnected.ipynb

答えて

3

から、==オペレータは異なる何かを意味しています。 2つのnumpy配列を要素ごとに比較し、2つの入力と同じサイズのブール値numpy配列を返します。 >=,<などの他の比較についても同様である。

などのような他の比較についても同様である。 numpyのため

import numpy as np 
print(np.array([5,8,2]) == np.array([5,3,2])) 
# [True False True] 
print((np.array([5,8,2]) == np.array([5,3,2])).astype(np.float32)) 
# [1. 0. 1.] 
1

==演算子は、ブール配列を返す要素ごとの操作であるアレイが挙げられます。 astype関数は、ブール値True1.0に、False0.0に変換します。

0

https://docs.python.org/3/reference/expressions.html#value-comparisonsは、==のような値の比較を記述する。デフォルトの比較はidentityx is yですが、どちらかの引数が__eq__メソッドを実装しているかどうかが最初に確認されます。数字、リスト、辞書は独自のバージョンを実装しています。そしてnumpyもそうです。

numpy__eq__のユニークな点は、可能な場合は要素ごとに比較し、同じサイズのブール値の配列を返す点です。

In [426]: [1,2,3]==[1,2,3] 
Out[426]: True 
In [427]: z1=np.array([1,2,3]); z2=np.array([1,2,3]) 
In [428]: z1==z2 
Out[428]: array([ True, True, True], dtype=bool) 
In [432]: z1=np.array([1,2,3]); z2=np.array([1,2,4]) 
In [433]: z1==z2 
Out[433]: array([ True, True, False], dtype=bool) 
In [434]: (z1==z2).astype(float)  # change bool to float 
Out[434]: array([ 1., 1., 0.]) 

よくある質問は「なぜこのValueErrorを取得するのですか?」です。

In [435]: if z1==z2: print('yes') 
... 
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() 

これは、複数のTrue/False値を持つ配列が比較されるためです。

フロートの比較も一般的な問題です。その問題が起きたらiscloseallcloseを確認してください。

関連する問題