File size: 6,976 Bytes
53dbcc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
import inspect
import sys
import typing
from collections.abc import Iterable, Sequence
from enum import Flag
from types import UnionType
from typing import Annotated, Any, Union, get_args, get_origin

import attrs

from cyclopts.utils import is_class_and_subclass

if sys.version_info < (3, 11):  # pragma: no cover
    from typing_extensions import NotRequired, Required
else:  # pragma: no cover
    from typing import NotRequired, Required

if sys.version_info >= (3, 12):  # pragma: no cover
    from typing import TypeAliasType
else:  # pragma: no cover
    TypeAliasType = None

# from types import NoneType is available >=3.10
NoneType = type(None)
AnnotatedType = type(Annotated[int, 0])

ITERABLE_TYPES = {
    Iterable,
    typing.Sequence,
    Sequence,
    frozenset,
    list,
    set,
    tuple,
}


def is_nonetype(hint):
    return hint is NoneType


def is_union(type_: type | None) -> bool:
    """Checks if a type is a union."""
    # Direct checks are faster than checking if the type is in a set that contains the union-types.
    if type_ is Union or type_ is UnionType:
        return True

    # The ``get_origin`` call is relatively expensive, so we'll check common types
    # that are passed in here to see if we can avoid calling ``get_origin``.
    if type_ is str or type_ is int or type_ is float or type_ is bool or is_annotated(type_):
        return False
    origin = get_origin(type_)
    return origin is Union or origin is UnionType


def is_pydantic(hint) -> bool:
    return hasattr(hint, "__pydantic_core_schema__")


def is_pydantic_secret(hint) -> bool:
    """Check if a type is a Pydantic secret type (SecretStr, SecretBytes, Secret, etc.)."""
    return (
        hasattr(hint, "__module__")
        and hint.__module__ == "pydantic.types"
        and hasattr(hint, "get_secret_value")
        and callable(getattr(hint, "get_secret_value", None))
    )


def is_dataclass(hint) -> bool:
    return hasattr(hint, "__dataclass_fields__")


def is_namedtuple(hint) -> bool:
    return is_class_and_subclass(hint, tuple) and hasattr(hint, "_fields")


def is_attrs(hint) -> bool:
    return attrs.has(hint)


def is_enum_flag(hint) -> bool:
    """Check if a type hint is an enum.Flag subclass."""
    return is_class_and_subclass(hint, Flag)


def is_annotated(hint) -> bool:
    return type(hint) is AnnotatedType


def is_iterable_type(hint) -> bool:
    """Check if a type hint is a collection/iterable type (list, set, tuple, etc.).

    Handles Annotated, Optional, TypeAlias, and NewType wrappers.
    """
    hint = resolve(hint)
    origin = get_origin(hint)
    return is_class_and_subclass(origin, tuple(ITERABLE_TYPES))


def contains_hint(hint, target_type) -> bool:
    """Indicates if ``target_type`` is in a possibly annotated/unioned ``hint``.

    E.g. ``contains_hint(Union[int, str], str) == True``
    """
    hint = resolve(hint)
    if is_union(hint):
        return any(contains_hint(x, target_type) for x in get_args(hint))
    else:
        return is_class_and_subclass(hint, target_type)


def is_typeddict(hint) -> bool:
    """Determine if a type annotation is a TypedDict.

    This is surprisingly hard! Modified from Beartype's implementation:

        https://github.com/beartype/beartype/blob/main/beartype/_util/hint/pep/proposal/utilpep589.py
    """
    hint = resolve(hint)
    if is_union(hint):
        return any(is_typeddict(x) for x in get_args(hint))

    if not is_class_and_subclass(hint, dict):
        return False

    return (
        hasattr(hint, "__annotations__")
        and hasattr(hint, "__total__")
        and hasattr(hint, "__required_keys__")
        and hasattr(hint, "__optional_keys__")
    )


def resolve(
    type_: Any,
) -> type:
    """Perform all simplifying resolutions."""
    if type_ is inspect.Parameter.empty:
        return str

    type_prev = None
    while type_ != type_prev:
        type_prev = type_
        type_ = resolve_type_alias(type_)
        type_ = resolve_annotated(type_)
        type_ = resolve_optional(type_)
        type_ = resolve_required(type_)
        type_ = resolve_new_type(type_)
    return type_


def resolve_optional(type_: Any) -> Any:
    """Only resolves Union's of None + one other type (i.e. Optional)."""
    type_ = resolve_type_alias(type_)
    # Python will automatically flatten out nested unions when possible.
    # So we don't need to loop over resolution.
    if not is_union(type_):
        return type_

    non_none_types = [t for t in get_args(type_) if t is not NoneType]
    if not non_none_types:  # pragma: no cover
        # This should never happen; python simplifies:
        #    ``Union[None, None] -> NoneType``
        raise ValueError("Union type cannot be all NoneType")
    elif len(non_none_types) == 1:
        type_ = non_none_types[0]
    elif len(non_none_types) > 1:
        return Union[tuple(resolve_optional(x) for x in non_none_types)]  # pyright: ignore  # noqa: UP007
    else:
        raise NotImplementedError

    return type_


def resolve_annotated(type_: Any) -> type:
    type_ = resolve_type_alias(type_)
    if is_annotated(type_):
        type_ = get_args(type_)[0]
    return type_


def get_annotated_discriminator(annotation) -> Any:
    """Return the ``discriminator`` metadata from an ``Annotated[...]`` hint, else ``None``.

    Only inspects ``Annotated`` hints — for other parameterized types (``list[X]``,
    ``dict[K, V]``, etc.) this returns ``None`` so that an incidental
    ``.discriminator`` attribute on a type parameter cannot spuriously match.
    """
    if not is_annotated(annotation):
        return None
    for meta in get_args(annotation)[1:]:
        try:
            return meta.discriminator
        except AttributeError:
            pass
    return None


def resolve_required(type_: Any) -> type:
    if get_origin(type_) in (Required, NotRequired):
        type_ = get_args(type_)[0]
    return type_


def resolve_new_type(type_: Any) -> type:
    try:
        return resolve_new_type(type_.__supertype__)
    except AttributeError:
        return type_


def resolve_type_alias(type_: Any) -> Any:
    """Resolve TypeAliasType (Python 3.12+ 'type' statement) to its underlying type."""
    if TypeAliasType is not None and isinstance(type_, TypeAliasType):
        return type_.__value__
    return type_


def get_hint_name(hint) -> str:
    if isinstance(hint, str):
        return hint
    if is_nonetype(hint):
        return "None"
    if hint is Any:
        return "Any"
    if is_union(hint):
        return "|".join(get_hint_name(arg) for arg in get_args(hint))
    if origin := get_origin(hint):
        out = get_hint_name(origin)
        if args := get_args(hint):
            out += "[" + ", ".join(get_hint_name(arg) for arg in args) + "]"
        return out
    if hasattr(hint, "__name__"):
        return hint.__name__
    if getattr(hint, "_name", None) is not None:
        return hint._name
    return str(hint)