rogermt commited on
Commit
f6b5eb9
·
verified ·
1 Parent(s): 1c86eb5

Fix 4: main.py — add --strict_score, stop if any model unscorable

Browse files
Files changed (1) hide show
  1. neurogolf_solver/main.py +32 -20
neurogolf_solver/main.py CHANGED
@@ -16,6 +16,7 @@ import onnxruntime as ort
16
  from .config import get_providers
17
  from .data_loader import load_tasks_dir, load_tasks_kaggle
18
  from .submission import run_tasks, generate_submission, print_summary
 
19
  from .constants import EXCLUDED_TASKS, MAX_ONNX_FILESIZE
20
 
21
  try:
@@ -24,27 +25,36 @@ except ImportError:
24
  wandb = None
25
 
26
 
27
- def check_output_sizes(output_dir, strict):
28
- """Check all .onnx files are within 1.44MB limit.
29
- If strict=True and any file is over, print error and exit immediately."""
30
- oversized = []
31
  for f in sorted(os.listdir(output_dir)):
32
- if f.endswith('.onnx'):
33
- fpath = os.path.join(output_dir, f)
34
- fsize = os.path.getsize(fpath)
35
- if fsize > MAX_ONNX_FILESIZE:
36
- oversized.append((f, fsize))
37
- if oversized:
 
 
 
 
 
 
 
38
  print(f"\n{'!'*70}")
39
- print(f"FATAL: {len(oversized)} .onnx files exceed 1.44MB limit:")
40
- for f, sz in oversized:
41
- print(f" {f}: {sz:,} bytes ({sz/1024:.1f} KB) — OVER by {sz - MAX_ONNX_FILESIZE:,} bytes")
42
- print(f"Kaggle WILL reject this submission.")
43
  print(f"{'!'*70}")
44
- if strict:
45
- print("Stopping (--strict_size is on). Fix oversized models before submitting.")
46
  sys.exit(1)
47
- return oversized
 
 
 
48
 
49
 
50
  def main():
@@ -57,7 +67,8 @@ def main():
57
  parser.add_argument('--tasks', type=str, default='', help='Comma-separated task numbers')
58
  parser.add_argument('--device', type=str, default='auto', choices=['auto', 'cpu', 'cuda'])
59
  parser.add_argument('--use_wandb', action='store_true', help='Enable W&B logging')
60
- parser.add_argument('--strict_size', type=bool, default=True, help='Stop run if any .onnx > 1.44MB (default: True)')
 
61
  args = parser.parse_args()
62
 
63
  providers = get_providers(args.device)
@@ -72,7 +83,8 @@ def main():
72
 
73
  ort.set_default_logger_severity(3)
74
  print(f"Using providers: {providers}")
75
- print(f"Strict size check: {args.strict_size} (max {MAX_ONNX_FILESIZE:,} bytes per .onnx)")
 
76
 
77
  # Load tasks
78
  if args.kaggle:
@@ -106,7 +118,7 @@ def main():
106
  elapsed = time.time() - t0
107
 
108
  # Check all output files BEFORE generating submission
109
- check_output_sizes(args.output_dir, args.strict_size)
110
 
111
  submission_info = generate_submission(args.output_dir, results, costs_dict, task_nums)
112
  print_summary(results, submission_info, elapsed)
 
16
  from .config import get_providers
17
  from .data_loader import load_tasks_dir, load_tasks_kaggle
18
  from .submission import run_tasks, generate_submission, print_summary
19
+ from .profiler import score_network
20
  from .constants import EXCLUDED_TASKS, MAX_ONNX_FILESIZE
21
 
22
  try:
 
25
  wandb = None
26
 
27
 
28
+ def check_all_models(output_dir, strict_size, strict_score):
29
+ """Check all .onnx files for size limit and scoreability.
30
+ If strict and any fail, exit immediately."""
31
+ problems = []
32
  for f in sorted(os.listdir(output_dir)):
33
+ if not f.endswith('.onnx'):
34
+ continue
35
+ fpath = os.path.join(output_dir, f)
36
+ fsize = os.path.getsize(fpath)
37
+
38
+ if fsize > MAX_ONNX_FILESIZE:
39
+ problems.append((f, f"OVERSIZED: {fsize:,} bytes ({fsize/1024:.1f} KB) > {MAX_ONNX_FILESIZE:,}"))
40
+
41
+ macs, memory, params = score_network(fpath)
42
+ if macs is None or memory is None or params is None:
43
+ problems.append((f, "UNSCORABLE: score_network returned (None, None, None) — Kaggle will reject"))
44
+
45
+ if problems:
46
  print(f"\n{'!'*70}")
47
+ print(f"PROBLEMS FOUND: {len(problems)} .onnx files will be REJECTED by Kaggle:")
48
+ for f, msg in problems:
49
+ print(f" {f}: {msg}")
 
50
  print(f"{'!'*70}")
51
+ if strict_size or strict_score:
52
+ print("Stopping. Fix these models before submitting.")
53
  sys.exit(1)
54
+ else:
55
+ print(f"\nAll .onnx files pass size and score checks.")
56
+
57
+ return problems
58
 
59
 
60
  def main():
 
67
  parser.add_argument('--tasks', type=str, default='', help='Comma-separated task numbers')
68
  parser.add_argument('--device', type=str, default='auto', choices=['auto', 'cpu', 'cuda'])
69
  parser.add_argument('--use_wandb', action='store_true', help='Enable W&B logging')
70
+ parser.add_argument('--strict_size', type=bool, default=True, help='Stop if any .onnx > 1.44MB (default: True)')
71
+ parser.add_argument('--strict_score', type=bool, default=True, help='Stop if any model unscorable (default: True)')
72
  args = parser.parse_args()
73
 
74
  providers = get_providers(args.device)
 
83
 
84
  ort.set_default_logger_severity(3)
85
  print(f"Using providers: {providers}")
86
+ print(f"Strict size: {args.strict_size} | Strict score: {args.strict_score}")
87
+ print(f"Max .onnx file size: {MAX_ONNX_FILESIZE:,} bytes")
88
 
89
  # Load tasks
90
  if args.kaggle:
 
118
  elapsed = time.time() - t0
119
 
120
  # Check all output files BEFORE generating submission
121
+ check_all_models(args.output_dir, args.strict_size, args.strict_score)
122
 
123
  submission_info = generate_submission(args.output_dir, results, costs_dict, task_nums)
124
  print_summary(results, submission_info, elapsed)