2017-01-05 9 views
0

Deeplearning4jの公式ガイドは.csvファイルの使用方法を示していますが、カスタムモデルの使用方法を知りたいと思っています。私は適切なDataSet実装を探してみましたが、何も見つけられないようです。通常の.csvの内容(文字列形式)を取っても十分に良いでしょう。Deeplearning4jモデルをデータセットに解析する

モデル:

package com.example.kamil.deeplearningandroid; 

public class Job implements LearnableModel { 
private int type; 
private int salary; 
private int choice; 

public Job(String type, int salary, boolean choice) { 
    this.type = encodeType(type); 
    this.salary = salary; 
    this.choice = encodeChoice(choice); 
} 

private int encodeType(String job) { 
    switch (job) { 
     case "Mechanic": return 0; 
     case "Programmer": return 1; 
     case "Teacher": return 2; 
     case "Driver": return 3; 
     case "Cook": return 4; 
     default: return 5; 
    } 
} 

private int encodeChoice(boolean choice) { 
    return choice ? 1: 0; 
} 

@Override 
public String toString() { 
    return type + SEPARATOR + salary + SEPARATOR + choice + "\n"; 
} 
} 

とJobClassifierで:

private DataSet readStringDataset(List<LearnableModel> data, int batchSize, int labelIndex, int numClasses) throws IOException, InterruptedException { 
    RecordReader rr = new LineRecordReader(); 
    rr.initialize(new StringSplit(modelToString(data))); 
    DataSetIterator iterator = new RecordReaderDataSetIterator(rr,batchSize,labelIndex,numClasses); 
    return iterator.next(); 
} 

private String modelToString(List<LearnableModel> list) { 
    StringBuilder sb = new StringBuilder(); 
    for (LearnableModel model: list) { 
     sb.append(model.toString()); 
    } 
    return sb.toString(); 
} 

をすべてこれで私は取得しています:

W/System.err: java.lang.NumberFormatException: Invalid double: "1,10,0 
W/System.err: 1,15,1 
W/System.err: 4,7,0 
W/System.err: 5,10,1 
W/System.err: 3,10,0 
W/System.err: 3,20,0 
W/System.err: 4,5,0 
W/System.err: 4,12,1 
W/System.err: 2,20,1 
W/System.err: 2,4,0 
W/System.err: 5,12,1 
W/System.err: 0,10,0 
W/System.err: 5,5,0 
W/System.err: 1,10,0 
W/System.err: 2,16,1 
W/System.err: 3,30,1 
W/System.err: 4,16,1 
W/System.err: 5,19,1 
W/System.err: 5,6,0 
W/System.err: 1,11,0" 
W/System.err:  at java.lang.StringToReal.invalidReal(StringToReal.java:63) 
W/System.err:  at java.lang.StringToReal.initialParse(StringToReal.java:164) 
W/System.err:  at java.lang.StringToReal.parseDouble(StringToReal.java:282) 
W/System.err:  at java.lang.Double.parseDouble(Double.java:301) 
W/System.err:  at org.datavec.api.writable.Text.toDouble(Text.java:601) 
W/System.err:  at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.getDataSet(RecordReaderDataSetIterator.java:271) 
W/System.err:  at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:177) 
W/System.err:  at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:372) 
W/System.err:  at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:52) 
W/System.err:  at com.example.kamil.deeplearningandroid.JobClassifier.readStringDataset(JobClassifier.java:185) 
W/System.err:  at com.example.kamil.deeplearningandroid.JobClassifier.classify(JobClassifier.java:65) 
W/System.err:  at com.example.kamil.deeplearningandroid.MainActivity.onCreate(MainActivity.java:23) 
W/System.err:  at android.app.Activity.performCreate(Activity.java:6251) 
W/System.err:  at android.app.Instrumentation.callActivityOnCreate(Instrumentation.java:1107) 
W/System.err:  at android.app.ActivityThread.performLaunchActivity(ActivityThread.java:2369) 
W/System.err:  at android.app.ActivityThread.handleLaunchActivity(ActivityThread.java:2476) 
W/System.err:  at android.app.ActivityThread.-wrap11(ActivityThread.java) 
W/System.err:  at android.app.ActivityThread$H.handleMessage(ActivityThread.java:1344) 
W/System.err:  at android.os.Handler.dispatchMessage(Handler.java:102) 
W/System.err:  at android.os.Looper.loop(Looper.java:148) 
W/System.err:  at android.app.ActivityThread.main(ActivityThread.java:5417) 
W/System.err:  at java.lang.reflect.Method.invoke(Native Method) 
W/System.err:  at com.android.internal.os.ZygoteInit$MethodAndArgsCaller.run(ZygoteInit.java:726) 
W/System.err:  at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:616) 

答えて

1

あなたを想定している私はこのようにそれをやってみましたdatavecを使用する。 "データセットの実装"がありません。すべてがndarraysに変換されます。

http://deeplearning4j.org/DataVec

私たちの例で、このカバーよりも多く:http://gitub.com/deeplearning4j/dl4j-examples

編集:シンプルなCSV上の推論のために。あなたは単純なことができます: String [] arr = line.split( "、");

ダブル[]を作成するか、Float.parseFloatとDouble.parseDouble(VIA [] FLOAT)

、次に行う: INDArray ARR = Nd4j.create(フロート[])。またはINDArray arr = Nd4j.create(double []);

推論/スコアリングのためのデータセットは必要ありません。トレーニングでは、RecordReaderDataSetIteratorまたはSequenceRecordReaderDataSetIteratorでdatavecを使用します。

+0

現在、両方のリンクは動作していません。おそらくhttps://deeplearning4j.org/datavecdoc/とhttp://github.com/deeplearning4j/dl4j-examples –

関連する問題