私の問題は、テンソルフローバックエンドを使用してケラスの予測方法で結果を得ることです。しかし、まず小さな紹介。Keras - センチメントデータの学習後のモデル予測は
私は
- のPythonを使用しています2.7.12
- Keras == 1.2.1
- numpyのの== 1.12.0
- tensorflowの== 0.12.1
私はこれらの文書のような畳み込みニューラルネットワークを作成しました:https://blog.keras.io/using-pre-trained-word-embeddings-in-a-keras-model.html
私は11842の準備されたtwitterテキストでネットワークを訓練しました。個々の変化は、結果(0,1,2)に対して3つの可能性があることだけです。私は次のコード行でそれを定義しました:
preds = Dense(3, activation='softmax')(x)
したがって、この方法は問題なく動作し、私は88-92%の間で達成しています。
model_fit = model.fit(x_train, y_train, validation_data=(x_val, y_val), nb_epoch=10, batch_size=128)
学習プロセスの後、私はモデルを.h5形式で保存しました(これもうまくいきます)。
私はモデルを読み込んで予測します。最初の例(trained_model)は、私がそれらを比較したいので、私が訓練するのに使ったのと同じデータを使って行われます。 2番目の例(trained_model_2)は新しいtwitterテキスト(私は以前に収集したもの)を介して行われます。両方のコード行が提起次のように入力dtype=int32
からされている。また
print data_train.shape # (11842, 1000)
print data_predict.shape # (46962, 1000)
: "ライブ/新しい" データセットでトレーニングデータセットを比較するための
trained_model = load_model("trained_model.h5")
prediction_result = trained_model.predict(data_train, batch_size=128)
print prediction_result.shape ### Prints: (11842, 3)
trained_model_2 = load_model("trained_model.h5")
prediction_result_2 = trained_model_2.predict(data_predict, batch_size=128)
最初のエラー:
prediction_result_2 = trained_model_2.predict(data_predict, batch_size=128)
tensorflow.python.framework.errors_impl.InvalidArgumentError: indices[0,999] = 13608 is not in [0, 13480) [[Node: Gather_1 = Gather[Tindices=DT_INT32, Tparams=DT_FLOAT, validate_indices=true, _device="/job:localhost/replica:0/task:0/cpu:0"](embedding_1_W_1/read, _recv_input_1_1_0)]]
次のコード行は、2番目のエラーが発生します。私が作成したメソッドの
trained_model_2 = load_model("trained_model.h5")
InvalidArgumentError (see above for traceback): indices[0,999] = 13608 is not in [0, 13480) [[Node: Gather_1 = Gather[Tindices=DT_INT32, Tparams=DT_FLOAT, validate_indices=true, _device="/job:localhost/replica:0/task:0/cpu:0"](embedding_1_W_1/read, _recv_input_1_1_0)]]
EDIT ソースコードを。メソッド "trainModule"はネットワークのトレーニング/保存にのみ使用されます。私の予測テストには、 "predict_sentiment"メソッドが使用されています。最初prediction_results作品や形状(11842、3) Code - pastbin
全体エラー出力以下でnumpyの配列を返します。Error output - pastbin
いくつかの追加情報が必要な場合、私は質問を更新します...
変換の種類あなたは、data_trainからx_val x_train得るために行うのですか? – rAyyy
モデルを学習するには?私は上記のリンクのようにしました。私はすべてのテキストをシャッフルし、トレーニングデータと検証データの間でそれらを分割した後に、 "VALIDATION_SPLIT"(私にとっては0.19)でそれらを掛けます。それとも、どうやってそれらを数の少ない配列にするのですか? – HauLuk