Spaces:
Sleeping
Sleeping
File size: 6,976 Bytes
53dbcc1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 | import inspect
import sys
import typing
from collections.abc import Iterable, Sequence
from enum import Flag
from types import UnionType
from typing import Annotated, Any, Union, get_args, get_origin
import attrs
from cyclopts.utils import is_class_and_subclass
if sys.version_info < (3, 11): # pragma: no cover
from typing_extensions import NotRequired, Required
else: # pragma: no cover
from typing import NotRequired, Required
if sys.version_info >= (3, 12): # pragma: no cover
from typing import TypeAliasType
else: # pragma: no cover
TypeAliasType = None
# from types import NoneType is available >=3.10
NoneType = type(None)
AnnotatedType = type(Annotated[int, 0])
ITERABLE_TYPES = {
Iterable,
typing.Sequence,
Sequence,
frozenset,
list,
set,
tuple,
}
def is_nonetype(hint):
return hint is NoneType
def is_union(type_: type | None) -> bool:
"""Checks if a type is a union."""
# Direct checks are faster than checking if the type is in a set that contains the union-types.
if type_ is Union or type_ is UnionType:
return True
# The ``get_origin`` call is relatively expensive, so we'll check common types
# that are passed in here to see if we can avoid calling ``get_origin``.
if type_ is str or type_ is int or type_ is float or type_ is bool or is_annotated(type_):
return False
origin = get_origin(type_)
return origin is Union or origin is UnionType
def is_pydantic(hint) -> bool:
return hasattr(hint, "__pydantic_core_schema__")
def is_pydantic_secret(hint) -> bool:
"""Check if a type is a Pydantic secret type (SecretStr, SecretBytes, Secret, etc.)."""
return (
hasattr(hint, "__module__")
and hint.__module__ == "pydantic.types"
and hasattr(hint, "get_secret_value")
and callable(getattr(hint, "get_secret_value", None))
)
def is_dataclass(hint) -> bool:
return hasattr(hint, "__dataclass_fields__")
def is_namedtuple(hint) -> bool:
return is_class_and_subclass(hint, tuple) and hasattr(hint, "_fields")
def is_attrs(hint) -> bool:
return attrs.has(hint)
def is_enum_flag(hint) -> bool:
"""Check if a type hint is an enum.Flag subclass."""
return is_class_and_subclass(hint, Flag)
def is_annotated(hint) -> bool:
return type(hint) is AnnotatedType
def is_iterable_type(hint) -> bool:
"""Check if a type hint is a collection/iterable type (list, set, tuple, etc.).
Handles Annotated, Optional, TypeAlias, and NewType wrappers.
"""
hint = resolve(hint)
origin = get_origin(hint)
return is_class_and_subclass(origin, tuple(ITERABLE_TYPES))
def contains_hint(hint, target_type) -> bool:
"""Indicates if ``target_type`` is in a possibly annotated/unioned ``hint``.
E.g. ``contains_hint(Union[int, str], str) == True``
"""
hint = resolve(hint)
if is_union(hint):
return any(contains_hint(x, target_type) for x in get_args(hint))
else:
return is_class_and_subclass(hint, target_type)
def is_typeddict(hint) -> bool:
"""Determine if a type annotation is a TypedDict.
This is surprisingly hard! Modified from Beartype's implementation:
https://github.com/beartype/beartype/blob/main/beartype/_util/hint/pep/proposal/utilpep589.py
"""
hint = resolve(hint)
if is_union(hint):
return any(is_typeddict(x) for x in get_args(hint))
if not is_class_and_subclass(hint, dict):
return False
return (
hasattr(hint, "__annotations__")
and hasattr(hint, "__total__")
and hasattr(hint, "__required_keys__")
and hasattr(hint, "__optional_keys__")
)
def resolve(
type_: Any,
) -> type:
"""Perform all simplifying resolutions."""
if type_ is inspect.Parameter.empty:
return str
type_prev = None
while type_ != type_prev:
type_prev = type_
type_ = resolve_type_alias(type_)
type_ = resolve_annotated(type_)
type_ = resolve_optional(type_)
type_ = resolve_required(type_)
type_ = resolve_new_type(type_)
return type_
def resolve_optional(type_: Any) -> Any:
"""Only resolves Union's of None + one other type (i.e. Optional)."""
type_ = resolve_type_alias(type_)
# Python will automatically flatten out nested unions when possible.
# So we don't need to loop over resolution.
if not is_union(type_):
return type_
non_none_types = [t for t in get_args(type_) if t is not NoneType]
if not non_none_types: # pragma: no cover
# This should never happen; python simplifies:
# ``Union[None, None] -> NoneType``
raise ValueError("Union type cannot be all NoneType")
elif len(non_none_types) == 1:
type_ = non_none_types[0]
elif len(non_none_types) > 1:
return Union[tuple(resolve_optional(x) for x in non_none_types)] # pyright: ignore # noqa: UP007
else:
raise NotImplementedError
return type_
def resolve_annotated(type_: Any) -> type:
type_ = resolve_type_alias(type_)
if is_annotated(type_):
type_ = get_args(type_)[0]
return type_
def get_annotated_discriminator(annotation) -> Any:
"""Return the ``discriminator`` metadata from an ``Annotated[...]`` hint, else ``None``.
Only inspects ``Annotated`` hints — for other parameterized types (``list[X]``,
``dict[K, V]``, etc.) this returns ``None`` so that an incidental
``.discriminator`` attribute on a type parameter cannot spuriously match.
"""
if not is_annotated(annotation):
return None
for meta in get_args(annotation)[1:]:
try:
return meta.discriminator
except AttributeError:
pass
return None
def resolve_required(type_: Any) -> type:
if get_origin(type_) in (Required, NotRequired):
type_ = get_args(type_)[0]
return type_
def resolve_new_type(type_: Any) -> type:
try:
return resolve_new_type(type_.__supertype__)
except AttributeError:
return type_
def resolve_type_alias(type_: Any) -> Any:
"""Resolve TypeAliasType (Python 3.12+ 'type' statement) to its underlying type."""
if TypeAliasType is not None and isinstance(type_, TypeAliasType):
return type_.__value__
return type_
def get_hint_name(hint) -> str:
if isinstance(hint, str):
return hint
if is_nonetype(hint):
return "None"
if hint is Any:
return "Any"
if is_union(hint):
return "|".join(get_hint_name(arg) for arg in get_args(hint))
if origin := get_origin(hint):
out = get_hint_name(origin)
if args := get_args(hint):
out += "[" + ", ".join(get_hint_name(arg) for arg in args) + "]"
return out
if hasattr(hint, "__name__"):
return hint.__name__
if getattr(hint, "_name", None) is not None:
return hint._name
return str(hint)
|