import sys
import ctypes
from ctypes import *
import unittest

__all__ = ['PAI_CONTIGUOUS', 'PAI_FORTRAN', 'PAI_ALIGNED',
           'PAI_NOTSWAPPED', 'PAI_WRITEABLE', 'PAI_ARR_HAS_DESCR',
           'ArrayInterface',]

try:
    c_ssize_t  # Undefined in early Python versions
except NameError:
    if sizeof(c_uint) == sizeof(c_void_p):
        c_size_t = c_uint
        c_ssize_t = c_int
    elif sizeof(c_ulong) == sizeof(c_void_p):
        c_size_t = c_ulong
        c_ssize_t = c_long
    elif sizeof(c_ulonglong) == sizeof(c_void_p):
        c_size_t = c_ulonglong
        c_ssize_t = c_longlong


SIZEOF_VOID_P = sizeof(c_void_p)
if SIZEOF_VOID_P <= sizeof(c_int):
    Py_intptr_t = c_int
elif SIZEOF_VOID_P <= sizeof(c_long):
    Py_intptr_t = c_long
elif 'c_longlong' in globals() and SIZEOF_VOID_P <= sizeof(c_longlong):
    Py_intptr_t = c_longlong
else:
    raise RuntimeError("Unrecognized pointer size %i" % (pointer_size,))

class PyArrayInterface(Structure):
    _fields_ = [('two', c_int), ('nd', c_int), ('typekind', c_char),
                ('itemsize', c_int), ('flags', c_int),
                ('shape', POINTER(Py_intptr_t)),
                ('strides', POINTER(Py_intptr_t)),
                ('data', c_void_p), ('descr', py_object)]

PAI_Ptr = POINTER(PyArrayInterface)
try:
    PyCObject_AsVoidPtr = pythonapi.PyCObject_AsVoidPtr
except AttributeError:
    def PyCObject_AsVoidPtr(o):
        raise TypeError("Not available")
else:
    PyCObject_AsVoidPtr.restype = c_void_p
    PyCObject_AsVoidPtr.argtypes = [py_object]
    PyCObject_GetDesc = pythonapi.PyCObject_GetDesc
    PyCObject_GetDesc.restype = c_void_p
    PyCObject_GetDesc.argtypes = [py_object]
try:
    PyCapsule_IsValid = pythonapi.PyCapsule_IsValid
except AttributeError:
    def PyCapsule_IsValid(capsule, name):
        return 0
else:
    PyCapsule_IsValid.restype = c_int
    PyCapsule_IsValid.argtypes = [py_object, c_char_p]
    PyCapsule_GetPointer = pythonapi.PyCapsule_GetPointer
    PyCapsule_GetPointer.restype = c_void_p
    PyCapsule_GetPointer.argtypes = [py_object, c_char_p]
    PyCapsule_GetContext = pythonapi.PyCapsule_GetContext
    PyCapsule_GetContext.restype = c_void_p
    PyCapsule_GetContext.argtypes = [py_object]

if sys.version_info >= (3,): # Python3
    PyCapsule_Destructor = CFUNCTYPE(None, py_object)
    PyCapsule_New = pythonapi.PyCapsule_New
    PyCapsule_New.restype = py_object
    PyCapsule_New.argtypes = [c_void_p, c_char_p, POINTER(PyCapsule_Destructor)]
    def capsule_new(p):
        return PyCapsule_New(addressof(p), None, None)
else:
    PyCObject_Destructor = CFUNCTYPE(None, c_void_p)
    PyCObject_FromVoidPtr = pythonapi.PyCObject_FromVoidPtr
    PyCObject_FromVoidPtr.restype = py_object
    PyCObject_FromVoidPtr.argtypes = [c_void_p, POINTER(PyCObject_Destructor)]
    def capsule_new(p):
        return PyCObject_FromVoidPtr(addressof(p), None)

PAI_CONTIGUOUS = 0x01
PAI_FORTRAN = 0x02
PAI_ALIGNED = 0x100
PAI_NOTSWAPPED = 0x200
PAI_WRITEABLE = 0x400
PAI_ARR_HAS_DESCR = 0x800

