Fix 4: main.py — add --strict_score, stop if any model unscorable
Browse files- neurogolf_solver/main.py +32 -20
neurogolf_solver/main.py
CHANGED
|
@@ -16,6 +16,7 @@ import onnxruntime as ort
|
|
| 16 |
from .config import get_providers
|
| 17 |
from .data_loader import load_tasks_dir, load_tasks_kaggle
|
| 18 |
from .submission import run_tasks, generate_submission, print_summary
|
|
|
|
| 19 |
from .constants import EXCLUDED_TASKS, MAX_ONNX_FILESIZE
|
| 20 |
|
| 21 |
try:
|
|
@@ -24,27 +25,36 @@ except ImportError:
|
|
| 24 |
wandb = None
|
| 25 |
|
| 26 |
|
| 27 |
-
def
|
| 28 |
-
"""Check all .onnx files
|
| 29 |
-
If strict
|
| 30 |
-
|
| 31 |
for f in sorted(os.listdir(output_dir)):
|
| 32 |
-
if f.endswith('.onnx'):
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
print(f"\n{'!'*70}")
|
| 39 |
-
print(f"
|
| 40 |
-
for f,
|
| 41 |
-
print(f" {f}: {
|
| 42 |
-
print(f"Kaggle WILL reject this submission.")
|
| 43 |
print(f"{'!'*70}")
|
| 44 |
-
if
|
| 45 |
-
print("Stopping
|
| 46 |
sys.exit(1)
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
|
| 50 |
def main():
|
|
@@ -57,7 +67,8 @@ def main():
|
|
| 57 |
parser.add_argument('--tasks', type=str, default='', help='Comma-separated task numbers')
|
| 58 |
parser.add_argument('--device', type=str, default='auto', choices=['auto', 'cpu', 'cuda'])
|
| 59 |
parser.add_argument('--use_wandb', action='store_true', help='Enable W&B logging')
|
| 60 |
-
parser.add_argument('--strict_size', type=bool, default=True, help='Stop
|
|
|
|
| 61 |
args = parser.parse_args()
|
| 62 |
|
| 63 |
providers = get_providers(args.device)
|
|
@@ -72,7 +83,8 @@ def main():
|
|
| 72 |
|
| 73 |
ort.set_default_logger_severity(3)
|
| 74 |
print(f"Using providers: {providers}")
|
| 75 |
-
print(f"Strict size
|
|
|
|
| 76 |
|
| 77 |
# Load tasks
|
| 78 |
if args.kaggle:
|
|
@@ -106,7 +118,7 @@ def main():
|
|
| 106 |
elapsed = time.time() - t0
|
| 107 |
|
| 108 |
# Check all output files BEFORE generating submission
|
| 109 |
-
|
| 110 |
|
| 111 |
submission_info = generate_submission(args.output_dir, results, costs_dict, task_nums)
|
| 112 |
print_summary(results, submission_info, elapsed)
|
|
|
|
| 16 |
from .config import get_providers
|
| 17 |
from .data_loader import load_tasks_dir, load_tasks_kaggle
|
| 18 |
from .submission import run_tasks, generate_submission, print_summary
|
| 19 |
+
from .profiler import score_network
|
| 20 |
from .constants import EXCLUDED_TASKS, MAX_ONNX_FILESIZE
|
| 21 |
|
| 22 |
try:
|
|
|
|
| 25 |
wandb = None
|
| 26 |
|
| 27 |
|
| 28 |
+
def check_all_models(output_dir, strict_size, strict_score):
|
| 29 |
+
"""Check all .onnx files for size limit and scoreability.
|
| 30 |
+
If strict and any fail, exit immediately."""
|
| 31 |
+
problems = []
|
| 32 |
for f in sorted(os.listdir(output_dir)):
|
| 33 |
+
if not f.endswith('.onnx'):
|
| 34 |
+
continue
|
| 35 |
+
fpath = os.path.join(output_dir, f)
|
| 36 |
+
fsize = os.path.getsize(fpath)
|
| 37 |
+
|
| 38 |
+
if fsize > MAX_ONNX_FILESIZE:
|
| 39 |
+
problems.append((f, f"OVERSIZED: {fsize:,} bytes ({fsize/1024:.1f} KB) > {MAX_ONNX_FILESIZE:,}"))
|
| 40 |
+
|
| 41 |
+
macs, memory, params = score_network(fpath)
|
| 42 |
+
if macs is None or memory is None or params is None:
|
| 43 |
+
problems.append((f, "UNSCORABLE: score_network returned (None, None, None) — Kaggle will reject"))
|
| 44 |
+
|
| 45 |
+
if problems:
|
| 46 |
print(f"\n{'!'*70}")
|
| 47 |
+
print(f"PROBLEMS FOUND: {len(problems)} .onnx files will be REJECTED by Kaggle:")
|
| 48 |
+
for f, msg in problems:
|
| 49 |
+
print(f" {f}: {msg}")
|
|
|
|
| 50 |
print(f"{'!'*70}")
|
| 51 |
+
if strict_size or strict_score:
|
| 52 |
+
print("Stopping. Fix these models before submitting.")
|
| 53 |
sys.exit(1)
|
| 54 |
+
else:
|
| 55 |
+
print(f"\nAll .onnx files pass size and score checks.")
|
| 56 |
+
|
| 57 |
+
return problems
|
| 58 |
|
| 59 |
|
| 60 |
def main():
|
|
|
|
| 67 |
parser.add_argument('--tasks', type=str, default='', help='Comma-separated task numbers')
|
| 68 |
parser.add_argument('--device', type=str, default='auto', choices=['auto', 'cpu', 'cuda'])
|
| 69 |
parser.add_argument('--use_wandb', action='store_true', help='Enable W&B logging')
|
| 70 |
+
parser.add_argument('--strict_size', type=bool, default=True, help='Stop if any .onnx > 1.44MB (default: True)')
|
| 71 |
+
parser.add_argument('--strict_score', type=bool, default=True, help='Stop if any model unscorable (default: True)')
|
| 72 |
args = parser.parse_args()
|
| 73 |
|
| 74 |
providers = get_providers(args.device)
|
|
|
|
| 83 |
|
| 84 |
ort.set_default_logger_severity(3)
|
| 85 |
print(f"Using providers: {providers}")
|
| 86 |
+
print(f"Strict size: {args.strict_size} | Strict score: {args.strict_score}")
|
| 87 |
+
print(f"Max .onnx file size: {MAX_ONNX_FILESIZE:,} bytes")
|
| 88 |
|
| 89 |
# Load tasks
|
| 90 |
if args.kaggle:
|
|
|
|
| 118 |
elapsed = time.time() - t0
|
| 119 |
|
| 120 |
# Check all output files BEFORE generating submission
|
| 121 |
+
check_all_models(args.output_dir, args.strict_size, args.strict_score)
|
| 122 |
|
| 123 |
submission_info = generate_submission(args.output_dir, results, costs_dict, task_nums)
|
| 124 |
print_summary(results, submission_info, elapsed)
|