from __future__ import annotations from typing import Any, Dict, Optional import torch from torch import nn from transformers import PreTrainedModel try: from .configuration_jnu_tsb import JNUTSBConfig except ImportError: # pragma: no cover - local execution fallback from configuration_jnu_tsb import JNUTSBConfig class JNUTSBModel(PreTrainedModel): """Tiny Hugging Face model wrapper for JNU-TSB. The actual computation lives in ``runtime.JNUTSBRuntime``. This class exists so that ``AutoModel.from_pretrained(..., trust_remote_code=True)`` and the custom Transformers pipeline can load the repo like a normal HF model. """ config_class = JNUTSBConfig base_model_prefix = "jnu_tsb" main_input_name = "inputs" def __init__(self, config: JNUTSBConfig) -> None: super().__init__(config) self.dummy = nn.Parameter(torch.zeros(1), requires_grad=False) self._runtime = None def forward(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: return { "message": "JNU-TSB is a router wrapper. Use model.predict(...) or pipeline(task='jnu-tsb', ...).", "repo_id": self.config.repo_id, } def get_runtime(self): if self._runtime is None: try: from .runtime import JNUTSBRuntime except ImportError: # pragma: no cover from runtime import JNUTSBRuntime self._runtime = JNUTSBRuntime.from_config(self.config) return self._runtime def predict(self, inputs: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Any: """Run the 3-way router. Supports either: model.predict({"stock": ..., "news": ...}, prediction_length=5) or: model.predict(stock=..., news=..., prediction_length=5) """ payload = dict(inputs or {}) for key in ("stock", "news", "future_news", "future_covariates"): if key in kwargs: payload[key] = kwargs.pop(key) return self.get_runtime().predict(inputs=payload, **kwargs)