qwen-trainer-scripts / agentic_data_gen.py
mindchain's picture
Upload folder using huggingface_hub
78a0ca9 verified
import os
import pandas as pd
import re
from typing import List, Optional, Dict, Any
from dataclasses import dataclass
try:
import data_designer.config as dd
from data_designer.config.column_configs import Score
from data_designer.interface import DataDesigner
except ImportError:
dd = None
Score = None
DataDesigner = None
@dataclass
class AgenticDataConfig:
name: str = "agentic_dataset"
num_records: int = 10
task_description: str = "SQL-to-Natural-Language conversion"
scenarios_path: Optional[str] = None # Optional path to a JSONL file with 'scenario' column
model_alias: str = "llm-text"
judge_model_alias: str = "llm-judge"
output_path: str = "agentic_synthetic_data.jsonl"
min_quality_score: int = 2 # Perplexity often gets penalized for citations even when they are accurate
generate_dpo: bool = False # Whether to generate 'rejected' responses for DPO
generate_reasoning: bool = False # Whether to generate <reasoning>...<answer> format
num_instructions_per_scenario: int = 1 # Number of instructions per scenario for diversity
max_tokens: int = 4096 # Max tokens for generation
class AgenticDataGenerator:
def __init__(self, designer: Optional[DataDesigner] = None):
if not designer:
# Configure OpenAI and Perplexity providers
model_providers = []
if os.environ.get("OPENAI_API_KEY"):
model_providers.append(dd.ModelProvider(
name="openai",
provider_type="openai",
api_key="OPENAI_API_KEY",
endpoint="https://api.openai.com/v1"
))
if os.environ.get("PERPLEXITY_API_KEY"):
model_providers.append(dd.ModelProvider(
name="perplexity",
provider_type="openai",
api_key="PERPLEXITY_API_KEY",
endpoint="https://api.perplexity.ai"
))
if os.environ.get("PAPERCLIP_API_KEY"):
model_providers.append(dd.ModelProvider(
name="paperclip",
provider_type="openai",
api_key="PAPERCLIP_API_KEY",
endpoint=os.environ.get("PAPERCLIP_API_URL", "") + "/v1"
))
if not model_providers:
raise ValueError("Neither OPENAI_API_KEY nor PERPLEXITY_API_KEY is set.")
designer = DataDesigner(model_providers=model_providers)
self.designer = designer
def strip_citations(self, text: str) -> str:
"""Removes Perplexity-style citations like [1], [2], etc."""
if not isinstance(text, str):
return text
return re.sub(r'\[\d+\]', '', text).strip()
def generate(self, config: AgenticDataConfig) -> pd.DataFrame:
print(f"Starting advanced agentic data generation for task: {config.task_description}")
# Determine default provider and model
# Switch to Paperclip as it's locally available
provider_name = "paperclip"
model_name = "gpt-4o"
llm_model = dd.ModelConfig(
alias=config.model_alias,
model=model_name,
provider=provider_name,
inference_parameters=dd.ChatCompletionInferenceParams(
max_parallel_requests=1,
max_tokens=config.max_tokens
)
)
builder = dd.DataDesignerConfigBuilder(model_configs=[llm_model])
if config.scenarios_path and os.path.exists(config.scenarios_path):
print(f"Loading scenarios from: {config.scenarios_path}")
scenarios_df = pd.read_json(config.scenarios_path, orient="records", lines=True)
if "scenario" not in scenarios_df.columns:
raise ValueError(f"Input file {config.scenarios_path} must contain a 'scenario' column.")
# Use SeedDatasetColumnConfig to load existing scenarios
builder.add_column(
dd.SamplerColumnConfig(
name="task",
sampler_type="category",
params=dd.CategorySamplerParams(values=[config.task_description])
)
)
scenarios = scenarios_df["scenario"].tolist()[:config.num_records]
builder.add_column(
dd.SamplerColumnConfig(
name="scenario",
sampler_type="category",
params=dd.CategorySamplerParams(values=scenarios)
)
)
else:
# Add task description as a sampler column
builder.add_column(
dd.SamplerColumnConfig(
name="task",
sampler_type="category",
params=dd.CategorySamplerParams(values=[config.task_description])
)
)
# Phase 1: Brainstorming Scenarios
builder.add_column(
dd.LLMTextColumnConfig(
name="scenario",
model_alias=config.model_alias,
prompt="Brainstorm a highly complex and challenging scenario for the task: '{{ task }}'. Focus on realistic edge cases, multi-step logic, and potential pitfalls. DO NOT use search. DO NOT use citations. Output a detailed scenario description."
)
)
# Phase 1.1: Solvability & Constraint Verification
builder.add_column(
dd.LLMTextColumnConfig(
name="scenario_verification",
model_alias=config.model_alias,
prompt="Review the scenario: '{{ scenario }}'. Is it clearly defined and solvable without external information? Identify any ambiguities or missing constraints. Output 'VERIFIED' if good, or a list of required clarifications. NO citations."
)
)
# Phase 2: Instruction Generation
instruction_prompt = "Based on the scenario: '{{ scenario }}', create a natural language request that a user might make for the task: '{{ task }}'. Output ONLY the request text. NO citations."
if config.num_instructions_per_scenario > 1:
# In a real production system, we'd use a seed dataset expansion here.
# For simplicity in this script, we'll just generate one instruction,
# as DataDesigner processes row-by-row.
pass
builder.add_column(
dd.LLMTextColumnConfig(
name="instruction",
model_alias=config.model_alias,
prompt=instruction_prompt
)
)
# Phase 2.1: Reasoning Output
output_prompt = "Based on the instruction: '{{ instruction }}', provide the expected output for the task: '{{ task }}'. Output ONLY the direct answer/code, no conversational filler. NO citations."
if config.generate_reasoning:
output_prompt = "Based on the instruction: '{{ instruction }}', provide the expected output for the task: '{{ task }}'. Use the following format: <reasoning>STEP BY STEP REASONING HERE</reasoning><answer>DIRECT ANSWER HERE</answer>. Ensure the reasoning is rigorous, comprehensive, and logically flawless."
builder.add_column(
dd.LLMTextColumnConfig(
name="initial_output",
model_alias=config.model_alias,
prompt=output_prompt
)
)
# Phase 2.2: Critique (Expert Review)
builder.add_column(
dd.LLMTextColumnConfig(
name="critique",
model_alias=config.model_alias,
prompt="Act as an expert reviewer. Critique the initial_output: '{{ initial_output }}' for the instruction: '{{ instruction }}' within scenario: '{{ scenario }}'. Identify any inaccuracies, logical gaps, mathematical errors, or formatting issues. Be extremely critical. DO NOT use search. DO NOT use citations."
)
)
# Phase 2.3: Refinement (Self-Correction)
format_instruction = "Use the following format: <reasoning>STEP BY STEP REASONING HERE</reasoning><answer>DIRECT ANSWER HERE</answer>." if config.generate_reasoning else "Output ONLY the direct answer/code, no conversational filler."
builder.add_column(
dd.LLMTextColumnConfig(
name="output",
model_alias=config.model_alias,
prompt="Based on the original instruction: '{{ instruction }}', the initial_output: '{{ initial_output }}', and the critique: '{{ critique }}', provide a final, verified, and highly accurate version of the output. " + format_instruction + " Ensure every logical step is explicit. NO citations."
)
)
# Phase 2.4: Rejected Generation (for DPO) - Targeted Failure
if config.generate_dpo:
rejected_prompt = "Based on the instruction: '{{ instruction }}' and the critique: '{{ critique }}', provide a response that is WRONG. Specifically, ignore one of the points from the critique or introduce a subtle logical error that a person might miss. " + format_instruction + " NO citations."
builder.add_column(
dd.LLMTextColumnConfig(
name="rejected",
model_alias=config.model_alias,
prompt=rejected_prompt
)
)
# Phase 3: Judging (LLM-as-a-Judge)
builder.add_column(
dd.LLMJudgeColumnConfig(
name="quality_score",
model_alias=config.model_alias,
prompt="Evaluate the final output: '{{ output }}' based on the instruction: '{{ instruction }}' and scenario: '{{ scenario }}'.",
scores=[
Score(
name="accuracy",
description="Is the output accurate and correct based on the instruction?",
options={1: "Incorrect", 2: "Partially correct / minor issues", 3: "Fully correct"}
),
Score(
name="reasoning",
description="Is the reasoning step-by-step and logically sound?",
options={1: "None/Poor", 2: "Decent but sparse", 3: "Rigorous and detailed"}
)
]
)
)
# Run creation
result = self.designer.create(config_builder=builder, num_records=config.num_records, dataset_name=config.name)
df = result.load_dataset()
# Post-process: Strip citations from all generated text columns
cols_to_strip = ["scenario", "instruction", "initial_output", "critique", "output", "scenario_verification"]
if config.generate_dpo:
cols_to_strip.append("rejected")
for col in cols_to_strip:
if col in df.columns:
df[col] = df[col].apply(self.strip_citations)
# Phase 4: Filtering
if "quality_score" in df.columns:
def extract_score(val, key="accuracy"):
if isinstance(val, dict) and key in val:
return val[key].get("score", 0)
return 0
df["accuracy_score"] = df["quality_score"].apply(lambda x: extract_score(x, "accuracy"))
df["reasoning_score"] = df["quality_score"].apply(lambda x: extract_score(x, "reasoning"))
print("Quality Scores (Accuracy):", df["accuracy_score"].tolist())
print("Reasoning Scores:", df["reasoning_score"].tolist())
# Save raw before filtering
df.to_json("raw_" + config.output_path, orient="records", lines=True)
# Filter by accuracy AND reasoning if reasoning was requested
if config.generate_reasoning:
filtered_df = df[(df["accuracy_score"] >= config.min_quality_score) & (df["reasoning_score"] >= 2)].copy()
else:
filtered_df = df[df["accuracy_score"] >= config.min_quality_score].copy()
print(f"Filtered dataset: {len(filtered_df)}/{len(df)} records passed quality threshold.")
df = filtered_df
# Save to JSONL
df.to_json(config.output_path, orient="records", lines=True)
print(f"Advanced agentic synthetic data saved to {config.output_path}")
return df
def format_for_qwen(self, df: pd.DataFrame) -> List[Dict[str, str]]:
"""Formats the dataframe into ChatML for Qwen training."""
chatml_data = []
for _, row in df.iterrows():
chatml_data.append({
"text": f"<|im_start|>user\n{row['instruction']}<|im_end|>\n<|im_start|>assistant\n{row['output']}<|im_end|>"
})
return chatml_data
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Agentic Synthetic Data Generation for Qwen Fine-tuning")
parser.add_argument("--task", type=str, default="SQL-to-Natural-Language conversion", help="Description of the task")
parser.add_argument("--scenarios", type=str, default=None, help="Path to JSONL with scenarios")
parser.add_argument("--num", type=int, default=2, help="Number of records to generate")
parser.add_argument("--output", type=str, default="agentic_synthetic_data.jsonl", help="Output path for the JSONL file")
parser.add_argument("--dpo", action="store_true", help="Generate rejected responses for DPO")
parser.add_argument("--reasoning", action="store_true", help="Generate <reasoning>...<answer> format")
parser.add_argument("--max-tokens", type=int, default=4096, help="Max tokens for generation")
args = parser.parse_args()
config = AgenticDataConfig(
num_records=args.num,
task_description=args.task,
scenarios_path=args.scenarios,
output_path=args.output,
generate_dpo=args.dpo,
generate_reasoning=args.reasoning,
max_tokens=args.max_tokens
)
generator = AgenticDataGenerator()
df = generator.generate(config)
if not df.empty:
print(f"Generated {len(df)} records.")
print("Sample record:")
print(df.iloc[0].to_dict())
else:
print("No records generated.")