File size: 4,713 Bytes
7474a91 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 | #!/usr/bin/env python3
"""Build a weak-layer-focused second-stage SFT dataset.
Creates local parquet files:
<output_dir>/train.parquet
<output_dir>/validation.parquet
<output_dir>/manifest.json
The train split contains:
- all examples from weak target layers,
- extra duplicated rows for very rare weak layers up to a configurable minimum,
- a replay buffer sampled from non-weak layers to reduce catastrophic forgetting.
All eval/test reporting should still use the official research-sota OOD splits.
"""
import argparse
import json
from pathlib import Path
import pandas as pd
from datasets import load_dataset
WEAK_LAYERS_DEFAULT = [
"o1_nrm",
"a1_policy",
"tmf921_lifecycle_report",
"tmf921_lifecycle_monitor",
"tmf921_lifecycle_scale",
]
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--dataset", default="nraptisss/TMF921-intent-to-config-research-sota")
p.add_argument("--train_split", default="train_sota")
p.add_argument("--validation_split", default="validation")
p.add_argument("--output_dir", required=True)
p.add_argument("--weak_layers", nargs="+", default=WEAK_LAYERS_DEFAULT)
p.add_argument("--rare_min_per_layer", type=int, default=1500, help="Duplicate weak layers with fewer rows to this minimum")
p.add_argument("--replay_ratio", type=float, default=0.30, help="Replay rows as fraction of final weak rows")
p.add_argument("--seed", type=int, default=42)
return p.parse_args()
def main():
args = parse_args()
out = Path(args.output_dir)
out.mkdir(parents=True, exist_ok=True)
ds = load_dataset(args.dataset)
train = ds[args.train_split].to_pandas()
val = ds[args.validation_split].to_pandas()
weak_layers = set(args.weak_layers)
weak_parts = []
layer_counts_before = {}
layer_counts_after = {}
for layer in args.weak_layers:
part = train[train["target_layer"] == layer].copy()
layer_counts_before[layer] = int(len(part))
if len(part) == 0:
continue
if len(part) < args.rare_min_per_layer:
reps = []
needed = args.rare_min_per_layer - len(part)
for i in range(needed):
r = part.iloc[i % len(part)].copy(deep=True)
original_id = r["id"]
r["id"] = f"{original_id}-stage2weak-{i:05d}"
r["is_augmented"] = True
r["augmentation_type"] = f"stage2_weak_duplicate_{layer}"
r["source_id"] = r.get("source_id", original_id)
reps.append(r)
if reps:
part = pd.concat([part, pd.DataFrame(reps)], ignore_index=True)
layer_counts_after[layer] = int(len(part))
weak_parts.append(part)
weak_df = pd.concat(weak_parts, ignore_index=True) if weak_parts else pd.DataFrame(columns=train.columns)
nonweak = train[~train["target_layer"].isin(weak_layers)].copy()
replay_n = min(len(nonweak), int(len(weak_df) * args.replay_ratio))
replay = nonweak.sample(n=replay_n, random_state=args.seed).copy() if replay_n > 0 else pd.DataFrame(columns=train.columns)
replay["augmentation_type"] = replay.get("augmentation_type", "none").astype(str) + "+stage2_replay"
stage2 = pd.concat([weak_df, replay], ignore_index=True).sample(frac=1.0, random_state=args.seed).reset_index(drop=True)
stage2["stage2_role"] = stage2["target_layer"].apply(lambda x: "weak" if x in weak_layers else "replay")
val = val.copy()
val["stage2_role"] = "validation"
train_path = out / "train.parquet"
val_path = out / "validation.parquet"
stage2.to_parquet(train_path, index=False)
val.to_parquet(val_path, index=False)
manifest = {
"source_dataset": args.dataset,
"train_split": args.train_split,
"validation_split": args.validation_split,
"output_dir": str(out),
"weak_layers": args.weak_layers,
"rare_min_per_layer": args.rare_min_per_layer,
"replay_ratio": args.replay_ratio,
"seed": args.seed,
"rows_train_stage2": int(len(stage2)),
"rows_validation": int(len(val)),
"weak_rows_total_after_duplication": int(len(weak_df)),
"replay_rows": int(len(replay)),
"layer_counts_before": layer_counts_before,
"layer_counts_after": layer_counts_after,
"stage2_role_counts": stage2["stage2_role"].value_counts().to_dict(),
"target_layer_counts": stage2["target_layer"].value_counts().to_dict(),
}
(out / "manifest.json").write_text(json.dumps(manifest, indent=2, ensure_ascii=False))
print(json.dumps(manifest, indent=2, ensure_ascii=False))
if __name__ == "__main__":
main()
|