class ArrayInterface(object):
    def __init__(self, arr):
        try:
            self._cobj = arr.__array_struct__
        except AttributeError:
            raise TypeError("The array object lacks an array structure")
        if not self._cobj:
            raise TypeError("The array object has a NULL array structure value")
        try:
            vp = PyCObject_AsVoidPtr(self._cobj)
        except TypeError:
            if PyCapsule_IsValid(self._cobj, None):
                vp = PyCapsule_GetPointer(self._cobj, None)
            else:
                raise TypeError("The array object has an invalid array structure")
            self.desc = PyCapsule_GetContext(self._cobj)
        else:
            self.desc = PyCObject_GetDesc(self._cobj)
        self._inter = cast(vp, PAI_Ptr)[0]

    def __getattr__(self, name):
        if (name == 'typekind'):
            return self._inter.typekind.decode('latin-1')
        return getattr(self._inter, name)

    def __str__(self):
        if isinstance(self.desc, tuple):
            ver = self.desc[0]
        else:
            ver = "N/A"
        return ("nd: %i\n"
                "typekind: %s\n"
                "itemsize: %i\n"
                "flags: %s\n"
                "shape: %s\n"
                "strides: %s\n"
                "ver: %s\n" %
                (self.nd, self.typekind, self.itemsize,
                 format_flags(self.flags),
                 format_shape(self.nd, self.shape),
                 format_strides(self.nd, self.strides), ver))

def format_flags(flags):
    names = []
    for flag, name in [(PAI_CONTIGUOUS, 'CONTIGUOUS'),
                       (PAI_FORTRAN, 'FORTRAN'),
                       (PAI_ALIGNED, 'ALIGNED'),
                       (PAI_NOTSWAPPED, 'NOTSWAPPED'),
                       (PAI_WRITEABLE, 'WRITEABLE'),
                       (PAI_ARR_HAS_DESCR, 'ARR_HAS_DESCR')]:
        if flag & flags:
            names.append(name)
    return ', '.join(names)

def format_shape(nd, shape):
    return ', '.join([str(shape[i]) for i in range(nd)])

def format_strides(nd, strides):
    return ', '.join([str(strides[i]) for i in range(nd)])

class Exporter(object):
    def __init__(self, shape,
                 typekind=None, itemsize=None, strides=None,
                 descr=None, flags=None):
        if typekind is None:
            typekind = 'u'
        if itemsize is None:
            itemsize = 1
        if flags is None:
            flags = PAI_WRITEABLE | PAI_ALIGNED | PAI_NOTSWAPPED
        if descr is not None:
            flags |= PAI_ARR_HAS_DESCR
        if len(typekind) != 1:
            raise ValueError("Argument 'typekind' must be length 1 string")
        nd = len(shape)
        self.typekind = typekind
        self.itemsize = itemsize
        self.nd = nd
        self.shape = tuple(shape)
        self._shape = (c_ssize_t * self.nd)(*self.shape)
        if strides is None:
            self._strides = (c_ssize_t * self.nd)()
            self._strides[self.nd - 1] = self.itemsize
            for i in range(self.nd - 1, 0, -1):
                self._strides[i - 1] = self.shape[i] * self._strides[i]
            strides = tuple(self._strides)
            self.strides = strides
        elif len(strides) == nd:
            self.strides = tuple(strides)
            self._strides = (c_ssize_t * self.nd)(*self.strides)
        else:
            raise ValueError("Mismatch in length of strides and shape")
        self.descr = descr
        if self.is_contiguous('C'):
            flags |= PAI_CONTIGUOUS
        if self.is_contiguous('F'):
            flags |= PAI_FORTRAN
        self.flags = flags
        sz = max(shape[i] * strides[i] for i in range(nd))
        self._data = (c_ubyte * sz)()
        self.data = addressof(self._data)
        self._inter = PyArrayInterface(2, nd, typekind.encode('latin_1'),
                                       itemsize, flags, self._shape,
                                       self._strides, self.data, descr)
        self.len = itemsize
        for i in range(nd):
            self.len *= self.shape[i]

    __array_struct__ = property(lambda self: capsule_new(self._inter))

    def is_contiguous(self, fortran):
        if fortran in "CA":
            if self.strides[-1] == self.itemsize:
                for i in range(self.nd - 1, 0, -1):
                    if self.strides[i - 1] != self.shape[i] * self.strides[i]:
                        break
                else:
                    return True
        if fortran in "FA":
            if self.strides[0] == self.itemsize:
                for i in range(0, self.nd - 1):
                    if self.strides[i + 1] != self.shape[i] * self.strides[i]:
                        break
                else:
                    return True
        return False

