HONGRIZON commited on
Commit
877bd6f
·
verified ·
1 Parent(s): bd81bdd

Delete pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +0 -70
pipeline.py DELETED
@@ -1,70 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from typing import Any, Dict
4
-
5
- from transformers import Pipeline
6
-
7
- try:
8
- from .runtime import JNUTSBRuntime
9
- except ImportError: # pragma: no cover - local execution fallback
10
- from runtime import JNUTSBRuntime
11
-
12
-
13
- class JNUTSBPipeline(Pipeline):
14
- """Transformers custom pipeline for JNU-TSB.
15
-
16
- Example:
17
- from transformers import pipeline
18
- pipe = pipeline(
19
- task="jnu-tsb",
20
- model="HONGRIZON/JNU-TSB",
21
- trust_remote_code=True,
22
- device=-1,
23
- )
24
- pipe({"stock": [...], "news": [...]}, prediction_length=5)
25
- """
26
-
27
- def __init__(self, *args: Any, **kwargs: Any) -> None:
28
- self._runtime = None
29
- super().__init__(*args, **kwargs)
30
-
31
- @property
32
- def runtime(self) -> JNUTSBRuntime:
33
- if self._runtime is None:
34
- device = getattr(self, "device", None)
35
- if device is None:
36
- device_str = None
37
- elif getattr(device, "type", "cpu") == "cuda":
38
- idx = getattr(device, "index", None)
39
- device_str = "cuda" if idx is None else f"cuda:{idx}"
40
- else:
41
- device_str = "cpu"
42
- self._runtime = JNUTSBRuntime.from_config(self.model.config, device=device_str)
43
- return self._runtime
44
-
45
- def _sanitize_parameters(self, **kwargs: Any):
46
- allowed = {
47
- "news",
48
- "stock",
49
- "future_news",
50
- "future_covariates",
51
- "prediction_length",
52
- "quantile_levels",
53
- "timestamp_column",
54
- "target",
55
- "id_column",
56
- "use_llm_extractor",
57
- }
58
- forward_kwargs = {k: v for k, v in kwargs.items() if k in allowed}
59
- return {}, forward_kwargs, {}
60
-
61
- def preprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
62
- if not isinstance(inputs, dict):
63
- raise TypeError("inputs는 {'stock': ..., 'news': ...} 형태의 dict여야 합니다.")
64
- return inputs
65
-
66
- def _forward(self, model_inputs: Dict[str, Any], **forward_kwargs: Any) -> Dict[str, Any]:
67
- return self.runtime.predict(inputs=model_inputs, **forward_kwargs)
68
-
69
- def postprocess(self, model_outputs: Dict[str, Any]) -> Dict[str, Any]:
70
- return model_outputs