j-js commited on
Commit
7834040
·
verified ·
1 Parent(s): 020d9c6

Update config.py

Browse files
Files changed (1) hide show
  1. config.py +26 -36
config.py CHANGED
@@ -4,66 +4,56 @@ import os
4
  from dataclasses import dataclass
5
  from pathlib import Path
6
 
7
-
8
  BASE_DIR = Path(__file__).resolve().parent
9
  DATA_DIR = BASE_DIR / "data"
10
  LOCAL_LOG_DIR = BASE_DIR / "logs"
11
 
12
 
 
 
 
 
13
  @dataclass(frozen=True)
14
  class Settings:
 
15
  app_name: str = os.getenv("APP_NAME", "Trading Game Study AI")
16
  app_version: str = os.getenv("APP_VERSION", "2.0.0")
17
  port: int = int(os.getenv("PORT", "7860"))
18
 
19
- embedding_model: str = os.getenv(
20
- "EMBEDDING_MODEL",
21
- "sentence-transformers/all-MiniLM-L6-v2",
22
- )
23
- cross_encoder_model: str = os.getenv(
24
- "CROSS_ENCODER_MODEL",
25
- "cross-encoder/ms-marco-MiniLM-L-6-v2",
26
- )
27
- generator_model: str = os.getenv(
28
- "GENERATOR_MODEL",
29
- "google/flan-t5-small",
30
- )
31
- generator_task: str = os.getenv(
32
- "GENERATOR_TASK",
33
- "text2text-generation",
34
- )
35
  generator_max_new_tokens: int = int(os.getenv("GENERATOR_MAX_NEW_TOKENS", "220"))
36
  generator_temperature: float = float(os.getenv("GENERATOR_TEMPERATURE", "0.6"))
37
  generator_top_p: float = float(os.getenv("GENERATOR_TOP_P", "0.9"))
38
- generator_do_sample: bool = os.getenv("GENERATOR_DO_SAMPLE", "1") == "1"
39
 
40
- local_chunks_path: str = os.getenv(
41
- "LOCAL_CHUNKS_PATH",
42
- str(DATA_DIR / "gmat_hf_chunks.jsonl"),
43
- )
44
- question_seed_path: str = os.getenv(
45
- "QUESTION_SEED_PATH",
46
- str(DATA_DIR / "gmat_question_seed.jsonl"),
47
- )
48
- topic_index_path: str = os.getenv(
49
- "TOPIC_INDEX_PATH",
50
- str(DATA_DIR / "gmat_topic_index.json"),
51
- )
52
  dataset_repo_id: str = os.getenv("DATASET_REPO_ID", "j-js/gmat-quant-corpus")
53
  dataset_split: str = os.getenv("DATASET_SPLIT", "train")
54
  retrieval_k: int = int(os.getenv("RETRIEVAL_K", "8"))
55
  rerank_k: int = int(os.getenv("RERANK_K", "4"))
56
  max_chunks_to_show: int = int(os.getenv("MAX_CHUNKS_TO_SHOW", "3"))
57
  max_reply_chars: int = int(os.getenv("MAX_REPLY_CHARS", "1600"))
58
- enable_remote_dataset_fallback: bool = os.getenv("ENABLE_REMOTE_DATASET_FALLBACK", "1") == "1"
59
 
 
60
  local_log_dir: str = os.getenv("LOCAL_LOG_DIR", str(LOCAL_LOG_DIR))
 
 
 
 
 
 
61
  ingest_api_key: str = os.getenv("INGEST_API_KEY", "")
62
  research_api_key: str = os.getenv("RESEARCH_API_KEY", "")
63
- hf_token: str = os.getenv("HF_TOKEN", "")
64
- log_dataset_repo_id: str = os.getenv("LOG_DATASET_REPO_ID", "")
65
- log_dataset_private: bool = os.getenv("LOG_DATASET_PRIVATE", "1") == "1"
66
- push_logs_to_hub: bool = os.getenv("PUSH_LOGS_TO_HUB", "0") == "1"
67
 
