2017-10-03 9 views
2

自分のバージョンのMNISTデータを作成しようとしています。トレーニングとテストのデータを次のファイルに変換しました。独自のMNISTデータセットの作成(MNISTフォーマットと同じ)

test-images-idx3-ubyte.gz 
test-labels-idx1-ubyte.gz 
train-images-idx3-ubyte.gz 
train-labels-idx1-ubyte.gz 

(興味がある人々のために、私は私が目指してるものに私が近づくように思われJPG-PNG-to-MNIST-NN-Formatを使用してこれをした。)

しかし、これは、ファイルの種類と形式と全く同じではありませんMNISTデータ(mnist.pkl.gz)。私はpklはデータが節約されたことを意味すると理解していますが、私は実際にデータを酸洗するプロセスを理解していません - 酸洗いに特定の順序がありますか?誰かが自分のデータをピクルに使うべきコードを提供することはできますか?

答えて

1
import gzip 
import os 

import numpy as np 
import six 
from six.moves.urllib import request 

parent = 'http://yann.lecun.com/exdb/mnist' 
train_images = 'train-images-idx3-ubyte.gz' 
train_labels = 'train-labels-idx1-ubyte.gz' 
test_images = 't10k-images-idx3-ubyte.gz' 
test_labels = 't10k-labels-idx1-ubyte.gz' 
num_train = 17010 
num_test = 3010 
dim = 32*32 


def load_mnist(images, labels, num): 
    data = np.zeros(num * dim, dtype=np.uint8).reshape((num, dim)) 
    target = np.zeros(num, dtype=np.uint8).reshape((num,)) 

    with gzip.open(images, 'rb') as f_images,\ 
      gzip.open(labels, 'rb') as f_labels: 
     f_images.read(16) 
     f_labels.read(8) 
     for i in six.moves.range(num): 
      target[i] = ord(f_labels.read(1)) 
      for j in six.moves.range(dim): 
       data[i, j] = ord(f_images.read(1)) 

    return data, target 


def download_mnist_data(): 

    print('Converting training data...') 
    data_train, target_train = load_mnist(train_images, train_labels, 
              num_train) 
    print('Done') 
    print('Converting test data...') 
    data_test, target_test = load_mnist(test_images, test_labels, num_test) 
    mnist = {} 
    mnist['data'] = np.append(data_train, data_test, axis=0) 
    mnist['target'] = np.append(target_train, target_test, axis=0) 

    print('Done') 
    print('Save output...') 
    with open('mnist.pkl', 'wb') as output: 
     six.moves.cPickle.dump(mnist, output, -1) 
    print('Done') 
    print('Convert completed') 


def load_mnist_data(): 
    if not os.path.exists('mnist.pkl'): 
     download_mnist_data() 
    with open('mnist.pkl', 'rb') as mnist_pickle: 
     mnist = six.moves.cPickle.load(mnist_pickle) 
    return mnist 
download_mnist_data() 
+0

これは大変です - なぜ誰もそれが気に入らなかった – javadba

関連する問題