2017-12-19 33 views
0

TensorflowのKerasで実装されているLSTMテキスト分類子を評価するためのメトリックの一部として精度、再呼び出し、FMeasureを計算する際に問題があります。私は、からKeras 2.02メトリックモジュールを知っています。私はsklearn.metrics.classification_reportにオブジェクトとして当てはめモデルと予測出力を解析しています別の方法としてSklearn Keras分類器の精度、再呼び出し、FMeasureのメトリック

# create the model 
embedding_vector_length = 32 
model = Sequential() 
# load the dataset with word embedding but only keep the top n words, zero the rest 
model.add(Embedding(top_words, embedding_vector_length, input_length=max_tweet_length)) 

model.add(LSTM(100)) 
model.add(Dense(1, activation='sigmoid')) 
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) 
print(model.summary()) 
model.fit(X_train, y_train, epochs=3, batch_size=64) 

# Final evaluation of the model 
scores = model.evaluate(X_test, y_test, verbose=0) 
print("Accuracy: %.2f%%" % (scores[1]*100)) 
print(scores) 

# print the classification report 
from sklearn.metrics import classification_report 
predicted = model.predict(X_test) 
report = classification_report(y_test, predicted) 
print(report) 

しかし私は、ターゲットのデータ型に関するエラーを得続けます。 Sigmoidアクティベーション機能を使用しているので、予測出力はfloat32形式ですが、ラベルはバイナリレベルの分類を持つテキストの集合です。私はKerasメトリクスから精度評価を得ていますが、精度、再現率、評価評価は問題です。

Traceback (most recent call last): 
    File "<stdin>", line 1, in <module> 
    File "/root/anaconda3/envs/py35/lib/python3.5/site-packages/sklearn/metrics/classification.py", line 1261, in precision_score 
    sample_weight=sample_weight) 
    File "/root/anaconda3/envs/py35/lib/python3.5/site-packages/sklearn/metrics/classification.py", line 1025, in precision_recall_fscore_support 
    y_type, y_true, y_pred = _check_targets(y_true, y_pred) 
    File "/root/anaconda3/envs/py35/lib/python3.5/site-packages/sklearn/metrics/classification.py", line 81, in _check_targets 
    "and {1} targets".format(type_true, type_pred)) 
ValueError: Classification metrics can't handle a mix of binary and continuous targets 
+0

例外はかなり簡単なようです。 'X_train'、' X_test'、 'y_train'、' y_test'の各要素の型は何ですか?あなたは0と1の束と潜在的にいくつかの無関係な要素を持つかもしれないようです。 – blacksite

答えて

0

明らかに、model.predictの出力が見つかりませんでした。実際には、binary_classificationを使用しました。を呼び出して、クラス/ラベルデータyに一致させる必要があります。

関連する問題