v5 refactor: add main.py entry point (with W&B init)
Browse files- neurogolf_solver/main.py +88 -0
neurogolf_solver/main.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
ARC-AGI NeuroGolf Championship - Main Entry Point
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python -m neurogolf_solver.main --data_dir ARC-AGI/data/training/ --output_dir submission
|
| 7 |
+
python -m neurogolf_solver.main --kaggle --output_dir /kaggle/working/submission
|
| 8 |
+
python -m neurogolf_solver.main --data_dir ARC-AGI/data/training/ --arcgen_dir ARC-GEN-100K/ --use_wandb
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import time
|
| 13 |
+
import onnxruntime as ort
|
| 14 |
+
from .config import get_providers
|
| 15 |
+
from .data_loader import load_tasks_dir, load_tasks_kaggle
|
| 16 |
+
from .submission import run_tasks, generate_submission, print_summary
|
| 17 |
+
from .constants import EXCLUDED_TASKS
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
import wandb
|
| 21 |
+
except ImportError:
|
| 22 |
+
wandb = None
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def main():
|
| 26 |
+
parser = argparse.ArgumentParser(description='NeuroGolf Solver v5')
|
| 27 |
+
parser.add_argument('--data_dir', default='ARC-AGI/data/training/')
|
| 28 |
+
parser.add_argument('--arcgen_dir', default='', help='Path to ARC-GEN-100K/ directory')
|
| 29 |
+
parser.add_argument('--output_dir', default='/kaggle/working/submission')
|
| 30 |
+
parser.add_argument('--kaggle', action='store_true', help='Use Kaggle task format')
|
| 31 |
+
parser.add_argument('--conv_budget', type=float, default=30.0, help='Seconds per conv solver per task')
|
| 32 |
+
parser.add_argument('--tasks', type=str, default='', help='Comma-separated task numbers')
|
| 33 |
+
parser.add_argument('--device', type=str, default='auto', choices=['auto', 'cpu', 'cuda'])
|
| 34 |
+
parser.add_argument('--use_wandb', action='store_true', help='Enable W&B logging')
|
| 35 |
+
args = parser.parse_args()
|
| 36 |
+
|
| 37 |
+
providers = get_providers(args.device)
|
| 38 |
+
|
| 39 |
+
config = {
|
| 40 |
+
"device": args.device,
|
| 41 |
+
"conv_budget": args.conv_budget,
|
| 42 |
+
"data_dir": args.data_dir,
|
| 43 |
+
"arcgen_dir": args.arcgen_dir,
|
| 44 |
+
"tasks": args.tasks,
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
ort.set_default_logger_severity(3)
|
| 48 |
+
print(f"Using providers: {providers}")
|
| 49 |
+
|
| 50 |
+
# Load tasks
|
| 51 |
+
if args.kaggle:
|
| 52 |
+
tasks = load_tasks_kaggle(args.data_dir)
|
| 53 |
+
else:
|
| 54 |
+
arcgen = args.arcgen_dir if args.arcgen_dir else None
|
| 55 |
+
tasks = load_tasks_dir(args.data_dir, arcgen_dir=arcgen)
|
| 56 |
+
|
| 57 |
+
total_arcgen = sum(len(t['data'].get('arc-gen', [])) for t in tasks.values())
|
| 58 |
+
print(f"Loaded {len(tasks)} tasks ({total_arcgen} ARC-GEN examples)")
|
| 59 |
+
print(f"Excluded tasks: {sorted(EXCLUDED_TASKS)}")
|
| 60 |
+
|
| 61 |
+
task_nums = [int(t) for t in args.tasks.split(',')] if args.tasks else sorted(tasks.keys())
|
| 62 |
+
active_tasks = [t for t in task_nums if t not in EXCLUDED_TASKS]
|
| 63 |
+
print(f"Solving {len(active_tasks)} active tasks (skipping {len(task_nums) - len(active_tasks)} excluded)")
|
| 64 |
+
print(f"Conv budget: {args.conv_budget}s per task")
|
| 65 |
+
print("=" * 70)
|
| 66 |
+
|
| 67 |
+
t0 = time.time()
|
| 68 |
+
|
| 69 |
+
if args.use_wandb and wandb is not None:
|
| 70 |
+
with wandb.init(project="neurogolf", name="solver_run", config=config):
|
| 71 |
+
results, costs_dict, total_score = run_tasks(
|
| 72 |
+
task_nums, tasks, args.output_dir, providers,
|
| 73 |
+
args.conv_budget, EXCLUDED_TASKS, use_wandb=True
|
| 74 |
+
)
|
| 75 |
+
else:
|
| 76 |
+
results, costs_dict, total_score = run_tasks(
|
| 77 |
+
task_nums, tasks, args.output_dir, providers,
|
| 78 |
+
args.conv_budget, EXCLUDED_TASKS, use_wandb=False
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
elapsed = time.time() - t0
|
| 82 |
+
|
| 83 |
+
submission_info = generate_submission(args.output_dir, results, costs_dict, active_tasks)
|
| 84 |
+
print_summary(results, submission_info, elapsed)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
if __name__ == '__main__':
|
| 88 |
+
main()
|