PEFT
qlora
sft
trl
qwen3
tmf921
intent-based-networking
network-slicing
rtx-6000-ada
ml-intern
File size: 4,713 Bytes
7474a91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#!/usr/bin/env python3
"""Build a weak-layer-focused second-stage SFT dataset.

Creates local parquet files:
  <output_dir>/train.parquet
  <output_dir>/validation.parquet
  <output_dir>/manifest.json

The train split contains:
- all examples from weak target layers,
- extra duplicated rows for very rare weak layers up to a configurable minimum,
- a replay buffer sampled from non-weak layers to reduce catastrophic forgetting.

All eval/test reporting should still use the official research-sota OOD splits.
"""
import argparse
import json
from pathlib import Path

import pandas as pd
from datasets import load_dataset

WEAK_LAYERS_DEFAULT = [
    "o1_nrm",
    "a1_policy",
    "tmf921_lifecycle_report",
    "tmf921_lifecycle_monitor",
    "tmf921_lifecycle_scale",
]


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--dataset", default="nraptisss/TMF921-intent-to-config-research-sota")
    p.add_argument("--train_split", default="train_sota")
    p.add_argument("--validation_split", default="validation")
    p.add_argument("--output_dir", required=True)
    p.add_argument("--weak_layers", nargs="+", default=WEAK_LAYERS_DEFAULT)
    p.add_argument("--rare_min_per_layer", type=int, default=1500, help="Duplicate weak layers with fewer rows to this minimum")
    p.add_argument("--replay_ratio", type=float, default=0.30, help="Replay rows as fraction of final weak rows")
    p.add_argument("--seed", type=int, default=42)
    return p.parse_args()


def main():
    args = parse_args()
    out = Path(args.output_dir)
    out.mkdir(parents=True, exist_ok=True)

    ds = load_dataset(args.dataset)
    train = ds[args.train_split].to_pandas()
    val = ds[args.validation_split].to_pandas()

    weak_layers = set(args.weak_layers)
    weak_parts = []
    layer_counts_before = {}
    layer_counts_after = {}

    for layer in args.weak_layers:
        part = train[train["target_layer"] == layer].copy()
        layer_counts_before[layer] = int(len(part))
        if len(part) == 0:
            continue
        if len(part) < args.rare_min_per_layer:
            reps = []
            needed = args.rare_min_per_layer - len(part)
            for i in range(needed):
                r = part.iloc[i % len(part)].copy(deep=True)
                original_id = r["id"]
                r["id"] = f"{original_id}-stage2weak-{i:05d}"
                r["is_augmented"] = True
                r["augmentation_type"] = f"stage2_weak_duplicate_{layer}"
                r["source_id"] = r.get("source_id", original_id)
                reps.append(r)
            if reps:
                part = pd.concat([part, pd.DataFrame(reps)], ignore_index=True)
        layer_counts_after[layer] = int(len(part))
        weak_parts.append(part)

    weak_df = pd.concat(weak_parts, ignore_index=True) if weak_parts else pd.DataFrame(columns=train.columns)
    nonweak = train[~train["target_layer"].isin(weak_layers)].copy()
    replay_n = min(len(nonweak), int(len(weak_df) * args.replay_ratio))
    replay = nonweak.sample(n=replay_n, random_state=args.seed).copy() if replay_n > 0 else pd.DataFrame(columns=train.columns)
    replay["augmentation_type"] = replay.get("augmentation_type", "none").astype(str) + "+stage2_replay"

    stage2 = pd.concat([weak_df, replay], ignore_index=True).sample(frac=1.0, random_state=args.seed).reset_index(drop=True)
    stage2["stage2_role"] = stage2["target_layer"].apply(lambda x: "weak" if x in weak_layers else "replay")
    val = val.copy()
    val["stage2_role"] = "validation"

    train_path = out / "train.parquet"
    val_path = out / "validation.parquet"
    stage2.to_parquet(train_path, index=False)
    val.to_parquet(val_path, index=False)

    manifest = {
        "source_dataset": args.dataset,
        "train_split": args.train_split,
        "validation_split": args.validation_split,
        "output_dir": str(out),
        "weak_layers": args.weak_layers,
        "rare_min_per_layer": args.rare_min_per_layer,
        "replay_ratio": args.replay_ratio,
        "seed": args.seed,
        "rows_train_stage2": int(len(stage2)),
        "rows_validation": int(len(val)),
        "weak_rows_total_after_duplication": int(len(weak_df)),
        "replay_rows": int(len(replay)),
        "layer_counts_before": layer_counts_before,
        "layer_counts_after": layer_counts_after,
        "stage2_role_counts": stage2["stage2_role"].value_counts().to_dict(),
        "target_layer_counts": stage2["target_layer"].value_counts().to_dict(),
    }
    (out / "manifest.json").write_text(json.dumps(manifest, indent=2, ensure_ascii=False))
    print(json.dumps(manifest, indent=2, ensure_ascii=False))


if __name__ == "__main__":
    main()