68
 
69
- settings = Settings()
 
4
  from dataclasses import dataclass
5
  from pathlib import Path
6
 
 
7
  BASE_DIR = Path(__file__).resolve().parent
8
  DATA_DIR = BASE_DIR / "data"
9
  LOCAL_LOG_DIR = BASE_DIR / "logs"
10
 
11
 
12
+ def env_bool(name: str, default: bool = False) -> bool:
13
+ return os.getenv(name, "1" if default else "0").strip().lower() in {"1", "true", "yes", "on"}
14
+
15
+
16
  @dataclass(frozen=True)
17
  class Settings:
18
+ # App
19
  app_name: str = os.getenv("APP_NAME", "Trading Game Study AI")
20
  app_version: str = os.getenv("APP_VERSION", "2.0.0")
21
  port: int = int(os.getenv("PORT", "7860"))
22
 
23
+ # Models
24
+ embedding_model: str = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
25
+ cross_encoder_model: str = os.getenv("CROSS_ENCODER_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2")
26
+ generator_model: str = os.getenv("GENERATOR_MODEL", "google/flan-t5-small")
27
+ generator_task: str = os.getenv("GENERATOR_TASK", "text2text-generation")
 
 
 
 
 
 
 
 
 
 
 
28
  generator_max_new_tokens: int = int(os.getenv("GENERATOR_MAX_NEW_TOKENS", "220"))
29
  generator_temperature: float = float(os.getenv("GENERATOR_TEMPERATURE", "0.6"))
30
  generator_top_p: float = float(os.getenv("GENERATOR_TOP_P", "0.9"))
31
+ generator_do_sample: bool = env_bool("GENERATOR_DO_SAMPLE", True)
32
 
33
+ # Local data
34
+ local_chunks_path: str = os.getenv("LOCAL_CHUNKS_PATH", str(DATA_DIR / "gmat_hf_chunks.jsonl"))
35
+ question_seed_path: str = os.getenv("QUESTION_SEED_PATH", str(DATA_DIR / "gmat_question_seed.jsonl"))
36
+ topic_index_path: str = os.getenv("TOPIC_INDEX_PATH", str(DATA_DIR / "gmat_topic_index.json"))
37
+
38
+ # Retrieval
 
 
 
 
 
 
39
  dataset_repo_id: str = os.getenv("DATASET_REPO_ID", "j-js/gmat-quant-corpus")
40
  dataset_split: str = os.getenv("DATASET_SPLIT", "train")
41
  retrieval_k: int = int(os.getenv("RETRIEVAL_K", "8"))
42
  rerank_k: int = int(os.getenv("RERANK_K", "4"))
43
  max_chunks_to_show: int = int(os.getenv("MAX_CHUNKS_TO_SHOW", "3"))
44
  max_reply_chars: int = int(os.getenv("MAX_REPLY_CHARS", "1600"))
45
+ enable_remote_dataset_fallback: bool = env_bool("ENABLE_REMOTE_DATASET_FALLBACK", True)
46
 
47
+ # Logging
48
  local_log_dir: str = os.getenv("LOCAL_LOG_DIR", str(LOCAL_LOG_DIR))
49
+ push_logs_to_hub: bool = env_bool("PUSH_LOGS_TO_HUB", False)
50
+ log_dataset_repo_id: str = os.getenv("LOG_DATASET_REPO_ID", "")
51
+ log_dataset_private: bool = env_bool("LOG_DATASET_PRIVATE", True)
52
+
53
+ # Secrets
54
+ hf_token: str = os.getenv("HF_TOKEN", "")
55
  ingest_api_key: str = os.getenv("INGEST_API_KEY", "")
56
  research_api_key: str = os.getenv("RESEARCH_API_KEY", "")
 
 
 
 
57
 
58
 
59
+ settings = Settings()