File size: 2,737 Bytes
4be5ba2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""End-to-end text encoder smoke test for API/local/auto modes."""

from __future__ import annotations

import argparse
import json
import time

from kimodo.model.load_model import DEFAULT_TEXT_ENCODER_URL, _select_text_encoder_conf
from kimodo.model.loading import get_env_var, instantiate_from_dict


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Kimodo text encoder smoke test")
    parser.add_argument(
        "--prompt",
        default="A person walks forward.",
        help="Prompt used for the end-to-end encoding call.",
    )
    parser.add_argument(
        "--strict",
        action="store_true",
        help="Return non-zero if any step fails.",
    )
    parser.add_argument(
        "--retry-delay-sec",
        type=float,
        default=10.0,
        help="Delay before a single retry when the first cold-start attempt fails.",
    )
    return parser.parse_args()


def main() -> int:
    args = parse_args()
    text_encoder_url = get_env_var("TEXT_ENCODER_URL", DEFAULT_TEXT_ENCODER_URL)
    mode = get_env_var("TEXT_ENCODER_MODE", "auto").lower()

    report = {
        "mode": mode,
        "text_encoder_url": text_encoder_url,
        "encoder_target": None,
        "ready": False,
        "encode_ok": False,
        "elapsed_ms": None,
        "output_shape": None,
        "lengths": None,
        "error": None,
    }

    started = time.time()
    conf = None
    encoder = None
    for attempt in range(2):
        try:
            if conf is None:
                conf = _select_text_encoder_conf(text_encoder_url)
                report["encoder_target"] = conf.get("_target_")
            if encoder is None:
                encoder = instantiate_from_dict(conf)

            # Probe readiness path first.
            encoder(["healthcheck"])
            report["ready"] = True

            encoded, lengths = encoder([args.prompt])
            report["encode_ok"] = True
            report["output_shape"] = tuple(encoded.shape)
            report["lengths"] = lengths
            report["attempts"] = attempt + 1
            break
        except Exception as error:  # pragma: no cover - runtime/network dependent
            report["error"] = f"{type(error).__name__}: {error}"
            report["attempts"] = attempt + 1
            if attempt == 0:
                time.sleep(max(0.0, args.retry_delay_sec))
                encoder = None
                continue

    report["elapsed_ms"] = int((time.time() - started) * 1000)

    print(json.dumps(report, indent=2, sort_keys=True))

    if args.strict and (not report["ready"] or not report["encode_ok"]):
        return 2
    return 0


if __name__ == "__main__":
    raise SystemExit(main())