agentic-market-research / src /workflows /market_analysis.py
pkgprateek's picture
feat(f1): Research Type Selection (#10)
c895509 unverified
"""Main LangGraph workflow for market intelligence."""
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
from langgraph.checkpoint.memory import MemorySaver
from src.workflows.types import IntelligenceState, ResearchType
from src.agents.researcher import ResearchAgent
from src.agents.analyst import AnalysisAgent
from src.agents.writer import WriterAgent
from src.utils.cost_tracker import CostTracker, BudgetExceededError
from src.utils.logging import setup_logger
logger = setup_logger(__name__)
class MarketIntelligenceWorkflow:
"""
LangGraph workflow orchestrating research, analysis, and writing agents.
Features:
- Multi-agent coordination
- State persistence with checkpointing
- Cost tracking and budget enforcement
- Human-in-the-loop approval
- Error recovery
"""
def __init__(
self,
checkpoint_path: str = "./checkpoints.db",
max_budget: float = 2.0,
model_name: str | None = None,
):
"""
Initialize workflow.
Args:
checkpoint_path: Path to SQLite checkpoint database
max_budget: Maximum cost per run in USD
model_name: Name of the LLM model to use
"""
self.max_budget = max_budget
self.cost_tracker = CostTracker()
self.checkpoint_path = checkpoint_path
self.model_name = model_name
# Initialize agents (shared cost tracker)
self.research_agent = ResearchAgent(
cost_tracker=self.cost_tracker, model=model_name
)
self.analysis_agent = AnalysisAgent(
cost_tracker=self.cost_tracker, model=model_name
)
self.writer_agent = WriterAgent(
cost_tracker=self.cost_tracker, model=model_name
)
# Build workflow graph blueprint
self.graph_builder = self._build_graph()
logger.info("Market Intelligence Workflow initialized")
def _build_graph(self) -> StateGraph:
"""Build LangGraph workflow."""
# Initialize graph
graph = StateGraph(IntelligenceState)
# Add nodes (agent wrappers)
graph.add_node("research", self._research_node)
graph.add_node("analysis", self._analysis_node)
graph.add_node("writing", self._writing_node)
graph.add_node("human_review", self._human_review_node)
# Set entry point
graph.set_entry_point("research")
# Add edges
graph.add_conditional_edges(
"research",
self._should_continue_to_analysis,
{"analysis": "analysis", "end": END},
)
graph.add_edge("analysis", "writing")
graph.add_edge("writing", "human_review")
graph.add_conditional_edges(
"human_review",
self._check_approval,
{"approved": END, "revise": "research", "max_revisions": END},
)
return graph
async def _research_node(self, state: IntelligenceState) -> dict:
"""Research agent node."""
logger.info(f"Research node: {state['company_name']}")
try:
# Run research agent
research_results = await self.research_agent.run(
company_name=state["company_name"],
industry=state.get("industry"),
research_depth=state.get("research_depth", "comprehensive"),
)
# Update state
return {
"current_agent": "research",
"research_data": research_results,
"competitors": research_results.get("competitors", ""),
"market_trends": research_results.get("market_trends", ""),
"raw_sources": research_results.get("raw_sources", []),
"iteration": state.get("iteration", 0) + 1,
}
except Exception as e:
logger.error(f"Research node failed: {e}")
return {
"errors": [f"Research failed: {str(e)}"],
"current_agent": "research",
}
async def _analysis_node(self, state: IntelligenceState) -> dict:
"""Analysis agent node."""
logger.info(f"Analysis node: {state['company_name']}")
try:
# Check budget before expensive analysis
self.cost_tracker.check_budget(self.max_budget)
# Run analysis agent
analysis_results = await self.analysis_agent.run(
research_data=state["research_data"]
)
# Update state
return {
"current_agent": "analysis",
"swot": analysis_results.get("swot", ""),
"competitive_matrix": analysis_results.get("competitive_matrix", ""),
"positioning": analysis_results.get("positioning", ""),
"strategic_recommendations": analysis_results.get(
"strategic_recommendations", ""
),
}
except BudgetExceededError as e:
logger.error(f"Budget exceeded: {e}")
return {
"errors": [f"Budget exceeded: {str(e)}"],
"current_agent": "analysis",
}
except Exception as e:
logger.error(f"Analysis node failed: {e}")
return {
"errors": [f"Analysis failed: {str(e)}"],
"current_agent": "analysis",
}
async def _writing_node(self, state: IntelligenceState) -> dict:
"""Writer agent node."""
logger.info(f"Writing node: {state['company_name']}")
try:
# Run writer agent
report_results = await self.writer_agent.run(
research_data=state["research_data"],
analysis_data={
"company_name": state["company_name"],
"swot": state.get("swot", ""),
"competitive_matrix": state.get("competitive_matrix", ""),
"positioning": state.get("positioning", ""),
"strategic_recommendations": state.get(
"strategic_recommendations", ""
),
},
)
# Get cost summary
cost_summary = self.cost_tracker.get_summary()
# Update state
return {
"current_agent": "writing",
"executive_summary": report_results.get("executive_summary", ""),
"full_report": report_results.get("full_report", ""),
"report_metadata": report_results.get("metadata", {}),
"total_cost": cost_summary["total_cost"],
"total_tokens": cost_summary["total_tokens"],
}
except Exception as e:
logger.error(f"Writing node failed: {e}")
return {
"errors": [f"Writing failed: {str(e)}"],
"current_agent": "writing",
}
async def _human_review_node(self, state: IntelligenceState) -> dict:
"""Human review node (placeholder for now)."""
logger.info(f"Human review node: {state['company_name']}")
# For now, auto-approve
# In Phase 5, this will connect to the Gradio UI
return {
"current_agent": "human_review",
"approved": True, # Auto-approve for testing
"human_feedback": None,
}
def _should_continue_to_analysis(self, state: IntelligenceState) -> str:
"""Decide whether to continue to analysis or end."""
# Check if research was successful
if state.get("errors") and state["errors"]:
logger.warning("Research had errors, ending workflow")
return "end"
if not state.get("research_data"):
logger.warning("No research data, ending workflow")
return "end"
return "analysis"
def _check_approval(self, state: IntelligenceState) -> str:
"""Check if report is approved or needs revision."""
# Check max revisions
revision_count = state.get("revision_count", 0)
if revision_count >= 2:
logger.warning("Max revisions reached")
return "max_revisions"
# Check approval
if state.get("approved"):
return "approved"
# Revision requested
if state.get("human_feedback"):
return "revise"
# Default to approved
return "approved"
async def run(
self,
company_name: str,
industry: str | None = None,
thread_id: str | None = None,
research_depth: str = "comprehensive",
research_type: ResearchType = ResearchType.COMPANY_ANALYSIS,
) -> dict:
"""
Run the complete workflow.
Args:
company_name: Target company name
industry: Optional industry context
thread_id: Optional thread ID for checkpointing
Returns:
Final state dictionary
"""
logger.info(f"Starting workflow for: {company_name}")
# Initial state
initial_state: IntelligenceState = {
"research_type": research_type,
"company_name": company_name,
"industry": industry,
"research_depth": research_depth,
"research_data": {
"company_name": company_name,
"industry": industry,
"company_overview": "",
"competitors": "",
"market_trends": "",
"raw_sources": [],
},
"competitors": "",
"market_trends": "",
"raw_sources": [],
"swot": "",
"competitive_matrix": "",
"positioning": "",
"strategic_recommendations": "",
"executive_summary": "",
"full_report": "",
"report_metadata": {},
"current_agent": "research",
"iteration": 0,
"total_cost": 0.0,
"total_tokens": 0,
"errors": [],
"human_feedback": None,
"approved": False,
"revision_count": 0,
}
# Run workflow with async checkpointer
config = {"configurable": {"thread_id": thread_id or "default"}}
try:
if self.checkpoint_path == ":memory:":
memory_checkpointer = MemorySaver()
workflow = self.graph_builder.compile(checkpointer=memory_checkpointer)
final_state = await workflow.ainvoke(initial_state, config) # type: ignore[arg-type]
else:
async with AsyncSqliteSaver.from_conn_string(
self.checkpoint_path
) as checkpointer:
workflow = self.graph_builder.compile(checkpointer=checkpointer)
final_state = await workflow.ainvoke(initial_state, config) # type: ignore[arg-type]
logger.info(f"Workflow complete. Cost: ${final_state['total_cost']:.4f}")
return final_state
except Exception as e:
logger.error(f"Workflow failed: {e}")
raise