###############################################################################
# Synchronization primitives based on our SemLock implementation
#
# author: Thomas Moreau and Olivier Grisel
#
# adapted from multiprocessing/synchronize.py (17/02/2017)
#  * Remove ctx argument for compatibility reason
#  * Registers a cleanup function with the loky resource_tracker to remove the
#    semaphore when the process dies instead.
#
# TODO: investigate which Python version is required to be able to use
# multiprocessing.resource_tracker and therefore multiprocessing.synchronize
# instead of a loky-specific fork.

import os
import sys
import tempfile
import threading
import _multiprocessing
from time import time as _time
from multiprocessing import process, util
from multiprocessing.context import assert_spawning

from . import resource_tracker

__all__ = [
    "Lock",
    "RLock",
    "Semaphore",
    "BoundedSemaphore",
    "Condition",
    "Event",
]
# Try to import the mp.synchronize module cleanly, if it fails
# raise ImportError for platforms lacking a working sem_open implementation.
# See issue 3770
try:
    from _multiprocessing import SemLock as _SemLock
    from _multiprocessing import sem_unlink
except ImportError:
    raise ImportError(
        "This platform lacks a functioning sem_open"
        " implementation, therefore, the required"
        " synchronization primitives needed will not"
        " function, see issue 3770."
    )

#
# Constants
#

RECURSIVE_MUTEX, SEMAPHORE = range(2)
SEM_VALUE_MAX = _multiprocessing.SemLock.SEM_VALUE_MAX


#
# Base class for semaphores and mutexes; wraps `_multiprocessing.SemLock`
#


class SemLock:

    _rand = tempfile._RandomNameSequence()

    def __init__(self, kind, value, maxvalue, name=None):
        # unlink_now is only used on win32 or when we are using fork.
        unlink_now = False
        if name is None:
            # Try to find an unused name for the SemLock instance.
            for _ in range(100):
                try:
                    self._semlock = _SemLock(
                        kind, value, maxvalue, SemLock._make_name(), unlink_now
                    )
                except FileExistsError:  # pragma: no cover
                    pass
                else:
                    break
            else:  # pragma: no cover
                raise FileExistsError("cannot find name for semaphore")
        else:
            self._semlock = _SemLock(kind, value, maxvalue, name, unlink_now)
        self.name = name
        util.debug(
            f"created semlock with handle {self._semlock.handle} and name "
            f'"{self.name}"'
        )

        self._make_methods()

        def _after_fork(obj):
            obj._semlock._after_fork()

        util.register_after_fork(self, _after_fork)

        # When the object is garbage collected or the
        # process shuts down we unlink the semaphore name
        resource_tracker.register(self._semlock.name, "semlock")
        util.Finalize(
            self, SemLock._cleanup, (self._semlock.name,), exitpriority=0
        )

    @staticmethod
    def _cleanup(name):
        try:
            sem_unlink(name)
        except FileNotFoundError:
            # Already unlinked, possibly by user code: ignore and make sure to
            # unregister the semaphore from the resource tracker.
            pass
        finally:
            resource_tracker.unregister(name, "semlock")

    def _make_methods(self):
        self.acquire = self._semlock.acquire
        self.release = self._semlock.release

    def __enter__(self):
        return self._semlock.acquire()

    def __exit__(self, *args):
        return self._semlock.release()

    def __getstate__(self):
        assert_spawning(self)
        sl = self._semlock
        h = sl.handle
        return (h, sl.kind, sl.maxvalue, sl.name)

    def __setstate__(self, state):
        self._semlock = _SemLock._rebuild(*state)
        util.debug(
            f'recreated blocker with handle {state[0]!r} and name "{state[3]}"'
        )
        self._make_methods()

    @staticmethod
    def _make_name():
        # OSX does not support long names for semaphores
        return f"/loky-{os.getpid()}-{next(SemLock._rand)}"


#
# Semaphore
#


class Semaphore(SemLock):
    def __init__(self, value=1):
        SemLock.__init__(self, SEMAPHORE, value, SEM_VALUE_MAX)

    def get_value(self):
        if sys.platform == "darwin":
            raise NotImplementedError("OSX does not implement sem_getvalue")
        return self._semlock._get_value()

    def __repr__(self):
        try:
            value = self._semlock._get_value()
        except Exception:
            value = "unknown"
        return f"<{self.__class__.__name__}(value={value})>"


#
# Bounded semaphore
#


class BoundedSemaphore(Semaphore):
    def __init__(self, value=1):
        SemLock.__init__(self, SEMAPHORE, value, value)

    def __repr__(self):
        try:
            value = self._semlock._get_value()
        except Exception:
            value = "unknown"
        return (
            f"<{self.__class__.__name__}(value={value}, "
            f"maxvalue={self._semlock.maxvalue})>"
        )


#
# Non-recursive lock
#


class Lock(SemLock):
    def __init__(self):
        super().__init__(SEMAPHORE, 1, 1)

    def __repr__(self):
        try:
            if self._semlock._is_mine():
                name = process.current_process().name
                if threading.current_thread().name != "MainThread":
                    name = f"{name}|{threading.current_thread().name}"
            elif self._semlock._get_value() == 1:
                name = "None"
            elif self._semlock._count() > 0:
                name = "SomeOtherThread"
            else:
                name = "SomeOtherProcess"
        except Exception:
            name = "unknown"
        return f"<{self.__class__.__name__}(owner={name})>"


