rogermt commited on
Commit
17e36c1
·
verified ·
1 Parent(s): 5598fb7

Fix: --strict_score default False (warnings only), --strict_size default True (halt on oversized)

Browse files
Files changed (1) hide show
  1. neurogolf_solver/main.py +23 -15
neurogolf_solver/main.py CHANGED
@@ -27,8 +27,10 @@ except ImportError:
27
 
28
  def check_all_models(output_dir, strict_size, strict_score):
29
  """Check all .onnx files for size limit and scoreability.
30
- If strict and any fail, exit immediately."""
31
- problems = []
 
 
32
  for f in sorted(os.listdir(output_dir)):
33
  if not f.endswith('.onnx'):
34
  continue
@@ -36,25 +38,31 @@ def check_all_models(output_dir, strict_size, strict_score):
36
  fsize = os.path.getsize(fpath)
37
 
38
  if fsize > MAX_ONNX_FILESIZE:
39
- problems.append((f, f"OVERSIZED: {fsize:,} bytes ({fsize/1024:.1f} KB) > {MAX_ONNX_FILESIZE:,}"))
40
 
41
  macs, memory, params = score_network(fpath)
42
  if macs is None or memory is None or params is None:
43
- problems.append((f, "UNSCORABLE: score_network returned (None, None, None) — Kaggle will reject"))
44
 
45
- if problems:
46
  print(f"\n{'!'*70}")
47
- print(f"PROBLEMS FOUND: {len(problems)} .onnx files will be REJECTED by Kaggle:")
48
- for f, msg in problems:
49
- print(f" {f}: {msg}")
50
  print(f"{'!'*70}")
51
- if strict_size or strict_score:
52
- print("Stopping. Fix these models before submitting.")
53
  sys.exit(1)
54
- else:
55
- print(f"\nAll .onnx files pass size and score checks.")
56
 
57
- return problems
 
 
 
 
 
 
 
 
 
58
 
59
 
60
  def main():
@@ -67,8 +75,8 @@ def main():
67
  parser.add_argument('--tasks', type=str, default='', help='Comma-separated task numbers')
68
  parser.add_argument('--device', type=str, default='auto', choices=['auto', 'cpu', 'cuda'])
69
  parser.add_argument('--use_wandb', action='store_true', help='Enable W&B logging')
70
- parser.add_argument('--strict_size', type=bool, default=True, help='Stop if any .onnx > 1.44MB (default: True)')
71
- parser.add_argument('--strict_score', type=bool, default=True, help='Stop if any model unscorable (default: True)')
72
  args = parser.parse_args()
73
 
74
  providers = get_providers(args.device)
 
27
 
28
  def check_all_models(output_dir, strict_size, strict_score):
29
  """Check all .onnx files for size limit and scoreability.
30
+ strict_size=True halt on oversized files.
31
+ strict_score=False → warn on unscorable but don't halt."""
32
+ size_problems = []
33
+ score_problems = []
34
  for f in sorted(os.listdir(output_dir)):
35
  if not f.endswith('.onnx'):
36
  continue
 
38
  fsize = os.path.getsize(fpath)
39
 
40
  if fsize > MAX_ONNX_FILESIZE:
41
+ size_problems.append((f, fsize))
42
 
43
  macs, memory, params = score_network(fpath)
44
  if macs is None or memory is None or params is None:
45
+ score_problems.append(f)
46
 
47
+ if size_problems:
48
  print(f"\n{'!'*70}")
49
+ print(f"FATAL: {len(size_problems)} .onnx files exceed 1.44MB limit:")
50
+ for f, sz in size_problems:
51
+ print(f" {f}: {sz:,} bytes ({sz/1024:.1f} KB)")
52
  print(f"{'!'*70}")
53
+ if strict_size:
 
54
  sys.exit(1)
 
 
55
 
56
+ if score_problems:
57
+ print(f"\nWARNING: {len(score_problems)} .onnx files unscorable by onnx_tool:")
58
+ for f in score_problems:
59
+ print(f" {f}")
60
+ if strict_score:
61
+ print("Stopping (--strict_score is on).")
62
+ sys.exit(1)
63
+
64
+ if not size_problems and not score_problems:
65
+ print(f"\nAll .onnx files pass size and score checks.")
66
 
67
 
68
  def main():
 
75
  parser.add_argument('--tasks', type=str, default='', help='Comma-separated task numbers')
76
  parser.add_argument('--device', type=str, default='auto', choices=['auto', 'cpu', 'cuda'])
77
  parser.add_argument('--use_wandb', action='store_true', help='Enable W&B logging')
78
+ parser.add_argument('--strict_size', type=bool, default=True, help='Halt if any .onnx > 1.44MB (default: True)')
79
+ parser.add_argument('--strict_score', type=bool, default=False, help='Halt if any model unscorable (default: False)')
80
  args = parser.parse_args()
81
 
82
  providers = get_providers(args.device)