Spaces:
Sleeping
Sleeping
feat(schemas): add group field to ParamSpec (basic/advanced)
Browse files- server/schemas.py +2 -0
- tests/test_schemas.py +13 -0
server/schemas.py
CHANGED
|
@@ -7,6 +7,7 @@ from pydantic import BaseModel, Field, model_validator
|
|
| 7 |
|
| 8 |
|
| 9 |
ParamType = Literal["float", "int", "bool", "enum"]
|
|
|
|
| 10 |
ModelStatus = Literal["idle", "loading", "loaded", "error"]
|
| 11 |
|
| 12 |
|
|
@@ -25,6 +26,7 @@ class ParamSpec(BaseModel):
|
|
| 25 |
step: float | int | None = None
|
| 26 |
choices: list[str] | None = None
|
| 27 |
help: str = ""
|
|
|
|
| 28 |
|
| 29 |
@model_validator(mode="after")
|
| 30 |
def _validate(self) -> "ParamSpec":
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
ParamType = Literal["float", "int", "bool", "enum"]
|
| 10 |
+
ParamGroup = Literal["basic", "advanced"]
|
| 11 |
ModelStatus = Literal["idle", "loading", "loaded", "error"]
|
| 12 |
|
| 13 |
|
|
|
|
| 26 |
step: float | int | None = None
|
| 27 |
choices: list[str] | None = None
|
| 28 |
help: str = ""
|
| 29 |
+
group: ParamGroup = "basic"
|
| 30 |
|
| 31 |
@model_validator(mode="after")
|
| 32 |
def _validate(self) -> "ParamSpec":
|
tests/test_schemas.py
CHANGED
|
@@ -80,3 +80,16 @@ def test_error_body_serializable():
|
|
| 80 |
def test_generate_params_accepts_arbitrary_dict():
|
| 81 |
g = GenerateParams(values={"temperature": 0.8, "cfg_weight": 0.5})
|
| 82 |
assert g.values["temperature"] == 0.8
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
def test_generate_params_accepts_arbitrary_dict():
|
| 81 |
g = GenerateParams(values={"temperature": 0.8, "cfg_weight": 0.5})
|
| 82 |
assert g.values["temperature"] == 0.8
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def test_param_spec_default_group_is_basic():
|
| 86 |
+
p = ParamSpec(name="t", label="T", type="float", default=0.5, min=0.0, max=1.0)
|
| 87 |
+
assert p.group == "basic"
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def test_param_spec_advanced_group_round_trips():
|
| 91 |
+
p = ParamSpec(
|
| 92 |
+
name="seed", label="Seed", type="int", default=-1, min=-1, group="advanced",
|
| 93 |
+
)
|
| 94 |
+
assert p.group == "advanced"
|
| 95 |
+
assert p.model_dump()["group"] == "advanced"
|