package com.example.minwoo_k.neural_network;
import android.os.AsyncTask;
import android.support.v7.app.AppCompatActivity;
import android.os.Bundle;
import android.util.Log;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.reflections.vfs.CommonsVfs2UrlType;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import static android.R.id.input;
import static org.reflections.Reflections.log;
public class MainActivity extends AppCompatActivity {
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
AsyncTask.execute(new Runnable() {
@Override
public void run() {
try {
createAndUseNetwork();
} catch (IOException e) {
e.printStackTrace();
}
}
});
}
private void createAndUseNetwork() throws IOException {
DenseLayer inputLayer = new DenseLayer.Builder() // Input Layer
.nIn(784)
.nOut(200)
.name("Input")
.activation(Activation.SIGMOID) // Sigmoid Activation function
.build();
DenseLayer hiddenLayer = new DenseLayer.Builder() // Hidden Layer
.nIn(200)
.nOut(10)
.name("Hidden")
.activation(Activation.SIGMOID) // Sigmoid Activation function
.build();
OutputLayer outputLayer = new OutputLayer.Builder() // Output Layer
.nIn(10)
.nOut(10)
.name("Output")
.activation(Activation.SOFTMAX) // Softmax Activation function
.build();
NeuralNetConfiguration.Builder nncBuilder = new NeuralNetConfiguration.Builder();
nncBuilder.iterations(5);
nncBuilder.learningRate(0.05); // Learning Rate
nncBuilder.weightInit(WeightInit.XAVIER);
nncBuilder.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT); // use SGD
NeuralNetConfiguration.ListBuilder listBuilder = nncBuilder.list();
listBuilder.layer(0, inputLayer);
listBuilder.layer(1, hiddenLayer);
listBuilder.layer(2, outputLayer);
listBuilder.backprop(true); // backpropagation
Log.d("ANN","****************Create ANN********************");
MultiLayerNetwork myNetwork = new MultiLayerNetwork(listBuilder.build());
myNetwork.init();
myNetwork.setListeners(new ScoreIterationListener(1));
Log.d("ANN","****************Get Data********************");
DataSetIterator mnistTrain = new MnistDataSetIterator(500, 10000, true);
DataSetIterator mnistTest = new MnistDataSetIterator(500, 100, true);
Log.d("ANN","****************Train ANN********************");
myNetwork.fit(mnistTrain);
Log.d("ANN","****************Evaluate ANN********************");
Evaluation eval = new Evaluation(10); //create an evaluation object with 10 possible classes
while(mnistTest.hasNext()){
DataSet next = mnistTest.next();
INDArray output = myNetwork.output(next.getFeatureMatrix()); //get the networks prediction
eval.eval(next.getLabels(), output); //check the prediction against the true class
}
log.info(eval.stats());
log.info("****************Example finished********************");
}
}
これは私のプログラムの完全なソースコードであり、mnistデータを読み取ることができません。 mnistデータセットを取得するにはどうすればよいですか?AndroidのDL4J DataSetlteratorからmnistデータを取得するにはどうすればよいですか?
12-15 12:26:06.526 3910から3930/com.example.minwoo_k.neural_network W/System.errの:にjava.io.IOException:/ MNIST 12-15 12時26分をMKDIRできませんでした:06.526 3910-3930/com.example.minwoo_k.neural_network W/System.err: org.deeplearning4j.base.MnistFetcher.downloadAndUntar(MnistFetcher.java:66) 12-15 12:26:06.529 3910-3930 /com.example.minwoo_k.neural_network W/System.err: org.deeararning4j.datasets.fetchers.MnistDataFetcher(MnistDataFetcher.java:65) 12-15 12:26:06.529 3910-3930/com.example .minwoo_k.neural_network W/System.err:at org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator(MnistDataSetIterator.java:65) 12-15 12:26:06.529 3910から3930/com.example.minwoo_k.neural_network W/System.errの: におけるORG .deeplearning4j.datasets.iterator.impl.MnistDataSetIterator。(MnistDataSetIterator.java:43) 12-15 12:26:06.529 3910-3930/com.example.minwoo_k.neural_network W/System.err: com.example .minwoo_k.neural_network.MainActivity.createAndUseNetwork(MainActivity.java:93) 12-15 12:26:06.529 3910-3930/com.example.minwoo_k.neural_network W/System.err: com.example.minwoo_k。 neural_network.MainActivity.access $ 000(MainActivity.java:33) 12-15 12:26:06.531 3910-3930/com.example.m inwoo_k.neural_network W/System.err: com.example.minwoo_k.neural_network.MainActivity $ 1.run(MainActivity.java:44) 12-15 12:26:06.531 3910-3930/com.example.minwoo_k。 neural_network W/System.err: android.os.AsyncTask $ SerialExecutor $ 1.run(AsyncTask.java:245)12-15 12:26:06.532 3910-3930/com.example.minwoo_k.neural_network W/System.err:at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1162) 12-15 26:06.532 3910-3930/com.example.minwoo_k.neural_network W/System.err:at java.util.concurrent.ThreadPoolExecutor $ Worker.run(ThreadPoolExecutor.java:636) 12-15 12:26:06.532 3910-393 0/com.example.minwoo_k.neural_network W/System.errの:java.lang.Thread.run(Thread.java:764)
でこれは私のLogcatレコードです。 どうすればこの問題を解決できますか?