class Array(Exporter):
    _ctypes = {('u', 1): c_uint8,
               ('u', 2): c_uint16,
               ('u', 4): c_uint32,
               ('u', 8): c_uint64,
               ('i', 1): c_int8,
               ('i', 2): c_int16,
               ('i', 4): c_int32,
               ('i', 8): c_int64}

    def __init__(self, *args, **kwds):
        super(Array, self).__init__(*args, **kwds)
        try:
            if self.flags & PAI_NOTSWAPPED:
                ct = self._ctypes[self.typekind, self.itemsize]
            elif c_int.__ctype_le__ is c_int:
                ct = self._ctypes[self.typekind, self.itemsize].__ctype_be__
            else:
                ct = self._ctypes[self.typekind, self.itemsize].__ctype_le__
        except KeyError:
            ct = c_uint8 * self.itemsize
        self._ctype = ct
        self._ctype_p = POINTER(ct)
    def __getitem__(self, key):
        return cast(self._addr_at(key), self._ctype_p)[0]
    def __setitem__(self, key, value):
        cast(self._addr_at(key), self._ctype_p)[0] = value
    def _addr_at(self, key):
        if not isinstance(key, tuple):
            key = key,
        if len(key) != self.nd:
            raise ValueError("wrong number of indexes")
        for i in range(self.nd):
            if not (0 <= key[i] < self.shape[i]):
                raise IndexError("index {} out of range".format(i))
        return self.data + sum(i * s for i, s in zip(key, self.strides))

class ExporterTest(unittest.TestCase):
    def test_strides(self):
        self.check_args(0, (10,), 'u', (2,), 20, 20, 2)
        self.check_args(0, (5, 3), 'u', (6, 2), 30, 30, 2)
        self.check_args(0, (7, 3, 5), 'u', (30, 10, 2), 210, 210, 2)
        self.check_args(0, (13, 5, 11, 3), 'u', (330, 66, 6, 2), 4290, 4290, 2)
        self.check_args(3, (7, 3, 5), 'i', (2, 14, 42), 210, 210, 2)
        self.check_args(3, (7, 3, 5), 'x', (2, 16, 48), 210, 240, 2)
        self.check_args(3, (13, 5, 11, 3), '%', (440, 88, 8, 2), 4290, 5720, 2)
        self.check_args(3, (7, 5), '-', (15, 3), 105, 105, 3)
        self.check_args(3, (7, 5), '*', (3, 21), 105, 105, 3)
        self.check_args(3, (7, 5), ' ', (3, 24), 105, 120, 3)

    def test_is_contiguous(self):
        a = Exporter((10,), itemsize=2)
        self.assertTrue(a.is_contiguous('C'))
        self.assertTrue(a.is_contiguous('F'))
        self.assertTrue(a.is_contiguous('A'))
        a = Exporter((10, 4), itemsize=2)
        self.assertTrue(a.is_contiguous('C'))
        self.assertTrue(a.is_contiguous('A'))
        self.assertFalse(a.is_contiguous('F'))
        a = Exporter((13, 5, 11, 3), itemsize=2, strides=(330, 66, 6, 2))
        self.assertTrue(a.is_contiguous('C'))
        self.assertTrue(a.is_contiguous('A'))
        self.assertFalse(a.is_contiguous('F'))
        a = Exporter((10, 4), itemsize=2, strides=(2, 20))
        self.assertTrue(a.is_contiguous('F'))
        self.assertTrue(a.is_contiguous('A'))
        self.assertFalse(a.is_contiguous('C'))
        a = Exporter((13, 5, 11, 3), itemsize=2, strides=(2, 26, 130, 1430))
        self.assertTrue(a.is_contiguous('F'))
        self.assertTrue(a.is_contiguous('A'))
        self.assertFalse(a.is_contiguous('C'))
        a = Exporter((2, 11, 6, 4), itemsize=2, strides=(576, 48, 8, 2))
        self.assertFalse(a.is_contiguous('A'))
        a = Exporter((2, 11, 6, 4), itemsize=2, strides=(2, 4, 48, 288))
        self.assertFalse(a.is_contiguous('A'))
        a = Exporter((3, 2, 2), itemsize=2, strides=(16, 8, 4))
        self.assertFalse(a.is_contiguous('A'))
        a = Exporter((3, 2, 2), itemsize=2, strides=(4, 12, 24))
        self.assertFalse(a.is_contiguous('A'))

    def check_args(self, call_flags,
                   shape, typekind, strides, length, bufsize, itemsize,
                   offset=0):
        if call_flags & 1:
            typekind_arg = typekind
        else:
            typekind_arg = None
        if call_flags & 2:
            strides_arg = strides
        else:
            strides_arg = None
        a = Exporter(shape, itemsize=itemsize, strides=strides_arg)
        self.assertEqual(sizeof(a._data), bufsize)
        self.assertEqual(a.data, ctypes.addressof(a._data) + offset)
        m = ArrayInterface(a)
        self.assertEqual(m.data, a.data)
        self.assertEqual(m.itemsize, itemsize)
        self.assertEqual(tuple(m.shape[0:m.nd]), shape)
        self.assertEqual(tuple(m.strides[0:m.nd]), strides)

