File size: 3,766 Bytes
4fbc241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#!/usr/bin/env python3
import argparse
import json
import os
import sys
from pathlib import Path

import pandas as pd
import numpy as np
from scipy import stats

def main() -> int:
    parser = argparse.ArgumentParser(description="Process BurstGPT raw data into InferenceGym traces.")
    parser.add_argument("--raw-csv", type=str, default="data/BurstGPT.csv", help="Path to raw BurstGPT CSV dump")
    parser.add_argument("--output-dir", type=str, default="data/burstgpt")
    args = parser.parse_args()

    print("[INFO] Processing BurstGPT Dataset...")
    raw_path = Path(args.raw_csv)
    if not raw_path.exists():
        print(f"[ERROR] Raw CSV not found at {raw_path}")
        return 1

    # Load and clean
    df = pd.read_csv(raw_path)
    df = df.sort_values("Timestamp")
    
    # Robust column detection
    log_col = next((c for c in df.columns if "log type" in c.lower()), "Log Type")
    req_col = next((c for c in df.columns if "request tokens" in c.lower()), "Request tokens")
    res_col = next((c for c in df.columns if "response tokens" in c.lower()), "Response tokens")

    # Calculate arrival deltas
    df["arrival_delta"] = df["Timestamp"].diff().fillna(0)
    
    # Separate by Log type
    chat_df = df[df[log_col].str.contains("Conversation", na=False, case=False)].copy()
    api_df = df[df[log_col].str.contains("API", na=False, case=False)].copy()
    
    if len(api_df) == 0:
        print(f"[WARN] No records found for '{log_col}' containing 'API'")
        # Fallback to model name if log type fails
        api_df = df[df["Model"].str.contains("API", na=False, case=False)].copy()
        chat_df = df[~df.index.isin(api_df.index)].copy()

    params = {}
    out_dir = Path(args.output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    # 1. Generate Arrival Params & Prompt Samples
    for name, subset in [("chat", chat_df), ("api", api_df)]:
        if len(subset) < 2:
            continue
            
        deltas = subset["arrival_delta"].values
        a, loc, b = stats.gamma.fit(deltas[deltas > 0], floc=0)
        params[name] = {"alpha": float(a), "beta": float(b)}
        
        token_pairs = subset[["Request tokens", "Response tokens"]].rename(
            columns={"Request tokens": "request_tokens", "Response tokens": "response_tokens"}
        )
        token_pairs.to_parquet(out_dir / f"{name}_prompts.parquet", index=False, engine="pyarrow")
        print(f"[SUCCESS] Processed {name} workload: {len(subset)} records")

    with open(out_dir / "arrival_params.json", "w") as f:
        json.dump(params, f, indent=4)

    # 2. Generate Legacy Traces to satisfy workload_configs.json
    trace_dir = Path("data/traces")
    trace_dir.mkdir(parents=True, exist_ok=True)
    
    # Static trace: Just a sample of the raw data
    static_trace = df.head(100).copy()
    static_trace.to_parquet(trace_dir / "static_workload_trace.parquet", index=False, engine="pyarrow")
    
    # Bursty trace: Middle-bursty section
    bursty_trace = df.iloc[len(df)//2 : len(df)//2 + 200].copy()
    bursty_trace.to_parquet(trace_dir / "bursty_workload_trace.parquet", index=False, engine="pyarrow")
    
    # Adversarial trace: End section
    adv_trace = df.tail(300).copy()
    adv_trace.to_parquet(trace_dir / "adversarial_multitenant_trace.parquet", index=False, engine="pyarrow")
    
    # ShareGPT prompt lengths for medium task
    sharegpt_prompts = df[["Request tokens"]].rename(columns={"Request tokens": "prompt_length"}).sample(n=50000, random_state=42)
    sharegpt_prompts.to_parquet(trace_dir / "sharegpt_prompt_lengths.parquet", index=False, engine="pyarrow")
    
    print(f"[SUCCESS] Generated traces in {trace_dir}/")
    
    return 0

if __name__ == "__main__":
    sys.exit(main())