File size: 2,093 Bytes
cf02581
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from __future__ import annotations

from typing import Any, Dict, Optional

import torch
from torch import nn
from transformers import PreTrainedModel

try:
    from .configuration_jnu_tsb import JNUTSBConfig
except ImportError:  # pragma: no cover - local execution fallback
    from configuration_jnu_tsb import JNUTSBConfig


class JNUTSBModel(PreTrainedModel):
    """Tiny Hugging Face model wrapper for JNU-TSB.

    The actual computation lives in ``runtime.JNUTSBRuntime``. This class exists
    so that ``AutoModel.from_pretrained(..., trust_remote_code=True)`` and the
    custom Transformers pipeline can load the repo like a normal HF model.
    """

    config_class = JNUTSBConfig
    base_model_prefix = "jnu_tsb"
    main_input_name = "inputs"

    def __init__(self, config: JNUTSBConfig) -> None:
        super().__init__(config)
        self.dummy = nn.Parameter(torch.zeros(1), requires_grad=False)
        self._runtime = None

    def forward(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
        return {
            "message": "JNU-TSB is a router wrapper. Use model.predict(...) or pipeline(task='jnu-tsb', ...).",
            "repo_id": self.config.repo_id,
        }

    def get_runtime(self):
        if self._runtime is None:
            try:
                from .runtime import JNUTSBRuntime
            except ImportError:  # pragma: no cover
                from runtime import JNUTSBRuntime
            self._runtime = JNUTSBRuntime.from_config(self.config)
        return self._runtime

    def predict(self, inputs: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Any:
        """Run the 3-way router.

        Supports either:
          model.predict({"stock": ..., "news": ...}, prediction_length=5)
        or:
          model.predict(stock=..., news=..., prediction_length=5)
        """
        payload = dict(inputs or {})
        for key in ("stock", "news", "future_news", "future_covariates"):
            if key in kwargs:
                payload[key] = kwargs.pop(key)
        return self.get_runtime().predict(inputs=payload, **kwargs)