HONGRIZON commited on
Commit
40a8006
·
verified ·
1 Parent(s): 3a30dcb

Delete runtime.py

Browse files
Files changed (1) hide show
  1. runtime.py +0 -374
runtime.py DELETED
@@ -1,374 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import json
4
- import os
5
- from pathlib import Path
6
- from typing import Any, Dict, Iterable, List, Optional, Sequence, Union
7
-
8
- import pandas as pd
9
- import torch
10
-
11
- try:
12
- from .configuration_jnu_tsb import JNUTSBConfig
13
- from .event_extractor import COVARIATE_COLUMNS, EventExtractor
14
- except ImportError: # pragma: no cover - local execution fallback
15
- from configuration_jnu_tsb import JNUTSBConfig
16
- from event_extractor import COVARIATE_COLUMNS, EventExtractor
17
-
18
-
19
- class JNUTSBRuntime:
20
- """Runtime used by the custom pipeline, handler.py, Gradio Space, and R examples.
21
-
22
- Routes inputs into three paths:
23
- 1. stock only -> Chronos-2 forecast
24
- 2. news only -> Polyglot/keyword event extraction
25
- 3. both -> news covariates + stock context -> Chronos-2 forecast
26
- """
27
-
28
- def __init__(
29
- self,
30
- chronos_model_id: str = "amazon/chronos-2",
31
- llm_model_id: str = "EleutherAI/polyglot-ko-1.3b",
32
- device: Optional[str] = None,
33
- quantile_levels: Optional[Sequence[float]] = None,
34
- use_llm_extractor: bool = True,
35
- max_new_tokens: int = 96,
36
- ) -> None:
37
- self.chronos_model_id = chronos_model_id
38
- self.llm_model_id = llm_model_id
39
- self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
40
- self.quantile_levels = list(quantile_levels or [0.1, 0.5, 0.9])
41
- self.use_llm_extractor = bool(use_llm_extractor)
42
- self.max_new_tokens = int(max_new_tokens)
43
-
44
- self._chronos = None
45
- self._tokenizer = None
46
- self._llm = None
47
- self._extractor = None
48
-
49
- @classmethod
50
- def from_config(cls, config: Union[JNUTSBConfig, Dict[str, Any]], **overrides: Any) -> "JNUTSBRuntime":
51
- if isinstance(config, JNUTSBConfig):
52
- data = config.to_runtime_dict()
53
- else:
54
- data = dict(config)
55
- data.update({k: v for k, v in overrides.items() if v is not None})
56
- return cls(
57
- chronos_model_id=data.get("chronos_model_id", "amazon/chronos-2"),
58
- llm_model_id=data.get("llm_model_id", "EleutherAI/polyglot-ko-1.3b"),
59
- quantile_levels=data.get("quantile_levels", [0.1, 0.5, 0.9]),
60
- use_llm_extractor=data.get("use_llm_extractor", True),
61
- device=data.get("device"),
62
- max_new_tokens=data.get("max_new_tokens", 96),
63
- )
64
-
65
- @classmethod
66
- def from_config_dir(cls, model_dir: Union[str, os.PathLike[str]], **overrides: Any) -> "JNUTSBRuntime":
67
- config_path = Path(model_dir) / "config.json"
68
- with open(config_path, "r", encoding="utf-8") as f:
69
- config = json.load(f)
70
- return cls.from_config(config, **overrides)
71
-
72
- @property
73
- def chronos(self):
74
- if self._chronos is None:
75
- from chronos import Chronos2Pipeline
76
-
77
- self._chronos = Chronos2Pipeline.from_pretrained(
78
- self.chronos_model_id,
79
- device_map=self.device,
80
- )
81
- return self._chronos
82
-
83
- @property
84
- def tokenizer(self):
85
- if self._tokenizer is None:
86
- from transformers import AutoTokenizer
87
-
88
- self._tokenizer = AutoTokenizer.from_pretrained(self.llm_model_id)
89
- if self._tokenizer.pad_token is None:
90
- self._tokenizer.pad_token = self._tokenizer.eos_token
91
- return self._tokenizer
92
-
93
- @property
94
- def llm(self):
95
- if self._llm is None:
96
- from transformers import AutoModelForCausalLM
97
-
98
- dtype = torch.float16 if self.device.startswith("cuda") else torch.float32
99
- self._llm = AutoModelForCausalLM.from_pretrained(
100
- self.llm_model_id,
101
- torch_dtype=dtype,
102
- device_map="auto" if self.device.startswith("cuda") else None,
103
- )
104
- if not self.device.startswith("cuda"):
105
- self._llm.to(self.device)
106
- self._llm.eval()
107
- return self._llm
108
-
109
- @property
110
- def extractor(self) -> EventExtractor:
111
- if self._extractor is None:
112
- self._extractor = EventExtractor(
113
- generate_fn=self._generate_text if self.use_llm_extractor else None,
114
- use_llm=self.use_llm_extractor,
115
- )
116
- return self._extractor
117
-
118
- def _generate_text(self, prompt: str) -> str:
119
- tokenizer = self.tokenizer
120
- model = self.llm
121
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1536)
122
- model_device = next(model.parameters()).device
123
- inputs = {k: v.to(model_device) for k, v in inputs.items()}
124
- with torch.no_grad():
125
- out = model.generate(
126
- **inputs,
127
- max_new_tokens=self.max_new_tokens,
128
- do_sample=False,
129
- pad_token_id=tokenizer.pad_token_id,
130
- eos_token_id=tokenizer.eos_token_id,
131
- )
132
- gen_ids = out[0, inputs["input_ids"].shape[1]:]
133
- return tokenizer.decode(gen_ids, skip_special_tokens=True)
134
-
135
- def predict(
136
- self,
137
- inputs: Optional[Dict[str, Any]] = None,
138
- *,
139
- news: Optional[Iterable[Dict[str, Any]]] = None,
140
- stock: Optional[Union[pd.DataFrame, List[Dict[str, Any]], Dict[str, Any], str, os.PathLike[str]]] = None,
141
- future_news: Optional[Iterable[Dict[str, Any]]] = None,
142
- future_covariates: Optional[Union[pd.DataFrame, List[Dict[str, Any]], Dict[str, Any], str, os.PathLike[str]]] = None,
143
- prediction_length: int = 5,
144
- quantile_levels: Optional[Sequence[float]] = None,
145
- timestamp_column: str = "timestamp",
146
- target: str = "target",
147
- id_column: str = "item_id",
148
- use_llm_extractor: Optional[bool] = None,
149
- ) -> Dict[str, Any]:
150
- inputs = dict(inputs or {})
151
- news = news if news is not None else inputs.get("news")
152
- stock = stock if stock is not None else inputs.get("stock")
153
- future_news = future_news if future_news is not None else inputs.get("future_news")
154
- future_covariates = future_covariates if future_covariates is not None else inputs.get("future_covariates")
155
-
156
- old_use_llm = self.use_llm_extractor
157
- if use_llm_extractor is not None and bool(use_llm_extractor) != old_use_llm:
158
- self.use_llm_extractor = bool(use_llm_extractor)
159
- self._extractor = None
160
-
161
- try:
162
- news_list = self._normalize_news(news)
163
- future_news_list = self._normalize_news(future_news)
164
- has_text = len(news_list) > 0
165
-
166
- stock_df = self._to_dataframe(stock) if stock is not None else None
167
- has_numeric = stock_df is not None and len(stock_df) > 0
168
-
169
- if not has_text and not has_numeric:
170
- raise ValueError("news와 stock 중 최소 하나는 필요합니다.")
171
-
172
- q = list(quantile_levels or self.quantile_levels)
173
-
174
- if has_text and not has_numeric:
175
- daily_covariates = self.extractor.aggregate_to_daily(news_list, timestamp_column=timestamp_column)
176
- return {
177
- "model": "JNU-TSB",
178
- "route": "text_only",
179
- "events": [self.extractor.extract(str(item.get("title", ""))) for item in news_list],
180
- "daily_covariates": self._df_to_records(daily_covariates),
181
- }
182
-
183
- stock_df = self._prepare_stock_df(stock_df, timestamp_column=timestamp_column, target=target, id_column=id_column)
184
-
185
- if has_text and has_numeric:
186
- context_df = self._merge_news_covariates(stock_df, news_list, timestamp_column=timestamp_column)
187
- future_df = self._prepare_future_covariates(
188
- stock_df=context_df,
189
- future_news=future_news_list,
190
- future_covariates=future_covariates,
191
- prediction_length=prediction_length,
192
- timestamp_column=timestamp_column,
193
- id_column=id_column,
194
- )
195
- pred = self._predict_chronos_df(
196
- context_df,
197
- future_df=future_df,
198
- prediction_length=int(prediction_length),
199
- quantile_levels=q,
200
- id_column=id_column,
201
- timestamp_column=timestamp_column,
202
- target=target,
203
- )
204
- return {
205
- "model": "JNU-TSB",
206
- "route": "hybrid",
207
- "prediction": self._df_to_records(pred),
208
- "context_columns": list(context_df.columns),
209
- "future_covariates_used": future_df is not None,
210
- "notes": "News was converted to daily covariates and merged into the Chronos-2 context.",
211
- }
212
-
213
- pred = self._predict_chronos_df(
214
- stock_df,
215
- future_df=None,
216
- prediction_length=int(prediction_length),
217
- quantile_levels=q,
218
- id_column=id_column,
219
- timestamp_column=timestamp_column,
220
- target=target,
221
- )
222
- return {"model": "JNU-TSB", "route": "chronos_only", "prediction": self._df_to_records(pred)}
223
- finally:
224
- if use_llm_extractor is not None and bool(use_llm_extractor) != old_use_llm:
225
- self.use_llm_extractor = old_use_llm
226
- self._extractor = None
227
-
228
- def __call__(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
229
- return self.predict(*args, **kwargs)
230
-
231
- def _normalize_news(self, news: Optional[Iterable[Dict[str, Any]]]) -> List[Dict[str, Any]]:
232
- if news is None:
233
- return []
234
- if isinstance(news, dict):
235
- if "data" in news and isinstance(news["data"], list):
236
- news = news["data"]
237
- else:
238
- news = [news]
239
- out: List[Dict[str, Any]] = []
240
- for item in list(news):
241
- if not isinstance(item, dict):
242
- continue
243
- title = item.get("title") or item.get("headline") or item.get("text") or ""
244
- date = item.get("date") or item.get("timestamp")
245
- normalized = dict(item)
246
- normalized["title"] = str(title)
247
- if date is not None:
248
- normalized["date"] = date
249
- out.append(normalized)
250
- return out
251
-
252
- def _to_dataframe(self, data: Union[pd.DataFrame, List[Dict[str, Any]], Dict[str, Any], str, os.PathLike[str], None]) -> Optional[pd.DataFrame]:
253
- if data is None:
254
- return None
255
- if isinstance(data, pd.DataFrame):
256
- return data.copy()
257
- if isinstance(data, (str, os.PathLike)):
258
- return pd.read_csv(data)
259
- if isinstance(data, list):
260
- return pd.DataFrame(data)
261
- if isinstance(data, dict):
262
- if "data" in data and isinstance(data["data"], list):
263
- return pd.DataFrame(data["data"])
264
- try:
265
- return pd.DataFrame(data)
266
- except ValueError:
267
- return pd.DataFrame([data])
268
- raise TypeError(f"지원하지 않는 데이터 타입입니다: {type(data)}")
269
-
270
- def _prepare_stock_df(self, df: pd.DataFrame, timestamp_column: str, target: str, id_column: str) -> pd.DataFrame:
271
- df = df.copy()
272
- if timestamp_column not in df.columns:
273
- for candidate in ["date", "Date", "datetime", "time"]:
274
- if candidate in df.columns:
275
- df = df.rename(columns={candidate: timestamp_column})
276
- break
277
- if target not in df.columns:
278
- for candidate in ["close", "Close", "price", "value", "y"]:
279
- if candidate in df.columns:
280
- df = df.rename(columns={candidate: target})
281
- break
282
- if timestamp_column not in df.columns or target not in df.columns:
283
- raise ValueError(f"stock에는 `{timestamp_column}`와 `{target}` 컬럼이 필요합니다.")
284
- if id_column not in df.columns:
285
- df[id_column] = "series_0"
286
- df[timestamp_column] = pd.to_datetime(df[timestamp_column])
287
- df[target] = pd.to_numeric(df[target], errors="coerce")
288
- df = df.dropna(subset=[timestamp_column, target])
289
- return df.sort_values([id_column, timestamp_column]).reset_index(drop=True)
290
-
291
- def _merge_news_covariates(self, stock_df: pd.DataFrame, news: Iterable[Dict[str, Any]], timestamp_column: str) -> pd.DataFrame:
292
- cov = self.extractor.aggregate_to_daily(news, timestamp_column=timestamp_column)
293
- context = stock_df.copy()
294
- day_col = "__day__"
295
- context[day_col] = pd.to_datetime(context[timestamp_column]).dt.floor("D")
296
- cov = cov.rename(columns={timestamp_column: day_col})
297
- merged = context.merge(cov, on=day_col, how="left").drop(columns=[day_col])
298
- for col in COVARIATE_COLUMNS:
299
- if col not in merged.columns:
300
- merged[col] = 0.0
301
- merged[col] = merged[col].fillna(0.0).astype(float)
302
- return merged
303
-
304
- def _prepare_future_covariates(
305
- self,
306
- stock_df: pd.DataFrame,
307
- future_news: Optional[List[Dict[str, Any]]],
308
- future_covariates: Optional[Union[pd.DataFrame, List[Dict[str, Any]], Dict[str, Any], str, os.PathLike[str]]],
309
- prediction_length: int,
310
- timestamp_column: str,
311
- id_column: str,
312
- ) -> Optional[pd.DataFrame]:
313
- if future_covariates is not None:
314
- fut = self._to_dataframe(future_covariates)
315
- if fut is None or len(fut) == 0:
316
- return None
317
- if id_column not in fut.columns:
318
- fut[id_column] = stock_df[id_column].iloc[0]
319
- if timestamp_column not in fut.columns:
320
- raise ValueError(f"future_covariates에는 `{timestamp_column}` 컬럼이 필요합니다.")
321
- fut[timestamp_column] = pd.to_datetime(fut[timestamp_column])
322
- for col in COVARIATE_COLUMNS:
323
- if col not in fut.columns:
324
- fut[col] = 0.0
325
- return fut
326
-
327
- if not future_news:
328
- return None
329
-
330
- first_id = stock_df[id_column].iloc[0]
331
- timestamps = pd.to_datetime(stock_df[timestamp_column]).drop_duplicates().sort_values()
332
- last_ts = timestamps.max()
333
- freq = pd.infer_freq(timestamps) or "D"
334
- future_dates = pd.date_range(start=last_ts, periods=int(prediction_length) + 1, freq=freq)[1:]
335
- base = pd.DataFrame({id_column: first_id, timestamp_column: future_dates})
336
-
337
- cov = self.extractor.aggregate_to_daily(future_news, timestamp_column=timestamp_column)
338
- base["__day__"] = pd.to_datetime(base[timestamp_column]).dt.floor("D")
339
- cov = cov.rename(columns={timestamp_column: "__day__"})
340
- fut = base.merge(cov, on="__day__", how="left").drop(columns=["__day__"])
341
- for col in COVARIATE_COLUMNS:
342
- if col not in fut.columns:
343
- fut[col] = 0.0
344
- fut[col] = fut[col].fillna(0.0).astype(float)
345
- return fut
346
-
347
- def _predict_chronos_df(
348
- self,
349
- context_df: pd.DataFrame,
350
- *,
351
- future_df: Optional[pd.DataFrame],
352
- prediction_length: int,
353
- quantile_levels: Sequence[float],
354
- id_column: str,
355
- timestamp_column: str,
356
- target: str,
357
- ) -> pd.DataFrame:
358
- kwargs: Dict[str, Any] = {
359
- "prediction_length": int(prediction_length),
360
- "quantile_levels": list(quantile_levels),
361
- "id_column": id_column,
362
- "timestamp_column": timestamp_column,
363
- "target": target,
364
- }
365
- if future_df is not None:
366
- kwargs["future_df"] = future_df
367
- return self.chronos.predict_df(context_df, **kwargs)
368
-
369
- def _df_to_records(self, df: pd.DataFrame) -> List[Dict[str, Any]]:
370
- out = df.copy()
371
- for col in out.columns:
372
- if pd.api.types.is_datetime64_any_dtype(out[col]):
373
- out[col] = out[col].astype(str)
374
- return out.to_dict(orient="records")