gaurv007 commited on
Commit
a2396a0
·
verified ·
1 Parent(s): 4c4a751

Upload alpha_factory/run.py

Browse files
Files changed (1) hide show
  1. alpha_factory/run.py +61 -7
alpha_factory/run.py CHANGED
@@ -1,8 +1,11 @@
1
  """
2
- Alpha Factory — Entry Point v3
3
  Single asyncio event loop, no session leaks.
 
4
  Run: python -m alpha_factory.run [--dry-run] [--batch-size N] [--interactive]
5
  [--proven] [--enable-brain]
 
 
6
  """
7
  import os
8
  import asyncio
@@ -18,14 +21,27 @@ except ImportError:
18
  from rich.console import Console
19
  from .config import load_config
20
  from .infra import ModelManager, interactive_model_select, LLMClient
 
21
  from .orchestration import AlphaPipeline
22
 
23
  console = Console()
24
 
25
 
26
- async def setup_models(interactive: bool = False, hf_token: str = None) -> ModelManager:
27
- """Discover models and optionally let user pick interactively."""
28
- manager = ModelManager(hf_token=hf_token)
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  console.print("\n[bold]🔍 Discovering available models...[/]")
31
  await manager.discover_all()
@@ -34,6 +50,24 @@ async def setup_models(interactive: bool = False, hf_token: str = None) -> Model
34
  selections = interactive_model_select(manager)
35
  for tier, model in selections.items():
