| import time |
| import json |
| import logging |
| import os |
| from typing import Any, Dict, List, Optional, Tuple, Sequence |
|
|
| import numpy as np |
| import torch |
|
|
| from backends_base import ChatBackend, ImagesBackend |
| from config import settings |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| def _parse_series(series: Any) -> np.ndarray: |
| """ |
| Accepts: list[float|int], list[dict{'y'|'value'}], or dict with 'values'/'y'. |
| Returns: 1D float32 numpy array. |
| """ |
| if series is None: |
| raise ValueError("series is required") |
|
|
| if isinstance(series, dict): |
| series = series.get("values") or series.get("y") |
|
|
| vals: List[float] = [] |
| if isinstance(series, (list, tuple)): |
| if series and isinstance(series[0], dict): |
| for item in series: |
| if "y" in item: |
| vals.append(float(item["y"])) |
| elif "value" in item: |
| vals.append(float(item["value"])) |
| else: |
| vals = [float(x) for x in series] |
| else: |
| raise ValueError("series must be a list/tuple or dict with 'values'/'y'") |
|
|
| if not vals: |
| raise ValueError("series is empty") |
| return np.asarray(vals, dtype=np.float32) |
|
|
| def _extract_json_from_text(s: str) -> Optional[Dict[str, Any]]: |
| s = s.strip() |
| if (s.startswith("{") and s.endswith("}")) or (s.startswith("[") and s.endswith("]")): |
| try: |
| obj = json.loads(s) |
| return obj if isinstance(obj, dict) else None |
| except Exception: |
| pass |
| if "```" in s: |
| parts = s.split("```") |
| for i in range(1, len(parts), 2): |
| block = parts[i] |
| if block.lstrip().lower().startswith("json"): |
| block = block.split("\n", 1)[-1] |
| try: |
| obj = json.loads(block.strip()) |
| return obj if isinstance(obj, dict) else None |
| except Exception: |
| continue |
| return None |
|
|
| def _merge_openai_message_json(payload: Dict[str, Any]) -> Dict[str, Any]: |
| msgs = payload.get("messages") |
| if not isinstance(msgs, list): |
| return payload |
| for m in reversed(msgs): |
| if not isinstance(m, dict) or m.get("role") != "user": |
| continue |
| content = m.get("content") |
| texts: List[str] = [] |
| if isinstance(content, list): |
| texts = [ |
| p.get("text") |
| for p in content |
| if isinstance(p, dict) and p.get("type") == "text" and isinstance(p.get("text"), str) |
| ] |
| elif isinstance(content, str): |
| texts = [content] |
| for t in reversed(texts): |
| obj = _extract_json_from_text(t) |
| if isinstance(obj, dict): |
| return {**payload, **obj} |
| break |
| return payload |
|
|
| |
| class TimesFMBackend(ChatBackend): |
| """ |
| TimesFM 2.5 backend. |
| Input JSON can be in top-level keys, in CloudEvents .data, or embedded in last user message. |
| Keys: |
| series: list[float|int|{y|value}] OR list of such lists for batch |
| horizon: int (>0) |
| Optional: |
| quantiles: bool (default True) -> include quantile forecasts |
| max_context, max_horizon: ints to override defaults |
| """ |
|
|
| def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None): |
| |
| self.model_id = model_id or "google/timesfm-2.5-200m-pytorch" |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
| self._model = None |
|
|
| def _ensure_model(self) -> None: |
| if self._model is not None: |
| return |
| try: |
| import os |
| import timesfm |
|
|
| hf_token = getattr(settings, "HF_TOKEN", None) or os.environ.get("HF_TOKEN") |
| cache_dir = getattr(settings, "TIMESFM_CACHE_DIR", None) |
|
|
| model = timesfm.TimesFM_2p5_200M_torch.from_pretrained( |
| self.model_id, |
| token=hf_token, |
| cache_dir=cache_dir, |
| local_files_only=False, |
| ) |
|
|
| try: |
| |
| target = getattr(model, "model", model) |
| target.to(self.device) |
| except Exception: |
| pass |
|
|
| cfg = timesfm.ForecastConfig( |
| max_context=1024, |
| max_horizon=256, |
| normalize_inputs=True, |
| use_continuous_quantile_head=True, |
| force_flip_invariance=True, |
| infer_is_positive=True, |
| fix_quantile_crossing=True, |
| ) |
| model.compile(cfg) |
| self._model = model |
| logger.info("TimesFM 2.5 model loaded on %s", self.device) |
| except Exception as e: |
| logger.exception("TimesFM 2.5 init failed") |
| raise RuntimeError(f"timesfm 2.5 init failed: {e}") from e |
|
|
| def _prepare_inputs(self, payload: Dict[str, Any]) -> Tuple[List[np.ndarray], int, bool, Dict[str, int]]: |
| |
| if isinstance(payload.get("data"), dict): |
| payload = {**payload, **payload["data"]} |
| if isinstance(payload.get("timeseries"), dict): |
| payload = {**payload, **payload["timeseries"]} |
| |
| payload = _merge_openai_message_json(payload) |
|
|
| horizon = int(payload.get("horizon", 0)) |
| if horizon <= 0: |
| raise ValueError("horizon must be a positive integer") |
|
|
| quantiles = bool(payload.get("quantiles", True)) |
| mc = int(payload.get("max_context", 1024)) |
| mh = int(payload.get("max_horizon", 256)) |
|
|
| series = payload.get("series") |
| inputs: List[np.ndarray] |
| if isinstance(series, list) and series and isinstance(series[0], (list, tuple, dict)): |
| |
| inputs = [_parse_series(s) for s in series] |
| else: |
| |
| inputs = [_parse_series(series)] |
|
|
| return inputs, horizon, quantiles, {"max_context": mc, "max_horizon": mh} |
|
|
| async def forecast(self, payload: Dict[str, Any]) -> Dict[str, Any]: |
| inputs, horizon, want_quantiles, cfg_overrides = self._prepare_inputs(payload) |
| self._ensure_model() |
|
|
| |
| try: |
| import timesfm |
| if cfg_overrides["max_context"] != 1024 or cfg_overrides["max_horizon"] != 256: |
| cfg = timesfm.ForecastConfig( |
| max_context=cfg_overrides["max_context"], |
| max_horizon=cfg_overrides["max_horizon"], |
| normalize_inputs=True, |
| use_continuous_quantile_head=want_quantiles, |
| force_flip_invariance=True, |
| infer_is_positive=True, |
| fix_quantile_crossing=True, |
| ) |
| self._model.compile(cfg) |
| except Exception: |
| pass |
|
|
| try: |
| point, quant = self._model.forecast(horizon=horizon, inputs=inputs) |
| point_list = [row.astype(float).tolist() for row in point] |
| quant_list = None |
| if want_quantiles and quant is not None: |
| |
| quant_list = [[row[h].astype(float).tolist() for h in range(row.shape[0])] for row in quant] |
| except Exception as e: |
| logger.exception("TimesFM 2.5 forecast failed") |
| raise RuntimeError(f"forecast failed: {e}") from e |
|
|
| |
| single = len(inputs) == 1 |
| return { |
| "model": self.model_id, |
| "horizon": horizon, |
| "forecast": point_list[0] if single else point_list, |
| "quantiles": (quant_list[0] if single else quant_list) if want_quantiles else None, |
| "backend": "timesfm-2.5", |
| } |
|
|
| async def stream(self, request: Dict[str, Any]): |
| rid = f"chatcmpl-timesfm-{int(time.time())}" |
| now = int(time.time()) |
| try: |
| result = await self.forecast(dict(request) if isinstance(request, dict) else {}) |
| content = json.dumps(result, separators=(",", ":"), ensure_ascii=False) |
| except Exception as e: |
| content = json.dumps({"error": str(e)}, separators=(",", ":"), ensure_ascii=False) |
|
|
| yield { |
| "id": rid, |
| "object": "chat.completion.chunk", |
| "created": now, |
| "model": self.model_id, |
| "choices": [ |
| {"index": 0, "delta": {"role": "assistant", "content": content}, "finish_reason": "stop"} |
| ], |
| } |
|
|
| class StubImagesBackend(ImagesBackend): |
| async def generate_b64(self, request: Dict[str, Any]) -> str: |
| logger.warning("Image generation not supported in TimesFM backend.") |
| return "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGP4BwQACfsD/etCJH0AAAAASUVORK5CYII=" |
|
|