Upload own-solver/neurogolf_solver/main.py
Browse files
own-solver/neurogolf_solver/main.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
ARC-AGI NeuroGolf Championship - Main Entry Point
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python -m neurogolf_solver.main --data_dir ARC-AGI/data/training/ --output_dir submission
|
| 7 |
+
python -m neurogolf_solver.main --kaggle --output_dir /kaggle/working/submission
|
| 8 |
+
python -m neurogolf_solver.main --data_dir ARC-AGI/data/training/ --arcgen_dir ARC-GEN-100K/ --use_wandb
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import os
|
| 13 |
+
import sys
|
| 14 |
+
import time
|
| 15 |
+
import onnxruntime as ort
|
| 16 |
+
from .config import get_providers
|
| 17 |
+
from .data_loader import load_tasks_dir, load_tasks_kaggle
|
| 18 |
+
from .submission import run_tasks, generate_submission, print_summary
|
| 19 |
+
from .profiler import score_network
|
| 20 |
+
from .constants import EXCLUDED_TASKS, MAX_ONNX_FILESIZE
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
import wandb
|
| 24 |
+
except ImportError:
|
| 25 |
+
wandb = None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def check_all_models(output_dir, strict_size, strict_score):
|
| 29 |
+
"""Check all .onnx files for size limit and scoreability."""
|
| 30 |
+
size_problems = []
|
| 31 |
+
score_problems = []
|
| 32 |
+
for f in sorted(os.listdir(output_dir)):
|
| 33 |
+
if not f.endswith('.onnx'):
|
| 34 |
+
continue
|
| 35 |
+
fpath = os.path.join(output_dir, f)
|
| 36 |
+
fsize = os.path.getsize(fpath)
|
| 37 |
+
if fsize > MAX_ONNX_FILESIZE:
|
| 38 |
+
size_problems.append((f, fsize))
|
| 39 |
+
macs, memory, params = score_network(fpath)
|
| 40 |
+
if macs is None or memory is None or params is None:
|
| 41 |
+
score_problems.append(f)
|
| 42 |
+
if size_problems:
|
| 43 |
+
print(f"\n{'!'*70}")
|
| 44 |
+
print(f"FATAL: {len(size_problems)} .onnx files exceed 1.44MB limit:")
|
| 45 |
+
for f, sz in size_problems:
|
| 46 |
+
print(f" {f}: {sz:,} bytes ({sz/1024:.1f} KB)")
|
| 47 |
+
print(f"{'!'*70}")
|
| 48 |
+
if strict_size:
|
| 49 |
+
sys.exit(1)
|
| 50 |
+
if score_problems:
|
| 51 |
+
print(f"\nWARNING: {len(score_problems)} .onnx files unscorable by onnx_tool:")
|
| 52 |
+
for f in score_problems:
|
| 53 |
+
print(f" {f}")
|
| 54 |
+
if strict_score:
|
| 55 |
+
print("Stopping (--strict_score is on).")
|
| 56 |
+
sys.exit(1)
|
| 57 |
+
if not size_problems and not score_problems:
|
| 58 |
+
print(f"\nAll .onnx files pass size and score checks.")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def main():
|
| 62 |
+
parser = argparse.ArgumentParser(description='NeuroGolf Solver v5')
|
| 63 |
+
parser.add_argument('--data_dir', default='ARC-AGI/data/training/')
|
| 64 |
+
parser.add_argument('--arcgen_dir', default='', help='Path to ARC-GEN-100K/ directory')
|
| 65 |
+
parser.add_argument('--output_dir', default='/kaggle/working/submission')
|
| 66 |
+
parser.add_argument('--kaggle', action='store_true', help='Use Kaggle task format')
|
| 67 |
+
parser.add_argument('--conv_budget', type=float, default=30.0, help='Seconds per conv solver per task')
|
| 68 |
+
parser.add_argument('--tasks', type=str, default='', help='Comma-separated task numbers')
|
| 69 |
+
parser.add_argument('--device', type=str, default='auto', choices=['auto', 'cpu', 'cuda'])
|
| 70 |
+
parser.add_argument('--use_wandb', action='store_true', help='Enable W&B logging')
|
| 71 |
+
parser.add_argument('--strict_size', type=bool, default=True, help='Halt if any .onnx > 1.44MB (default: True)')
|
| 72 |
+
parser.add_argument('--strict_score', type=bool, default=False, help='Halt if any model unscorable (default: False)')
|
| 73 |
+
args = parser.parse_args()
|
| 74 |
+
providers = get_providers(args.device)
|
| 75 |
+
config = {"device": args.device, "conv_budget": args.conv_budget, "data_dir": args.data_dir, "arcgen_dir": args.arcgen_dir, "tasks": args.tasks}
|
| 76 |
+
ort.set_default_logger_severity(3)
|
| 77 |
+
print(f"Using providers: {providers}")
|
| 78 |
+
print(f"Strict size: {args.strict_size} | Strict score: {args.strict_score}")
|
| 79 |
+
print(f"Max .onnx file size: {MAX_ONNX_FILESIZE:,} bytes")
|
| 80 |
+
if args.kaggle:
|
| 81 |
+
tasks = load_tasks_kaggle(args.data_dir)
|
| 82 |
+
else:
|
| 83 |
+
arcgen = args.arcgen_dir if args.arcgen_dir else None
|
| 84 |
+
tasks = load_tasks_dir(args.data_dir, arcgen_dir=arcgen)
|
| 85 |
+
total_arcgen = sum(len(t['data'].get('arc-gen', [])) for t in tasks.values())
|
| 86 |
+
print(f"Loaded {len(tasks)} tasks ({total_arcgen} ARC-GEN examples)")
|
| 87 |
+
task_nums = [int(t) for t in args.tasks.split(',')] if args.tasks else sorted(tasks.keys())
|
| 88 |
+
print(f"Solving {len(task_nums)} tasks")
|
| 89 |
+
print(f"Conv budget: {args.conv_budget}s per task")
|
| 90 |
+
print("=" * 70)
|
| 91 |
+
t0 = time.time()
|
| 92 |
+
if args.use_wandb and wandb is not None:
|
| 93 |
+
with wandb.init(project="neurogolf", name="solver_run", config=config):
|
| 94 |
+
results, costs_dict, total_score = run_tasks(task_nums, tasks, args.output_dir, providers, args.conv_budget, EXCLUDED_TASKS, use_wandb=True)
|
| 95 |
+
else:
|
| 96 |
+
results, costs_dict, total_score = run_tasks(task_nums, tasks, args.output_dir, providers, args.conv_budget, EXCLUDED_TASKS, use_wandb=False)
|
| 97 |
+
elapsed = time.time() - t0
|
| 98 |
+
check_all_models(args.output_dir, args.strict_size, args.strict_score)
|
| 99 |
+
submission_info = generate_submission(args.output_dir, results, costs_dict, task_nums)
|
| 100 |
+
print_summary(results, submission_info, elapsed)
|
| 101 |
+
|
| 102 |
+
if __name__ == '__main__':
|
| 103 |
+
main()
|