from __future__ import annotations import logging import operator import re import sys import typing as t from datetime import datetime from datetime import timezone if t.TYPE_CHECKING: from _typeshed.wsgi import WSGIEnvironment from .wrappers.request import Request _logger: logging.Logger | None = None class _Missing: def __repr__(self) -> str: return "no value" def __reduce__(self) -> str: return "_missing" _missing = _Missing() @t.overload def _make_encode_wrapper(reference: str) -> t.Callable[[str], str]: ... @t.overload def _make_encode_wrapper(reference: bytes) -> t.Callable[[str], bytes]: ... def _make_encode_wrapper(reference: t.AnyStr) -> t.Callable[[str], t.AnyStr]: """Create a function that will be called with a string argument. If the reference is bytes, values will be encoded to bytes. """ if isinstance(reference, str): return lambda x: x return operator.methodcaller("encode", "latin1") def _check_str_tuple(value: tuple[t.AnyStr, ...]) -> None: """Ensure tuple items are all strings or all bytes.""" if not value: return item_type = str if isinstance(value[0], str) else bytes if any(not isinstance(item, item_type) for item in value): raise TypeError(f"Cannot mix str and bytes arguments (got {value!r})") _default_encoding = sys.getdefaultencoding() def _to_bytes( x: str | bytes, charset: str = _default_encoding, errors: str = "strict" ) -> bytes: if x is None or isinstance(x, bytes): return x if isinstance(x, (bytearray, memoryview)): return bytes(x) if isinstance(x, str): return x.encode(charset, errors) raise TypeError("Expected bytes") @t.overload def _to_str( # type: ignore x: None, charset: str | None = ..., errors: str = ..., allow_none_charset: bool = ..., ) -> None: ... @t.overload def _to_str( x: t.Any, charset: str | None = ..., errors: str = ..., allow_none_charset: bool = ..., ) -> str: ... def _to_str( x: t.Any | None, charset: str | None = _default_encoding, errors: str = "strict", allow_none_charset: bool = False, ) -> str | bytes | None: if x is None or isinstance(x, str): return x if not isinstance(x, (bytes, bytearray)): return str(x) if charset is None: if allow_none_charset: return x return x.decode(charset, errors) # type: ignore def _wsgi_decoding_dance( s: str, charset: str = "utf-8", errors: str = "replace" ) -> str: return s.encode("latin1").decode(charset, errors) def _wsgi_encoding_dance(s: str, charset: str = "utf-8", errors: str = "strict") -> str: return s.encode(charset).decode("latin1", errors) def _get_environ(obj: WSGIEnvironment | Request) -> WSGIEnvironment: env = getattr(obj, "environ", obj) assert isinstance( env, dict ), f"{type(obj).__name__!r} is not a WSGI environment (has to be a dict)" return env def _has_level_handler(logger: logging.Logger) -> bool: """Check if there is a handler in the logging chain that will handle the given logger's effective level. """ level = logger.getEffectiveLevel() current = logger while current: if any(handler.level <= level for handler in current.handlers): return True if not current.propagate: break current = current.parent # type: ignore return False class _ColorStreamHandler(logging.StreamHandler): """On Windows, wrap stream with Colorama for ANSI style support.""" def __init__(self) -> None: try: import colorama except ImportError: stream = None else: stream = colorama.AnsiToWin32(sys.stderr) super().__init__(stream) def _log(type: str, message: str, *args: t.Any, **kwargs: t.Any) -> None: """Log a message to the 'werkzeug' logger. The logger is created the first time it is needed. If there is no level set, it is set to :data:`logging.INFO`. If there is no handler for the logger's effective level, a :class:`logging.StreamHandler` is added. """ global _logger if _logger is None: _logger = logging.getLogger("werkzeug") if _logger.level == logging.NOTSET: _logger.setLevel(logging.INFO) if not _has_level_handler(_logger): _logger.addHandler(_ColorStreamHandler()) getattr(_logger, type)(message.rstrip(), *args, **kwargs) @t.overload def _dt_as_utc(dt: None) -> None: ... @t.overload def _dt_as_utc(dt: datetime) -> datetime: ... def _dt_as_utc(dt: datetime | None) -> datetime | None: if dt is None: return dt if dt.tzinfo is None: return dt.replace(tzinfo=timezone.utc) elif dt.tzinfo != timezone.utc: return dt.astimezone(timezone.utc) return dt _TAccessorValue = t.TypeVar("_TAccessorValue") class _DictAccessorProperty(t.Generic[_TAccessorValue]): """Baseclass for `environ_property` and `header_property`.""" read_only = False def __init__( self, name: str, default: _TAccessorValue | None = None, load_func: t.Callable[[str], _TAccessorValue] | None = None, dump_func: t.Callable[[_TAccessorValue], str] | None = None, read_only: bool | None = None, doc: str | None = None, ) -> None: self.name = name self.default = default self.load_func = load_func self.dump_func = dump_func if read_only is not None: self.read_only = read_only self.__doc__ = doc def lookup(self, instance: t.Any) -> t.MutableMapping[str, t.Any]: raise NotImplementedError @t.overload def __get__( self, instance: None, owner: type ) -> _DictAccessorProperty[_TAccessorValue]: ... @t.overload def __get__(self, instance: t.Any, owner: type) -> _TAccessorValue: ... def __get__( self, instance: t.Any | None, owner: type ) -> _TAccessorValue | _DictAccessorProperty[_TAccessorValue]: if instance is None: return self storage = self.lookup(instance) if self.name not in storage: return self.default # type: ignore value = storage[self.name] if self.load_func is not None: try: return self.load_func(value) except (ValueError, TypeError): return self.default # type: ignore return value # type: ignore def __set__(self, instance: t.Any, value: _TAccessorValue) -> None: if self.read_only: raise AttributeError("read only property") if self.dump_func is not None: self.lookup(instance)[self.name] = self.dump_func(value) else: self.lookup(instance)[self.name] = value def __delete__(self, instance: t.Any) -> None: if self.read_only: raise AttributeError("read only property") self.lookup(instance).pop(self.name, None) def __repr__(self) -> str: return f"<{type(self).__name__} {self.name}>" def _decode_idna(domain: str) -> str: try: data = domain.encode("ascii") except UnicodeEncodeError: # If the domain is not ASCII, it's decoded already. return domain try: # Try decoding in one shot. return data.decode("idna") except UnicodeDecodeError: pass # Decode each part separately, leaving invalid parts as punycode. parts = [] for part in data.split(b"."): try: parts.append(part.decode("idna")) except UnicodeDecodeError: parts.append(part.decode("ascii")) return ".".join(parts) _plain_int_re = re.compile(r"-?\d+", re.ASCII) _plain_float_re = re.compile(r"-?\d+\.\d+", re.ASCII) def _plain_int(value: str) -> int: """Parse an int only if it is only ASCII digits and ``-``. This disallows ``+``, ``_``, and non-ASCII digits, which are accepted by ``int`` but are not allowed in HTTP header values. """ if _plain_int_re.fullmatch(value) is None: raise ValueError return int(value) def _plain_float(value: str) -> float: """Parse a float only if it is only ASCII digits and ``-``, and contains digits before and after the ``.``. This disallows ``+``, ``_``, non-ASCII digits, and ``.123``, which are accepted by ``float`` but are not allowed in HTTP header values. """ if _plain_float_re.fullmatch(value) is None: raise ValueError return float(value)