main.py: add --strict_size (default True) — stops entire run if any .onnx > 1.44MB
Browse files- neurogolf_solver/main.py +31 -1
neurogolf_solver/main.py
CHANGED
|
@@ -9,12 +9,14 @@ Usage:
|
|
| 9 |
"""
|
| 10 |
|
| 11 |
import argparse
|
|
|
|
|
|
|
| 12 |
import time
|
| 13 |
import onnxruntime as ort
|
| 14 |
from .config import get_providers
|
| 15 |
from .data_loader import load_tasks_dir, load_tasks_kaggle
|
| 16 |
from .submission import run_tasks, generate_submission, print_summary
|
| 17 |
-
from .constants import EXCLUDED_TASKS
|
| 18 |
|
| 19 |
try:
|
| 20 |
import wandb
|
|
@@ -22,6 +24,29 @@ except ImportError:
|
|
| 22 |
wandb = None
|
| 23 |
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
def main():
|
| 26 |
parser = argparse.ArgumentParser(description='NeuroGolf Solver v5')
|
| 27 |
parser.add_argument('--data_dir', default='ARC-AGI/data/training/')
|
|
@@ -32,6 +57,7 @@ def main():
|
|
| 32 |
parser.add_argument('--tasks', type=str, default='', help='Comma-separated task numbers')
|
| 33 |
parser.add_argument('--device', type=str, default='auto', choices=['auto', 'cpu', 'cuda'])
|
| 34 |
parser.add_argument('--use_wandb', action='store_true', help='Enable W&B logging')
|
|
|
|
| 35 |
args = parser.parse_args()
|
| 36 |
|
| 37 |
providers = get_providers(args.device)
|
|
@@ -46,6 +72,7 @@ def main():
|
|
| 46 |
|
| 47 |
ort.set_default_logger_severity(3)
|
| 48 |
print(f"Using providers: {providers}")
|
|
|
|
| 49 |
|
| 50 |
# Load tasks
|
| 51 |
if args.kaggle:
|
|
@@ -78,6 +105,9 @@ def main():
|
|
| 78 |
|
| 79 |
elapsed = time.time() - t0
|
| 80 |
|
|
|
|
|
|
|
|
|
|
| 81 |
submission_info = generate_submission(args.output_dir, results, costs_dict, task_nums)
|
| 82 |
print_summary(results, submission_info, elapsed)
|
| 83 |
|
|
|
|
| 9 |
"""
|
| 10 |
|
| 11 |
import argparse
|
| 12 |
+
import os
|
| 13 |
+
import sys
|
| 14 |
import time
|
| 15 |
import onnxruntime as ort
|
| 16 |
from .config import get_providers
|
| 17 |
from .data_loader import load_tasks_dir, load_tasks_kaggle
|
| 18 |
from .submission import run_tasks, generate_submission, print_summary
|
| 19 |
+
from .constants import EXCLUDED_TASKS, MAX_ONNX_FILESIZE
|
| 20 |
|
| 21 |
try:
|
| 22 |
import wandb
|
|
|
|
| 24 |
wandb = None
|
| 25 |
|
| 26 |
|
| 27 |
+
def check_output_sizes(output_dir, strict):
|
| 28 |
+
"""Check all .onnx files are within 1.44MB limit.
|
| 29 |
+
If strict=True and any file is over, print error and exit immediately."""
|
| 30 |
+
oversized = []
|
| 31 |
+
for f in sorted(os.listdir(output_dir)):
|
| 32 |
+
if f.endswith('.onnx'):
|
| 33 |
+
fpath = os.path.join(output_dir, f)
|
| 34 |
+
fsize = os.path.getsize(fpath)
|
| 35 |
+
if fsize > MAX_ONNX_FILESIZE:
|
| 36 |
+
oversized.append((f, fsize))
|
| 37 |
+
if oversized:
|
| 38 |
+
print(f"\n{'!'*70}")
|
| 39 |
+
print(f"FATAL: {len(oversized)} .onnx files exceed 1.44MB limit:")
|
| 40 |
+
for f, sz in oversized:
|
| 41 |
+
print(f" {f}: {sz:,} bytes ({sz/1024:.1f} KB) — OVER by {sz - MAX_ONNX_FILESIZE:,} bytes")
|
| 42 |
+
print(f"Kaggle WILL reject this submission.")
|
| 43 |
+
print(f"{'!'*70}")
|
| 44 |
+
if strict:
|
| 45 |
+
print("Stopping (--strict_size is on). Fix oversized models before submitting.")
|
| 46 |
+
sys.exit(1)
|
| 47 |
+
return oversized
|
| 48 |
+
|
| 49 |
+
|
| 50 |
def main():
|
| 51 |
parser = argparse.ArgumentParser(description='NeuroGolf Solver v5')
|
| 52 |
parser.add_argument('--data_dir', default='ARC-AGI/data/training/')
|
|
|
|
| 57 |
parser.add_argument('--tasks', type=str, default='', help='Comma-separated task numbers')
|
| 58 |
parser.add_argument('--device', type=str, default='auto', choices=['auto', 'cpu', 'cuda'])
|
| 59 |
parser.add_argument('--use_wandb', action='store_true', help='Enable W&B logging')
|
| 60 |
+
parser.add_argument('--strict_size', type=bool, default=True, help='Stop run if any .onnx > 1.44MB (default: True)')
|
| 61 |
args = parser.parse_args()
|
| 62 |
|
| 63 |
providers = get_providers(args.device)
|
|
|
|
| 72 |
|
| 73 |
ort.set_default_logger_severity(3)
|
| 74 |
print(f"Using providers: {providers}")
|
| 75 |
+
print(f"Strict size check: {args.strict_size} (max {MAX_ONNX_FILESIZE:,} bytes per .onnx)")
|
| 76 |
|
| 77 |
# Load tasks
|
| 78 |
if args.kaggle:
|
|
|
|
| 105 |
|
| 106 |
elapsed = time.time() - t0
|
| 107 |
|
| 108 |
+
# Check all output files BEFORE generating submission
|
| 109 |
+
check_output_sizes(args.output_dir, args.strict_size)
|
| 110 |
+
|
| 111 |
submission_info = generate_submission(args.output_dir, results, costs_dict, task_nums)
|
| 112 |
print_summary(results, submission_info, elapsed)
|
| 113 |
|