Fix: --strict_score default False (warnings only), --strict_size default True (halt on oversized)
Browse files- neurogolf_solver/main.py +23 -15
neurogolf_solver/main.py
CHANGED
|
@@ -27,8 +27,10 @@ except ImportError:
|
|
| 27 |
|
| 28 |
def check_all_models(output_dir, strict_size, strict_score):
|
| 29 |
"""Check all .onnx files for size limit and scoreability.
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
| 32 |
for f in sorted(os.listdir(output_dir)):
|
| 33 |
if not f.endswith('.onnx'):
|
| 34 |
continue
|
|
@@ -36,25 +38,31 @@ def check_all_models(output_dir, strict_size, strict_score):
|
|
| 36 |
fsize = os.path.getsize(fpath)
|
| 37 |
|
| 38 |
if fsize > MAX_ONNX_FILESIZE:
|
| 39 |
-
|
| 40 |
|
| 41 |
macs, memory, params = score_network(fpath)
|
| 42 |
if macs is None or memory is None or params is None:
|
| 43 |
-
|
| 44 |
|
| 45 |
-
if
|
| 46 |
print(f"\n{'!'*70}")
|
| 47 |
-
print(f"
|
| 48 |
-
for f,
|
| 49 |
-
print(f" {f}: {
|
| 50 |
print(f"{'!'*70}")
|
| 51 |
-
if strict_size
|
| 52 |
-
print("Stopping. Fix these models before submitting.")
|
| 53 |
sys.exit(1)
|
| 54 |
-
else:
|
| 55 |
-
print(f"\nAll .onnx files pass size and score checks.")
|
| 56 |
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
|
| 60 |
def main():
|
|
@@ -67,8 +75,8 @@ def main():
|
|
| 67 |
parser.add_argument('--tasks', type=str, default='', help='Comma-separated task numbers')
|
| 68 |
parser.add_argument('--device', type=str, default='auto', choices=['auto', 'cpu', 'cuda'])
|
| 69 |
parser.add_argument('--use_wandb', action='store_true', help='Enable W&B logging')
|
| 70 |
-
parser.add_argument('--strict_size', type=bool, default=True, help='
|
| 71 |
-
parser.add_argument('--strict_score', type=bool, default=
|
| 72 |
args = parser.parse_args()
|
| 73 |
|
| 74 |
providers = get_providers(args.device)
|
|
|
|
| 27 |
|
| 28 |
def check_all_models(output_dir, strict_size, strict_score):
|
| 29 |
"""Check all .onnx files for size limit and scoreability.
|
| 30 |
+
strict_size=True → halt on oversized files.
|
| 31 |
+
strict_score=False → warn on unscorable but don't halt."""
|
| 32 |
+
size_problems = []
|
| 33 |
+
score_problems = []
|
| 34 |
for f in sorted(os.listdir(output_dir)):
|
| 35 |
if not f.endswith('.onnx'):
|
| 36 |
continue
|
|
|
|
| 38 |
fsize = os.path.getsize(fpath)
|
| 39 |
|
| 40 |
if fsize > MAX_ONNX_FILESIZE:
|
| 41 |
+
size_problems.append((f, fsize))
|
| 42 |
|
| 43 |
macs, memory, params = score_network(fpath)
|
| 44 |
if macs is None or memory is None or params is None:
|
| 45 |
+
score_problems.append(f)
|
| 46 |
|
| 47 |
+
if size_problems:
|
| 48 |
print(f"\n{'!'*70}")
|
| 49 |
+
print(f"FATAL: {len(size_problems)} .onnx files exceed 1.44MB limit:")
|
| 50 |
+
for f, sz in size_problems:
|
| 51 |
+
print(f" {f}: {sz:,} bytes ({sz/1024:.1f} KB)")
|
| 52 |
print(f"{'!'*70}")
|
| 53 |
+
if strict_size:
|
|
|
|
| 54 |
sys.exit(1)
|
|
|
|
|
|
|
| 55 |
|
| 56 |
+
if score_problems:
|
| 57 |
+
print(f"\nWARNING: {len(score_problems)} .onnx files unscorable by onnx_tool:")
|
| 58 |
+
for f in score_problems:
|
| 59 |
+
print(f" {f}")
|
| 60 |
+
if strict_score:
|
| 61 |
+
print("Stopping (--strict_score is on).")
|
| 62 |
+
sys.exit(1)
|
| 63 |
+
|
| 64 |
+
if not size_problems and not score_problems:
|
| 65 |
+
print(f"\nAll .onnx files pass size and score checks.")
|
| 66 |
|
| 67 |
|
| 68 |
def main():
|
|
|
|
| 75 |
parser.add_argument('--tasks', type=str, default='', help='Comma-separated task numbers')
|
| 76 |
parser.add_argument('--device', type=str, default='auto', choices=['auto', 'cpu', 'cuda'])
|
| 77 |
parser.add_argument('--use_wandb', action='store_true', help='Enable W&B logging')
|
| 78 |
+
parser.add_argument('--strict_size', type=bool, default=True, help='Halt if any .onnx > 1.44MB (default: True)')
|
| 79 |
+
parser.add_argument('--strict_score', type=bool, default=False, help='Halt if any model unscorable (default: False)')
|
| 80 |
args = parser.parse_args()
|
| 81 |
|
| 82 |
providers = get_providers(args.device)
|