File size: 1,894 Bytes
cf02581
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from __future__ import annotations

from typing import Any, Dict, Optional, Tuple

from transformers import Pipeline


class JNUTSBPipeline(Pipeline):
    """Custom Transformers pipeline for JNU-TSB.

    Example:
        from transformers import pipeline
        pipe = pipeline("jnu-tsb", model="HONGRIZON/JNU-TSB", trust_remote_code=True)
        pipe({"stock": [...], "news": [...]}, prediction_length=5)
    """

    def _sanitize_parameters(
        self,
        prediction_length: Optional[int] = None,
        quantile_levels: Optional[list] = None,
        use_llm_extractor: Optional[bool] = None,
        allow_naive_fallback: Optional[bool] = None,
        **kwargs: Any,
    ) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
        forward_params: Dict[str, Any] = dict(kwargs)
        if prediction_length is not None:
            forward_params["prediction_length"] = int(prediction_length)
        if quantile_levels is not None:
            forward_params["quantile_levels"] = quantile_levels
        if use_llm_extractor is not None:
            forward_params["use_llm_extractor"] = bool(use_llm_extractor)
        if allow_naive_fallback is not None:
            forward_params["allow_naive_fallback"] = bool(allow_naive_fallback)
        return {}, forward_params, {}

    def preprocess(self, inputs: Any, **preprocess_params: Any) -> Any:
        if inputs is None:
            raise ValueError("JNU-TSB expects a dict with 'stock', 'news', or both.")
        return inputs

    def _forward(self, model_inputs: Any, **forward_params: Any) -> Any:
        if not hasattr(self.model, "predict"):
            raise TypeError("The loaded model does not expose a predict(...) method.")
        return self.model.predict(model_inputs, **forward_params)

    def postprocess(self, model_outputs: Any, **postprocess_params: Any) -> Any:
        return model_outputs