| from __future__ import annotations |
|
|
| import re |
| import textwrap |
| from collections.abc import Iterable |
| from typing import Any, Optional, Callable |
|
|
| from . import inspect as mi, to_builtins |
|
|
| __all__ = ("schema", "schema_components") |
|
|
|
|
| def schema( |
| type: Any, *, schema_hook: Optional[Callable[[type], dict[str, Any]]] = None |
| ) -> dict[str, Any]: |
| """Generate a JSON Schema for a given type. |
| |
| Any schemas for (potentially) shared components are extracted and stored in |
| a top-level ``"$defs"`` field. |
| |
| If you want to generate schemas for multiple types, or to have more control |
| over the generated schema you may want to use ``schema_components`` instead. |
| |
| Parameters |
| ---------- |
| type : type |
| The type to generate the schema for. |
| schema_hook : callable, optional |
| An optional callback to use for generating JSON schemas of custom |
| types. Will be called with the custom type, and should return a dict |
| representation of the JSON schema for that type. |
| |
| Returns |
| ------- |
| schema : dict |
| The generated JSON Schema. |
| |
| See Also |
| -------- |
| schema_components |
| """ |
| (out,), components = schema_components((type,), schema_hook=schema_hook) |
| if components: |
| out["$defs"] = components |
| return out |
|
|
|
|
| def schema_components( |
| types: Iterable[Any], |
| *, |
| schema_hook: Optional[Callable[[type], dict[str, Any]]] = None, |
| ref_template: str = "#/$defs/{name}", |
| ) -> tuple[tuple[dict[str, Any], ...], dict[str, Any]]: |
| """Generate JSON Schemas for one or more types. |
| |
| Any schemas for (potentially) shared components are extracted and returned |
| in a separate ``components`` dict. |
| |
| Parameters |
| ---------- |
| types : Iterable[type] |
| An iterable of one or more types to generate schemas for. |
| schema_hook : callable, optional |
| An optional callback to use for generating JSON schemas of custom |
| types. Will be called with the custom type, and should return a dict |
| representation of the JSON schema for that type. |
| ref_template : str, optional |
| A template to use when generating ``"$ref"`` fields. This template is |
| formatted with the type name as ``template.format(name=name)``. This |
| can be useful if you intend to store the ``components`` mapping |
| somewhere other than a top-level ``"$defs"`` field. For example, you |
| might use ``ref_template="#/components/{name}"`` if generating an |
| OpenAPI schema. |
| |
| Returns |
| ------- |
| schemas : tuple[dict] |
| A tuple of JSON Schemas, one for each type in ``types``. |
| components : dict |
| A mapping of name to schema for any shared components used by |
| ``schemas``. |
| |
| See Also |
| -------- |
| schema |
| """ |
| type_infos = mi.multi_type_info(types) |
|
|
| component_types = _collect_component_types(type_infos) |
|
|
| name_map = _build_name_map(component_types) |
|
|
| gen = _SchemaGenerator(name_map, schema_hook, ref_template) |
|
|
| schemas = tuple(gen.to_schema(t) for t in type_infos) |
|
|
| components = { |
| name_map[cls]: gen.to_schema(t, False) for cls, t in component_types.items() |
| } |
| return schemas, components |
|
|
|
|
| def _collect_component_types(type_infos: Iterable[mi.Type]) -> dict[Any, mi.Type]: |
| """Find all types in the type tree that are "nameable" and worthy of being |
| extracted out into a shared top-level components mapping. |
| |
| Currently this looks for Struct, Dataclass, NamedTuple, TypedDict, and Enum |
| types. |
| """ |
| components = {} |
|
|
| def collect(t): |
| if isinstance( |
| t, (mi.StructType, mi.TypedDictType, mi.DataclassType, mi.NamedTupleType) |
| ): |
| if t.cls not in components: |
| components[t.cls] = t |
| for f in t.fields: |
| collect(f.type) |
| elif isinstance(t, mi.EnumType): |
| components[t.cls] = t |
| elif isinstance(t, mi.Metadata): |
| collect(t.type) |
| elif isinstance(t, mi.CollectionType): |
| collect(t.item_type) |
| elif isinstance(t, mi.TupleType): |
| for st in t.item_types: |
| collect(st) |
| elif isinstance(t, mi.DictType): |
| collect(t.key_type) |
| collect(t.value_type) |
| elif isinstance(t, mi.UnionType): |
| for st in t.types: |
| collect(st) |
|
|
| for t in type_infos: |
| collect(t) |
|
|
| return components |
|
|
|
|
| def _type_repr(obj): |
| return obj.__name__ if isinstance(obj, type) else repr(obj) |
|
|
|
|
| def _get_class_name(cls: Any) -> str: |
| if hasattr(cls, "__origin__"): |
| name = cls.__origin__.__name__ |
| args = ", ".join(_type_repr(a) for a in cls.__args__) |
| return f"{name}[{args}]" |
| return cls.__name__ |
|
|
|
|
| def _get_doc(t: mi.Type) -> str: |
| assert hasattr(t, "cls") |
| cls = getattr(t.cls, "__origin__", t.cls) |
| doc = getattr(cls, "__doc__", "") |
| if not doc: |
| return "" |
| doc = textwrap.dedent(doc).strip("\r\n") |
| if isinstance(t, mi.EnumType): |
| if doc == "An enumeration.": |
| return "" |
| elif isinstance(t, (mi.NamedTupleType, mi.DataclassType)): |
| if doc.startswith(f"{cls.__name__}(") and doc.endswith(")"): |
| return "" |
| return doc |
|
|
|
|
| def _build_name_map(component_types: dict[Any, mi.Type]) -> dict[Any, str]: |
| """A mapping from nameable subcomponents to a generated name. |
| |
| The generated name is usually a normalized version of the class name. In |
| the case of conflicts, the name will be expanded to also include the full |
| import path. |
| """ |
|
|
| def normalize(name): |
| return re.sub(r"[^a-zA-Z0-9.\-_]", "_", name) |
|
|
| def fullname(cls): |
| return normalize(f"{cls.__module__}.{cls.__qualname__}") |
|
|
| conflicts = set() |
| names: dict[str, Any] = {} |
|
|
| for cls in component_types: |
| name = normalize(_get_class_name(cls)) |
| if name in names: |
| old = names.pop(name) |
| conflicts.add(name) |
| names[fullname(old)] = old |
| if name in conflicts: |
| names[fullname(cls)] = cls |
| else: |
| names[name] = cls |
| return {v: k for k, v in names.items()} |
|
|
|
|
| class _SchemaGenerator: |
| def __init__( |
| self, |
| name_map: dict[Any, str], |
| schema_hook: Optional[Callable[[type], dict[str, Any]]] = None, |
| ref_template: str = "#/$defs/{name}", |
| ): |
| self.name_map = name_map |
| self.schema_hook = schema_hook |
| self.ref_template = ref_template |
|
|
| def to_schema(self, t: mi.Type, check_ref: bool = True) -> dict[str, Any]: |
| """Converts a Type to a json-schema.""" |
| schema: dict[str, Any] = {} |
|
|
| while isinstance(t, mi.Metadata): |
| schema = mi._merge_json(schema, t.extra_json_schema) |
| t = t.type |
|
|
| if check_ref and hasattr(t, "cls"): |
| if name := self.name_map.get(t.cls): |
| schema["$ref"] = self.ref_template.format(name=name) |
| return schema |
|
|
| if isinstance(t, (mi.AnyType, mi.RawType)): |
| pass |
| elif isinstance(t, mi.NoneType): |
| schema["type"] = "null" |
| elif isinstance(t, mi.BoolType): |
| schema["type"] = "boolean" |
| elif isinstance(t, (mi.IntType, mi.FloatType)): |
| schema["type"] = "integer" if isinstance(t, mi.IntType) else "number" |
| if t.ge is not None: |
| schema["minimum"] = t.ge |
| if t.gt is not None: |
| schema["exclusiveMinimum"] = t.gt |
| if t.le is not None: |
| schema["maximum"] = t.le |
| if t.lt is not None: |
| schema["exclusiveMaximum"] = t.lt |
| if t.multiple_of is not None: |
| schema["multipleOf"] = t.multiple_of |
| elif isinstance(t, mi.StrType): |
| schema["type"] = "string" |
| if t.max_length is not None: |
| schema["maxLength"] = t.max_length |
| if t.min_length is not None: |
| schema["minLength"] = t.min_length |
| if t.pattern is not None: |
| schema["pattern"] = t.pattern |
| elif isinstance(t, (mi.BytesType, mi.ByteArrayType, mi.MemoryViewType)): |
| schema["type"] = "string" |
| schema["contentEncoding"] = "base64" |
| if t.max_length is not None: |
| schema["maxLength"] = 4 * ((t.max_length + 2) // 3) |
| if t.min_length is not None: |
| schema["minLength"] = 4 * ((t.min_length + 2) // 3) |
| elif isinstance(t, mi.DateTimeType): |
| schema["type"] = "string" |
| if t.tz is True: |
| schema["format"] = "date-time" |
| elif isinstance(t, mi.TimeType): |
| schema["type"] = "string" |
| if t.tz is True: |
| schema["format"] = "time" |
| elif t.tz is False: |
| schema["format"] = "partial-time" |
| elif isinstance(t, mi.DateType): |
| schema["type"] = "string" |
| schema["format"] = "date" |
| elif isinstance(t, mi.TimeDeltaType): |
| schema["type"] = "string" |
| schema["format"] = "duration" |
| elif isinstance(t, mi.UUIDType): |
| schema["type"] = "string" |
| schema["format"] = "uuid" |
| elif isinstance(t, mi.DecimalType): |
| schema["type"] = "string" |
| schema["format"] = "decimal" |
| elif isinstance(t, mi.CollectionType): |
| schema["type"] = "array" |
| if not isinstance(t.item_type, mi.AnyType): |
| schema["items"] = self.to_schema(t.item_type) |
| if t.max_length is not None: |
| schema["maxItems"] = t.max_length |
| if t.min_length is not None: |
| schema["minItems"] = t.min_length |
| elif isinstance(t, mi.TupleType): |
| schema["type"] = "array" |
| schema["minItems"] = schema["maxItems"] = len(t.item_types) |
| if t.item_types: |
| schema["prefixItems"] = [self.to_schema(i) for i in t.item_types] |
| schema["items"] = False |
| elif isinstance(t, mi.DictType): |
| schema["type"] = "object" |
| |
| if isinstance(key_type := t.key_type, mi.StrType): |
| property_names: dict[str, Any] = {} |
| if key_type.min_length is not None: |
| property_names["minLength"] = key_type.min_length |
| if key_type.max_length is not None: |
| property_names["maxLength"] = key_type.max_length |
| if key_type.pattern is not None: |
| property_names["pattern"] = key_type.pattern |
| if property_names: |
| schema["propertyNames"] = property_names |
| if not isinstance(t.value_type, mi.AnyType): |
| schema["additionalProperties"] = self.to_schema(t.value_type) |
| if t.max_length is not None: |
| schema["maxProperties"] = t.max_length |
| if t.min_length is not None: |
| schema["minProperties"] = t.min_length |
| elif isinstance(t, mi.UnionType): |
| structs = {} |
| other = [] |
| tag_field = None |
| for subtype in t.types: |
| real_type = subtype |
| while isinstance(real_type, mi.Metadata): |
| real_type = real_type.type |
| if isinstance(real_type, mi.StructType) and not real_type.array_like: |
| tag_field = real_type.tag_field |
| structs[real_type.tag] = real_type |
| else: |
| other.append(subtype) |
|
|
| options = [self.to_schema(a) for a in other] |
|
|
| if len(structs) >= 2: |
| mapping = { |
| k: self.ref_template.format(name=self.name_map[v.cls]) |
| for k, v in structs.items() |
| } |
| struct_schema = { |
| "anyOf": [self.to_schema(v) for v in structs.values()], |
| "discriminator": {"propertyName": tag_field, "mapping": mapping}, |
| } |
| if options: |
| options.append(struct_schema) |
| schema["anyOf"] = options |
| else: |
| schema.update(struct_schema) |
| elif len(structs) == 1: |
| _, subtype = structs.popitem() |
| options.append(self.to_schema(subtype)) |
| schema["anyOf"] = options |
| else: |
| schema["anyOf"] = options |
| elif isinstance(t, mi.LiteralType): |
| schema["enum"] = sorted(t.values) |
| elif isinstance(t, mi.EnumType): |
| schema.setdefault("title", t.cls.__name__) |
| if doc := _get_doc(t): |
| schema.setdefault("description", doc) |
| schema["enum"] = sorted(e.value for e in t.cls) |
| elif isinstance(t, mi.StructType): |
| schema.setdefault("title", _get_class_name(t.cls)) |
| if doc := _get_doc(t): |
| schema.setdefault("description", doc) |
| required = [] |
| names = [] |
| fields = [] |
|
|
| if t.tag_field is not None: |
| required.append(t.tag_field) |
| names.append(t.tag_field) |
| fields.append({"enum": [t.tag]}) |
|
|
| for field in t.fields: |
| field_schema = self.to_schema(field.type) |
| if field.required: |
| required.append(field.encode_name) |
| elif field.default is not mi.NODEFAULT: |
| field_schema["default"] = to_builtins(field.default, str_keys=True) |
| elif field.default_factory in (list, dict, set, bytearray): |
| field_schema["default"] = field.default_factory() |
| names.append(field.encode_name) |
| fields.append(field_schema) |
|
|
| if t.array_like: |
| n_trailing_defaults = 0 |
| for n_trailing_defaults, f in enumerate(reversed(t.fields)): |
| if f.required: |
| break |
| schema["type"] = "array" |
| schema["prefixItems"] = fields |
| schema["minItems"] = len(fields) - n_trailing_defaults |
| if t.forbid_unknown_fields: |
| schema["maxItems"] = len(fields) |
| else: |
| schema["type"] = "object" |
| schema["properties"] = dict(zip(names, fields)) |
| schema["required"] = required |
| if t.forbid_unknown_fields: |
| schema["additionalProperties"] = False |
| elif isinstance(t, (mi.TypedDictType, mi.DataclassType, mi.NamedTupleType)): |
| schema.setdefault("title", _get_class_name(t.cls)) |
| if doc := _get_doc(t): |
| schema.setdefault("description", doc) |
| names = [] |
| fields = [] |
| required = [] |
| for field in t.fields: |
| field_schema = self.to_schema(field.type) |
| if field.required: |
| required.append(field.encode_name) |
| elif field.default is not mi.NODEFAULT: |
| field_schema["default"] = to_builtins(field.default, str_keys=True) |
| names.append(field.encode_name) |
| fields.append(field_schema) |
| if isinstance(t, mi.NamedTupleType): |
| schema["type"] = "array" |
| schema["prefixItems"] = fields |
| schema["minItems"] = len(required) |
| schema["maxItems"] = len(fields) |
| else: |
| schema["type"] = "object" |
| schema["properties"] = dict(zip(names, fields)) |
| schema["required"] = required |
| elif isinstance(t, mi.ExtType): |
| raise TypeError("json-schema doesn't support msgpack Ext types") |
| elif isinstance(t, mi.CustomType): |
| if self.schema_hook: |
| try: |
| schema = mi._merge_json(self.schema_hook(t.cls), schema) |
| except NotImplementedError: |
| pass |
| if not schema: |
| raise TypeError( |
| "Generating JSON schema for custom types requires either:\n" |
| "- specifying a `schema_hook`\n" |
| "- annotating the type with `Meta(extra_json_schema=...)`\n" |
| "\n" |
| f"type {t.cls!r} is not supported" |
| ) |
| else: |
| |
| raise TypeError(f"json-schema doesn't support type {t!r}") |
|
|
| return schema |
|
|