| |
| """Build a weak-layer-focused second-stage SFT dataset. |
| |
| Creates local parquet files: |
| <output_dir>/train.parquet |
| <output_dir>/validation.parquet |
| <output_dir>/manifest.json |
| |
| The train split contains: |
| - all examples from weak target layers, |
| - extra duplicated rows for very rare weak layers up to a configurable minimum, |
| - a replay buffer sampled from non-weak layers to reduce catastrophic forgetting. |
| |
| All eval/test reporting should still use the official research-sota OOD splits. |
| """ |
| import argparse |
| import json |
| from pathlib import Path |
|
|
| import pandas as pd |
| from datasets import load_dataset |
|
|
| WEAK_LAYERS_DEFAULT = [ |
| "o1_nrm", |
| "a1_policy", |
| "tmf921_lifecycle_report", |
| "tmf921_lifecycle_monitor", |
| "tmf921_lifecycle_scale", |
| ] |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--dataset", default="nraptisss/TMF921-intent-to-config-research-sota") |
| p.add_argument("--train_split", default="train_sota") |
| p.add_argument("--validation_split", default="validation") |
| p.add_argument("--output_dir", required=True) |
| p.add_argument("--weak_layers", nargs="+", default=WEAK_LAYERS_DEFAULT) |
| p.add_argument("--rare_min_per_layer", type=int, default=1500, help="Duplicate weak layers with fewer rows to this minimum") |
| p.add_argument("--replay_ratio", type=float, default=0.30, help="Replay rows as fraction of final weak rows") |
| p.add_argument("--seed", type=int, default=42) |
| return p.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
| out = Path(args.output_dir) |
| out.mkdir(parents=True, exist_ok=True) |
|
|
| ds = load_dataset(args.dataset) |
| train = ds[args.train_split].to_pandas() |
| val = ds[args.validation_split].to_pandas() |
|
|
| weak_layers = set(args.weak_layers) |
| weak_parts = [] |
| layer_counts_before = {} |
| layer_counts_after = {} |
|
|
| for layer in args.weak_layers: |
| part = train[train["target_layer"] == layer].copy() |
| layer_counts_before[layer] = int(len(part)) |
| if len(part) == 0: |
| continue |
| if len(part) < args.rare_min_per_layer: |
| reps = [] |
| needed = args.rare_min_per_layer - len(part) |
| for i in range(needed): |
| r = part.iloc[i % len(part)].copy(deep=True) |
| original_id = r["id"] |
| r["id"] = f"{original_id}-stage2weak-{i:05d}" |
| r["is_augmented"] = True |
| r["augmentation_type"] = f"stage2_weak_duplicate_{layer}" |
| r["source_id"] = r.get("source_id", original_id) |
| reps.append(r) |
| if reps: |
| part = pd.concat([part, pd.DataFrame(reps)], ignore_index=True) |
| layer_counts_after[layer] = int(len(part)) |
| weak_parts.append(part) |
|
|
| weak_df = pd.concat(weak_parts, ignore_index=True) if weak_parts else pd.DataFrame(columns=train.columns) |
| nonweak = train[~train["target_layer"].isin(weak_layers)].copy() |
| replay_n = min(len(nonweak), int(len(weak_df) * args.replay_ratio)) |
| replay = nonweak.sample(n=replay_n, random_state=args.seed).copy() if replay_n > 0 else pd.DataFrame(columns=train.columns) |
| replay["augmentation_type"] = replay.get("augmentation_type", "none").astype(str) + "+stage2_replay" |
|
|
| stage2 = pd.concat([weak_df, replay], ignore_index=True).sample(frac=1.0, random_state=args.seed).reset_index(drop=True) |
| stage2["stage2_role"] = stage2["target_layer"].apply(lambda x: "weak" if x in weak_layers else "replay") |
| val = val.copy() |
| val["stage2_role"] = "validation" |
|
|
| train_path = out / "train.parquet" |
| val_path = out / "validation.parquet" |
| stage2.to_parquet(train_path, index=False) |
| val.to_parquet(val_path, index=False) |
|
|
| manifest = { |
| "source_dataset": args.dataset, |
| "train_split": args.train_split, |
| "validation_split": args.validation_split, |
| "output_dir": str(out), |
| "weak_layers": args.weak_layers, |
| "rare_min_per_layer": args.rare_min_per_layer, |
| "replay_ratio": args.replay_ratio, |
| "seed": args.seed, |
| "rows_train_stage2": int(len(stage2)), |
| "rows_validation": int(len(val)), |
| "weak_rows_total_after_duplication": int(len(weak_df)), |
| "replay_rows": int(len(replay)), |
| "layer_counts_before": layer_counts_before, |
| "layer_counts_after": layer_counts_after, |
| "stage2_role_counts": stage2["stage2_role"].value_counts().to_dict(), |
| "target_layer_counts": stage2["target_layer"].value_counts().to_dict(), |
| } |
| (out / "manifest.json").write_text(json.dumps(manifest, indent=2, ensure_ascii=False)) |
| print(json.dumps(manifest, indent=2, ensure_ascii=False)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|