#!/usr/bin/env python3 """Build a weak-layer-focused second-stage SFT dataset. Creates local parquet files: /train.parquet /validation.parquet /manifest.json The train split contains: - all examples from weak target layers, - extra duplicated rows for very rare weak layers up to a configurable minimum, - a replay buffer sampled from non-weak layers to reduce catastrophic forgetting. All eval/test reporting should still use the official research-sota OOD splits. """ import argparse import json from pathlib import Path import pandas as pd from datasets import load_dataset WEAK_LAYERS_DEFAULT = [ "o1_nrm", "a1_policy", "tmf921_lifecycle_report", "tmf921_lifecycle_monitor", "tmf921_lifecycle_scale", ] def parse_args(): p = argparse.ArgumentParser() p.add_argument("--dataset", default="nraptisss/TMF921-intent-to-config-research-sota") p.add_argument("--train_split", default="train_sota") p.add_argument("--validation_split", default="validation") p.add_argument("--output_dir", required=True) p.add_argument("--weak_layers", nargs="+", default=WEAK_LAYERS_DEFAULT) p.add_argument("--rare_min_per_layer", type=int, default=1500, help="Duplicate weak layers with fewer rows to this minimum") p.add_argument("--replay_ratio", type=float, default=0.30, help="Replay rows as fraction of final weak rows") p.add_argument("--seed", type=int, default=42) return p.parse_args() def main(): args = parse_args() out = Path(args.output_dir) out.mkdir(parents=True, exist_ok=True) ds = load_dataset(args.dataset) train = ds[args.train_split].to_pandas() val = ds[args.validation_split].to_pandas() weak_layers = set(args.weak_layers) weak_parts = [] layer_counts_before = {} layer_counts_after = {} for layer in args.weak_layers: part = train[train["target_layer"] == layer].copy() layer_counts_before[layer] = int(len(part)) if len(part) == 0: continue if len(part) < args.rare_min_per_layer: reps = [] needed = args.rare_min_per_layer - len(part) for i in range(needed): r = part.iloc[i % len(part)].copy(deep=True) original_id = r["id"] r["id"] = f"{original_id}-stage2weak-{i:05d}" r["is_augmented"] = True r["augmentation_type"] = f"stage2_weak_duplicate_{layer}" r["source_id"] = r.get("source_id", original_id) reps.append(r) if reps: part = pd.concat([part, pd.DataFrame(reps)], ignore_index=True) layer_counts_after[layer] = int(len(part)) weak_parts.append(part) weak_df = pd.concat(weak_parts, ignore_index=True) if weak_parts else pd.DataFrame(columns=train.columns) nonweak = train[~train["target_layer"].isin(weak_layers)].copy() replay_n = min(len(nonweak), int(len(weak_df) * args.replay_ratio)) replay = nonweak.sample(n=replay_n, random_state=args.seed).copy() if replay_n > 0 else pd.DataFrame(columns=train.columns) replay["augmentation_type"] = replay.get("augmentation_type", "none").astype(str) + "+stage2_replay" stage2 = pd.concat([weak_df, replay], ignore_index=True).sample(frac=1.0, random_state=args.seed).reset_index(drop=True) stage2["stage2_role"] = stage2["target_layer"].apply(lambda x: "weak" if x in weak_layers else "replay") val = val.copy() val["stage2_role"] = "validation" train_path = out / "train.parquet" val_path = out / "validation.parquet" stage2.to_parquet(train_path, index=False) val.to_parquet(val_path, index=False) manifest = { "source_dataset": args.dataset, "train_split": args.train_split, "validation_split": args.validation_split, "output_dir": str(out), "weak_layers": args.weak_layers, "rare_min_per_layer": args.rare_min_per_layer, "replay_ratio": args.replay_ratio, "seed": args.seed, "rows_train_stage2": int(len(stage2)), "rows_validation": int(len(val)), "weak_rows_total_after_duplication": int(len(weak_df)), "replay_rows": int(len(replay)), "layer_counts_before": layer_counts_before, "layer_counts_after": layer_counts_after, "stage2_role_counts": stage2["stage2_role"].value_counts().to_dict(), "target_layer_counts": stage2["target_layer"].value_counts().to_dict(), } (out / "manifest.json").write_text(json.dumps(manifest, indent=2, ensure_ascii=False)) print(json.dumps(manifest, indent=2, ensure_ascii=False)) if __name__ == "__main__": main()