36
  manager.select_model(tier, model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  else:
38
  manager.auto_assign_defaults()
39
 
@@ -41,9 +75,13 @@ async def setup_models(interactive: bool = False, hf_token: str = None) -> Model
41
  return manager
42
 
43
 
44
- async def run_pipeline(config, ollama_url: str = "http://localhost:11434", manager: ModelManager | None = None):
 
 
 
 
45
  """Run the full pipeline under a single event loop.
46
-
47
  Args:
48
  config: Pipeline configuration
49
  ollama_url: Base URL for Ollama (default: http://localhost:11434)
@@ -84,11 +122,16 @@ def main():
84
  parser = argparse.ArgumentParser(description="Alpha Factory — LLM-Driven Alpha Generation Pipeline")
85
  parser.add_argument("--dry-run", action="store_true", help="Run without BRAIN submissions")
86
  parser.add_argument("--batch-size", type=int, default=10, help="Number of candidates per batch")
87
- parser.add_argument("--interactive", action="store_true", help="Interactively select models")
88
  parser.add_argument("--hf-token", type=str, default=None, help="HuggingFace API token (or set HF_TOKEN env)")
89
  parser.add_argument("--ollama-url", type=str, default="http://localhost:11434", help="Ollama server URL")
90
  parser.add_argument("--proven", action="store_true", help="Use proven templates (no LLM, deterministic generation)")
91
  parser.add_argument("--enable-brain", action="store_true", help="Enable live BRAIN submission (requires BRAIN_SESSION_TOKEN)")
 
 
 
 
 
92
  args = parser.parse_args()
93
 
94
  config = load_config()
@@ -102,6 +145,15 @@ def main():
102
  # Resolve HF token: CLI arg > env var (loaded from .env)
103
  hf_token = args.hf_token or os.getenv("HF_TOKEN")
104
 
 
 
 
 
 
 
 
 
 
105
  mode_str = "PROVEN TEMPLATES" if args.proven else "LLM GENERATION"
106
  brain_str = "LIVE (BRAIN submissions)" if config.enable_brain_client else "DRY RUN"
107
 
@@ -131,6 +183,8 @@ def main():
131
  manager = await setup_models(
132
  interactive=args.interactive,
133
  hf_token=hf_token,
 
 
134
  )
135
  return await run_pipeline(config, ollama_url=args.ollama_url, manager=manager)
136
  asyncio.run(_discover_and_run())
 
1
  """
2
+ Alpha Factory — Entry Point v4
3
  Single asyncio event loop, no session leaks.
4
+ Per-tier model selection via CLI or UI.
5
  Run: python -m alpha_factory.run [--dry-run] [--batch-size N] [--interactive]
6
  [--proven] [--enable-brain]
7
+ [--microfish MODEL] [--tinyfish MODEL]
8
+ [--mediumfish MODEL] [--bigfish MODEL]
9
  """
10
  import os
11
  import asyncio
 
21
  from rich.console import Console
22
  from .config import load_config
23
  from .infra import ModelManager, interactive_model_select, LLMClient
24
+ from .infra.model_manager import ModelInfo, ModelProvider
25
  from .orchestration import AlphaPipeline
26
 
27
  console = Console()
28
 
29
 
30
+ async def setup_models(
31
+ interactive: bool = False,
32
+ hf_token: str = None,
33
+ ollama_url: str = "http://localhost:11434",
34
+ per_tier: dict[str, str] | None = None,
35
+ ) -> ModelManager:
36
+ """Discover models and optionally let user pick interactively.
37
+
38
+ Args:
39
+ interactive: Open CLI prompt to pick per-tier models.
40
+ hf_token: HuggingFace token for cloud model validation.
41
+ ollama_url: Ollama server URL for local model discovery.
42
+ per_tier: Pre-selected model names (e.g. {"microfish": "qwen2.5:1.5b"}).
43
+ """
44
+ manager = ModelManager(ollama_url=ollama_url, hf_token=hf_token)
45
 
46
  console.print("\n[bold]🔍 Discovering available models...[/]")
47
  await manager.discover_all()
 
50
  selections = interactive_model_select(manager)
51
  for tier, model in selections.items():
52
  manager.select_model(tier, model)
53
+ elif per_tier:
54
+ # Resolve per-tier model names to ModelInfo objects
55
+ all_models = manager.get_all_models()
56
+ for tier, model_name in per_tier.items():
57
+ if not model_name:
58
+ continue
59
+ # Try exact match first
60
+ matched = [m for m in all_models if m.name == model_name]
61
+ if not matched:
62
+ # Try case-insensitive partial match
63
+ matched = [m for m in all_models if model_name.lower() in m.name.lower()]
64
+ if matched:
65
+ manager.select_model(tier, matched[0])
66
+ console.print(f" [green]{tier}[/]: {matched[0].display_name()}")
67
+ else:
68
+ console.print(f" [yellow]{tier}[/]: '{model_name}' not found — using default")
69
+ # Fill remaining tiers with defaults
70
+ manager.auto_assign_defaults()
71
  else:
72
  manager.auto_assign_defaults()
73
 
 
75
  return manager
76
 
77
 
78
+ async def run_pipeline(
79
+ config,
80
+ ollama_url: str = "http://localhost:11434",
81
+ manager: ModelManager | None = None,
82
+ ):
83
  """Run the full pipeline under a single event loop.
84
+
85
  Args:
86
  config: Pipeline configuration
87
  ollama_url: Base URL for Ollama (default: http://localhost:11434)
 
122
  parser = argparse.ArgumentParser(description="Alpha Factory — LLM-Driven Alpha Generation Pipeline")
123
  parser.add_argument("--dry-run", action="store_true", help="Run without BRAIN submissions")
124
  parser.add_argument("--batch-size", type=int, default=10, help="Number of candidates per batch")
125
+ parser.add_argument("--interactive", action="store_true", help="Interactively select models (CLI)")
126
  parser.add_argument("--hf-token", type=str, default=None, help="HuggingFace API token (or set HF_TOKEN env)")
127
  parser.add_argument("--ollama-url", type=str, default="http://localhost:11434", help="Ollama server URL")
128
  parser.add_argument("--proven", action="store_true", help="Use proven templates (no LLM, deterministic generation)")
129
  parser.add_argument("--enable-brain", action="store_true", help="Enable live BRAIN submission (requires BRAIN_SESSION_TOKEN)")
130
+ # Per-tier model overrides
131
+ parser.add_argument("--microfish", type=str, default=None, help="Model for hypothesis generation (bulk)")
132
+ parser.add_argument("--tinyfish", type=str, default=None, help="Model for expression compilation")
133
+ parser.add_argument("--mediumfish", type=str, default=None, help="Model for crowd scout + performance surgeon")
134
+ parser.add_argument("--bigfish", type=str, default=None, help="Model for gatekeeper (final memo)")
135
  args = parser.parse_args()
136
 
137
  config = load_config()
 
145
  # Resolve HF token: CLI arg > env var (loaded from .env)
146
  hf_token = args.hf_token or os.getenv("HF_TOKEN")
147
 
148
+ # Collect per-tier selections from CLI
149
+ per_tier = {
150
+ "microfish": args.microfish,
151
+ "tinyfish": args.tinyfish,
152
+ "mediumfish": args.mediumfish,
153
+ "bigfish": args.bigfish,
154
+ }
155
+ has_per_tier = any(v is not None for v in per_tier.values())
156
+
157
  mode_str = "PROVEN TEMPLATES" if args.proven else "LLM GENERATION"
158
  brain_str = "LIVE (BRAIN submissions)" if config.enable_brain_client else "DRY RUN"
159
 
 
183
  manager = await setup_models(
184
  interactive=args.interactive,
185
  hf_token=hf_token,
186
+ ollama_url=args.ollama_url,
187
+ per_tier=per_tier if has_per_tier else None,
188
  )
189
  return await run_pipeline(config, ollama_url=args.ollama_url, manager=manager)
190
  asyncio.run(_discover_and_run())