rogermt commited on
Commit
2231985
·
verified ·
1 Parent(s): 2b04dc9

main.py: add --strict_size (default True) — stops entire run if any .onnx > 1.44MB

Browse files
Files changed (1) hide show
  1. neurogolf_solver/main.py +31 -1
neurogolf_solver/main.py CHANGED
@@ -9,12 +9,14 @@ Usage:
9
  """
10
 
11
  import argparse
 
 
12
  import time
13
  import onnxruntime as ort
14
  from .config import get_providers
15
  from .data_loader import load_tasks_dir, load_tasks_kaggle
16
  from .submission import run_tasks, generate_submission, print_summary
17
- from .constants import EXCLUDED_TASKS
18
 
19
  try:
20
  import wandb
@@ -22,6 +24,29 @@ except ImportError:
22
  wandb = None
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def main():
26
  parser = argparse.ArgumentParser(description='NeuroGolf Solver v5')
27
  parser.add_argument('--data_dir', default='ARC-AGI/data/training/')
@@ -32,6 +57,7 @@ def main():
32
  parser.add_argument('--tasks', type=str, default='', help='Comma-separated task numbers')
33
  parser.add_argument('--device', type=str, default='auto', choices=['auto', 'cpu', 'cuda'])
34
  parser.add_argument('--use_wandb', action='store_true', help='Enable W&B logging')
 
35
  args = parser.parse_args()
36
 
37
  providers = get_providers(args.device)
@@ -46,6 +72,7 @@ def main():
46
 
47
  ort.set_default_logger_severity(3)
48
  print(f"Using providers: {providers}")
 
49
 
50
  # Load tasks
51
  if args.kaggle:
@@ -78,6 +105,9 @@ def main():
78
 
79
  elapsed = time.time() - t0
80
 
 
 
 
81
  submission_info = generate_submission(args.output_dir, results, costs_dict, task_nums)
82
  print_summary(results, submission_info, elapsed)
83
 
 
9
  """
10
 
11
  import argparse
12
+ import os
13
+ import sys
14
  import time
15
  import onnxruntime as ort
16
  from .config import get_providers
17
  from .data_loader import load_tasks_dir, load_tasks_kaggle
18
  from .submission import run_tasks, generate_submission, print_summary
19
+ from .constants import EXCLUDED_TASKS, MAX_ONNX_FILESIZE
20
 
21
  try:
22
  import wandb
 
24
  wandb = None
25
 
26
 
27
+ def check_output_sizes(output_dir, strict):
28
+ """Check all .onnx files are within 1.44MB limit.
29
+ If strict=True and any file is over, print error and exit immediately."""
30
+ oversized = []
31
+ for f in sorted(os.listdir(output_dir)):
32
+ if f.endswith('.onnx'):
33
+ fpath = os.path.join(output_dir, f)
34
+ fsize = os.path.getsize(fpath)
35
+ if fsize > MAX_ONNX_FILESIZE:
36
+ oversized.append((f, fsize))
37
+ if oversized:
38
+ print(f"\n{'!'*70}")
39
+ print(f"FATAL: {len(oversized)} .onnx files exceed 1.44MB limit:")
40
+ for f, sz in oversized:
41
+ print(f" {f}: {sz:,} bytes ({sz/1024:.1f} KB) — OVER by {sz - MAX_ONNX_FILESIZE:,} bytes")
42
+ print(f"Kaggle WILL reject this submission.")
43
+ print(f"{'!'*70}")
44
+ if strict:
45
+ print("Stopping (--strict_size is on). Fix oversized models before submitting.")
46
+ sys.exit(1)
47
+ return oversized
48
+
49
+
50
  def main():
51
  parser = argparse.ArgumentParser(description='NeuroGolf Solver v5')
52
  parser.add_argument('--data_dir', default='ARC-AGI/data/training/')
 
57
  parser.add_argument('--tasks', type=str, default='', help='Comma-separated task numbers')
58
  parser.add_argument('--device', type=str, default='auto', choices=['auto', 'cpu', 'cuda'])
59
  parser.add_argument('--use_wandb', action='store_true', help='Enable W&B logging')
60
+ parser.add_argument('--strict_size', type=bool, default=True, help='Stop run if any .onnx > 1.44MB (default: True)')
61
  args = parser.parse_args()
62
 
63
  providers = get_providers(args.device)
 
72
 
73
  ort.set_default_logger_severity(3)
74
  print(f"Using providers: {providers}")
75
+ print(f"Strict size check: {args.strict_size} (max {MAX_ONNX_FILESIZE:,} bytes per .onnx)")
76
 
77
  # Load tasks
78
  if args.kaggle:
 
105
 
106
  elapsed = time.time() - t0
107
 
108
+ # Check all output files BEFORE generating submission
109
+ check_output_sizes(args.output_dir, args.strict_size)
110
+
111
  submission_info = generate_submission(args.output_dir, results, costs_dict, task_nums)
112
  print_summary(results, submission_info, elapsed)
113