2017-04-03 10 views
0

テンソルフローのLinearClassifierデータを有名なタイタニックデータセットで再生しました。tf.contrib.learn.LinearClassifierウェイトを視覚化する

(私の質問自体は一番下にダウンしている - これは、モデル自体のすべてのいくつかのコードです)

だから私は私の機能のカラムがあります。

CONTINUOUS_COLS = ['Age', 'Fare'] 
CATEGORICAL_COLS = ['Sex', 'Pclass', 'Title'] 
LABELS_COL = 'Survived' 

sex_col = sparse_column_with_keys('Sex', keys=['male', 'female']) 
title_col = sparse_column_with_hash_bucket('Title', 10) 
fare_class_col = sparse_column_with_keys('Pclass', keys=['1','2','3']) 
age_col = real_valued_column('Age') 
fare_col = real_valued_column('Fare') 

私の入力機能:

def create_input_fn(df): 
    continous_features = {k: tf.constant(df[k].values) for k in CONTINUOUS_COLS} 
    categorical_features = {k : tf.SparseTensor(
     indices=[[0,i] for i in range(df[k].size)], 
     values=df[k].values, 
     dense_shape=[df[k].size, 1] 
    ) for k in CATEGORICAL_COLS} 
    feature_cols = {**continous_features, **categorical_features} 
    labels = tf.constant(df[LABELS_COL].values) 
    return feature_cols, labels 

と私のモデル:

clf = LinearClassifier(feature_columns=[sex_col, fare_class_col, age_col, fare_col, title_col], 
    optimizer=tf.train.FtrlOptimizer(
     learning_rate=0.5, 
     l1_regularization_strength=1.0, 
     l2_regularization_strength=1.0), 
    model_dir=tempfile.TemporaryDirectory().name) 

私はモデルを実行するとき、それはうまくいくので、モデルの重みを見て、それらをよりよく視覚化したいと思っています。

(非推奨として、それがリストされているが)だから、clf.weights_が存在するので、私はちょうどそれらを手動で引き出す:

for var in clf.get_variable_names(): 
    if var.endswith('weights'): 
     print(f'{var} -> {clf.get_variable_value(var)}') 

そして、私はいくつかのまともな結果を得る:

linear/Pclass/weights -> [[ 0.  ] 
[ 0.  ] 
[-0.01772301]] 
linear/Sex/weights -> [[-0.07285357] 
[ 0.  ]] 
linear/Title/weights -> [[ 0.  ] 
[ 0.  ] 
[ 0.  ] 
[-0.03760524] 
[ 0.  ] 
[ 0.  ] 
[ 0.  ] 
[ 0.  ] 
[ 0.  ] 
[ 0.  ]] 

今私の質問は - もともと使用されていたキーを引き出すにはどうすればよいですか? 私は、例えば、性別などの数字をよりよく一致させることができます。キーはもともと男性/女性にマッピングされています。

ありがとうございます! sparse_column_with_keysについては

答えて

0


sex_col.lookup_config.keys # ('male', 'female')

だから、のようなもの:

matched = {} 
weights = clf.get_variable_value('linear/Sex/weights') # np array 
for index, key in enumerate(sex_col.lookup_config.keys): 
    matched[key] = weights[index] 

、あなたdir(sex_col.lookup_config)他のいくつかの興味深い属性があるとのメソッドのドキュメンテーション文字列をチェックしてください。Source for SparseColumn Feature classes https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/layers/python/layers/feature_column.py

Iから地図を見つけられませんでしたまだ

あなたが持っていた場合は、チュートリアルでage_bucketsようtf.contrib.layers.bucketized_columnage_buckets.boundaries

関連する問題