File size: 14,818 Bytes
cf02581
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
from __future__ import annotations

import json
import os
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Union

import pandas as pd
import torch

try:
    from .configuration_jnu_tsb import JNUTSBConfig
    from .event_extractor import COVARIATE_COLUMNS, EventExtractor
except ImportError:  # pragma: no cover - local execution fallback
    from configuration_jnu_tsb import JNUTSBConfig
    from event_extractor import COVARIATE_COLUMNS, EventExtractor


class JNUTSBRuntime:
    """Runtime used by the model wrapper, pipeline, Endpoint handler, Gradio, and R examples.

    Routes inputs into three paths:
      1. stock only -> Chronos-2 forecast
      2. news only -> event extraction and daily covariates
      3. stock + news -> news covariates + stock context -> Chronos-2 forecast
    """

    def __init__(
        self,
        config: Union[JNUTSBConfig, Dict[str, Any]],
        chronos_device_map: Optional[str] = None,
        llm_device_map: Optional[str] = None,
    ) -> None:
        if isinstance(config, dict):
            config = JNUTSBConfig(**config)
        self.config = config
        self.chronos_device_map = chronos_device_map or os.getenv("JNU_TSB_CHRONOS_DEVICE_MAP", "cpu")
        self.llm_device_map = llm_device_map or os.getenv("JNU_TSB_LLM_DEVICE_MAP", "cpu")
        self._chronos = None
        self._llm_pipe = None
        self._extractor = None

    @classmethod
    def from_config(cls, config: Union[JNUTSBConfig, Dict[str, Any]], **kwargs: Any) -> "JNUTSBRuntime":
        return cls(config=config, **kwargs)

    @classmethod
    def from_config_dir(cls, path: Union[str, os.PathLike[str]], **kwargs: Any) -> "JNUTSBRuntime":
        path = Path(path)
        with open(path / "config.json", "r", encoding="utf-8") as f:
            payload = json.load(f)
        return cls(config=payload, **kwargs)

    @property
    def chronos(self):
        if self._chronos is None:
            try:
                from chronos import Chronos2Pipeline
            except Exception as exc:  # pragma: no cover
                raise ImportError(
                    "chronos-forecasting is required for Chronos-2 inference. "
                    "Install it with: pip install chronos-forecasting"
                ) from exc
            self._chronos = Chronos2Pipeline.from_pretrained(
                self.config.chronos_model_id,
                device_map=self.chronos_device_map,
            )
        return self._chronos

    @property
    def extractor(self) -> EventExtractor:
        if self._extractor is None:
            self._extractor = EventExtractor(
                generate_fn=self._generate_with_polyglot if self.config.use_llm_extractor else None,
                categories=self.config.event_categories,
                use_llm=self.config.use_llm_extractor,
            )
        return self._extractor

    def _generate_with_polyglot(self, prompt: str) -> str:
        if self._llm_pipe is None:
            from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline as hf_pipeline

            tokenizer = AutoTokenizer.from_pretrained(self.config.llm_model_id)
            model = AutoModelForCausalLM.from_pretrained(
                self.config.llm_model_id,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                device_map=self.llm_device_map,
            )
            self._llm_pipe = hf_pipeline(
                "text-generation",
                model=model,
                tokenizer=tokenizer,
            )
        output = self._llm_pipe(
            prompt,
            max_new_tokens=96,
            do_sample=False,
            return_full_text=False,
        )
        if isinstance(output, list) and output:
            return output[0].get("generated_text", "")
        return str(output)

    def predict(
        self,
        inputs: Optional[Dict[str, Any]] = None,
        prediction_length: Optional[int] = None,
        quantile_levels: Optional[Sequence[float]] = None,
        use_llm_extractor: Optional[bool] = None,
        allow_naive_fallback: Optional[bool] = None,
        **kwargs: Any,
    ) -> Dict[str, Any]:
        payload: Dict[str, Any] = dict(inputs or {})
        payload.update(kwargs)

        if use_llm_extractor is not None and bool(use_llm_extractor) != self.config.use_llm_extractor:
            # Rebuild extractor with the requested setting for this runtime instance.
            self.config.use_llm_extractor = bool(use_llm_extractor)
            self._extractor = None

        prediction_length = int(prediction_length or self.config.prediction_length)
        quantile_levels = list(quantile_levels or self.config.quantile_levels)
        allow_naive_fallback = self.config.allow_naive_fallback if allow_naive_fallback is None else bool(allow_naive_fallback)

        news = payload.get("news")
        stock = payload.get("stock")
        future_news = payload.get("future_news")
        future_covariates = payload.get("future_covariates")

        has_news = bool(news)
        stock_df = self._prepare_stock_df(stock)
        has_stock = stock_df is not None and not stock_df.empty

        if has_news and has_stock:
            context_df = self._merge_news_covariates(stock_df, news)
            future_df = self._prepare_future_covariates(
                stock_df=context_df,
                future_news=future_news,
                future_covariates=future_covariates,
                prediction_length=prediction_length,
            )
            return self._forecast(
                context_df=context_df,
                prediction_length=prediction_length,
                quantile_levels=quantile_levels,
                route="hybrid",
                future_df=future_df,
                allow_naive_fallback=allow_naive_fallback,
            )

        if has_stock:
            return self._forecast(
                context_df=stock_df,
                prediction_length=prediction_length,
                quantile_levels=quantile_levels,
                route="chronos_only",
                future_df=None,
                allow_naive_fallback=allow_naive_fallback,
            )

        if has_news:
            events = [self.extractor.extract(item.get("title") or item.get("headline") or item.get("text") or "") for item in news]
            daily_covariates = self.extractor.aggregate_to_daily(news)
            return {
                "route": "text_only",
                "repo_id": self.config.repo_id,
                "events": events,
                "daily_covariates": self._df_to_records(daily_covariates),
            }

        raise ValueError("JNU-TSB expects at least one of: stock, news.")

    def _forecast(
        self,
        context_df: pd.DataFrame,
        prediction_length: int,
        quantile_levels: Sequence[float],
        route: str,
        future_df: Optional[pd.DataFrame] = None,
        allow_naive_fallback: bool = True,
    ) -> Dict[str, Any]:
        try:
            kwargs = dict(
                prediction_length=prediction_length,
                quantile_levels=list(quantile_levels),
                id_column=self.config.id_column,
                timestamp_column=self.config.timestamp_column,
                target=self.config.target_column,
            )
            if future_df is not None and not future_df.empty:
                pred = self.chronos.predict_df(context_df, future_df=future_df, **kwargs)
            else:
                pred = self.chronos.predict_df(context_df, **kwargs)
            return {
                "route": route,
                "repo_id": self.config.repo_id,
                "engine": self.config.chronos_model_id,
                "forecast": self._df_to_records(pred),
                "used_naive_fallback": False,
            }
        except Exception as exc:
            if not allow_naive_fallback:
                raise
            pred = self._naive_forecast(context_df, prediction_length, quantile_levels)
            return {
                "route": route,
                "repo_id": self.config.repo_id,
                "engine": "naive_last_value_fallback",
                "forecast": self._df_to_records(pred),
                "used_naive_fallback": True,
                "warning": f"Chronos-2 inference failed or was unavailable: {type(exc).__name__}: {exc}",
            }

    def _prepare_stock_df(self, stock: Any) -> Optional[pd.DataFrame]:
        if stock is None:
            return None
        if isinstance(stock, pd.DataFrame):
            df = stock.copy()
        elif isinstance(stock, list):
            df = pd.DataFrame(stock)
        elif isinstance(stock, dict):
            df = pd.DataFrame(stock)
        else:
            raise TypeError("stock must be a pandas DataFrame, list of dicts, or dict of columns.")

        if df.empty:
            return df

        timestamp_col = self.config.timestamp_column
        if timestamp_col not in df.columns:
            for cand in ("date", "Date", "datetime", "time"):
                if cand in df.columns:
                    df = df.rename(columns={cand: timestamp_col})
                    break

        target_col = self.config.target_column
        if target_col not in df.columns:
            for cand in ("close", "Close", "price", "value", "y"):
                if cand in df.columns:
                    df = df.rename(columns={cand: target_col})
                    break

        if timestamp_col not in df.columns or target_col not in df.columns:
            raise ValueError(f"stock must contain '{timestamp_col}' and '{target_col}' columns.")

        if self.config.id_column not in df.columns:
            df[self.config.id_column] = self.config.default_item_id

        df[timestamp_col] = pd.to_datetime(df[timestamp_col])
        df = df.sort_values([self.config.id_column, timestamp_col]).reset_index(drop=True)
        return df

    def _prepare_future_df(self, data: Any) -> Optional[pd.DataFrame]:
        if data is None:
            return None
        if isinstance(data, pd.DataFrame):
            df = data.copy()
        elif isinstance(data, list):
            df = pd.DataFrame(data)
        elif isinstance(data, dict):
            df = pd.DataFrame(data)
        else:
            raise TypeError("future_covariates must be a pandas DataFrame, list of dicts, or dict of columns.")

        if df.empty:
            return df

        timestamp_col = self.config.timestamp_column
        if timestamp_col not in df.columns:
            for cand in ("date", "Date", "datetime", "time"):
                if cand in df.columns:
                    df = df.rename(columns={cand: timestamp_col})
                    break
        if timestamp_col not in df.columns:
            raise ValueError(f"future_covariates must contain a '{timestamp_col}' column.")
        if self.config.id_column not in df.columns:
            df[self.config.id_column] = self.config.default_item_id

        df[timestamp_col] = pd.to_datetime(df[timestamp_col])
        df = df.sort_values([self.config.id_column, timestamp_col]).reset_index(drop=True)
        return df

    def _merge_news_covariates(self, stock_df: pd.DataFrame, news: Iterable[Dict[str, Any]]) -> pd.DataFrame:
        cov = self.extractor.aggregate_to_daily(news)
        context = stock_df.copy()
        day_col = "__day__"
        context[day_col] = pd.to_datetime(context[self.config.timestamp_column]).dt.floor("D")
        cov = cov.rename(columns={"timestamp": day_col})
        merged = context.merge(cov, on=day_col, how="left").drop(columns=[day_col])
        for col in COVARIATE_COLUMNS:
            if col in merged.columns:
                merged[col] = merged[col].fillna(0).astype(float)
        return merged

    def _prepare_future_covariates(
        self,
        stock_df: pd.DataFrame,
        future_news: Optional[Iterable[Dict[str, Any]]],
        future_covariates: Any,
        prediction_length: int,
    ) -> Optional[pd.DataFrame]:
        if future_covariates is not None:
            fut = self._prepare_future_df(future_covariates)
            if fut is not None and not fut.empty:
                return fut.drop(columns=[self.config.target_column], errors="ignore")

        if not future_news:
            return None

        first_id = stock_df[self.config.id_column].iloc[0]
        last_ts = pd.to_datetime(stock_df[self.config.timestamp_column]).max()
        freq = pd.infer_freq(pd.to_datetime(stock_df[self.config.timestamp_column]).drop_duplicates().sort_values()) or "D"
        future_dates = pd.date_range(start=last_ts, periods=prediction_length + 1, freq=freq)[1:]
        base = pd.DataFrame({
            self.config.timestamp_column: future_dates,
            self.config.id_column: first_id,
        })

        cov = self.extractor.aggregate_to_daily(future_news)
        if cov.empty:
            return base
        cov_day = cov.rename(columns={"timestamp": "__day__"})
        base["__day__"] = pd.to_datetime(base[self.config.timestamp_column]).dt.floor("D")
        merged = base.merge(cov_day, on="__day__", how="left").drop(columns=["__day__"])
        for col in COVARIATE_COLUMNS:
            if col in merged.columns:
                merged[col] = merged[col].fillna(0).astype(float)
        return merged

    def _naive_forecast(self, context_df: pd.DataFrame, prediction_length: int, quantile_levels: Sequence[float]) -> pd.DataFrame:
        timestamp_col = self.config.timestamp_column
        target_col = self.config.target_column
        id_col = self.config.id_column

        rows: List[Dict[str, Any]] = []
        for item_id, group in context_df.groupby(id_col):
            group = group.sort_values(timestamp_col)
            last_ts = pd.to_datetime(group[timestamp_col].iloc[-1])
            last_value = float(group[target_col].iloc[-1])
            freq = pd.infer_freq(pd.to_datetime(group[timestamp_col]).drop_duplicates().sort_values()) or "D"
            dates = pd.date_range(start=last_ts, periods=prediction_length + 1, freq=freq)[1:]
            for ts in dates:
                row: Dict[str, Any] = {id_col: item_id, timestamp_col: ts}
                for q in quantile_levels:
                    row[str(q)] = last_value
                    row[f"q{q}"] = last_value
                row["mean"] = last_value
                row["prediction"] = last_value
                rows.append(row)
        return pd.DataFrame(rows)

    def _df_to_records(self, df: pd.DataFrame) -> List[Dict[str, Any]]:
        out = df.copy()
        for col in out.columns:
            if pd.api.types.is_datetime64_any_dtype(out[col]):
                out[col] = out[col].astype(str)
        return out.to_dict(orient="records")