#
# Recursive lock
#


class RLock(SemLock):
    def __init__(self):
        super().__init__(RECURSIVE_MUTEX, 1, 1)

    def __repr__(self):
        try:
            if self._semlock._is_mine():
                name = process.current_process().name
                if threading.current_thread().name != "MainThread":
                    name = f"{name}|{threading.current_thread().name}"
                count = self._semlock._count()
            elif self._semlock._get_value() == 1:
                name, count = "None", 0
            elif self._semlock._count() > 0:
                name, count = "SomeOtherThread", "nonzero"
            else:
                name, count = "SomeOtherProcess", "nonzero"
        except Exception:
            name, count = "unknown", "unknown"
        return f"<{self.__class__.__name__}({name}, {count})>"


#
# Condition variable
#


class Condition:
    def __init__(self, lock=None):
        self._lock = lock or RLock()
        self._sleeping_count = Semaphore(0)
        self._woken_count = Semaphore(0)
        self._wait_semaphore = Semaphore(0)
        self._make_methods()

    def __getstate__(self):
        assert_spawning(self)
        return (
            self._lock,
            self._sleeping_count,
            self._woken_count,
            self._wait_semaphore,
        )

    def __setstate__(self, state):
        (
            self._lock,
            self._sleeping_count,
            self._woken_count,
            self._wait_semaphore,
        ) = state
        self._make_methods()

    def __enter__(self):
        return self._lock.__enter__()

    def __exit__(self, *args):
        return self._lock.__exit__(*args)

    def _make_methods(self):
        self.acquire = self._lock.acquire
        self.release = self._lock.release

    def __repr__(self):
        try:
            num_waiters = (
                self._sleeping_count._semlock._get_value()
                - self._woken_count._semlock._get_value()
            )
        except Exception:
            num_waiters = "unknown"
        return f"<{self.__class__.__name__}({self._lock}, {num_waiters})>"

    def wait(self, timeout=None):
        assert (
            self._lock._semlock._is_mine()
        ), "must acquire() condition before using wait()"

        # indicate that this thread is going to sleep
        self._sleeping_count.release()

        # release lock
        count = self._lock._semlock._count()
        for _ in range(count):
            self._lock.release()

        try:
            # wait for notification or timeout
            return self._wait_semaphore.acquire(True, timeout)
        finally:
            # indicate that this thread has woken
            self._woken_count.release()

            # reacquire lock
            for _ in range(count):
                self._lock.acquire()

    def notify(self):
        assert self._lock._semlock._is_mine(), "lock is not owned"
        assert not self._wait_semaphore.acquire(False)

        # to take account of timeouts since last notify() we subtract
        # woken_count from sleeping_count and rezero woken_count
        while self._woken_count.acquire(False):
            res = self._sleeping_count.acquire(False)
            assert res

        if self._sleeping_count.acquire(False):  # try grabbing a sleeper
            self._wait_semaphore.release()  # wake up one sleeper
            self._woken_count.acquire()  # wait for the sleeper to wake

            # rezero _wait_semaphore in case a timeout just happened
            self._wait_semaphore.acquire(False)

    def notify_all(self):
        assert self._lock._semlock._is_mine(), "lock is not owned"
        assert not self._wait_semaphore.acquire(False)

        # to take account of timeouts since last notify*() we subtract
        # woken_count from sleeping_count and rezero woken_count
        while self._woken_count.acquire(False):
            res = self._sleeping_count.acquire(False)
            assert res

        sleepers = 0
        while self._sleeping_count.acquire(False):
            self._wait_semaphore.release()  # wake up one sleeper
            sleepers += 1

        if sleepers:
            for _ in range(sleepers):
                self._woken_count.acquire()  # wait for a sleeper to wake

            # rezero wait_semaphore in case some timeouts just happened
            while self._wait_semaphore.acquire(False):
                pass

    def wait_for(self, predicate, timeout=None):
        result = predicate()
        if result:
            return result
        if timeout is not None:
            endtime = _time() + timeout
        else:
            endtime = None
            waittime = None
        while not result:
            if endtime is not None:
                waittime = endtime - _time()
                if waittime <= 0:
                    break
            self.wait(waittime)
            result = predicate()
        return result


#
# Event
#


class Event:
    def __init__(self):
        self._cond = Condition(Lock())
        self._flag = Semaphore(0)

    def is_set(self):
        with self._cond:
            if self._flag.acquire(False):
                self._flag.release()
                return True
            return False

    def set(self):
        with self._cond:
            self._flag.acquire(False)
            self._flag.release()
            self._cond.notify_all()

    def clear(self):
        with self._cond:
            self._flag.acquire(False)

    def wait(self, timeout=None):
        with self._cond:
            if self._flag.acquire(False):
                self._flag.release()
            else:
                self._cond.wait(timeout)

            if self._flag.acquire(False):
                self._flag.release()
                return True
            return False