Update config.py
Browse files
config.py
CHANGED
|
@@ -4,66 +4,56 @@ import os
|
|
| 4 |
from dataclasses import dataclass
|
| 5 |
from pathlib import Path
|
| 6 |
|
| 7 |
-
|
| 8 |
BASE_DIR = Path(__file__).resolve().parent
|
| 9 |
DATA_DIR = BASE_DIR / "data"
|
| 10 |
LOCAL_LOG_DIR = BASE_DIR / "logs"
|
| 11 |
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
@dataclass(frozen=True)
|
| 14 |
class Settings:
|
|
|
|
| 15 |
app_name: str = os.getenv("APP_NAME", "Trading Game Study AI")
|
| 16 |
app_version: str = os.getenv("APP_VERSION", "2.0.0")
|
| 17 |
port: int = int(os.getenv("PORT", "7860"))
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
)
|
| 23 |
-
|
| 24 |
-
"CROSS_ENCODER_MODEL",
|
| 25 |
-
"cross-encoder/ms-marco-MiniLM-L-6-v2",
|
| 26 |
-
)
|
| 27 |
-
generator_model: str = os.getenv(
|
| 28 |
-
"GENERATOR_MODEL",
|
| 29 |
-
"google/flan-t5-small",
|
| 30 |
-
)
|
| 31 |
-
generator_task: str = os.getenv(
|
| 32 |
-
"GENERATOR_TASK",
|
| 33 |
-
"text2text-generation",
|
| 34 |
-
)
|
| 35 |
generator_max_new_tokens: int = int(os.getenv("GENERATOR_MAX_NEW_TOKENS", "220"))
|
| 36 |
generator_temperature: float = float(os.getenv("GENERATOR_TEMPERATURE", "0.6"))
|
| 37 |
generator_top_p: float = float(os.getenv("GENERATOR_TOP_P", "0.9"))
|
| 38 |
-
generator_do_sample: bool =
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
)
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
str(DATA_DIR / "gmat_question_seed.jsonl"),
|
| 47 |
-
)
|
| 48 |
-
topic_index_path: str = os.getenv(
|
| 49 |
-
"TOPIC_INDEX_PATH",
|
| 50 |
-
str(DATA_DIR / "gmat_topic_index.json"),
|
| 51 |
-
)
|
| 52 |
dataset_repo_id: str = os.getenv("DATASET_REPO_ID", "j-js/gmat-quant-corpus")
|
| 53 |
dataset_split: str = os.getenv("DATASET_SPLIT", "train")
|
| 54 |
retrieval_k: int = int(os.getenv("RETRIEVAL_K", "8"))
|
| 55 |
rerank_k: int = int(os.getenv("RERANK_K", "4"))
|
| 56 |
max_chunks_to_show: int = int(os.getenv("MAX_CHUNKS_TO_SHOW", "3"))
|
| 57 |
max_reply_chars: int = int(os.getenv("MAX_REPLY_CHARS", "1600"))
|
| 58 |
-
enable_remote_dataset_fallback: bool =
|
| 59 |
|
|
|
|
| 60 |
local_log_dir: str = os.getenv("LOCAL_LOG_DIR", str(LOCAL_LOG_DIR))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
ingest_api_key: str = os.getenv("INGEST_API_KEY", "")
|
| 62 |
research_api_key: str = os.getenv("RESEARCH_API_KEY", "")
|
| 63 |
-
hf_token: str = os.getenv("HF_TOKEN", "")
|
| 64 |
-
log_dataset_repo_id: str = os.getenv("LOG_DATASET_REPO_ID", "")
|
| 65 |
-
log_dataset_private: bool = os.getenv("LOG_DATASET_PRIVATE", "1") == "1"
|
| 66 |
-
push_logs_to_hub: bool = os.getenv("PUSH_LOGS_TO_HUB", "0") == "1"
|
| 67 |
|
| 68 |
|
| 69 |
-
settings = Settings()
|
|
|
|
| 4 |
from dataclasses import dataclass
|
| 5 |
from pathlib import Path
|
| 6 |
|
|
|
|
| 7 |
BASE_DIR = Path(__file__).resolve().parent
|
| 8 |
DATA_DIR = BASE_DIR / "data"
|
| 9 |
LOCAL_LOG_DIR = BASE_DIR / "logs"
|
| 10 |
|
| 11 |
|
| 12 |
+
def env_bool(name: str, default: bool = False) -> bool:
|
| 13 |
+
return os.getenv(name, "1" if default else "0").strip().lower() in {"1", "true", "yes", "on"}
|
| 14 |
+
|
| 15 |
+
|
| 16 |
@dataclass(frozen=True)
|
| 17 |
class Settings:
|
| 18 |
+
# App
|
| 19 |
app_name: str = os.getenv("APP_NAME", "Trading Game Study AI")
|
| 20 |
app_version: str = os.getenv("APP_VERSION", "2.0.0")
|
| 21 |
port: int = int(os.getenv("PORT", "7860"))
|
| 22 |
|
| 23 |
+
# Models
|
| 24 |
+
embedding_model: str = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
|
| 25 |
+
cross_encoder_model: str = os.getenv("CROSS_ENCODER_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2")
|
| 26 |
+
generator_model: str = os.getenv("GENERATOR_MODEL", "google/flan-t5-small")
|
| 27 |
+
generator_task: str = os.getenv("GENERATOR_TASK", "text2text-generation")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
generator_max_new_tokens: int = int(os.getenv("GENERATOR_MAX_NEW_TOKENS", "220"))
|
| 29 |
generator_temperature: float = float(os.getenv("GENERATOR_TEMPERATURE", "0.6"))
|
| 30 |
generator_top_p: float = float(os.getenv("GENERATOR_TOP_P", "0.9"))
|
| 31 |
+
generator_do_sample: bool = env_bool("GENERATOR_DO_SAMPLE", True)
|
| 32 |
|
| 33 |
+
# Local data
|
| 34 |
+
local_chunks_path: str = os.getenv("LOCAL_CHUNKS_PATH", str(DATA_DIR / "gmat_hf_chunks.jsonl"))
|
| 35 |
+
question_seed_path: str = os.getenv("QUESTION_SEED_PATH", str(DATA_DIR / "gmat_question_seed.jsonl"))
|
| 36 |
+
topic_index_path: str = os.getenv("TOPIC_INDEX_PATH", str(DATA_DIR / "gmat_topic_index.json"))
|
| 37 |
+
|
| 38 |
+
# Retrieval
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
dataset_repo_id: str = os.getenv("DATASET_REPO_ID", "j-js/gmat-quant-corpus")
|
| 40 |
dataset_split: str = os.getenv("DATASET_SPLIT", "train")
|
| 41 |
retrieval_k: int = int(os.getenv("RETRIEVAL_K", "8"))
|
| 42 |
rerank_k: int = int(os.getenv("RERANK_K", "4"))
|
| 43 |
max_chunks_to_show: int = int(os.getenv("MAX_CHUNKS_TO_SHOW", "3"))
|
| 44 |
max_reply_chars: int = int(os.getenv("MAX_REPLY_CHARS", "1600"))
|
| 45 |
+
enable_remote_dataset_fallback: bool = env_bool("ENABLE_REMOTE_DATASET_FALLBACK", True)
|
| 46 |
|
| 47 |
+
# Logging
|
| 48 |
local_log_dir: str = os.getenv("LOCAL_LOG_DIR", str(LOCAL_LOG_DIR))
|
| 49 |
+
push_logs_to_hub: bool = env_bool("PUSH_LOGS_TO_HUB", False)
|
| 50 |
+
log_dataset_repo_id: str = os.getenv("LOG_DATASET_REPO_ID", "")
|
| 51 |
+
log_dataset_private: bool = env_bool("LOG_DATASET_PRIVATE", True)
|
| 52 |
+
|
| 53 |
+
# Secrets
|
| 54 |
+
hf_token: str = os.getenv("HF_TOKEN", "")
|
| 55 |
ingest_api_key: str = os.getenv("INGEST_API_KEY", "")
|
| 56 |
research_api_key: str = os.getenv("RESEARCH_API_KEY", "")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
|
| 59 |
+
settings = Settings()
|