Spaces:
Sleeping
Sleeping
| import inspect | |
| import sys | |
| import typing | |
| from collections.abc import Iterable, Sequence | |
| from enum import Flag | |
| from types import UnionType | |
| from typing import Annotated, Any, Union, get_args, get_origin | |
| import attrs | |
| from cyclopts.utils import is_class_and_subclass | |
| if sys.version_info < (3, 11): # pragma: no cover | |
| from typing_extensions import NotRequired, Required | |
| else: # pragma: no cover | |
| from typing import NotRequired, Required | |
| if sys.version_info >= (3, 12): # pragma: no cover | |
| from typing import TypeAliasType | |
| else: # pragma: no cover | |
| TypeAliasType = None | |
| # from types import NoneType is available >=3.10 | |
| NoneType = type(None) | |
| AnnotatedType = type(Annotated[int, 0]) | |
| ITERABLE_TYPES = { | |
| Iterable, | |
| typing.Sequence, | |
| Sequence, | |
| frozenset, | |
| list, | |
| set, | |
| tuple, | |
| } | |
| def is_nonetype(hint): | |
| return hint is NoneType | |
| def is_union(type_: type | None) -> bool: | |
| """Checks if a type is a union.""" | |
| # Direct checks are faster than checking if the type is in a set that contains the union-types. | |
| if type_ is Union or type_ is UnionType: | |
| return True | |
| # The ``get_origin`` call is relatively expensive, so we'll check common types | |
| # that are passed in here to see if we can avoid calling ``get_origin``. | |
| if type_ is str or type_ is int or type_ is float or type_ is bool or is_annotated(type_): | |
| return False | |
| origin = get_origin(type_) | |
| return origin is Union or origin is UnionType | |
| def is_pydantic(hint) -> bool: | |
| return hasattr(hint, "__pydantic_core_schema__") | |
| def is_pydantic_secret(hint) -> bool: | |
| """Check if a type is a Pydantic secret type (SecretStr, SecretBytes, Secret, etc.).""" | |
| return ( | |
| hasattr(hint, "__module__") | |
| and hint.__module__ == "pydantic.types" | |
| and hasattr(hint, "get_secret_value") | |
| and callable(getattr(hint, "get_secret_value", None)) | |
| ) | |
| def is_dataclass(hint) -> bool: | |
| return hasattr(hint, "__dataclass_fields__") | |
| def is_namedtuple(hint) -> bool: | |
| return is_class_and_subclass(hint, tuple) and hasattr(hint, "_fields") | |
| def is_attrs(hint) -> bool: | |
| return attrs.has(hint) | |
| def is_enum_flag(hint) -> bool: | |
| """Check if a type hint is an enum.Flag subclass.""" | |
| return is_class_and_subclass(hint, Flag) | |
| def is_annotated(hint) -> bool: | |
| return type(hint) is AnnotatedType | |
| def is_iterable_type(hint) -> bool: | |
| """Check if a type hint is a collection/iterable type (list, set, tuple, etc.). | |
| Handles Annotated, Optional, TypeAlias, and NewType wrappers. | |
| """ | |
| hint = resolve(hint) | |
| origin = get_origin(hint) | |
| return is_class_and_subclass(origin, tuple(ITERABLE_TYPES)) | |
| def contains_hint(hint, target_type) -> bool: | |
| """Indicates if ``target_type`` is in a possibly annotated/unioned ``hint``. | |
| E.g. ``contains_hint(Union[int, str], str) == True`` | |
| """ | |
| hint = resolve(hint) | |
| if is_union(hint): | |
| return any(contains_hint(x, target_type) for x in get_args(hint)) | |
| else: | |
| return is_class_and_subclass(hint, target_type) | |
| def is_typeddict(hint) -> bool: | |
| """Determine if a type annotation is a TypedDict. | |
| This is surprisingly hard! Modified from Beartype's implementation: | |
| https://github.com/beartype/beartype/blob/main/beartype/_util/hint/pep/proposal/utilpep589.py | |
| """ | |
| hint = resolve(hint) | |
| if is_union(hint): | |
| return any(is_typeddict(x) for x in get_args(hint)) | |
| if not is_class_and_subclass(hint, dict): | |
| return False | |
| return ( | |
| hasattr(hint, "__annotations__") | |
| and hasattr(hint, "__total__") | |
| and hasattr(hint, "__required_keys__") | |
| and hasattr(hint, "__optional_keys__") | |
| ) | |
| def resolve( | |
| type_: Any, | |
| ) -> type: | |
| """Perform all simplifying resolutions.""" | |
| if type_ is inspect.Parameter.empty: | |
| return str | |
| type_prev = None | |
| while type_ != type_prev: | |
| type_prev = type_ | |
| type_ = resolve_type_alias(type_) | |
| type_ = resolve_annotated(type_) | |
| type_ = resolve_optional(type_) | |
| type_ = resolve_required(type_) | |
| type_ = resolve_new_type(type_) | |
| return type_ | |
| def resolve_optional(type_: Any) -> Any: | |
| """Only resolves Union's of None + one other type (i.e. Optional).""" | |
| type_ = resolve_type_alias(type_) | |
| # Python will automatically flatten out nested unions when possible. | |
| # So we don't need to loop over resolution. | |
| if not is_union(type_): | |
| return type_ | |
| non_none_types = [t for t in get_args(type_) if t is not NoneType] | |
| if not non_none_types: # pragma: no cover | |
| # This should never happen; python simplifies: | |
| # ``Union[None, None] -> NoneType`` | |
| raise ValueError("Union type cannot be all NoneType") | |
| elif len(non_none_types) == 1: | |
| type_ = non_none_types[0] | |
| elif len(non_none_types) > 1: | |
| return Union[tuple(resolve_optional(x) for x in non_none_types)] # pyright: ignore # noqa: UP007 | |
| else: | |
| raise NotImplementedError | |
| return type_ | |
| def resolve_annotated(type_: Any) -> type: | |
| type_ = resolve_type_alias(type_) | |
| if is_annotated(type_): | |
| type_ = get_args(type_)[0] | |
| return type_ | |
| def get_annotated_discriminator(annotation) -> Any: | |
| """Return the ``discriminator`` metadata from an ``Annotated[...]`` hint, else ``None``. | |
| Only inspects ``Annotated`` hints — for other parameterized types (``list[X]``, | |
| ``dict[K, V]``, etc.) this returns ``None`` so that an incidental | |
| ``.discriminator`` attribute on a type parameter cannot spuriously match. | |
| """ | |
| if not is_annotated(annotation): | |
| return None | |
| for meta in get_args(annotation)[1:]: | |
| try: | |
| return meta.discriminator | |
| except AttributeError: | |
| pass | |
| return None | |
| def resolve_required(type_: Any) -> type: | |
| if get_origin(type_) in (Required, NotRequired): | |
| type_ = get_args(type_)[0] | |
| return type_ | |
| def resolve_new_type(type_: Any) -> type: | |
| try: | |
| return resolve_new_type(type_.__supertype__) | |
| except AttributeError: | |
| return type_ | |
| def resolve_type_alias(type_: Any) -> Any: | |
| """Resolve TypeAliasType (Python 3.12+ 'type' statement) to its underlying type.""" | |
| if TypeAliasType is not None and isinstance(type_, TypeAliasType): | |
| return type_.__value__ | |
| return type_ | |
| def get_hint_name(hint) -> str: | |
| if isinstance(hint, str): | |
| return hint | |
| if is_nonetype(hint): | |
| return "None" | |
| if hint is Any: | |
| return "Any" | |
| if is_union(hint): | |
| return "|".join(get_hint_name(arg) for arg in get_args(hint)) | |
| if origin := get_origin(hint): | |
| out = get_hint_name(origin) | |
| if args := get_args(hint): | |
| out += "[" + ", ".join(get_hint_name(arg) for arg in args) + "]" | |
| return out | |
| if hasattr(hint, "__name__"): | |
| return hint.__name__ | |
| if getattr(hint, "_name", None) is not None: | |
| return hint._name | |
| return str(hint) | |