vegarl / scripts /process_burstgpt.py
ronitraj's picture
Deploy Space without oversized raw dataset
4fbc241
#!/usr/bin/env python3
import argparse
import json
import os
import sys
from pathlib import Path
import pandas as pd
import numpy as np
from scipy import stats
def main() -> int:
parser = argparse.ArgumentParser(description="Process BurstGPT raw data into InferenceGym traces.")
parser.add_argument("--raw-csv", type=str, default="data/BurstGPT.csv", help="Path to raw BurstGPT CSV dump")
parser.add_argument("--output-dir", type=str, default="data/burstgpt")
args = parser.parse_args()
print("[INFO] Processing BurstGPT Dataset...")
raw_path = Path(args.raw_csv)
if not raw_path.exists():
print(f"[ERROR] Raw CSV not found at {raw_path}")
return 1
# Load and clean
df = pd.read_csv(raw_path)
df = df.sort_values("Timestamp")
# Robust column detection
log_col = next((c for c in df.columns if "log type" in c.lower()), "Log Type")
req_col = next((c for c in df.columns if "request tokens" in c.lower()), "Request tokens")
res_col = next((c for c in df.columns if "response tokens" in c.lower()), "Response tokens")
# Calculate arrival deltas
df["arrival_delta"] = df["Timestamp"].diff().fillna(0)
# Separate by Log type
chat_df = df[df[log_col].str.contains("Conversation", na=False, case=False)].copy()
api_df = df[df[log_col].str.contains("API", na=False, case=False)].copy()
if len(api_df) == 0:
print(f"[WARN] No records found for '{log_col}' containing 'API'")
# Fallback to model name if log type fails
api_df = df[df["Model"].str.contains("API", na=False, case=False)].copy()
chat_df = df[~df.index.isin(api_df.index)].copy()
params = {}
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
# 1. Generate Arrival Params & Prompt Samples
for name, subset in [("chat", chat_df), ("api", api_df)]:
if len(subset) < 2:
continue
deltas = subset["arrival_delta"].values
a, loc, b = stats.gamma.fit(deltas[deltas > 0], floc=0)
params[name] = {"alpha": float(a), "beta": float(b)}
token_pairs = subset[["Request tokens", "Response tokens"]].rename(
columns={"Request tokens": "request_tokens", "Response tokens": "response_tokens"}
)
token_pairs.to_parquet(out_dir / f"{name}_prompts.parquet", index=False, engine="pyarrow")
print(f"[SUCCESS] Processed {name} workload: {len(subset)} records")
with open(out_dir / "arrival_params.json", "w") as f:
json.dump(params, f, indent=4)
# 2. Generate Legacy Traces to satisfy workload_configs.json
trace_dir = Path("data/traces")
trace_dir.mkdir(parents=True, exist_ok=True)
# Static trace: Just a sample of the raw data
static_trace = df.head(100).copy()
static_trace.to_parquet(trace_dir / "static_workload_trace.parquet", index=False, engine="pyarrow")
# Bursty trace: Middle-bursty section
bursty_trace = df.iloc[len(df)//2 : len(df)//2 + 200].copy()
bursty_trace.to_parquet(trace_dir / "bursty_workload_trace.parquet", index=False, engine="pyarrow")
# Adversarial trace: End section
adv_trace = df.tail(300).copy()
adv_trace.to_parquet(trace_dir / "adversarial_multitenant_trace.parquet", index=False, engine="pyarrow")
# ShareGPT prompt lengths for medium task
sharegpt_prompts = df[["Request tokens"]].rename(columns={"Request tokens": "prompt_length"}).sample(n=50000, random_state=42)
sharegpt_prompts.to_parquet(trace_dir / "sharegpt_prompt_lengths.parquet", index=False, engine="pyarrow")
print(f"[SUCCESS] Generated traces in {trace_dir}/")
return 0
if __name__ == "__main__":
sys.exit(main())