Spaces:
Sleeping
Sleeping
| from pydantic import BaseModel, Field, validator | |
| from typing import Optional, Dict, Any, List, Union | |
| from enum import Enum | |
| class ActionType(str, Enum): | |
| """Valid action types in the multi-step workflow""" | |
| CLASSIFY = "classify" | |
| PRIORITIZE = "prioritize" | |
| DECIDE_STRATEGY = "decide_strategy" | |
| RESPOND = "respond" | |
| ESCALATE = "escalate" | |
| USE_TOOL = "use_tool" | |
| class StrategyType(str, Enum): | |
| """Valid strategy types for handling emails""" | |
| AUTO_RESOLVE = "auto_resolve" | |
| REQUEST_MORE_INFO = "request_more_info" | |
| OFFER_REFUND = "offer_refund" | |
| ESCALATE_TO_HUMAN = "escalate_to_human" | |
| class ToolType(str, Enum): | |
| """Available tools for agent use""" | |
| LOOKUP_CUSTOMER = "lookup_customer" | |
| SEARCH_HISTORY = "search_history" | |
| CHECK_POLICY = "check_policy" | |
| class ToolAction(BaseModel): | |
| """Tool usage action""" | |
| tool_type: ToolType | |
| parameters: Dict[str, Any] = Field(default_factory=dict) | |
| class Config: | |
| json_schema_extra = { | |
| "example": { | |
| "tool_type": "lookup_customer", | |
| "parameters": {"customer_id": "12345"} | |
| } | |
| } | |
| class ToolResult(BaseModel): | |
| """Result from tool execution""" | |
| tool_type: ToolType | |
| success: bool | |
| data: Dict[str, Any] = Field(default_factory=dict) | |
| error: Optional[str] = None | |
| class EmailObservation(BaseModel): | |
| """Enhanced observation representing incoming customer support email with workflow context""" | |
| email_id: str = Field(..., description="Unique email identifier") | |
| subject: str = Field(..., description="Email subject line") | |
| body: str = Field(..., description="Email body content") | |
| customer_history: str = Field(..., description="Summary of customer interaction history") | |
| step_count: int = Field(default=0, description="Current step in workflow (0-5)") | |
| workflow_step: str = Field(..., description="Current workflow step name") | |
| available_actions: List[str] = Field(..., description="List of valid action types for current step") | |
| available_tools: List[str] = Field(default_factory=list, description="List of available tools for agent use") | |
| previous_decisions: Dict[str, Any] = Field(default_factory=dict, description="Previous agent decisions in this episode") | |
| customer_sentiment: str = Field(..., description="Detected customer sentiment: positive, neutral, negative, angry") | |
| urgency_indicators: List[str] = Field(default_factory=list, description="Detected urgency indicators from email") | |
| tool_result: Optional[ToolResult] = Field(default=None, description="Result from last tool execution") | |
| class Config: | |
| json_schema_extra = { | |
| "example": { | |
| "email_id": "email_001", | |
| "subject": "Refund request - duplicate charge", | |
| "body": "I was charged twice. Please refund.", | |
| "customer_history": "Good customer, first complaint", | |
| "step_count": 0, | |
| "workflow_step": "classification", | |
| "available_actions": ["classify"], | |
| "previous_decisions": {}, | |
| "customer_sentiment": "neutral", | |
| "urgency_indicators": ["refund", "immediately"] | |
| } | |
| } | |
| class EmailAction(BaseModel): | |
| """Enhanced action with action_type, content, and tool support for multi-step workflow""" | |
| action_type: ActionType = Field(..., description="Type of action being taken") | |
| content: Union[str, Dict[str, Any]] = Field(..., description="Action content (string for responses, dict for structured data)") | |
| tool_action: Optional[ToolAction] = Field(default=None, description="Tool action if using a tool") | |
| def validate_content(cls, v, values): | |
| """Validate content based on action_type""" | |
| if 'action_type' not in values: | |
| return v | |
| action_type = values['action_type'] | |
| if action_type == ActionType.CLASSIFY: | |
| if not isinstance(v, str) or v not in ["billing", "tech", "complaint", "spam"]: | |
| raise ValueError("Classification content must be one of: billing, tech, complaint, spam") | |
| elif action_type == ActionType.PRIORITIZE: | |
| if not isinstance(v, str) or v not in ["low", "medium", "high"]: | |
| raise ValueError("Priority content must be one of: low, medium, high") | |
| elif action_type == ActionType.DECIDE_STRATEGY: | |
| if not isinstance(v, str) or v not in [s.value for s in StrategyType]: | |
| raise ValueError(f"Strategy content must be one of: {[s.value for s in StrategyType]}") | |
| elif action_type == ActionType.RESPOND: | |
| if not isinstance(v, str) or len(v.strip()) < 10: | |
| raise ValueError("Response content must be string with at least 10 characters") | |
| elif action_type == ActionType.ESCALATE: | |
| if not isinstance(v, dict) or 'reason' not in v: | |
| raise ValueError("Escalation content must be dict with 'reason' key") | |
| elif action_type == ActionType.USE_TOOL: | |
| pass # Free-form content for tool usage | |
| return v | |
| class Config: | |
| json_schema_extra = { | |
| "example": { | |
| "action_type": "classify", | |
| "content": "billing" | |
| } | |
| } | |
| class EmailState(BaseModel): | |
| """Enhanced state tracking workflow progress and decisions""" | |
| episode_id: str = Field(..., description="Unique episode identifier") | |
| step_count: int = Field(default=0, description="Number of steps taken (0-5)") | |
| done: bool = Field(default=False, description="Whether episode is complete") | |
| current_email: Optional[str] = Field(default=None, description="Current email ID being processed") | |
| total_reward: float = Field(default=0.0, description="Cumulative episode reward") | |
| # Workflow state | |
| classification: Optional[str] = Field(default=None, description="Agent's classification decision") | |
| priority: Optional[str] = Field(default=None, description="Agent's priority decision") | |
| strategy: Optional[str] = Field(default=None, description="Agent's strategy decision") | |
| response: Optional[str] = Field(default=None, description="Agent's response text") | |
| escalation: Optional[Dict[str, Any]] = Field(default=None, description="Escalation decision if taken") | |
| # Validation state | |
| invalid_actions: int = Field(default=0, description="Count of invalid actions taken") | |
| workflow_completed: bool = Field(default=False, description="Whether full workflow was completed") | |
| class Config: | |
| json_schema_extra = { | |
| "example": { | |
| "episode_id": "ep-123-456", | |
| "step_count": 4, | |
| "done": False, | |
| "current_email": "email_001", | |
| "total_reward": 0.65, | |
| "classification": "billing", | |
| "priority": "high", | |
| "strategy": "auto_resolve", | |
| "response": "Thank you for reporting...", | |
| "escalation": None, | |
| "invalid_actions": 0, | |
| "workflow_completed": False | |
| } | |
| } | |
| class StepReturn(BaseModel): | |
| """Return value from step() method with enhanced info""" | |
| observation: EmailObservation = Field(..., description="New observation") | |
| reward: float = Field(..., description="Reward for this step (incremental)") | |
| done: bool = Field(..., description="Whether episode is complete") | |
| info: Dict[str, Any] = Field(default_factory=dict, description="Additional info and score breakdown") | |
| step_reward_breakdown: Dict[str, float] = Field(default_factory=dict, description="Breakdown of reward components for this step") | |
| class ResetReturn(BaseModel): | |
| """Return value from reset() method""" | |
| observation: EmailObservation = Field(..., description="Initial observation for new episode") | |
| info: Dict[str, Any] = Field(default_factory=dict, description="Metadata about episode") | |
| class WorkflowStep: | |
| """Constants for workflow steps""" | |
| CLASSIFICATION = "classification" | |
| PRIORITIZATION = "prioritization" | |
| STRATEGY_DECISION = "strategy_decision" | |
| RESPONSE_GENERATION = "response_generation" | |
| ESCALATION_DECISION = "escalation_decision" | |
| COMPLETED = "completed" | |
| class RewardWeights: | |
| """Constants for reward calculation""" | |
| CLASSIFICATION_WEIGHT = 0.3 | |
| PRIORITY_WEIGHT = 0.2 | |
| STRATEGY_WEIGHT = 0.2 | |
| RESPONSE_WEIGHT = 0.2 | |
| ESCALATION_WEIGHT = 0.1 | |
| # Response quality sub-weights | |
| RESPONSE_LENGTH_WEIGHT = 0.4 | |
| RESPONSE_POLITENESS_WEIGHT = 0.3 | |
| RESPONSE_RELEVANCE_WEIGHT = 0.2 | |
| RESPONSE_MEMORY_WEIGHT = 0.1 # Bonus for using customer history | |
| # Penalties | |
| INVALID_ACTION_PENALTY = -0.1 | |