2017-04-08 13 views
0

イム新しいdeeplearning4j(畳み込みネットワーク)で訓練を受けたカスタムモデルに新しいイメージを分類します。私はすでにword2vecの機能を試していて、すべてうまくいきました。しかし、今ではイメージの分類に関してちょっと混乱しています。私はこの例で遊んでいた: はdeeplearning4Jへ

https://github.com/deeplearning4j/dl4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/convolution/AnimalsClassification.java

は私がtrueに「保存」のフラグを変更し、私のモデルはmodel.binファイルに格納されています。 問題のある部分が出てきます(これは愚かな疑問に聞こえる場合は申し訳ありませんが、ここでは本当に明白なものがないかもしれません)

私はAnimalClassifierという別のクラスを作成しました。そこからニューラルネットワークを抽出し、復元されたネットワークを用いて単一の画像を分類する。私は必ずたかった(以前AnimalsClassification.javaでトレーニングプロセスで使用されたホッキョクグマの絵を置く> dl4j-例/ srcに/メイン/リソース/動物/ TEMP/- この単一のイメージのために、私は「一時」フォルダを作成しましたその画像は正しく分類されるので、 "bear"フォルダから画像を再利用しました)。ホッキョクグマを分類しようとしている

この私のコード:

protected static int height = 100; 
    protected static int width = 100; 
    protected static int channels = 3; 
    protected static int numExamples = 1; 
    protected static int numLabels = 1; 
    protected static int batchSize = 10; 

    protected static long seed = 42; 
    protected static Random rng = new Random(seed); 
    protected static int listenerFreq = 1; 
    protected static int iterations = 1; 
    protected static int epochs = 7; 
    protected static double splitTrainTest = 0.8; 
    protected static int nCores = 2; 
    protected static boolean save = true; 

    protected static String modelType = "AlexNet"; // 

    public static void main(String[] args) throws Exception { 

     String basePath = FilenameUtils.concat(System.getProperty("user.dir"), "dl4j-examples/src/main/resources/"); 
     MultiLayerNetwork multiLayerNetwork = ModelSerializer.restoreMultiLayerNetwork(basePath + "model.bin", true); 

     ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); 
     File mainPath = new File(System.getProperty("user.dir"), "dl4j-examples/src/main/resources/animals/temp/"); 
     FileSplit fileSplit = new FileSplit(mainPath, NativeImageLoader.ALLOWED_FORMATS, rng); 
     BalancedPathFilter pathFilter = new BalancedPathFilter(rng, labelMaker, numExamples, numLabels, batchSize); 


     InputSplit[] inputSplit = fileSplit.sample(pathFilter, 1); 
     InputSplit analysedData = inputSplit[0]; 


     ImageRecordReader recordReader = new ImageRecordReader(height, width, channels); 
     recordReader.initialize(analysedData); 
     DataSetIterator dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, 0, 4); 
     while (dataIter.hasNext()) { 
      DataSet testDataSet = dataIter.next(); 

      String expectedResult = testDataSet.getLabelName(0); 
      List<String> predict = multiLayerNetwork.predict(testDataSet); 
      String modelResult = predict.get(0); 
      System.out.println("\nFor example that is labeled " + expectedResult + " the model predicted " + modelResult + "\n\n"); 
     } 
    } 

これを実行した後、私はエラーを取得する:

java.lang.UnsupportedOperationExceptionが org.datavec.api.writable.ArrayWritable.toIntでorg.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:186)でorg.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.getDataSet(RecordReaderDataSetIterator.java:275) で(ArrayWritable.java:47) でorg.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:389) at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:52) at org.deeplearning4j.examples.convolution。 AnimalClassifier.main(AnimalClassifier.java:66)ターゲットVMから切断 、アドレス:「127.0.0.1:63967」、輸送:「ソケット」のスレッドで 例外「メイン」java.lang.IllegalStateException:ラベル名が定義されていませんこのデータセットでIDを持つgetLabelNameを使用するには、ラベル名を追加します。 org.deeplearning4j.examples.convolution.AnimalClassifier.mainでorg.nd4j.linalg.dataset.DataSet.getLabelName(DataSet.java:1106) で (AnimalClassifier.java:68)

私が見ることができる

がありますメソッドpublic無効MultiLayerNetwork.javaでsetLabels(INDArrayラベル)が、私は(それが引数INDArrayとして取る場合は特に)の使用方法を得ることはありません。

私はRecordReaderDataSetIteratorのコンストラクタで可能なラベルの数を指定する必要がなぜ私も混乱しています。私はそのモデルがどのラベルを使うべきかをすでに知っていることを期待しています(訓練中に自動的に使用されたラベルは使用しないでください)。私はおそらく、私は完全に間違った方法で画像をロードしていると思います...

だから私は単純に次のことを達成したいと思い、要約する:

    1. は(これが機能している)モデルからネットワークを復元負荷イメージに分類される(これも作業)
    2. 使用して、この画像を分類します訓練中に使用された同じラベル(クマ、シカ、カモ、カメ)

    ご協力ありがとうございました!

  • +0

    どのバージョンをお使いですか?現在RecordReaderDataSetIteratorを見て:https://github.com/deeplearning4j/deeplearning4j/blob/master/deeplearning4j-core/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetIterator.java#L275にはライン275が存在しませんnumLabelsに関するあなたの質問について:これは遺産であり、修正される予定です:https://github.com/deeplearning4j/deeplearning4j/issues/3216それを指摘してくれてありがとうございます。 –

    +0

    ありがとうございます。バージョンは0.8.0です。現時点では、手順3(復元されたネットワーク上で単一イメージを分類する)を実行する方法を簡単な回避策ですか? –

    +0

    は、一般的にそれはかなりまっすぐ進むのです - あなたはこれを適応し、あなたが得るものを見てみてくださいすることができますhttps://github.com/deeplearning4j/dl4j-examples/tree/master/dl4j-examples/src/main/java/org/deeplearning4j/例/ dataexamples(同じワークフローをロード/セーブMNISTを見て)あなたはこれを見た場合、ラベルインデックスがオフになっているように、それが見えます:https://github.com/deeplearning4j/dl4j-examples/blob/master/dl4j-examples/ src/main/java/org/deeplearning4j/examples/dataexamples/MnistImagePipelineExampleLoad.java#L111 0は配列ではありません。0はintラベルではありません。 –

    答えて

    0

    だからここにあなたの複数の質問を要約:画像のための レコードは、コレクション内の2つのエントリです。 2番目のラベルはラベルです。ラベルのインデックスは、あなたが渡したレコードの種類に関連しています。

    複数のエントリは、データセットから離れていてもかまいません。このリストは、ミニブタの特定のにあるアイテムのラベルを参照します。

    関連する問題