from __future__ import annotations from typing import Any, Dict, Optional, Tuple from transformers import Pipeline class JNUTSBPipeline(Pipeline): """Custom Transformers pipeline for JNU-TSB. Example: from transformers import pipeline pipe = pipeline("jnu-tsb", model="HONGRIZON/JNU-TSB", trust_remote_code=True) pipe({"stock": [...], "news": [...]}, prediction_length=5) """ def _sanitize_parameters( self, prediction_length: Optional[int] = None, quantile_levels: Optional[list] = None, use_llm_extractor: Optional[bool] = None, allow_naive_fallback: Optional[bool] = None, **kwargs: Any, ) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: forward_params: Dict[str, Any] = dict(kwargs) if prediction_length is not None: forward_params["prediction_length"] = int(prediction_length) if quantile_levels is not None: forward_params["quantile_levels"] = quantile_levels if use_llm_extractor is not None: forward_params["use_llm_extractor"] = bool(use_llm_extractor) if allow_naive_fallback is not None: forward_params["allow_naive_fallback"] = bool(allow_naive_fallback) return {}, forward_params, {} def preprocess(self, inputs: Any, **preprocess_params: Any) -> Any: if inputs is None: raise ValueError("JNU-TSB expects a dict with 'stock', 'news', or both.") return inputs def _forward(self, model_inputs: Any, **forward_params: Any) -> Any: if not hasattr(self.model, "predict"): raise TypeError("The loaded model does not expose a predict(...) method.") return self.model.predict(model_inputs, **forward_params) def postprocess(self, model_outputs: Any, **postprocess_params: Any) -> Any: return model_outputs