PEFT
qlora
sft
trl
qwen3
tmf921
intent-based-networking
network-slicing
rtx-6000-ada
ml-intern
nraptisss commited on
Commit
e0c5f96
·
verified ·
1 Parent(s): e16ed3a

Add sampled zero-shot baseline runner

Browse files
Files changed (1) hide show
  1. scripts/run_zero_shot_baseline.sh +41 -0
scripts/run_zero_shot_baseline.sh ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ # Run a sampled zero-shot baseline for publication comparison.
5
+ # Default evaluates Qwen/Qwen3-8B without any adapter on 200 examples per split.
6
+ # Usage:
7
+ # EVAL_BATCH_SIZE=4 BASELINE_MAX_SAMPLES=200 bash scripts/run_zero_shot_baseline.sh outputs/baselines/qwen3-8b-zero-shot
8
+
9
+ OUT_DIR="${1:-outputs/baselines/qwen3-8b-zero-shot}"
10
+ MODEL_ID="${BASELINE_MODEL:-Qwen/Qwen3-8B}"
11
+ MAX_SAMPLES="${BASELINE_MAX_SAMPLES:-200}"
12
+
13
+ source .venv/bin/activate
14
+ export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}"
15
+ export PYTHONPATH="$PWD/src:${PYTHONPATH:-}"
16
+ export TOKENIZERS_PARALLELISM=false
17
+
18
+ python scripts/check_gpu.py
19
+ mkdir -p "$OUT_DIR"
20
+
21
+ python scripts/evaluate_model.py \
22
+ --model "$MODEL_ID" \
23
+ --dataset nraptisss/TMF921-intent-to-config-research-sota \
24
+ --output_dir "$OUT_DIR" \
25
+ --batch_size "${EVAL_BATCH_SIZE:-4}" \
26
+ --max_samples_per_split "$MAX_SAMPLES" \
27
+ --max_new_tokens "${EVAL_MAX_NEW_TOKENS:-1536}" \
28
+ --gold_length_buffer "${EVAL_GOLD_LENGTH_BUFFER:-96}" \
29
+ --save_every "${EVAL_SAVE_EVERY:-25}"
30
+
31
+ python scripts/normalize_eval_metrics.py --eval_dir "$OUT_DIR"
32
+
33
+ python - <<PY
34
+ import json
35
+ from pathlib import Path
36
+ p = Path("$OUT_DIR") / "all_normalized_metrics.json"
37
+ m = json.loads(p.read_text())
38
+ print("Zero-shot baseline:", "$MODEL_ID")
39
+ for split, s in m.items():
40
+ print(f"{split}: n={s.get('num_examples')} parse={s.get('parse_json'):.4f} norm_field_f1={s.get('norm_field_f1'):.4f} norm_key_f1={s.get('norm_key_f1'):.4f} norm_exact={s.get('norm_exact_match'):.4f}")
41
+ PY