1

私はKerasを使って畳み込みニューラルネットワークを設計することを学んでいます。私はVGG16をベースにしたシンプルなモデルを開発しました。私は、データセットに約6クラスの画像を持っています。私のモデルのコードと説明は次のとおりです。ここKerasの複数分類における複数の予測?

model = models.Sequential() 
conv_base = VGG16(weights='imagenet' ,include_top=False, input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3)) 
conv_base.trainable = False 
model.add(conv_base) 
model.add(layers.Flatten()) 
model.add(layers.Dense(256, activation='relu', kernel_regularizer=regularizers.l2(0.001))) 
model.add(layers.Dropout(0.5)) 
model.add(layers.Dense(6, activation='sigmoid')) 

1

モデルをコンパイルし、取り付けるためのコードである:ここ

model.compile(loss='categorical_crossentropy', 
     optimizer=optimizers.RMSprop(lr=1e-4), 
     metrics=['acc']) 
model.summary() 

callbacks = [ 
    EarlyStopping(monitor='acc', patience=1, mode='auto'), 
    ModelCheckpoint(monitor='val_loss', save_best_only=True, filepath=model_file_path) 
] 

history = model.fit_generator(
    train_generator, 
    steps_per_epoch=10, 
    epochs=EPOCHS, 
    validation_data=validation_generator, 
    callbacks = callbacks, 
    validation_steps=10) 

新しいイメージ

img = image.load_img(img_path, target_size=(IMAGE_SIZE, IMAGE_SIZE)) 
plt.figure(index) 
imgplot = plt.imshow(img) 

x = image.img_to_array(img) 
x = x.reshape((1,) + x.shape) 
prediction = model.predict(x)[0] 
# print(prediction) 

の予測しばしばモデルのコードです。 predict()メソッドは複数のクラスを予測します。

[0 1 1 0 0 0] 

私は、マルチクラス分類モデルは、複数の出力を予測するための質問

  1. のカップルが、それは正常ですか?
  2. 複数のクラスが予測された場合、トレーニング時間中に精度がどのように測定されますか?
  3. どのようにして1つのクラスだけが予測されるようにニューラルネットワークを変更できますか?

何か助けていただければ幸いです。どうもありがとうございます!

答えて

1

あなたはマルチクラスの分類ではなく、マルチラベルです。これは、出力層でのシグモイド活性化の使用によって引き起こされます。マルチクラスの分類を適切に行うには、出力でsoftmaxのアクティベーションを使用します。これにより、クラスに対して確率分布が生成されます。 確率が最大(argmax)のクラスを取得すると、期待どおりに1つのクラス予測が生成されます。

+0

ありがとうございます。それだった! Keras列車機能によって正確さがどのように測定されるのか説明できますか?例えば、訓練中に、モデルが[0,0.2,0.4,0.7,0.1,0]と予測されたと仮定する。ケラスは0.7の精度でクラス3を見て、それを真実と比較するモデルの出力として考えますか? – TMS

+0

このアプローチはライブラリーに固有のものではありません。一般的なマルチクラス分類確率は、n個の出力クラスでソフトマックス活性化を使用することであり、「ピック」を最も高い確率の1つとする。あなたのケースでは、はいクラス3が選択されたクラスとみなされます。 1つのサンプルの精度はバイナリであり、入力の平均値です。 @TMS。 – modesitt

関連する問題