Projekt_AI-Automatyczny_saper/venv/Lib/site-packages/mnist/packer.py
2021-06-01 18:53:56 +02:00

63 lines
1.6 KiB
Python

import gzip
import os
import struct
def _binary_writter(data, filepath):
with open(filepath, 'wb') as file:
file.write(data)
def _gzip_writter(data, filepath):
with gzip.open(filepath, 'wb') as file:
file.write(data)
def img_packer(path, filename, imgs, gzip=False,
magic=2051, rows=28, cols=28):
data = b''
data += struct.pack(">IIII", magic, len(imgs), rows, cols)
to_list = list()
if type(imgs).__name__ == 'array':
to_list = list(imgs)
elif type(imgs).__name__ == 'ndarray':
to_list = list(imgs)
elif type(imgs).__name__ == 'list':
to_list = imgs
else:
raise TypeError('Unsupported data type.')
for i in to_list:
pack_format = '>' + 'B' * len(i)
data += struct.pack(pack_format, *i)
if gzip:
_gzip_writter(data, os.path.join(path, filename))
else:
_binary_writter(data, os.path.join(path, filename))
def label_packer(path, filename, label,
gzip=False, magic=2049):
data = b''
data += struct.pack(">II", magic, len(label))
to_list = list()
if type(label).__name__ == 'array':
to_list = list(label)
elif type(label).__name__ == 'ndarray':
to_list = list(label)
elif type(label).__name__ == 'list':
to_list = label
else:
raise TypeError('Unsupported label type.')
pack_format = '>' + 'B' * len(to_list)
data += struct.pack(pack_format, *to_list)
if gzip:
_gzip_writter(data, os.path.join(path, filename))
else:
_binary_writter(data, os.path.join(path, filename))