network
This commit is contained in:
parent
4399265e76
commit
1a315a092c
@ -8,6 +8,7 @@ import os
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
from torchvision.io import read_image
|
||||
|
||||
from torchvision.transforms import Resize, Lambda, transforms, ToTensor
|
||||
|
||||
from Constants import *
|
||||
@ -25,7 +26,7 @@ class Neurons:
|
||||
#self.loadImages()
|
||||
self.train_dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
|
||||
|
||||
input_dim = 28 * 28
|
||||
input_dim = 100
|
||||
output_dim = 10
|
||||
hidden_dim = 10
|
||||
print("create model")
|
||||
@ -45,7 +46,7 @@ class Neurons:
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.001)
|
||||
|
||||
for epoch in range(n_iter):
|
||||
for image, label in next(iter(self.train_dataloader)):
|
||||
for image, label in self.train_dataloader:
|
||||
optimizer.zero_grad()
|
||||
|
||||
output = model(image)
|
||||
@ -122,7 +123,7 @@ class CustomImageDataset(Dataset):
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
|
||||
image = Image.open(img_path)
|
||||
image = Image.open(img_path).convert("RGB")
|
||||
label = self.img_labels.iloc[idx, 1]
|
||||
if self.transform:
|
||||
image = self.transform(image)
|
||||
|
1
venv/Lib/site-packages/mnist-0.2.2.dist-info/INSTALLER
Normal file
1
venv/Lib/site-packages/mnist-0.2.2.dist-info/INSTALLER
Normal file
@ -0,0 +1 @@
|
||||
pip
|
45
venv/Lib/site-packages/mnist-0.2.2.dist-info/METADATA
Normal file
45
venv/Lib/site-packages/mnist-0.2.2.dist-info/METADATA
Normal file
@ -0,0 +1,45 @@
|
||||
Metadata-Version: 2.1
|
||||
Name: mnist
|
||||
Version: 0.2.2
|
||||
Summary: Python utilities to download and parse the MNIST dataset
|
||||
Home-page: https://github.com/datapythonista/mnist
|
||||
Author: Marc Garcia
|
||||
Author-email: garcia.marc@gmail.com
|
||||
License: BSD
|
||||
Platform: UNKNOWN
|
||||
Classifier: Development Status :: 4 - Beta
|
||||
Classifier: Environment :: Console
|
||||
Classifier: Operating System :: OS Independent
|
||||
Classifier: Intended Audience :: Science/Research
|
||||
Classifier: Programming Language :: Python :: 2
|
||||
Classifier: Programming Language :: Python :: 2.7
|
||||
Classifier: Programming Language :: Python :: 3
|
||||
Classifier: Programming Language :: Python :: 3.5
|
||||
Classifier: Programming Language :: Python :: 3.6
|
||||
Classifier: Programming Language :: Python :: 3.7
|
||||
Classifier: Topic :: Scientific/Engineering
|
||||
Requires-Dist: numpy
|
||||
Requires-Dist: mock; python_version=="2.7"
|
||||
|
||||
|
||||
The MNIST database is available at http://yann.lecun.com/exdb/mnist/
|
||||
|
||||
The MNIST database is a dataset of handwritten digits. It has 60,000
|
||||
training samples, and 10,000 test samples. Each image is represented
|
||||
by 28x28 pixels, each containing a value 0 - 255 with its grayscale value.
|
||||
|
||||
It is a subset of a larger set available from NIST. The digits have been
|
||||
size-normalized and centered in a fixed-size image.
|
||||
|
||||
It is a good database for people who want to try learning techniques and
|
||||
pattern recognition methods on real-world data while spending minimal
|
||||
efforts on preprocessing and formatting.
|
||||
|
||||
There are four files available, which contain separately train and test,
|
||||
and images and labels.
|
||||
|
||||
Thanks to Yann LeCun, Corinna Cortes, Christopher J.C. Burges.
|
||||
|
||||
mnist makes it easier to download and parse MNIST files.
|
||||
|
||||
|
8
venv/Lib/site-packages/mnist-0.2.2.dist-info/RECORD
Normal file
8
venv/Lib/site-packages/mnist-0.2.2.dist-info/RECORD
Normal file
@ -0,0 +1,8 @@
|
||||
mnist-0.2.2.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
||||
mnist-0.2.2.dist-info/METADATA,sha256=6EcFTjjdJvJY5OeJ0GK_83z2kM8-wJd2gt5qIYNwDag,1649
|
||||
mnist-0.2.2.dist-info/RECORD,,
|
||||
mnist-0.2.2.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
mnist-0.2.2.dist-info/WHEEL,sha256=J3CsTk7Mf2JNUyhImI-mjX-fmI4oDjyiXgWT4qgZiCE,110
|
||||
mnist-0.2.2.dist-info/top_level.txt,sha256=LZk4DcL6HJ7bDK2qWX3Jgg25BAR1cgCUo--jQ78YiHg,6
|
||||
mnist/__init__.py,sha256=FrH55GR3AqdLQ1Ddvl_e3x2QcaZQCxHQ6TSEtKiPElA,5992
|
||||
mnist/__pycache__/__init__.cpython-39.pyc,,
|
6
venv/Lib/site-packages/mnist-0.2.2.dist-info/WHEEL
Normal file
6
venv/Lib/site-packages/mnist-0.2.2.dist-info/WHEEL
Normal file
@ -0,0 +1,6 @@
|
||||
Wheel-Version: 1.0
|
||||
Generator: bdist_wheel (0.31.0)
|
||||
Root-Is-Purelib: true
|
||||
Tag: py2-none-any
|
||||
Tag: py3-none-any
|
||||
|
@ -0,0 +1 @@
|
||||
mnist
|
4
venv/Lib/site-packages/mnist/__init__.py
Normal file
4
venv/Lib/site-packages/mnist/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from .loader import MNIST
|
||||
from .packer import label_packer, img_packer
|
||||
|
||||
__all__ = [MNIST, label_packer, img_packer]
|
BIN
venv/Lib/site-packages/mnist/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
venv/Lib/site-packages/mnist/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
venv/Lib/site-packages/mnist/__pycache__/loader.cpython-39.pyc
Normal file
BIN
venv/Lib/site-packages/mnist/__pycache__/loader.cpython-39.pyc
Normal file
Binary file not shown.
BIN
venv/Lib/site-packages/mnist/__pycache__/packer.cpython-39.pyc
Normal file
BIN
venv/Lib/site-packages/mnist/__pycache__/packer.cpython-39.pyc
Normal file
Binary file not shown.
301
venv/Lib/site-packages/mnist/loader.py
Normal file
301
venv/Lib/site-packages/mnist/loader.py
Normal file
@ -0,0 +1,301 @@
|
||||
import gzip
|
||||
import os
|
||||
import struct
|
||||
from array import array
|
||||
import random
|
||||
|
||||
_allowed_modes = (
|
||||
# integer values in {0..255}
|
||||
'vanilla',
|
||||
|
||||
# integer values in {0,1}
|
||||
# values set at 1 (instead of 0) with probability p = orig/255
|
||||
# as in Ruslan Salakhutdinov and Iain Murray's paper
|
||||
# 'On The Quantitative Analysis of Deep Belief Network' (2008)
|
||||
'randomly_binarized',
|
||||
|
||||
# integer values in {0,1}
|
||||
# values set at 1 (instead of 0) if orig/255 > 0.5
|
||||
'rounded_binarized',
|
||||
)
|
||||
|
||||
_allowed_return_types = (
|
||||
# default return type. Computationally more expensive.
|
||||
# Useful if numpy is not installed.
|
||||
'lists',
|
||||
|
||||
# Numpy module will be dynamically loaded on demand.
|
||||
'numpy',
|
||||
)
|
||||
|
||||
np = None
|
||||
def _import_numpy():
|
||||
# will be called only when the numpy return type has been specifically
|
||||
# requested via the 'return_type' parameter in MNIST class' constructor.
|
||||
global np
|
||||
if np is None: # import only once
|
||||
try:
|
||||
import numpy as _np
|
||||
except ImportError as e:
|
||||
raise MNISTException(
|
||||
"need to have numpy installed to return numpy arrays."\
|
||||
+" Otherwise, please set return_type='lists' in constructor."
|
||||
)
|
||||
np = _np
|
||||
else:
|
||||
pass # was already previously imported
|
||||
return np
|
||||
|
||||
class MNISTException(Exception):
|
||||
pass
|
||||
|
||||
class MNIST(object):
|
||||
def __init__(self, path='.', mode='vanilla', return_type='lists', gz=False):
|
||||
self.path = path
|
||||
|
||||
assert mode in _allowed_modes, \
|
||||
"selected mode '{}' not in {}".format(mode,_allowed_modes)
|
||||
|
||||
self._mode = mode
|
||||
|
||||
assert return_type in _allowed_return_types, \
|
||||
"selected return_type '{}' not in {}".format(
|
||||
return_type,
|
||||
_allowed_return_types
|
||||
)
|
||||
|
||||
self._return_type = return_type
|
||||
|
||||
self.test_img_fname = 't10k-images-idx3-ubyte'
|
||||
self.test_lbl_fname = 't10k-labels-idx1-ubyte'
|
||||
|
||||
self.train_img_fname = 'train-images-idx3-ubyte'
|
||||
self.train_lbl_fname = 'train-labels-idx1-ubyte'
|
||||
|
||||
self.gz = gz
|
||||
self.emnistRotate = False
|
||||
|
||||
self.test_images = []
|
||||
self.test_labels = []
|
||||
|
||||
self.train_images = []
|
||||
self.train_labels = []
|
||||
|
||||
def select_emnist(self, dataset='digits'):
|
||||
'''
|
||||
Select one of the EMNIST datasets
|
||||
|
||||
Available datasets:
|
||||
- balanced
|
||||
- byclass
|
||||
- bymerge
|
||||
- digits
|
||||
- letters
|
||||
- mnist
|
||||
'''
|
||||
template = 'emnist-{0}-{1}-{2}-idx{3}-ubyte'
|
||||
|
||||
self.gz = True
|
||||
self.emnistRotate = True
|
||||
|
||||
self.test_img_fname = template.format(dataset, 'test', 'images', 3)
|
||||
self.test_lbl_fname = template.format(dataset, 'test', 'labels', 1)
|
||||
|
||||
self.train_img_fname = template.format(dataset, 'train', 'images', 3)
|
||||
self.train_lbl_fname = template.format(dataset, 'train', 'labels', 1)
|
||||
|
||||
@property # read only because set only once, via constructor
|
||||
def mode(self):
|
||||
return self._mode
|
||||
|
||||
@property # read only because set only once, via constructor
|
||||
def return_type(self):
|
||||
return self._return_type
|
||||
|
||||
def load_testing(self):
|
||||
ims, labels = self.load(os.path.join(self.path, self.test_img_fname),
|
||||
os.path.join(self.path, self.test_lbl_fname))
|
||||
|
||||
self.test_images = self.process_images(ims)
|
||||
self.test_labels = self.process_labels(labels)
|
||||
|
||||
return self.test_images, self.test_labels
|
||||
|
||||
def load_training(self):
|
||||
ims, labels = self.load(os.path.join(self.path, self.train_img_fname),
|
||||
os.path.join(self.path, self.train_lbl_fname))
|
||||
|
||||
self.train_images = self.process_images(ims)
|
||||
self.train_labels = self.process_labels(labels)
|
||||
|
||||
return self.train_images, self.train_labels
|
||||
|
||||
def load_training_in_batches(self, batch_size):
|
||||
if type(batch_size) is not int:
|
||||
raise ValueError('batch_size must be a int number')
|
||||
batch_sp = 0
|
||||
last = False
|
||||
self._get_dataset_size(os.path.join(self.path, self.train_img_fname),
|
||||
os.path.join(self.path, self.train_lbl_fname))
|
||||
|
||||
while True:
|
||||
ims, labels = self.load(
|
||||
os.path.join(self.path, self.train_img_fname),
|
||||
os.path.join(self.path, self.train_lbl_fname),
|
||||
batch=[batch_sp, batch_size])
|
||||
|
||||
self.train_images = self.process_images(ims)
|
||||
self.train_labels = self.process_labels(labels)
|
||||
|
||||
yield self.train_images, self.train_labels
|
||||
|
||||
if last:
|
||||
break
|
||||
|
||||
batch_sp += batch_size
|
||||
if batch_sp + batch_size > self.dataset_size:
|
||||
last = True
|
||||
batch_size = self.dataset_size - batch_sp
|
||||
|
||||
def _get_dataset_size(self, path_img, path_lbl):
|
||||
with self.opener(path_lbl, 'rb') as file:
|
||||
magic, lb_size = struct.unpack(">II", file.read(8))
|
||||
if magic != 2049:
|
||||
raise ValueError('Magic number mismatch, expected 2049,'
|
||||
'got {}'.format(magic))
|
||||
|
||||
with self.opener(path_img, 'rb') as file:
|
||||
magic, im_size = struct.unpack(">II", file.read(8))
|
||||
if magic != 2051:
|
||||
raise ValueError('Magic number mismatch, expected 2051,'
|
||||
'got {}'.format(magic))
|
||||
|
||||
if lb_size != im_size:
|
||||
raise ValueError('image size is not equal to label size')
|
||||
|
||||
self.dataset_size = lb_size
|
||||
|
||||
def process_images(self, images):
|
||||
if self.return_type is 'lists':
|
||||
return self.process_images_to_lists(images)
|
||||
elif self.return_type is 'numpy':
|
||||
return self.process_images_to_numpy(images)
|
||||
else:
|
||||
raise MNISTException("unknown return_type '{}'".format(self.return_type))
|
||||
|
||||
def process_labels(self, labels):
|
||||
if self.return_type is 'lists':
|
||||
return labels
|
||||
elif self.return_type is 'numpy':
|
||||
_np = _import_numpy()
|
||||
return _np.array(labels)
|
||||
else:
|
||||
raise MNISTException("unknown return_type '{}'".format(self.return_type))
|
||||
|
||||
def process_images_to_numpy(self,images):
|
||||
_np = _import_numpy()
|
||||
|
||||
images_np = _np.array(images)
|
||||
|
||||
if self.mode == 'vanilla':
|
||||
pass # no processing, return them vanilla
|
||||
|
||||
elif self.mode == 'randomly_binarized':
|
||||
r = _np.random.random(images_np.shape)
|
||||
images_np = (r <= ( images_np / 255)).astype('int') # bool to 0/1
|
||||
|
||||
elif self.mode == 'rounded_binarized':
|
||||
images_np = ((images_np / 255) > 0.5).astype('int') # bool to 0/1
|
||||
|
||||
else:
|
||||
raise MNISTException("unknown mode '{}'".format(self.mode))
|
||||
|
||||
return images_np
|
||||
|
||||
def process_images_to_lists(self,images):
|
||||
if self.mode == 'vanilla':
|
||||
pass # no processing, return them vanilla
|
||||
|
||||
elif self.mode == 'randomly_binarized':
|
||||
for i in range(len(images)):
|
||||
for j in range(len(images[i])):
|
||||
pixel = images[i][j]
|
||||
images[i][j] = int(random.random() <= pixel/255) # bool to 0/1
|
||||
|
||||
elif self.mode == 'rounded_binarized':
|
||||
for i in range(len(images)):
|
||||
for j in range(len(images[i])):
|
||||
pixel = images[i][j]
|
||||
images[i][j] = int(pixel/255 > 0.5) # bool to 0/1
|
||||
else:
|
||||
raise MNISTException("unknown mode '{}'".format(self.mode))
|
||||
|
||||
return images
|
||||
|
||||
def opener(self, path_fn, *args, **kwargs):
|
||||
if self.gz:
|
||||
return gzip.open(path_fn + '.gz', *args, **kwargs)
|
||||
else:
|
||||
return open(path_fn, *args, **kwargs)
|
||||
|
||||
def load(self, path_img, path_lbl, batch=None):
|
||||
if batch is not None:
|
||||
if type(batch) is not list or len(batch) is not 2:
|
||||
raise ValueError('batch should be a 1-D list'
|
||||
'(start_point, batch_size)')
|
||||
|
||||
with self.opener(path_lbl, 'rb') as file:
|
||||
magic, size = struct.unpack(">II", file.read(8))
|
||||
if magic != 2049:
|
||||
raise ValueError('Magic number mismatch, expected 2049,'
|
||||
'got {}'.format(magic))
|
||||
|
||||
labels = array("B", file.read())
|
||||
|
||||
with self.opener(path_img, 'rb') as file:
|
||||
magic, size, rows, cols = struct.unpack(">IIII", file.read(16))
|
||||
if magic != 2051:
|
||||
raise ValueError('Magic number mismatch, expected 2051,'
|
||||
'got {}'.format(magic))
|
||||
|
||||
image_data = array("B", file.read())
|
||||
|
||||
if batch is not None:
|
||||
image_data = image_data[batch[0] * rows * cols:\
|
||||
(batch[0] + batch[1]) * rows * cols]
|
||||
labels = labels[batch[0]: batch[0] + batch[1]]
|
||||
size = batch[1]
|
||||
|
||||
images = []
|
||||
for i in range(size):
|
||||
images.append([0] * rows * cols)
|
||||
|
||||
for i in range(size):
|
||||
images[i][:] = image_data[i * rows * cols:(i + 1) * rows * cols]
|
||||
|
||||
# for some reason EMNIST is mirrored and rotated
|
||||
if self.emnistRotate:
|
||||
x = image_data[i * rows * cols:(i + 1) * rows * cols]
|
||||
|
||||
subs = []
|
||||
for r in range(rows):
|
||||
subs.append(x[(rows - r) * cols - cols:(rows - r)*cols])
|
||||
|
||||
l = list(zip(*reversed(subs)))
|
||||
fixed = [item for sublist in l for item in sublist]
|
||||
|
||||
images[i][:] = fixed
|
||||
|
||||
return images, labels
|
||||
|
||||
@classmethod
|
||||
def display(cls, img, width=28, threshold=200):
|
||||
render = ''
|
||||
for i in range(len(img)):
|
||||
if i % width == 0:
|
||||
render += '\n'
|
||||
if img[i] > threshold:
|
||||
render += '@'
|
||||
else:
|
||||
render += '.'
|
||||
return render
|
62
venv/Lib/site-packages/mnist/packer.py
Normal file
62
venv/Lib/site-packages/mnist/packer.py
Normal file
@ -0,0 +1,62 @@
|
||||
import gzip
|
||||
import os
|
||||
import struct
|
||||
|
||||
|
||||
def _binary_writter(data, filepath):
|
||||
with open(filepath, 'wb') as file:
|
||||
file.write(data)
|
||||
|
||||
|
||||
def _gzip_writter(data, filepath):
|
||||
with gzip.open(filepath, 'wb') as file:
|
||||
file.write(data)
|
||||
|
||||
|
||||
def img_packer(path, filename, imgs, gzip=False,
|
||||
magic=2051, rows=28, cols=28):
|
||||
data = b''
|
||||
data += struct.pack(">IIII", magic, len(imgs), rows, cols)
|
||||
|
||||
to_list = list()
|
||||
if type(imgs).__name__ == 'array':
|
||||
to_list = list(imgs)
|
||||
elif type(imgs).__name__ == 'ndarray':
|
||||
to_list = list(imgs)
|
||||
elif type(imgs).__name__ == 'list':
|
||||
to_list = imgs
|
||||
else:
|
||||
raise TypeError('Unsupported data type.')
|
||||
|
||||
for i in to_list:
|
||||
pack_format = '>' + 'B' * len(i)
|
||||
data += struct.pack(pack_format, *i)
|
||||
|
||||
if gzip:
|
||||
_gzip_writter(data, os.path.join(path, filename))
|
||||
else:
|
||||
_binary_writter(data, os.path.join(path, filename))
|
||||
|
||||
|
||||
def label_packer(path, filename, label,
|
||||
gzip=False, magic=2049):
|
||||
data = b''
|
||||
data += struct.pack(">II", magic, len(label))
|
||||
|
||||
to_list = list()
|
||||
if type(label).__name__ == 'array':
|
||||
to_list = list(label)
|
||||
elif type(label).__name__ == 'ndarray':
|
||||
to_list = list(label)
|
||||
elif type(label).__name__ == 'list':
|
||||
to_list = label
|
||||
else:
|
||||
raise TypeError('Unsupported label type.')
|
||||
|
||||
pack_format = '>' + 'B' * len(to_list)
|
||||
data += struct.pack(pack_format, *to_list)
|
||||
|
||||
if gzip:
|
||||
_gzip_writter(data, os.path.join(path, filename))
|
||||
else:
|
||||
_binary_writter(data, os.path.join(path, filename))
|
@ -0,0 +1 @@
|
||||
Richard Marko <srk(at)48(dot)io>
|
@ -0,0 +1 @@
|
||||
pip
|
25
venv/Lib/site-packages/python_mnist-0.7.dist-info/LICENSE
Normal file
25
venv/Lib/site-packages/python_mnist-0.7.dist-info/LICENSE
Normal file
@ -0,0 +1,25 @@
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, and the entire permission notice in its entirety,
|
||||
including the disclaimer of warranties.
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
3. The name of the author may not be used to endorse or promote
|
||||
products derived from this software without specific prior
|
||||
written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
|
||||
WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
|
||||
OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE, ALL OF
|
||||
WHICH ARE HEREBY DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT
|
||||
OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
|
||||
BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
USE OF THIS SOFTWARE, EVEN IF NOT ADVISED OF THE POSSIBILITY OF SUCH
|
||||
DAMAGE.
|
169
venv/Lib/site-packages/python_mnist-0.7.dist-info/METADATA
Normal file
169
venv/Lib/site-packages/python_mnist-0.7.dist-info/METADATA
Normal file
@ -0,0 +1,169 @@
|
||||
Metadata-Version: 2.1
|
||||
Name: python-mnist
|
||||
Version: 0.7
|
||||
Summary: Simple MNIST and EMNIST data parser written in pure Python
|
||||
Home-page: https://github.com/sorki/python-mnist
|
||||
Author: Richard Marko
|
||||
Author-email: srk@48.io
|
||||
License: BSD
|
||||
Platform: UNKNOWN
|
||||
Classifier: Development Status :: 5 - Production/Stable
|
||||
Classifier: Intended Audience :: Developers
|
||||
Classifier: Operating System :: OS Independent
|
||||
Classifier: Programming Language :: Python
|
||||
|
||||
python-mnist
|
||||
============
|
||||
|
||||
Simple MNIST and EMNIST data parser written in pure Python.
|
||||
|
||||
MNIST is a database of handwritten digits available on
|
||||
http://yann.lecun.com/exdb/mnist/. EMNIST is an extended MNIST database
|
||||
https://www.nist.gov/itl/iad/image-group/emnist-dataset.
|
||||
|
||||
Requirements
|
||||
------------
|
||||
|
||||
- Python 2 or Python 3
|
||||
|
||||
Usage
|
||||
-----
|
||||
|
||||
- ``git clone https://github.com/sorki/python-mnist``
|
||||
|
||||
- ``cd python-mnist``
|
||||
|
||||
- Get MNIST data:
|
||||
|
||||
::
|
||||
|
||||
./bin/mnist_get_data.sh
|
||||
|
||||
- Check preview with:
|
||||
|
||||
::
|
||||
|
||||
PYTHONPATH=. ./bin/mnist_preview
|
||||
|
||||
Installation
|
||||
------------
|
||||
|
||||
Get the package from PyPi:
|
||||
|
||||
::
|
||||
|
||||
pip install python-mnist
|
||||
|
||||
or install with ``setup.py``:
|
||||
|
||||
::
|
||||
|
||||
python setup.py install
|
||||
|
||||
Code sample:
|
||||
|
||||
::
|
||||
|
||||
from mnist import MNIST
|
||||
mndata = MNIST('./dir_with_mnist_data_files')
|
||||
images, labels = mndata.load_training()
|
||||
|
||||
To enable loading of gzip-ed files use:
|
||||
|
||||
::
|
||||
|
||||
mndata.gz = True
|
||||
|
||||
Library tries to load files named t10k-images-idx3-ubyte
|
||||
train-labels-idx1-ubyte train-images-idx3-ubyte and
|
||||
t10k-labels-idx1-ubyte. If loading throws an exception check if these
|
||||
names match.
|
||||
|
||||
EMNIST
|
||||
------
|
||||
|
||||
- Get EMNIST data:
|
||||
|
||||
::
|
||||
|
||||
./bin/emnist_get_data.sh
|
||||
|
||||
- Check preview with:
|
||||
|
||||
::
|
||||
|
||||
PYTHONPATH=. ./bin/emnist_preview
|
||||
|
||||
To use EMNIST datasets you need to call:
|
||||
|
||||
::
|
||||
|
||||
mndata.select_emnist('digits')
|
||||
|
||||
Where digits is one of the available EMNIST datasets. You can choose
|
||||
from
|
||||
|
||||
- balanced
|
||||
- byclass
|
||||
- bymerge
|
||||
- digits
|
||||
- letters
|
||||
- mnist
|
||||
|
||||
EMNIST loader uses gziped files by default, this can be disabled by by
|
||||
setting:
|
||||
|
||||
::
|
||||
|
||||
mndata.gz = False
|
||||
|
||||
You also need to unpack EMNIST files as bin/emnist_get_data.sh script
|
||||
won't do it for you. EMNIST loader also needs to mirror and rotate
|
||||
images so it is a bit slower (If this is an issue for you, you should
|
||||
repack the data to avoid mirroring and rotation on each load).
|
||||
|
||||
Notes
|
||||
-----
|
||||
|
||||
This package doesn't use numpy by design as when I've tried to find a
|
||||
working implementation all of them were based on some archaic version of
|
||||
numpy and none of them worked. This loads data files with struct.unpack
|
||||
instead.
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
::
|
||||
|
||||
$ PYTHONPATH=. ./bin/mnist_preview
|
||||
Showing num: 3
|
||||
|
||||
............................
|
||||
............................
|
||||
............................
|
||||
............................
|
||||
............................
|
||||
............................
|
||||
.............@@@@@..........
|
||||
..........@@@@@@@@@@........
|
||||
.......@@@@@@......@@.......
|
||||
.......@@@........@@@.......
|
||||
.................@@.........
|
||||
................@@@.........
|
||||
...............@@@@@........
|
||||
.............@@@............
|
||||
.............@.......@......
|
||||
.....................@......
|
||||
.....................@@.....
|
||||
....................@@......
|
||||
...................@@@......
|
||||
.................@@@@.......
|
||||
................@@@@........
|
||||
....@........@@@@@..........
|
||||
....@@@@@@@@@@@@............
|
||||
......@@@@@@................
|
||||
............................
|
||||
............................
|
||||
............................
|
||||
............................
|
||||
|
19
venv/Lib/site-packages/python_mnist-0.7.dist-info/RECORD
Normal file
19
venv/Lib/site-packages/python_mnist-0.7.dist-info/RECORD
Normal file
@ -0,0 +1,19 @@
|
||||
../../Scripts/emnist_get_data.sh,sha256=KcyB6aRMSJc1ttt0GeQHfNeOE1K0jWLD9Ou0lvwGauQ,337
|
||||
../../Scripts/emnist_preview,sha256=jrWtZ6MYGFq89clptlQcL7OA0z1ULVXfPHx_DoDdFkk,1223
|
||||
../../Scripts/emnist_repack,sha256=g1RocXo1ZODzpSl2_uSK0VDieABwwKWzD2uoDvKp7Vo,1191
|
||||
../../Scripts/mnist_get_data.sh,sha256=pV1P4dtLPGNPgfIn4zcyztlUmHL6nwlAONdIb1k83sc,284
|
||||
../../Scripts/mnist_preview,sha256=rp9IsSiCEsNus46JkW1a22vaGLQmgrrHb0FrxjxSM_M,851
|
||||
mnist/__init__.py,sha256=4Qo1t78Wb6oniBzjcqqrDyCXXN1qIqJeopC6CANibHM,116
|
||||
mnist/__pycache__/__init__.cpython-39.pyc,,
|
||||
mnist/__pycache__/loader.cpython-39.pyc,,
|
||||
mnist/__pycache__/packer.cpython-39.pyc,,
|
||||
mnist/loader.py,sha256=VVx63_2QInSvoQiJYGsUktysLThTSKzwv85MtPPACD0,10039
|
||||
mnist/packer.py,sha256=K3IpvL1pV4d0INYK-9aD3V201Qq8_QEVgi7rOSd6ubk,1630
|
||||
python_mnist-0.7.dist-info/AUTHORS,sha256=hIRY5VS6F-AnrVF6Zvgnw7hy0I4xvopBFafEMEzEKuo,33
|
||||
python_mnist-0.7.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
||||
python_mnist-0.7.dist-info/LICENSE,sha256=m3GKlGD-1ZUkZkISNbx560nU6erMkg16ndYoWrj9bG0,1391
|
||||
python_mnist-0.7.dist-info/METADATA,sha256=u-BxMiB8Eo3W5XVFchd5he3UPgl3syNknSrZxBvyq5U,3511
|
||||
python_mnist-0.7.dist-info/RECORD,,
|
||||
python_mnist-0.7.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
python_mnist-0.7.dist-info/WHEEL,sha256=8zNYZbwQSXoB9IfXOjPfeNwvAsALAjffgk27FqvCWbo,110
|
||||
python_mnist-0.7.dist-info/top_level.txt,sha256=LZk4DcL6HJ7bDK2qWX3Jgg25BAR1cgCUo--jQ78YiHg,6
|
6
venv/Lib/site-packages/python_mnist-0.7.dist-info/WHEEL
Normal file
6
venv/Lib/site-packages/python_mnist-0.7.dist-info/WHEEL
Normal file
@ -0,0 +1,6 @@
|
||||
Wheel-Version: 1.0
|
||||
Generator: bdist_wheel (0.33.6)
|
||||
Root-Is-Purelib: true
|
||||
Tag: py2-none-any
|
||||
Tag: py3-none-any
|
||||
|
@ -0,0 +1 @@
|
||||
mnist
|
16
venv/Scripts/emnist_get_data.sh
Normal file
16
venv/Scripts/emnist_get_data.sh
Normal file
@ -0,0 +1,16 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
if [ -d emnist_data ]; then
|
||||
echo "emnist_data directory already present, exiting"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
mkdir emnist_data
|
||||
pushd emnist_data
|
||||
#wget http://biometrics.nist.gov/cs_links/EMNIST/gzip.zip
|
||||
wget http://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip
|
||||
unzip gzip.zip
|
||||
rm -f gzip.zip
|
||||
mv gzip/* .
|
||||
rmdir gzip
|
||||
popd
|
44
venv/Scripts/emnist_preview
Normal file
44
venv/Scripts/emnist_preview
Normal file
@ -0,0 +1,44 @@
|
||||
#!c:\users\kratu\pycharmprojects\projekt_ai-automatyczny_saper\venv\scripts\python.exe
|
||||
|
||||
import random
|
||||
import argparse
|
||||
from mnist import MNIST
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--id", default=None, type=int,
|
||||
help="ID (position) of the letter to show")
|
||||
parser.add_argument("--training", action="store_true",
|
||||
help="Use training set instead of testing set")
|
||||
parser.add_argument("--dataset", default="digits",
|
||||
help="EMNIST dataset to load")
|
||||
parser.add_argument("--data", default="./emnist_data",
|
||||
help="Path to MNIST data dir")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
mn = MNIST(args.data)
|
||||
mn.select_emnist(args.dataset)
|
||||
|
||||
if args.training:
|
||||
img, label = mn.load_training()
|
||||
else:
|
||||
img, label = mn.load_testing()
|
||||
|
||||
if args.id:
|
||||
which = args.id
|
||||
else:
|
||||
which = random.randrange(0, len(label))
|
||||
|
||||
print('Showing id {}, num: {}'.format(which, label[which]))
|
||||
|
||||
# letters dataset uses A=1 B=2 ...
|
||||
if args.dataset == 'letters':
|
||||
print('Letter "{}"'.format(chr(label[which] + ord('a') - 1)))
|
||||
|
||||
print(mn.display(img[which]))
|
||||
wat = img[which]
|
||||
#import IPython
|
||||
#IPython.embed()
|
43
venv/Scripts/emnist_repack
Normal file
43
venv/Scripts/emnist_repack
Normal file
@ -0,0 +1,43 @@
|
||||
#!c:\users\kratu\pycharmprojects\projekt_ai-automatyczny_saper\venv\scripts\python.exe
|
||||
|
||||
import argparse
|
||||
import os.path
|
||||
from mnist import MNIST
|
||||
from mnist import img_packer
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data", default="./emnist_data",
|
||||
help="Path to MNIST data dir")
|
||||
parser.add_argument("--output", default=None,
|
||||
help="Where to save result")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
DATASETS = ["balanced", "byclass", "bymerge",
|
||||
"digits", "letters", "mnist"]
|
||||
|
||||
mn = MNIST(args.data)
|
||||
|
||||
if not args.output:
|
||||
dest = args.data
|
||||
train_img_fname = 'rf_' + mn.train_img_fname
|
||||
test_img_fname = 'rf_' + mn.test_img_fname
|
||||
else:
|
||||
dest = args.output
|
||||
train_img_fname = mn.train_img_fname
|
||||
test_img_fname = mn.test_img_fname
|
||||
|
||||
for dt_name in DATASETS:
|
||||
mn.select_emnist(dt_name)
|
||||
|
||||
print("========procesing {} dataset========".format(dt_name))
|
||||
|
||||
tra_img, _ = mn.load_training()
|
||||
img_packer(dest, train_img_fname,
|
||||
tra_img, gzip=True)
|
||||
|
||||
tes_img, _ = mn.load_testing()
|
||||
img_packer(dest, test_img_fname,
|
||||
tes_img, gzip=True)
|
13
venv/Scripts/mnist_get_data.sh
Normal file
13
venv/Scripts/mnist_get_data.sh
Normal file
@ -0,0 +1,13 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
if [ -d data ]; then
|
||||
echo "data directory already present, exiting"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
mkdir data
|
||||
wget --recursive --level=1 --cut-dirs=3 --no-host-directories \
|
||||
--directory-prefix=data --accept '*.gz' http://yann.lecun.com/exdb/mnist/
|
||||
pushd data
|
||||
gunzip *
|
||||
popd
|
34
venv/Scripts/mnist_preview
Normal file
34
venv/Scripts/mnist_preview
Normal file
@ -0,0 +1,34 @@
|
||||
#!c:\users\kratu\pycharmprojects\projekt_ai-automatyczny_saper\venv\scripts\python.exe
|
||||
|
||||
import random
|
||||
import argparse
|
||||
from mnist import MNIST
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--id", default=None, type=int,
|
||||
help="ID (position) of the letter to show")
|
||||
parser.add_argument("--training", action="store_true",
|
||||
help="Use training set instead of testing set")
|
||||
|
||||
parser.add_argument("--data", default="./data",
|
||||
help="Path to MNIST data dir")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
mn = MNIST(args.data)
|
||||
|
||||
if args.training:
|
||||
img, label = mn.load_training()
|
||||
else:
|
||||
img, label = mn.load_testing()
|
||||
|
||||
if args.id:
|
||||
which = args.id
|
||||
else:
|
||||
which = random.randrange(0, len(label))
|
||||
|
||||
print('Showing num: {}'.format(label[which]))
|
||||
print(mn.display(img[which]))
|
Loading…
Reference in New Issue
Block a user