JNU-TSB / pipeline.py
HONGRIZON's picture
Upload 18 files
cf02581 verified
from __future__ import annotations
from typing import Any, Dict, Optional, Tuple
from transformers import Pipeline
class JNUTSBPipeline(Pipeline):
"""Custom Transformers pipeline for JNU-TSB.
Example:
from transformers import pipeline
pipe = pipeline("jnu-tsb", model="HONGRIZON/JNU-TSB", trust_remote_code=True)
pipe({"stock": [...], "news": [...]}, prediction_length=5)
"""
def _sanitize_parameters(
self,
prediction_length: Optional[int] = None,
quantile_levels: Optional[list] = None,
use_llm_extractor: Optional[bool] = None,
allow_naive_fallback: Optional[bool] = None,
**kwargs: Any,
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
forward_params: Dict[str, Any] = dict(kwargs)
if prediction_length is not None:
forward_params["prediction_length"] = int(prediction_length)
if quantile_levels is not None:
forward_params["quantile_levels"] = quantile_levels
if use_llm_extractor is not None:
forward_params["use_llm_extractor"] = bool(use_llm_extractor)
if allow_naive_fallback is not None:
forward_params["allow_naive_fallback"] = bool(allow_naive_fallback)
return {}, forward_params, {}
def preprocess(self, inputs: Any, **preprocess_params: Any) -> Any:
if inputs is None:
raise ValueError("JNU-TSB expects a dict with 'stock', 'news', or both.")
return inputs
def _forward(self, model_inputs: Any, **forward_params: Any) -> Any:
if not hasattr(self.model, "predict"):
raise TypeError("The loaded model does not expose a predict(...) method.")
return self.model.predict(model_inputs, **forward_params)
def postprocess(self, model_outputs: Any, **postprocess_params: Any) -> Any:
return model_outputs