class ArrayTest(unittest.TestCase):

    def __init__(self, *args, **kwds):
        unittest.TestCase.__init__(self, *args, **kwds)
        self.a = Array((20, 15), 'i', 4)

    def setUp(self):
        # Every test starts with a zeroed array.
        memset(self.a.data, 0, sizeof(self.a._data))

    def test__addr_at(self):
        a = self.a
        self.assertEqual(a._addr_at((0, 0)), a.data)
        self.assertEqual(a._addr_at((0, 1)), a.data + 4)
        self.assertEqual(a._addr_at((1, 0)), a.data + 60)
        self.assertEqual(a._addr_at((1, 1)), a.data + 64)

    def test_indices(self):
        a = self.a
        self.assertEqual(a[0, 0], 0)
        self.assertEqual(a[19, 0], 0)
        self.assertEqual(a[0, 14], 0)
        self.assertEqual(a[19, 14], 0)
        self.assertEqual(a[5, 8], 0)
        a[0, 0] = 12
        a[5, 8] = 99
        self.assertEqual(a[0, 0], 12)
        self.assertEqual(a[5, 8], 99)
        self.assertRaises(IndexError, a.__getitem__, (-1, 0))
        self.assertRaises(IndexError, a.__getitem__, (0, -1))
        self.assertRaises(IndexError, a.__getitem__, (20, 0))
        self.assertRaises(IndexError, a.__getitem__, (0, 15))
        self.assertRaises(ValueError, a.__getitem__, 0)
        self.assertRaises(ValueError, a.__getitem__, (0, 0, 0))
        a = Array((3,), 'i', 4)
        a[1] = 333
        self.assertEqual(a[1], 333)

    def test_typekind(self):
        a = Array((1,), 'i', 4)
        self.assertTrue(a._ctype is c_int32)
        self.assertTrue(a._ctype_p is POINTER(c_int32))
        a = Array((1,), 'u', 4)
        self.assertTrue(a._ctype is c_uint32)
        self.assertTrue(a._ctype_p is POINTER(c_uint32))
        a = Array((1,), 'f', 4) # float types unsupported: size system dependent
        ct = a._ctype
        self.assertTrue(issubclass(ct, ctypes.Array))
        self.assertEqual(sizeof(ct), 4)

    def test_itemsize(self):
        for size in [1, 2, 4, 8]:
            a = Array((1,), 'i', size)
            ct = a._ctype
            self.assertTrue(issubclass(ct, ctypes._SimpleCData))
            self.assertEqual(sizeof(ct), size)

    def test_oddball_itemsize(self):
        for size in [3, 5, 6, 7, 9]:
            a = Array((1,), 'i', size)
            ct = a._ctype
            self.assertTrue(issubclass(ct, ctypes.Array))
            self.assertEqual(sizeof(ct), size)

    def test_byteswapped(self):
        a = Array((1,), 'u', 4, flags=(PAI_ALIGNED | PAI_WRITEABLE))
        ct = a._ctype
        self.assertTrue(ct is not c_uint32)
        if sys.byteorder == 'little':
            self.assertTrue(ct is c_uint32.__ctype_be__)
        else:
            self.assertTrue(ct is c_uint32.__ctype_le__)
        i = 0xa0b0c0d
        n = c_uint32(i)
        a[0] = i
        self.assertEqual(a[0], i)
        self.assertEqual(a._data[0:4],
                         cast(addressof(n), POINTER(c_uint8))[3:-1:-1])


if __name__ == '__main__':
    unittest.main()