JNU-TSB / configuration_jnu_tsb.py
HONGRIZON's picture
Upload 18 files
cf02581 verified
from __future__ import annotations
from typing import Any, Dict, List, Optional
from transformers import PretrainedConfig
DEFAULT_EVENT_CATEGORIES = [
"earnings",
"product",
"macro",
"regulation",
"supply_chain",
"competition",
"other",
]
DEFAULT_COVARIATE_COLUMNS = [
"cov_earnings_count",
"cov_product_count",
"cov_macro_count",
"cov_regulation_count",
"cov_supply_chain_count",
"cov_competition_count",
"cov_other_count",
"cov_sentiment_pos_count",
"cov_sentiment_neg_count",
"cov_sentiment_neu_count",
"cov_news_count",
"cov_sentiment_mean",
"cov_confidence_mean",
"cov_event_score",
]
class JNUTSBConfig(PretrainedConfig):
"""Configuration for the JNU-TSB router wrapper.
The repository stores lightweight code and metadata only. The upstream
models, amazon/chronos-2 and EleutherAI/polyglot-ko-1.3b, are loaded lazily
at runtime when the corresponding route is used.
"""
model_type = "jnu_tsb"
def __init__(
self,
repo_id: str = "HONGRIZON/JNU-TSB",
project_name: str = "JNU-TSB",
project_full_name: str = "Jeju National University Time-Series Bridge",
chronos_model_id: str = "amazon/chronos-2",
llm_model_id: str = "EleutherAI/polyglot-ko-1.3b",
timestamp_column: str = "timestamp",
target_column: str = "target",
id_column: str = "item_id",
default_item_id: str = "series_0",
prediction_length: int = 5,
quantile_levels: Optional[List[float]] = None,
event_categories: Optional[List[str]] = None,
covariate_columns: Optional[List[str]] = None,
use_llm_extractor: bool = True,
allow_naive_fallback: bool = True,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.repo_id = repo_id
self.project_name = project_name
self.project_full_name = project_full_name
self.chronos_model_id = chronos_model_id
self.llm_model_id = llm_model_id
self.timestamp_column = timestamp_column
self.target_column = target_column
self.id_column = id_column
self.default_item_id = default_item_id
self.prediction_length = int(prediction_length)
self.quantile_levels = quantile_levels or [0.1, 0.5, 0.9]
self.event_categories = event_categories or list(DEFAULT_EVENT_CATEGORIES)
self.covariate_columns = covariate_columns or list(DEFAULT_COVARIATE_COLUMNS)
self.use_llm_extractor = bool(use_llm_extractor)
self.allow_naive_fallback = bool(allow_naive_fallback)
def to_router_dict(self) -> Dict[str, Any]:
return {
"repo_id": self.repo_id,
"project_name": self.project_name,
"project_full_name": self.project_full_name,
"chronos_model_id": self.chronos_model_id,
"llm_model_id": self.llm_model_id,
"timestamp_column": self.timestamp_column,
"target_column": self.target_column,
"id_column": self.id_column,
"default_item_id": self.default_item_id,
"prediction_length": self.prediction_length,
"quantile_levels": self.quantile_levels,
"event_categories": self.event_categories,
"covariate_columns": self.covariate_columns,
"use_llm_extractor": self.use_llm_extractor,
"allow_naive_fallback": self.allow_naive_fallback,
}