2017-05-02 9 views
0

私はMac OS 10.12.4、Anaconda Python 3.5、Tensorflow 1.1で動作しています。 以下に示す再現可能なコードを一緒に飾ってあります。 "my_model"を引数 "features"と "labels"で定義しました。 私はそれらを定義しませんでした。 "my_model"関数は引数なしで呼び出されます。 Spyderの「変数」ウィンドウには、プログラムの実行後に表示されません。 私の質問は:これらの変数はどこに定義されていますか?このTensorflowプログラムはなぜ機能しますか?

チャールズ

from sklearn import metrics, cross_validation 
from tensorflow.contrib import layers 
from tensorflow.contrib import learn 
from sklearn.preprocessing import LabelEncoder 
import pandas as pd 

# shut up the warnings 
import warnings 
warnings.filterwarnings('ignore') 
import logging 
logging.getLogger("tensorflow").setLevel(logging.ERROR) 
import os 
os.environ['TF_CPP_MIN_LOG_LEVEL']='2' 
import tensorflow as tf 
tf.logging.set_verbosity(tf.logging.ERROR) 

def my_model(features, labels): 
    labels = tf.one_hot(labels, 3, 1, 0) 
    features = layers.stack(features, layers.fully_connected, [10, 20, 10]) 
    prediction, loss = (learn.models.logistic_regression(features, labels)) 
    train_op = tf.contrib.layers.optimize_loss(
     loss, 
     tf.contrib.framework.get_global_step(), 
     optimizer='Adagrad', 
     learning_rate=0.1) 

    return {'class': tf.argmax(prediction, 1), 'prob': prediction}, loss, train_op 

df = pd.read_csv("iris.csv") 
df = df.sample(frac=1) # shuffle all rows 
print(df.head()) 
column_names = list(df.columns[:4]) 
X = df[column_names].as_matrix() 
y = df['Species'] 
le = LabelEncoder() 
le.fit(df["Species"]) 
y = le.transform(df["Species"]) 
x_train, x_test, y_train, y_test = cross_validation.train_test_split(
    X, y, test_size=0.2, random_state=35) 

classifier = tf.contrib.learn.Estimator(model_fn = my_model) 
classifier.fit(x_train, y_train, steps=1000) 

y_predicted = [p['class'] for p in classifier.predict(x_test, as_iterable=True)] 
score = metrics.accuracy_score(y_test, y_predicted) 
print('Accuracy: {0:f}'.format(score)) 

答えて

0

my_modelは、あなたのコード内で呼び出されません。フィーチャーとラベルの2つの引数を持つEstimatorによって呼び出されるコールバック関数です。

実際には、fit()関数の場合はx_trainy_trainです。 the docが言うように


「モデルの機能は、機能を取り、テンソルまたはテンソルのdictsを対象とし、予測や損失テンソル例:返します。 "(機能、ターゲット) - >(予測、損失)"

そして、あなたはmodel_fnがsource code of Estimatorにライン1125で呼び出され見ることができます:

model_fn_results = self._model_fn(features, labels, **kwargs) 
関連する問題