Commit Β·
7219c67
1
Parent(s): 6240099
feat: Add Text Story module with iMessage-style chat video generation
Browse files- modules/bar_race/__init__.py +0 -35
- modules/bar_race/assets/fonts/.gitkeep +0 -1
- modules/bar_race/assets/images/.gitkeep +0 -1
- modules/bar_race/assets/music/.gitkeep +0 -3
- modules/bar_race/deep_researcher/__init__.py +0 -13
- modules/bar_race/deep_researcher/configuration.py +0 -81
- modules/bar_race/deep_researcher/graph.py +0 -456
- modules/bar_race/deep_researcher/prompts.py +0 -112
- modules/bar_race/deep_researcher/rate_limiter.py +0 -120
- modules/bar_race/deep_researcher/state.py +0 -25
- modules/bar_race/deep_researcher/utils.py +0 -333
- modules/bar_race/router.py +0 -265
- modules/bar_race/schemas.py +0 -59
- modules/bar_race/services/__init__.py +0 -1
- modules/bar_race/services/analyst.py +0 -517
- modules/bar_race/services/artist.py +0 -301
- modules/bar_race/services/brain.py +0 -365
- modules/bar_race/services/director.py +0 -438
- modules/text_story/__init__.py +66 -0
- modules/text_story/router.py +344 -0
- modules/text_story/schemas.py +69 -0
- modules/text_story/services/__init__.py +1 -0
- modules/text_story/services/background.py +231 -0
- modules/text_story/services/renderer.py +295 -0
- modules/text_story/services/tts_handler.py +134 -0
- modules/text_story/services/video_composer.py +236 -0
- requirements.txt +0 -12
- static/index.html +263 -47
modules/bar_race/__init__.py
DELETED
|
@@ -1,35 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Bar Race Module
|
| 3 |
-
Intelligent Bar Chart Race Video Generator.
|
| 4 |
-
|
| 5 |
-
Architecture:
|
| 6 |
-
- Brain: LLM Planner (Gemini)
|
| 7 |
-
- Scout: Data Fetcher (APIs + Scraping)
|
| 8 |
-
- Surgeon: Data Cleaner
|
| 9 |
-
- Artist: Image Processor
|
| 10 |
-
- Director: Video Generator
|
| 11 |
-
|
| 12 |
-
100% standalone - no dependency on other modules.
|
| 13 |
-
"""
|
| 14 |
-
import logging
|
| 15 |
-
from fastapi import FastAPI
|
| 16 |
-
|
| 17 |
-
logger = logging.getLogger(__name__)
|
| 18 |
-
|
| 19 |
-
# Module metadata for auto-discovery
|
| 20 |
-
MODULE_NAME = "bar_race"
|
| 21 |
-
MODULE_PREFIX = "/api/bar-race"
|
| 22 |
-
MODULE_DESCRIPTION = "Bar Chart Race Video Generator"
|
| 23 |
-
|
| 24 |
-
_app = None
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def register(app: FastAPI, config=None):
|
| 28 |
-
"""Register Bar Race module routes"""
|
| 29 |
-
global _app
|
| 30 |
-
_app = app
|
| 31 |
-
|
| 32 |
-
from .router import router
|
| 33 |
-
app.include_router(router, prefix="/api/bar-race", tags=["Bar Race"])
|
| 34 |
-
|
| 35 |
-
logger.info("Bar Race module registered at /api/bar-race")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/bar_race/assets/fonts/.gitkeep
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
# Custom fonts for video rendering
|
|
|
|
|
|
modules/bar_race/assets/images/.gitkeep
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
# Entity images will be downloaded here during video generation
|
|
|
|
|
|
modules/bar_race/assets/music/.gitkeep
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
# Optional background music files
|
| 2 |
-
# Supported formats: .mp3, .wav, .m4a, .ogg
|
| 3 |
-
# Music will be automatically added if files exist here
|
|
|
|
|
|
|
|
|
|
|
|
modules/bar_race/deep_researcher/__init__.py
DELETED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
"""Deep Researcher Module - LangGraph-based web research agent"""
|
| 2 |
-
|
| 3 |
-
from .graph import graph
|
| 4 |
-
from .state import SummaryState, SummaryStateInput, SummaryStateOutput
|
| 5 |
-
from .configuration import Configuration
|
| 6 |
-
|
| 7 |
-
__all__ = [
|
| 8 |
-
"graph",
|
| 9 |
-
"SummaryState",
|
| 10 |
-
"SummaryStateInput",
|
| 11 |
-
"SummaryStateOutput",
|
| 12 |
-
"Configuration",
|
| 13 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/bar_race/deep_researcher/configuration.py
DELETED
|
@@ -1,81 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
from enum import Enum
|
| 3 |
-
from pydantic import BaseModel, Field
|
| 4 |
-
from typing import Any, Optional, Literal
|
| 5 |
-
|
| 6 |
-
from langchain_core.runnables import RunnableConfig
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
class SearchAPI(Enum):
|
| 10 |
-
PERPLEXITY = "perplexity"
|
| 11 |
-
TAVILY = "tavily"
|
| 12 |
-
DUCKDUCKGO = "duckduckgo"
|
| 13 |
-
# SEARXNG removed - requires langchain-community
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class Configuration(BaseModel):
|
| 17 |
-
"""The configurable fields for the research assistant."""
|
| 18 |
-
|
| 19 |
-
max_web_research_loops: int = Field(
|
| 20 |
-
default=3,
|
| 21 |
-
title="Research Depth",
|
| 22 |
-
description="Number of research iterations to perform",
|
| 23 |
-
)
|
| 24 |
-
local_llm: str = Field(
|
| 25 |
-
default="llama3.2",
|
| 26 |
-
title="LLM Model Name",
|
| 27 |
-
description="Name of the LLM model to use",
|
| 28 |
-
)
|
| 29 |
-
llm_provider: Literal["ollama", "lmstudio"] = Field(
|
| 30 |
-
default="ollama",
|
| 31 |
-
title="LLM Provider",
|
| 32 |
-
description="Provider for the LLM (Ollama or LMStudio)",
|
| 33 |
-
)
|
| 34 |
-
search_api: Literal["perplexity", "tavily", "duckduckgo"] = Field(
|
| 35 |
-
default="tavily", title="Search API", description="Web search API to use"
|
| 36 |
-
)
|
| 37 |
-
fetch_full_page: bool = Field(
|
| 38 |
-
default=True,
|
| 39 |
-
title="Fetch Full Page",
|
| 40 |
-
description="Include the full page content in the search results",
|
| 41 |
-
)
|
| 42 |
-
ollama_base_url: str = Field(
|
| 43 |
-
default="http://localhost:11434/",
|
| 44 |
-
title="Ollama Base URL",
|
| 45 |
-
description="Base URL for Ollama API",
|
| 46 |
-
)
|
| 47 |
-
lmstudio_base_url: str = Field(
|
| 48 |
-
default="http://localhost:1234/v1",
|
| 49 |
-
title="LMStudio Base URL",
|
| 50 |
-
description="Base URL for LMStudio OpenAI-compatible API",
|
| 51 |
-
)
|
| 52 |
-
strip_thinking_tokens: bool = Field(
|
| 53 |
-
default=True,
|
| 54 |
-
title="Strip Thinking Tokens",
|
| 55 |
-
description="Whether to strip <think> tokens from model responses",
|
| 56 |
-
)
|
| 57 |
-
use_tool_calling: bool = Field(
|
| 58 |
-
default=False,
|
| 59 |
-
title="Use Tool Calling",
|
| 60 |
-
description="Use tool calling instead of JSON mode for structured output",
|
| 61 |
-
)
|
| 62 |
-
|
| 63 |
-
@classmethod
|
| 64 |
-
def from_runnable_config(
|
| 65 |
-
cls, config: Optional[RunnableConfig] = None
|
| 66 |
-
) -> "Configuration":
|
| 67 |
-
"""Create a Configuration instance from a RunnableConfig."""
|
| 68 |
-
configurable = (
|
| 69 |
-
config["configurable"] if config and "configurable" in config else {}
|
| 70 |
-
)
|
| 71 |
-
|
| 72 |
-
# Get raw values from environment or config
|
| 73 |
-
raw_values: dict[str, Any] = {
|
| 74 |
-
name: os.environ.get(name.upper(), configurable.get(name))
|
| 75 |
-
for name in cls.model_fields.keys()
|
| 76 |
-
}
|
| 77 |
-
|
| 78 |
-
# Filter out None values
|
| 79 |
-
values = {k: v for k, v in raw_values.items() if v is not None}
|
| 80 |
-
|
| 81 |
-
return cls(**values)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/bar_race/deep_researcher/graph.py
DELETED
|
@@ -1,456 +0,0 @@
|
|
| 1 |
-
import json
|
| 2 |
-
import os
|
| 3 |
-
|
| 4 |
-
from pydantic import BaseModel, Field
|
| 5 |
-
from typing_extensions import Literal
|
| 6 |
-
|
| 7 |
-
from langchain_core.messages import HumanMessage, SystemMessage
|
| 8 |
-
from langchain_core.runnables import RunnableConfig
|
| 9 |
-
from langchain_core.tools import tool
|
| 10 |
-
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 11 |
-
from langgraph.graph import START, END, StateGraph
|
| 12 |
-
|
| 13 |
-
from .configuration import Configuration, SearchAPI
|
| 14 |
-
from .utils import (
|
| 15 |
-
deduplicate_and_format_sources,
|
| 16 |
-
tavily_search,
|
| 17 |
-
format_sources,
|
| 18 |
-
perplexity_search,
|
| 19 |
-
duckduckgo_search,
|
| 20 |
-
strip_thinking_tokens,
|
| 21 |
-
get_config_value,
|
| 22 |
-
)
|
| 23 |
-
from .state import (
|
| 24 |
-
SummaryState,
|
| 25 |
-
SummaryStateInput,
|
| 26 |
-
SummaryStateOutput,
|
| 27 |
-
)
|
| 28 |
-
from .prompts import (
|
| 29 |
-
query_writer_instructions,
|
| 30 |
-
summarizer_instructions,
|
| 31 |
-
reflection_instructions,
|
| 32 |
-
get_current_date,
|
| 33 |
-
json_mode_query_instructions,
|
| 34 |
-
tool_calling_query_instructions,
|
| 35 |
-
json_mode_reflection_instructions,
|
| 36 |
-
tool_calling_reflection_instructions,
|
| 37 |
-
)
|
| 38 |
-
from .rate_limiter import gemini_rate_limiter
|
| 39 |
-
|
| 40 |
-
# Constants
|
| 41 |
-
MAX_TOKENS_PER_SOURCE = 1000
|
| 42 |
-
CHARS_PER_TOKEN = 4
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def convert_messages_for_gemma(messages: list) -> list:
|
| 46 |
-
"""
|
| 47 |
-
Convert SystemMessage to HumanMessage for Gemma model compatibility.
|
| 48 |
-
|
| 49 |
-
Gemma models don't support 'developer instructions' (SystemMessage) through
|
| 50 |
-
the LangChain interface. This function converts all SystemMessages to
|
| 51 |
-
HumanMessages with a clear instruction prefix.
|
| 52 |
-
|
| 53 |
-
Args:
|
| 54 |
-
messages: List of LangChain messages
|
| 55 |
-
|
| 56 |
-
Returns:
|
| 57 |
-
List of messages with SystemMessages converted to HumanMessages
|
| 58 |
-
"""
|
| 59 |
-
converted = []
|
| 60 |
-
for msg in messages:
|
| 61 |
-
if isinstance(msg, SystemMessage):
|
| 62 |
-
# Convert SystemMessage to HumanMessage with instruction prefix
|
| 63 |
-
converted.append(HumanMessage(
|
| 64 |
-
content=f"[INSTRUCTIONS]\n{msg.content}\n[/INSTRUCTIONS]"
|
| 65 |
-
))
|
| 66 |
-
else:
|
| 67 |
-
converted.append(msg)
|
| 68 |
-
return converted
|
| 69 |
-
|
| 70 |
-
def generate_search_query_with_structured_output(
|
| 71 |
-
configurable: Configuration,
|
| 72 |
-
messages: list,
|
| 73 |
-
tool_class,
|
| 74 |
-
fallback_query: str,
|
| 75 |
-
tool_query_field: str,
|
| 76 |
-
json_query_field: str,
|
| 77 |
-
):
|
| 78 |
-
"""Helper function to generate search queries using either tool calling or JSON mode.
|
| 79 |
-
|
| 80 |
-
Args:
|
| 81 |
-
configurable: Configuration object
|
| 82 |
-
messages: List of messages to send to LLM
|
| 83 |
-
tool_class: Tool class for tool calling mode
|
| 84 |
-
fallback_query: Fallback search query if extraction fails
|
| 85 |
-
tool_query_field: Field name in tool args containing the query
|
| 86 |
-
json_query_field: Field name in JSON response containing the query
|
| 87 |
-
|
| 88 |
-
Returns:
|
| 89 |
-
Dictionary with "search_query" key
|
| 90 |
-
"""
|
| 91 |
-
# Convert messages for Gemma compatibility (no SystemMessage)
|
| 92 |
-
messages = convert_messages_for_gemma(messages)
|
| 93 |
-
|
| 94 |
-
if configurable.use_tool_calling:
|
| 95 |
-
llm = get_llm(configurable).bind_tools([tool_class])
|
| 96 |
-
gemini_rate_limiter.acquire() # Rate limit before API call
|
| 97 |
-
result = llm.invoke(messages)
|
| 98 |
-
|
| 99 |
-
if not result.tool_calls:
|
| 100 |
-
return {"search_query": fallback_query}
|
| 101 |
-
|
| 102 |
-
try:
|
| 103 |
-
tool_data = result.tool_calls[0]["args"]
|
| 104 |
-
search_query = tool_data.get(tool_query_field)
|
| 105 |
-
return {"search_query": search_query}
|
| 106 |
-
except (IndexError, KeyError):
|
| 107 |
-
return {"search_query": fallback_query}
|
| 108 |
-
|
| 109 |
-
else:
|
| 110 |
-
# Use JSON mode
|
| 111 |
-
llm = get_llm(configurable)
|
| 112 |
-
gemini_rate_limiter.acquire() # Rate limit before API call
|
| 113 |
-
result = llm.invoke(messages)
|
| 114 |
-
print(f"result: {result}")
|
| 115 |
-
content = result.content
|
| 116 |
-
|
| 117 |
-
try:
|
| 118 |
-
parsed_json = json.loads(content)
|
| 119 |
-
search_query = parsed_json.get(json_query_field)
|
| 120 |
-
if not search_query:
|
| 121 |
-
return {"search_query": fallback_query}
|
| 122 |
-
return {"search_query": search_query}
|
| 123 |
-
except (json.JSONDecodeError, KeyError):
|
| 124 |
-
if configurable.strip_thinking_tokens:
|
| 125 |
-
content = strip_thinking_tokens(content)
|
| 126 |
-
return {"search_query": fallback_query}
|
| 127 |
-
|
| 128 |
-
def get_llm(configurable: Configuration):
|
| 129 |
-
"""Helper function to initialize LLM based on configuration.
|
| 130 |
-
|
| 131 |
-
Uses Gemini API for all operations.
|
| 132 |
-
|
| 133 |
-
Args:
|
| 134 |
-
configurable: Configuration object containing LLM settings
|
| 135 |
-
|
| 136 |
-
Returns:
|
| 137 |
-
Configured LLM instance
|
| 138 |
-
"""
|
| 139 |
-
# Use Gemini for all providers
|
| 140 |
-
# Using gemma-3-27b-it for higher rate limits (30 req/min vs 10 req/min)
|
| 141 |
-
return ChatGoogleGenerativeAI(
|
| 142 |
-
model=os.getenv("GEMINI_MODEL", "gemma-3-27b-it"),
|
| 143 |
-
google_api_key=os.getenv("GEMINI_API_KEY"),
|
| 144 |
-
temperature=0,
|
| 145 |
-
)
|
| 146 |
-
|
| 147 |
-
# Nodes
|
| 148 |
-
def generate_query(state: SummaryState, config: RunnableConfig):
|
| 149 |
-
"""LangGraph node that generates a search query based on the research topic.
|
| 150 |
-
|
| 151 |
-
Uses an LLM to create an optimized search query for web research based on
|
| 152 |
-
the user's research topic. Supports both LMStudio and Ollama as LLM providers.
|
| 153 |
-
|
| 154 |
-
Args:
|
| 155 |
-
state: Current graph state containing the research topic
|
| 156 |
-
config: Configuration for the runnable, including LLM provider settings
|
| 157 |
-
|
| 158 |
-
Returns:
|
| 159 |
-
Dictionary with state update, including search_query key containing the generated query
|
| 160 |
-
"""
|
| 161 |
-
|
| 162 |
-
# Format the prompt
|
| 163 |
-
current_date = get_current_date()
|
| 164 |
-
formatted_prompt = query_writer_instructions.format(
|
| 165 |
-
current_date=current_date, research_topic=state.research_topic
|
| 166 |
-
)
|
| 167 |
-
|
| 168 |
-
# Generate a query
|
| 169 |
-
configurable = Configuration.from_runnable_config(config)
|
| 170 |
-
|
| 171 |
-
@tool
|
| 172 |
-
class Query(BaseModel):
|
| 173 |
-
"""
|
| 174 |
-
This tool is used to generate a query for web search.
|
| 175 |
-
"""
|
| 176 |
-
|
| 177 |
-
query: str = Field(description="The actual search query string")
|
| 178 |
-
rationale: str = Field(
|
| 179 |
-
description="Brief explanation of why this query is relevant"
|
| 180 |
-
)
|
| 181 |
-
|
| 182 |
-
messages = [
|
| 183 |
-
SystemMessage(
|
| 184 |
-
content=formatted_prompt + (
|
| 185 |
-
tool_calling_query_instructions if configurable.use_tool_calling
|
| 186 |
-
else json_mode_query_instructions
|
| 187 |
-
)
|
| 188 |
-
),
|
| 189 |
-
HumanMessage(content="Generate a query for web search:"),
|
| 190 |
-
]
|
| 191 |
-
|
| 192 |
-
return generate_search_query_with_structured_output(
|
| 193 |
-
configurable=configurable,
|
| 194 |
-
messages=messages,
|
| 195 |
-
tool_class=Query,
|
| 196 |
-
fallback_query=f"Tell me more about {state.research_topic}",
|
| 197 |
-
tool_query_field="query",
|
| 198 |
-
json_query_field="query",
|
| 199 |
-
)
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
def web_research(state: SummaryState, config: RunnableConfig):
|
| 203 |
-
"""LangGraph node that performs web research using the generated search query.
|
| 204 |
-
|
| 205 |
-
Executes a web search using the configured search API (tavily, perplexity,
|
| 206 |
-
duckduckgo, or searxng) and formats the results for further processing.
|
| 207 |
-
|
| 208 |
-
Args:
|
| 209 |
-
state: Current graph state containing the search query and research loop count
|
| 210 |
-
config: Configuration for the runnable, including search API settings
|
| 211 |
-
|
| 212 |
-
Returns:
|
| 213 |
-
Dictionary with state update, including sources_gathered, research_loop_count, and web_research_results
|
| 214 |
-
"""
|
| 215 |
-
|
| 216 |
-
# Configure
|
| 217 |
-
configurable = Configuration.from_runnable_config(config)
|
| 218 |
-
|
| 219 |
-
# Get the search API
|
| 220 |
-
search_api = get_config_value(configurable.search_api)
|
| 221 |
-
|
| 222 |
-
# Search the web
|
| 223 |
-
if search_api == "tavily":
|
| 224 |
-
search_results = tavily_search(
|
| 225 |
-
state.search_query,
|
| 226 |
-
fetch_full_page=configurable.fetch_full_page,
|
| 227 |
-
max_results=1,
|
| 228 |
-
)
|
| 229 |
-
search_str = deduplicate_and_format_sources(
|
| 230 |
-
search_results,
|
| 231 |
-
max_tokens_per_source=MAX_TOKENS_PER_SOURCE,
|
| 232 |
-
fetch_full_page=configurable.fetch_full_page,
|
| 233 |
-
)
|
| 234 |
-
elif search_api == "perplexity":
|
| 235 |
-
search_results = perplexity_search(
|
| 236 |
-
state.search_query, state.research_loop_count
|
| 237 |
-
)
|
| 238 |
-
search_str = deduplicate_and_format_sources(
|
| 239 |
-
search_results,
|
| 240 |
-
max_tokens_per_source=MAX_TOKENS_PER_SOURCE,
|
| 241 |
-
fetch_full_page=configurable.fetch_full_page,
|
| 242 |
-
)
|
| 243 |
-
elif search_api == "duckduckgo":
|
| 244 |
-
search_results = duckduckgo_search(
|
| 245 |
-
state.search_query,
|
| 246 |
-
max_results=3,
|
| 247 |
-
fetch_full_page=configurable.fetch_full_page,
|
| 248 |
-
)
|
| 249 |
-
search_str = deduplicate_and_format_sources(
|
| 250 |
-
search_results,
|
| 251 |
-
max_tokens_per_source=MAX_TOKENS_PER_SOURCE,
|
| 252 |
-
fetch_full_page=configurable.fetch_full_page,
|
| 253 |
-
)
|
| 254 |
-
# Note: searxng removed - use tavily or duckduckgo instead
|
| 255 |
-
else:
|
| 256 |
-
raise ValueError(f"Unsupported search API: {configurable.search_api}")
|
| 257 |
-
|
| 258 |
-
return {
|
| 259 |
-
"sources_gathered": [format_sources(search_results)],
|
| 260 |
-
"research_loop_count": state.research_loop_count + 1,
|
| 261 |
-
"web_research_results": [search_str],
|
| 262 |
-
}
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
def summarize_sources(state: SummaryState, config: RunnableConfig):
|
| 266 |
-
"""LangGraph node that summarizes web research results.
|
| 267 |
-
|
| 268 |
-
Uses an LLM to create or update a running summary based on the newest web research
|
| 269 |
-
results, integrating them with any existing summary.
|
| 270 |
-
|
| 271 |
-
Args:
|
| 272 |
-
state: Current graph state containing research topic, running summary,
|
| 273 |
-
and web research results
|
| 274 |
-
config: Configuration for the runnable, including LLM provider settings
|
| 275 |
-
|
| 276 |
-
Returns:
|
| 277 |
-
Dictionary with state update, including running_summary key containing the updated summary
|
| 278 |
-
"""
|
| 279 |
-
|
| 280 |
-
# Existing summary
|
| 281 |
-
existing_summary = state.running_summary
|
| 282 |
-
|
| 283 |
-
# Most recent web research
|
| 284 |
-
most_recent_web_research = state.web_research_results[-1]
|
| 285 |
-
|
| 286 |
-
# Build the human message
|
| 287 |
-
if existing_summary:
|
| 288 |
-
human_message_content = (
|
| 289 |
-
f"<Existing Summary> \n {existing_summary} \n <Existing Summary>\n\n"
|
| 290 |
-
f"<New Context> \n {most_recent_web_research} \n <New Context>"
|
| 291 |
-
f"Update the Existing Summary with the New Context on this topic: \n <User Input> \n {state.research_topic} \n <User Input>\n\n"
|
| 292 |
-
)
|
| 293 |
-
else:
|
| 294 |
-
human_message_content = (
|
| 295 |
-
f"<Context> \n {most_recent_web_research} \n <Context>"
|
| 296 |
-
f"Create a Summary using the Context on this topic: \n <User Input> \n {state.research_topic} \n <User Input>\n\n"
|
| 297 |
-
)
|
| 298 |
-
|
| 299 |
-
# Run the LLM
|
| 300 |
-
configurable = Configuration.from_runnable_config(config)
|
| 301 |
-
|
| 302 |
-
# Use Gemini via get_llm helper
|
| 303 |
-
llm = get_llm(configurable)
|
| 304 |
-
|
| 305 |
-
# Build messages and convert for Gemma compatibility
|
| 306 |
-
messages = convert_messages_for_gemma([
|
| 307 |
-
SystemMessage(content=summarizer_instructions),
|
| 308 |
-
HumanMessage(content=human_message_content),
|
| 309 |
-
])
|
| 310 |
-
|
| 311 |
-
gemini_rate_limiter.acquire() # Rate limit before API call
|
| 312 |
-
result = llm.invoke(messages)
|
| 313 |
-
|
| 314 |
-
# Strip thinking tokens if configured
|
| 315 |
-
running_summary = result.content
|
| 316 |
-
if configurable.strip_thinking_tokens:
|
| 317 |
-
running_summary = strip_thinking_tokens(running_summary)
|
| 318 |
-
|
| 319 |
-
return {"running_summary": running_summary}
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
def reflect_on_summary(state: SummaryState, config: RunnableConfig):
|
| 323 |
-
"""LangGraph node that identifies knowledge gaps and generates follow-up queries.
|
| 324 |
-
|
| 325 |
-
Analyzes the current summary to identify areas for further research and generates
|
| 326 |
-
a new search query to address those gaps. Uses structured output to extract
|
| 327 |
-
the follow-up query in JSON format.
|
| 328 |
-
|
| 329 |
-
Args:
|
| 330 |
-
state: Current graph state containing the running summary and research topic
|
| 331 |
-
config: Configuration for the runnable, including LLM provider settings
|
| 332 |
-
|
| 333 |
-
Returns:
|
| 334 |
-
Dictionary with state update, including search_query key containing the generated follow-up query
|
| 335 |
-
"""
|
| 336 |
-
|
| 337 |
-
# Generate a query
|
| 338 |
-
configurable = Configuration.from_runnable_config(config)
|
| 339 |
-
formatted_prompt = reflection_instructions.format(
|
| 340 |
-
research_topic=state.research_topic
|
| 341 |
-
)
|
| 342 |
-
|
| 343 |
-
@tool
|
| 344 |
-
class FollowUpQuery(BaseModel):
|
| 345 |
-
"""
|
| 346 |
-
This tool is used to generate a follow-up query to address a knowledge gap.
|
| 347 |
-
"""
|
| 348 |
-
|
| 349 |
-
follow_up_query: str = Field(
|
| 350 |
-
description="Write a specific question to address this gap"
|
| 351 |
-
)
|
| 352 |
-
knowledge_gap: str = Field(
|
| 353 |
-
description="Describe what information is missing or needs clarification"
|
| 354 |
-
)
|
| 355 |
-
|
| 356 |
-
messages = [
|
| 357 |
-
SystemMessage(
|
| 358 |
-
content=formatted_prompt + (
|
| 359 |
-
tool_calling_reflection_instructions if configurable.use_tool_calling
|
| 360 |
-
else json_mode_reflection_instructions
|
| 361 |
-
)
|
| 362 |
-
),
|
| 363 |
-
HumanMessage(
|
| 364 |
-
content=f"Reflect on our existing knowledge: \n === \n {state.running_summary}, \n === \n And now identify a knowledge gap and generate a follow-up web search query:"
|
| 365 |
-
),
|
| 366 |
-
]
|
| 367 |
-
|
| 368 |
-
return generate_search_query_with_structured_output(
|
| 369 |
-
configurable=configurable,
|
| 370 |
-
messages=messages,
|
| 371 |
-
tool_class=FollowUpQuery,
|
| 372 |
-
fallback_query=f"Tell me more about {state.research_topic}",
|
| 373 |
-
tool_query_field="follow_up_query",
|
| 374 |
-
json_query_field="follow_up_query",
|
| 375 |
-
)
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
def finalize_summary(state: SummaryState):
|
| 379 |
-
"""LangGraph node that finalizes the research summary.
|
| 380 |
-
|
| 381 |
-
Prepares the final output by deduplicating and formatting sources, then
|
| 382 |
-
combining them with the running summary to create a well-structured
|
| 383 |
-
research report with proper citations.
|
| 384 |
-
|
| 385 |
-
Args:
|
| 386 |
-
state: Current graph state containing the running summary and sources gathered
|
| 387 |
-
|
| 388 |
-
Returns:
|
| 389 |
-
Dictionary with state update, including running_summary key containing the formatted final summary with sources
|
| 390 |
-
"""
|
| 391 |
-
|
| 392 |
-
# Deduplicate sources before joining
|
| 393 |
-
seen_sources = set()
|
| 394 |
-
unique_sources = []
|
| 395 |
-
|
| 396 |
-
for source in state.sources_gathered:
|
| 397 |
-
# Split the source into lines and process each individually
|
| 398 |
-
for line in source.split("\n"):
|
| 399 |
-
# Only process non-empty lines
|
| 400 |
-
if line.strip() and line not in seen_sources:
|
| 401 |
-
seen_sources.add(line)
|
| 402 |
-
unique_sources.append(line)
|
| 403 |
-
|
| 404 |
-
# Join the deduplicated sources
|
| 405 |
-
all_sources = "\n".join(unique_sources)
|
| 406 |
-
state.running_summary = (
|
| 407 |
-
f"## Summary\n{state.running_summary}\n\n ### Sources:\n{all_sources}"
|
| 408 |
-
)
|
| 409 |
-
return {"running_summary": state.running_summary}
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
def route_research(
|
| 413 |
-
state: SummaryState, config: RunnableConfig
|
| 414 |
-
) -> Literal["finalize_summary", "web_research"]:
|
| 415 |
-
"""LangGraph routing function that determines the next step in the research flow.
|
| 416 |
-
|
| 417 |
-
Controls the research loop by deciding whether to continue gathering information
|
| 418 |
-
or to finalize the summary based on the configured maximum number of research loops.
|
| 419 |
-
|
| 420 |
-
Args:
|
| 421 |
-
state: Current graph state containing the research loop count
|
| 422 |
-
config: Configuration for the runnable, including max_web_research_loops setting
|
| 423 |
-
|
| 424 |
-
Returns:
|
| 425 |
-
String literal indicating the next node to visit ("web_research" or "finalize_summary")
|
| 426 |
-
"""
|
| 427 |
-
|
| 428 |
-
configurable = Configuration.from_runnable_config(config)
|
| 429 |
-
if state.research_loop_count <= configurable.max_web_research_loops:
|
| 430 |
-
return "web_research"
|
| 431 |
-
else:
|
| 432 |
-
return "finalize_summary"
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
# Add nodes and edges
|
| 436 |
-
builder = StateGraph(
|
| 437 |
-
SummaryState,
|
| 438 |
-
input=SummaryStateInput,
|
| 439 |
-
output=SummaryStateOutput,
|
| 440 |
-
config_schema=Configuration,
|
| 441 |
-
)
|
| 442 |
-
builder.add_node("generate_query", generate_query)
|
| 443 |
-
builder.add_node("web_research", web_research)
|
| 444 |
-
builder.add_node("summarize_sources", summarize_sources)
|
| 445 |
-
builder.add_node("reflect_on_summary", reflect_on_summary)
|
| 446 |
-
builder.add_node("finalize_summary", finalize_summary)
|
| 447 |
-
|
| 448 |
-
# Add edges
|
| 449 |
-
builder.add_edge(START, "generate_query")
|
| 450 |
-
builder.add_edge("generate_query", "web_research")
|
| 451 |
-
builder.add_edge("web_research", "summarize_sources")
|
| 452 |
-
builder.add_edge("summarize_sources", "reflect_on_summary")
|
| 453 |
-
builder.add_conditional_edges("reflect_on_summary", route_research)
|
| 454 |
-
builder.add_edge("finalize_summary", END)
|
| 455 |
-
|
| 456 |
-
graph = builder.compile()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/bar_race/deep_researcher/prompts.py
DELETED
|
@@ -1,112 +0,0 @@
|
|
| 1 |
-
from datetime import datetime
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
# Get current date in a readable format
|
| 5 |
-
def get_current_date():
|
| 6 |
-
return datetime.now().strftime("%B %d, %Y")
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
query_writer_instructions = """Your goal is to generate a targeted web search query.
|
| 10 |
-
|
| 11 |
-
<CONTEXT>
|
| 12 |
-
Current date: {current_date}
|
| 13 |
-
Please ensure your queries account for the most current information available as of this date.
|
| 14 |
-
</CONTEXT>
|
| 15 |
-
|
| 16 |
-
<TOPIC>
|
| 17 |
-
{research_topic}
|
| 18 |
-
</TOPIC>
|
| 19 |
-
|
| 20 |
-
<EXAMPLE>
|
| 21 |
-
Example output:
|
| 22 |
-
{{
|
| 23 |
-
"query": "machine learning transformer architecture explained",
|
| 24 |
-
"rationale": "Understanding the fundamental structure of transformer models"
|
| 25 |
-
}}
|
| 26 |
-
</EXAMPLE>"""
|
| 27 |
-
|
| 28 |
-
json_mode_query_instructions = """<FORMAT>
|
| 29 |
-
Format your response as a JSON object with ALL three of these exact keys:
|
| 30 |
-
- "query": The actual search query string
|
| 31 |
-
- "rationale": Brief explanation of why this query is relevant
|
| 32 |
-
</FORMAT>
|
| 33 |
-
|
| 34 |
-
Provide your response in JSON format:"""
|
| 35 |
-
|
| 36 |
-
tool_calling_query_instructions = """<INSTRUCTIONS >
|
| 37 |
-
Call the Query tool to format your response with the following keys:
|
| 38 |
-
- "query": The actual search query string
|
| 39 |
-
- "rationale": Brief explanation of why this query is relevant
|
| 40 |
-
</INSTRUCTIONS>
|
| 41 |
-
|
| 42 |
-
Call the Query Tool to generate a query for this request:"""
|
| 43 |
-
|
| 44 |
-
summarizer_instructions = """
|
| 45 |
-
<GOAL>
|
| 46 |
-
Generate a high-quality summary of the provided context.
|
| 47 |
-
</GOAL>
|
| 48 |
-
|
| 49 |
-
<REQUIREMENTS>
|
| 50 |
-
When creating a NEW summary:
|
| 51 |
-
1. Highlight the most relevant information related to the user topic from the search results
|
| 52 |
-
2. Ensure a coherent flow of information
|
| 53 |
-
|
| 54 |
-
When EXTENDING an existing summary:
|
| 55 |
-
1. Read the existing summary and new search results carefully.
|
| 56 |
-
2. Compare the new information with the existing summary.
|
| 57 |
-
3. For each piece of new information:
|
| 58 |
-
a. If it's related to existing points, integrate it into the relevant paragraph.
|
| 59 |
-
b. If it's entirely new but relevant, add a new paragraph with a smooth transition.
|
| 60 |
-
c. If it's not relevant to the user topic, skip it.
|
| 61 |
-
4. Ensure all additions are relevant to the user's topic.
|
| 62 |
-
5. Verify that your final output differs from the input summary.
|
| 63 |
-
< /REQUIREMENTS >
|
| 64 |
-
|
| 65 |
-
< FORMATTING >
|
| 66 |
-
- Start directly with the updated summary, without preamble or titles. Do not use XML tags in the output.
|
| 67 |
-
< /FORMATTING >
|
| 68 |
-
|
| 69 |
-
<Task>
|
| 70 |
-
Think carefully about the provided Context first. Then generate a summary of the context to address the User Input.
|
| 71 |
-
</Task>
|
| 72 |
-
"""
|
| 73 |
-
|
| 74 |
-
reflection_instructions = """You are an expert research assistant analyzing a summary about {research_topic}.
|
| 75 |
-
|
| 76 |
-
<GOAL>
|
| 77 |
-
1. Identify knowledge gaps or areas that need deeper exploration
|
| 78 |
-
2. Generate a follow-up question that would help expand your understanding
|
| 79 |
-
3. Focus on technical details, implementation specifics, or emerging trends that weren't fully covered
|
| 80 |
-
</GOAL>
|
| 81 |
-
|
| 82 |
-
<REQUIREMENTS>
|
| 83 |
-
Ensure the follow-up question is self-contained and includes necessary context for web search.
|
| 84 |
-
</REQUIREMENTS>"""
|
| 85 |
-
|
| 86 |
-
json_mode_reflection_instructions = """<FORMAT>
|
| 87 |
-
Format your response as a JSON object with these exact keys:
|
| 88 |
-
- knowledge_gap: Describe what information is missing or needs clarification
|
| 89 |
-
- follow_up_query: Write a specific question to address this gap
|
| 90 |
-
</FORMAT>
|
| 91 |
-
|
| 92 |
-
<Task>
|
| 93 |
-
Reflect carefully on the Summary to identify knowledge gaps and produce a follow-up query. Then, produce your output following this JSON format:
|
| 94 |
-
{{
|
| 95 |
-
"knowledge_gap": "The summary lacks information about performance metrics and benchmarks",
|
| 96 |
-
"follow_up_query": "What are typical performance benchmarks and metrics used to evaluate [specific technology]?"
|
| 97 |
-
}}
|
| 98 |
-
</Task>
|
| 99 |
-
|
| 100 |
-
Provide your analysis in JSON format:"""
|
| 101 |
-
|
| 102 |
-
tool_calling_reflection_instructions = """<INSTRUCTIONS>
|
| 103 |
-
Call the FollowUpQuery tool to format your response with the following keys:
|
| 104 |
-
- follow_up_query: Write a specific question to address this gap
|
| 105 |
-
- knowledge_gap: Describe what information is missing or needs clarification
|
| 106 |
-
</INSTRUCTIONS>
|
| 107 |
-
|
| 108 |
-
<Task>
|
| 109 |
-
Reflect carefully on the Summary to identify knowledge gaps and produce a follow-up query.
|
| 110 |
-
</Task>
|
| 111 |
-
|
| 112 |
-
Call the FollowUpQuery Tool to generate a reflection for this request:"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/bar_race/deep_researcher/rate_limiter.py
DELETED
|
@@ -1,120 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Rate Limiter for Gemini API calls.
|
| 3 |
-
|
| 4 |
-
Gemma-3-27b-it has a limit of 30 requests per minute.
|
| 5 |
-
This module ensures we don't exceed that limit.
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
import time
|
| 9 |
-
import threading
|
| 10 |
-
import logging
|
| 11 |
-
from collections import deque
|
| 12 |
-
from typing import Optional
|
| 13 |
-
|
| 14 |
-
logger = logging.getLogger(__name__)
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
class RateLimiter:
|
| 18 |
-
"""
|
| 19 |
-
Thread-safe rate limiter for API calls.
|
| 20 |
-
|
| 21 |
-
Ensures no more than `max_requests` are made within `time_window` seconds.
|
| 22 |
-
"""
|
| 23 |
-
|
| 24 |
-
def __init__(self, max_requests: int = 29, time_window: int = 62):
|
| 25 |
-
"""
|
| 26 |
-
Initialize the rate limiter.
|
| 27 |
-
|
| 28 |
-
Args:
|
| 29 |
-
max_requests: Maximum requests allowed in the time window (default 29)
|
| 30 |
-
time_window: Time window in seconds (default 60)
|
| 31 |
-
"""
|
| 32 |
-
self.max_requests = max_requests
|
| 33 |
-
self.time_window = time_window
|
| 34 |
-
self.requests: deque = deque()
|
| 35 |
-
self.lock = threading.Lock()
|
| 36 |
-
|
| 37 |
-
def wait_if_needed(self) -> float:
|
| 38 |
-
"""
|
| 39 |
-
Wait if we've exceeded the rate limit.
|
| 40 |
-
|
| 41 |
-
Returns:
|
| 42 |
-
Time waited in seconds (0 if no wait needed)
|
| 43 |
-
"""
|
| 44 |
-
with self.lock:
|
| 45 |
-
now = time.time()
|
| 46 |
-
|
| 47 |
-
# Remove old requests outside the time window
|
| 48 |
-
while self.requests and self.requests[0] < now - self.time_window:
|
| 49 |
-
self.requests.popleft()
|
| 50 |
-
|
| 51 |
-
# Check if we need to wait
|
| 52 |
-
if len(self.requests) >= self.max_requests:
|
| 53 |
-
# Wait until oldest request exits the window
|
| 54 |
-
wait_time = self.requests[0] + self.time_window - now + 0.1 # +0.1s buffer
|
| 55 |
-
|
| 56 |
-
if wait_time > 0:
|
| 57 |
-
logger.info(f"RateLimiter: Waiting {wait_time:.1f}s to avoid rate limit...")
|
| 58 |
-
self.lock.release() # Release lock while waiting
|
| 59 |
-
time.sleep(wait_time)
|
| 60 |
-
self.lock.acquire() # Re-acquire lock
|
| 61 |
-
|
| 62 |
-
# Clean up old requests after waiting
|
| 63 |
-
now = time.time()
|
| 64 |
-
while self.requests and self.requests[0] < now - self.time_window:
|
| 65 |
-
self.requests.popleft()
|
| 66 |
-
|
| 67 |
-
return wait_time
|
| 68 |
-
|
| 69 |
-
return 0.0
|
| 70 |
-
|
| 71 |
-
def record_request(self):
|
| 72 |
-
"""Record that a request was made."""
|
| 73 |
-
with self.lock:
|
| 74 |
-
self.requests.append(time.time())
|
| 75 |
-
|
| 76 |
-
def acquire(self) -> float:
|
| 77 |
-
"""
|
| 78 |
-
Acquire permission to make a request.
|
| 79 |
-
|
| 80 |
-
This is the main method to call before making an API request.
|
| 81 |
-
It will wait if necessary and record the request.
|
| 82 |
-
|
| 83 |
-
Returns:
|
| 84 |
-
Time waited in seconds
|
| 85 |
-
"""
|
| 86 |
-
wait_time = self.wait_if_needed()
|
| 87 |
-
self.record_request()
|
| 88 |
-
return wait_time
|
| 89 |
-
|
| 90 |
-
@property
|
| 91 |
-
def current_count(self) -> int:
|
| 92 |
-
"""Get current request count in the time window."""
|
| 93 |
-
with self.lock:
|
| 94 |
-
now = time.time()
|
| 95 |
-
# Remove old requests
|
| 96 |
-
while self.requests and self.requests[0] < now - self.time_window:
|
| 97 |
-
self.requests.popleft()
|
| 98 |
-
return len(self.requests)
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
# Global rate limiter instance for Gemini API
|
| 102 |
-
# 29 requests per 62 seconds to stay safely under the 30/min limit
|
| 103 |
-
gemini_rate_limiter = RateLimiter(max_requests=29, time_window=62)
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
def rate_limited_call(func):
|
| 107 |
-
"""
|
| 108 |
-
Decorator to rate limit function calls.
|
| 109 |
-
|
| 110 |
-
Usage:
|
| 111 |
-
@rate_limited_call
|
| 112 |
-
def my_api_call():
|
| 113 |
-
...
|
| 114 |
-
"""
|
| 115 |
-
def wrapper(*args, **kwargs):
|
| 116 |
-
wait_time = gemini_rate_limiter.acquire()
|
| 117 |
-
if wait_time > 0:
|
| 118 |
-
logger.info(f"RateLimiter: Waited {wait_time:.1f}s before API call")
|
| 119 |
-
return func(*args, **kwargs)
|
| 120 |
-
return wrapper
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/bar_race/deep_researcher/state.py
DELETED
|
@@ -1,25 +0,0 @@
|
|
| 1 |
-
import operator
|
| 2 |
-
from dataclasses import dataclass, field
|
| 3 |
-
from typing_extensions import Annotated
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
# Note: kw_only=True removed for Python 3.9 compatibility
|
| 7 |
-
@dataclass
|
| 8 |
-
class SummaryState:
|
| 9 |
-
research_topic: str = field(default=None) # Report topic
|
| 10 |
-
search_query: str = field(default=None) # Search query
|
| 11 |
-
web_research_results: Annotated[list, operator.add] = field(default_factory=list)
|
| 12 |
-
sources_gathered: Annotated[list, operator.add] = field(default_factory=list)
|
| 13 |
-
research_loop_count: int = field(default=0) # Research loop count
|
| 14 |
-
running_summary: str = field(default=None) # Final report
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
@dataclass
|
| 18 |
-
class SummaryStateInput:
|
| 19 |
-
research_topic: str = field(default=None) # Report topic
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
@dataclass
|
| 23 |
-
class SummaryStateOutput:
|
| 24 |
-
running_summary: str = field(default=None) # Final report
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/bar_race/deep_researcher/utils.py
DELETED
|
@@ -1,333 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import httpx
|
| 3 |
-
import requests
|
| 4 |
-
from typing import Dict, Any, List, Union, Optional
|
| 5 |
-
|
| 6 |
-
from markdownify import markdownify
|
| 7 |
-
from langsmith import traceable
|
| 8 |
-
from tavily import TavilyClient
|
| 9 |
-
from duckduckgo_search import DDGS
|
| 10 |
-
|
| 11 |
-
# Note: SearxSearchWrapper removed to avoid langchain-community dependency
|
| 12 |
-
# We use Tavily and DuckDuckGo for search instead
|
| 13 |
-
|
| 14 |
-
# Constants
|
| 15 |
-
CHARS_PER_TOKEN = 4
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def get_config_value(value: Any) -> str:
|
| 19 |
-
"""
|
| 20 |
-
Convert configuration values to string format, handling both string and enum types.
|
| 21 |
-
|
| 22 |
-
Args:
|
| 23 |
-
value (Any): The configuration value to process. Can be a string or an Enum.
|
| 24 |
-
|
| 25 |
-
Returns:
|
| 26 |
-
str: The string representation of the value.
|
| 27 |
-
|
| 28 |
-
Examples:
|
| 29 |
-
>>> get_config_value("tavily")
|
| 30 |
-
'tavily'
|
| 31 |
-
>>> get_config_value(SearchAPI.TAVILY)
|
| 32 |
-
'tavily'
|
| 33 |
-
"""
|
| 34 |
-
return value if isinstance(value, str) else value.value
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def strip_thinking_tokens(text: str) -> str:
|
| 38 |
-
"""
|
| 39 |
-
Remove <think> and </think> tags and their content from the text.
|
| 40 |
-
|
| 41 |
-
Iteratively removes all occurrences of content enclosed in thinking tokens.
|
| 42 |
-
|
| 43 |
-
Args:
|
| 44 |
-
text (str): The text to process
|
| 45 |
-
|
| 46 |
-
Returns:
|
| 47 |
-
str: The text with thinking tokens and their content removed
|
| 48 |
-
"""
|
| 49 |
-
while "<think>" in text and "</think>" in text:
|
| 50 |
-
start = text.find("<think>")
|
| 51 |
-
end = text.find("</think>") + len("</think>")
|
| 52 |
-
text = text[:start] + text[end:]
|
| 53 |
-
return text
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
def deduplicate_and_format_sources(
|
| 57 |
-
search_response: Union[Dict[str, Any], List[Dict[str, Any]]],
|
| 58 |
-
max_tokens_per_source: int,
|
| 59 |
-
fetch_full_page: bool = False,
|
| 60 |
-
) -> str:
|
| 61 |
-
"""
|
| 62 |
-
Format and deduplicate search responses from various search APIs.
|
| 63 |
-
|
| 64 |
-
Takes either a single search response or list of responses from search APIs,
|
| 65 |
-
deduplicates them by URL, and formats them into a structured string.
|
| 66 |
-
|
| 67 |
-
Args:
|
| 68 |
-
search_response (Union[Dict[str, Any], List[Dict[str, Any]]]): Either:
|
| 69 |
-
- A dict with a 'results' key containing a list of search results
|
| 70 |
-
- A list of dicts, each containing search results
|
| 71 |
-
max_tokens_per_source (int): Maximum number of tokens to include for each source's content
|
| 72 |
-
fetch_full_page (bool, optional): Whether to include the full page content. Defaults to False.
|
| 73 |
-
|
| 74 |
-
Returns:
|
| 75 |
-
str: Formatted string with deduplicated sources
|
| 76 |
-
|
| 77 |
-
Raises:
|
| 78 |
-
ValueError: If input is neither a dict with 'results' key nor a list of search results
|
| 79 |
-
"""
|
| 80 |
-
# Convert input to list of results
|
| 81 |
-
if isinstance(search_response, dict):
|
| 82 |
-
sources_list = search_response["results"]
|
| 83 |
-
elif isinstance(search_response, list):
|
| 84 |
-
sources_list = []
|
| 85 |
-
for response in search_response:
|
| 86 |
-
if isinstance(response, dict) and "results" in response:
|
| 87 |
-
sources_list.extend(response["results"])
|
| 88 |
-
else:
|
| 89 |
-
sources_list.extend(response)
|
| 90 |
-
else:
|
| 91 |
-
raise ValueError(
|
| 92 |
-
"Input must be either a dict with 'results' or a list of search results"
|
| 93 |
-
)
|
| 94 |
-
|
| 95 |
-
# Deduplicate by URL
|
| 96 |
-
unique_sources = {}
|
| 97 |
-
for source in sources_list:
|
| 98 |
-
if source["url"] not in unique_sources:
|
| 99 |
-
unique_sources[source["url"]] = source
|
| 100 |
-
|
| 101 |
-
# Format output
|
| 102 |
-
formatted_text = "Sources:\n\n"
|
| 103 |
-
for i, source in enumerate(unique_sources.values(), 1):
|
| 104 |
-
formatted_text += f"Source: {source['title']}\n===\n"
|
| 105 |
-
formatted_text += f"URL: {source['url']}\n===\n"
|
| 106 |
-
formatted_text += (
|
| 107 |
-
f"Most relevant content from source: {source['content']}\n===\n"
|
| 108 |
-
)
|
| 109 |
-
if fetch_full_page:
|
| 110 |
-
# Using rough estimate of characters per token
|
| 111 |
-
char_limit = max_tokens_per_source * CHARS_PER_TOKEN
|
| 112 |
-
# Handle None raw_content
|
| 113 |
-
raw_content = source.get("raw_content", "")
|
| 114 |
-
if raw_content is None:
|
| 115 |
-
raw_content = ""
|
| 116 |
-
print(f"Warning: No raw_content found for source {source['url']}")
|
| 117 |
-
if len(raw_content) > char_limit:
|
| 118 |
-
raw_content = raw_content[:char_limit] + "... [truncated]"
|
| 119 |
-
formatted_text += f"Full source content limited to {max_tokens_per_source} tokens: {raw_content}\n\n"
|
| 120 |
-
|
| 121 |
-
return formatted_text.strip()
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
def format_sources(search_results: Dict[str, Any]) -> str:
|
| 125 |
-
"""
|
| 126 |
-
Format search results into a bullet-point list of sources with URLs.
|
| 127 |
-
|
| 128 |
-
Creates a simple bulleted list of search results with title and URL for each source.
|
| 129 |
-
|
| 130 |
-
Args:
|
| 131 |
-
search_results (Dict[str, Any]): Search response containing a 'results' key with
|
| 132 |
-
a list of search result objects
|
| 133 |
-
|
| 134 |
-
Returns:
|
| 135 |
-
str: Formatted string with sources as bullet points in the format "* title : url"
|
| 136 |
-
"""
|
| 137 |
-
return "\n".join(
|
| 138 |
-
f"* {source['title']} : {source['url']}" for source in search_results["results"]
|
| 139 |
-
)
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
def fetch_raw_content(url: str) -> Optional[str]:
|
| 143 |
-
"""
|
| 144 |
-
Fetch HTML content from a URL and convert it to markdown format.
|
| 145 |
-
|
| 146 |
-
Uses a 10-second timeout to avoid hanging on slow sites or large pages.
|
| 147 |
-
|
| 148 |
-
Args:
|
| 149 |
-
url (str): The URL to fetch content from
|
| 150 |
-
|
| 151 |
-
Returns:
|
| 152 |
-
Optional[str]: The fetched content converted to markdown if successful,
|
| 153 |
-
None if any error occurs during fetching or conversion
|
| 154 |
-
"""
|
| 155 |
-
try:
|
| 156 |
-
# Create a client with reasonable timeout
|
| 157 |
-
with httpx.Client(timeout=10.0) as client:
|
| 158 |
-
response = client.get(url)
|
| 159 |
-
response.raise_for_status()
|
| 160 |
-
return markdownify(response.text)
|
| 161 |
-
except Exception as e:
|
| 162 |
-
print(f"Warning: Failed to fetch full page content for {url}: {str(e)}")
|
| 163 |
-
return None
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
@traceable
|
| 167 |
-
def duckduckgo_search(
|
| 168 |
-
query: str, max_results: int = 3, fetch_full_page: bool = False
|
| 169 |
-
) -> Dict[str, List[Dict[str, Any]]]:
|
| 170 |
-
"""
|
| 171 |
-
Search the web using DuckDuckGo and return formatted results.
|
| 172 |
-
|
| 173 |
-
Uses the DDGS library to perform web searches through DuckDuckGo.
|
| 174 |
-
|
| 175 |
-
Args:
|
| 176 |
-
query (str): The search query to execute
|
| 177 |
-
max_results (int, optional): Maximum number of results to return. Defaults to 3.
|
| 178 |
-
fetch_full_page (bool, optional): Whether to fetch full page content from result URLs.
|
| 179 |
-
Defaults to False.
|
| 180 |
-
Returns:
|
| 181 |
-
Dict[str, List[Dict[str, Any]]]: Search response containing:
|
| 182 |
-
- results (list): List of search result dictionaries, each containing:
|
| 183 |
-
- title (str): Title of the search result
|
| 184 |
-
- url (str): URL of the search result
|
| 185 |
-
- content (str): Snippet/summary of the content
|
| 186 |
-
- raw_content (str or None): Full page content if fetch_full_page is True,
|
| 187 |
-
otherwise same as content
|
| 188 |
-
"""
|
| 189 |
-
try:
|
| 190 |
-
with DDGS() as ddgs:
|
| 191 |
-
results = []
|
| 192 |
-
search_results = list(ddgs.text(query, max_results=max_results))
|
| 193 |
-
|
| 194 |
-
for r in search_results:
|
| 195 |
-
url = r.get("href")
|
| 196 |
-
title = r.get("title")
|
| 197 |
-
content = r.get("body")
|
| 198 |
-
|
| 199 |
-
if not all([url, title, content]):
|
| 200 |
-
print(f"Warning: Incomplete result from DuckDuckGo: {r}")
|
| 201 |
-
continue
|
| 202 |
-
|
| 203 |
-
raw_content = content
|
| 204 |
-
if fetch_full_page:
|
| 205 |
-
raw_content = fetch_raw_content(url)
|
| 206 |
-
|
| 207 |
-
# Add result to list
|
| 208 |
-
result = {
|
| 209 |
-
"title": title,
|
| 210 |
-
"url": url,
|
| 211 |
-
"content": content,
|
| 212 |
-
"raw_content": raw_content,
|
| 213 |
-
}
|
| 214 |
-
results.append(result)
|
| 215 |
-
|
| 216 |
-
return {"results": results}
|
| 217 |
-
except Exception as e:
|
| 218 |
-
print(f"Error in DuckDuckGo search: {str(e)}")
|
| 219 |
-
print(f"Full error details: {type(e).__name__}")
|
| 220 |
-
return {"results": []}
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
@traceable
|
| 224 |
-
def tavily_search(
|
| 225 |
-
query: str, fetch_full_page: bool = True, max_results: int = 3
|
| 226 |
-
) -> Dict[str, List[Dict[str, Any]]]:
|
| 227 |
-
"""
|
| 228 |
-
Search the web using the Tavily API and return formatted results.
|
| 229 |
-
|
| 230 |
-
Uses the TavilyClient to perform searches. Tavily API key must be configured
|
| 231 |
-
in the environment.
|
| 232 |
-
|
| 233 |
-
Args:
|
| 234 |
-
query (str): The search query to execute
|
| 235 |
-
fetch_full_page (bool, optional): Whether to include raw content from sources.
|
| 236 |
-
Defaults to True.
|
| 237 |
-
max_results (int, optional): Maximum number of results to return. Defaults to 3.
|
| 238 |
-
|
| 239 |
-
Returns:
|
| 240 |
-
Dict[str, List[Dict[str, Any]]]: Search response containing:
|
| 241 |
-
- results (list): List of search result dictionaries, each containing:
|
| 242 |
-
- title (str): Title of the search result
|
| 243 |
-
- url (str): URL of the search result
|
| 244 |
-
- content (str): Snippet/summary of the content
|
| 245 |
-
- raw_content (str or None): Full content of the page if available and
|
| 246 |
-
fetch_full_page is True
|
| 247 |
-
"""
|
| 248 |
-
|
| 249 |
-
tavily_client = TavilyClient()
|
| 250 |
-
return tavily_client.search(
|
| 251 |
-
query, max_results=max_results, include_raw_content=fetch_full_page
|
| 252 |
-
)
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
@traceable
|
| 256 |
-
def perplexity_search(
|
| 257 |
-
query: str, perplexity_search_loop_count: int = 0
|
| 258 |
-
) -> Dict[str, Any]:
|
| 259 |
-
"""
|
| 260 |
-
Search the web using the Perplexity API and return formatted results.
|
| 261 |
-
|
| 262 |
-
Uses the Perplexity API to perform searches with the 'sonar-pro' model.
|
| 263 |
-
Requires a PERPLEXITY_API_KEY environment variable to be set.
|
| 264 |
-
|
| 265 |
-
Args:
|
| 266 |
-
query (str): The search query to execute
|
| 267 |
-
perplexity_search_loop_count (int, optional): The loop step for perplexity search
|
| 268 |
-
(used for source labeling). Defaults to 0.
|
| 269 |
-
|
| 270 |
-
Returns:
|
| 271 |
-
Dict[str, Any]: Search response containing:
|
| 272 |
-
- results (list): List of search result dictionaries, each containing:
|
| 273 |
-
- title (str): Title of the search result (includes search counter)
|
| 274 |
-
- url (str): URL of the citation source
|
| 275 |
-
- content (str): Content of the response or reference to main content
|
| 276 |
-
- raw_content (str or None): Full content for the first source, None for additional
|
| 277 |
-
citation sources
|
| 278 |
-
|
| 279 |
-
Raises:
|
| 280 |
-
requests.exceptions.HTTPError: If the API request fails
|
| 281 |
-
"""
|
| 282 |
-
|
| 283 |
-
headers = {
|
| 284 |
-
"accept": "application/json",
|
| 285 |
-
"content-type": "application/json",
|
| 286 |
-
"Authorization": f"Bearer {os.getenv('PERPLEXITY_API_KEY')}",
|
| 287 |
-
}
|
| 288 |
-
|
| 289 |
-
payload = {
|
| 290 |
-
"model": "sonar-pro",
|
| 291 |
-
"messages": [
|
| 292 |
-
{
|
| 293 |
-
"role": "system",
|
| 294 |
-
"content": "Search the web and provide factual information with sources.",
|
| 295 |
-
},
|
| 296 |
-
{"role": "user", "content": query},
|
| 297 |
-
],
|
| 298 |
-
}
|
| 299 |
-
|
| 300 |
-
response = requests.post(
|
| 301 |
-
"https://api.perplexity.ai/chat/completions", headers=headers, json=payload
|
| 302 |
-
)
|
| 303 |
-
response.raise_for_status() # Raise exception for bad status codes
|
| 304 |
-
|
| 305 |
-
# Parse the response
|
| 306 |
-
data = response.json()
|
| 307 |
-
content = data["choices"][0]["message"]["content"]
|
| 308 |
-
|
| 309 |
-
# Perplexity returns a list of citations for a single search result
|
| 310 |
-
citations = data.get("citations", ["https://perplexity.ai"])
|
| 311 |
-
|
| 312 |
-
# Return first citation with full content, others just as references
|
| 313 |
-
results = [
|
| 314 |
-
{
|
| 315 |
-
"title": f"Perplexity Search {perplexity_search_loop_count + 1}, Source 1",
|
| 316 |
-
"url": citations[0],
|
| 317 |
-
"content": content,
|
| 318 |
-
"raw_content": content,
|
| 319 |
-
}
|
| 320 |
-
]
|
| 321 |
-
|
| 322 |
-
# Add additional citations without duplicating content
|
| 323 |
-
for i, citation in enumerate(citations[1:], start=2):
|
| 324 |
-
results.append(
|
| 325 |
-
{
|
| 326 |
-
"title": f"Perplexity Search {perplexity_search_loop_count + 1}, Source {i}",
|
| 327 |
-
"url": citation,
|
| 328 |
-
"content": "See above for full content",
|
| 329 |
-
"raw_content": None,
|
| 330 |
-
}
|
| 331 |
-
)
|
| 332 |
-
|
| 333 |
-
return {"results": results}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/bar_race/router.py
DELETED
|
@@ -1,265 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Bar Race Router
|
| 3 |
-
API endpoints for bar chart race video generation.
|
| 4 |
-
"""
|
| 5 |
-
import logging
|
| 6 |
-
import os
|
| 7 |
-
import uuid
|
| 8 |
-
import shutil
|
| 9 |
-
import traceback
|
| 10 |
-
from typing import Dict
|
| 11 |
-
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
| 12 |
-
from fastapi.responses import FileResponse, RedirectResponse
|
| 13 |
-
|
| 14 |
-
from .schemas import BarRaceRequest, JobResponse, JobStatus
|
| 15 |
-
|
| 16 |
-
logger = logging.getLogger(__name__)
|
| 17 |
-
|
| 18 |
-
router = APIRouter()
|
| 19 |
-
|
| 20 |
-
# Job storage
|
| 21 |
-
jobs: Dict[str, dict] = {}
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
def update_job(
|
| 25 |
-
job_id: str,
|
| 26 |
-
status: str,
|
| 27 |
-
progress: int = 0,
|
| 28 |
-
current_step: str = None,
|
| 29 |
-
video_url: str = None,
|
| 30 |
-
error: str = None
|
| 31 |
-
):
|
| 32 |
-
"""Update job status"""
|
| 33 |
-
if job_id in jobs:
|
| 34 |
-
jobs[job_id].update({
|
| 35 |
-
"status": status,
|
| 36 |
-
"progress": progress,
|
| 37 |
-
"current_step": current_step,
|
| 38 |
-
"video_url": video_url,
|
| 39 |
-
"error": error
|
| 40 |
-
})
|
| 41 |
-
logger.debug(f"Job {job_id}: {status} ({progress}%) - {current_step}")
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
async def generate_bar_race_video(job_id: str, request: BarRaceRequest):
|
| 45 |
-
"""Background task to generate bar race video using Agentic AI Pipeline"""
|
| 46 |
-
temp_dir = f"temp/bar_race_{job_id}"
|
| 47 |
-
|
| 48 |
-
try:
|
| 49 |
-
os.makedirs(temp_dir, exist_ok=True)
|
| 50 |
-
|
| 51 |
-
# Get API key from environment
|
| 52 |
-
gemini_api_key = os.getenv("GEMINI_API_KEY")
|
| 53 |
-
os.environ["TAVILY_API_KEY"] = "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"
|
| 54 |
-
|
| 55 |
-
# ============ BRAIN: Enhance Topic ============
|
| 56 |
-
update_job(job_id, "processing", 5, "Brain: Enhancing topic...")
|
| 57 |
-
|
| 58 |
-
from .services.brain import Brain
|
| 59 |
-
brain = Brain(gemini_api_key=gemini_api_key)
|
| 60 |
-
|
| 61 |
-
# Enhance raw topic into research-ready prompt
|
| 62 |
-
enhanced_topic = brain.enhance_topic(request.topic)
|
| 63 |
-
logger.info(f"Brain: Enhanced topic: {enhanced_topic[:100]}...")
|
| 64 |
-
|
| 65 |
-
# ============ DEEP RESEARCHER: Tavily Search ============
|
| 66 |
-
update_job(job_id, "processing", 15, "Researcher: Searching for data...")
|
| 67 |
-
|
| 68 |
-
try:
|
| 69 |
-
from .deep_researcher import graph
|
| 70 |
-
|
| 71 |
-
# Run the deep research graph with ENHANCED topic
|
| 72 |
-
research_result = await graph.ainvoke(
|
| 73 |
-
{"research_topic": enhanced_topic},
|
| 74 |
-
config={"configurable": {"search_api": "tavily", "max_web_research_loops": 2}}
|
| 75 |
-
)
|
| 76 |
-
|
| 77 |
-
research_text = research_result.get("running_summary", "")
|
| 78 |
-
|
| 79 |
-
if not research_text:
|
| 80 |
-
raise Exception("Deep Researcher returned no data")
|
| 81 |
-
|
| 82 |
-
logger.info(f"Researcher: Got research text ({len(research_text)} chars)")
|
| 83 |
-
|
| 84 |
-
except Exception as e:
|
| 85 |
-
logger.error(f"Deep Researcher failed: {e}")
|
| 86 |
-
raise Exception(f"Deep Researcher failed: {e}")
|
| 87 |
-
|
| 88 |
-
# ============ ANALYST: Extract CSV ============
|
| 89 |
-
update_job(job_id, "processing", 40, "Analyst: Extracting data...")
|
| 90 |
-
|
| 91 |
-
from .services.analyst import Analyst
|
| 92 |
-
analyst = Analyst(gemini_api_key=gemini_api_key)
|
| 93 |
-
|
| 94 |
-
# Extract CSV from research text
|
| 95 |
-
wide_df = analyst.extract_csv(research_text, request.topic)
|
| 96 |
-
|
| 97 |
-
if wide_df is None or wide_df.empty:
|
| 98 |
-
raise Exception("Analyst failed to extract CSV from research")
|
| 99 |
-
|
| 100 |
-
logger.info(f"Analyst: Extracted {len(wide_df)} rows, {len(wide_df.columns)} columns")
|
| 101 |
-
|
| 102 |
-
# ============ ANALYST: Find & Fill Gaps ============
|
| 103 |
-
update_job(job_id, "processing", 50, "Analyst: Filling data gaps...")
|
| 104 |
-
|
| 105 |
-
gaps = analyst.find_gaps(wide_df)
|
| 106 |
-
|
| 107 |
-
if gaps:
|
| 108 |
-
logger.info(f"Analyst: Found {len(gaps)} gaps, filling with Brain knowledge...")
|
| 109 |
-
# Use Brain to fill gaps
|
| 110 |
-
wide_df = brain.fill_data_gaps(wide_df, gaps, enhanced_topic)
|
| 111 |
-
|
| 112 |
-
# ============ ANALYST: Clean Data ============
|
| 113 |
-
update_job(job_id, "processing", 55, "Analyst: Cleaning data...")
|
| 114 |
-
|
| 115 |
-
# Save raw CSV for potential Groq fixing
|
| 116 |
-
raw_csv_backup = wide_df.to_csv(index=True) # Keep index in case it has years
|
| 117 |
-
|
| 118 |
-
wide_df = analyst.clean_data(wide_df)
|
| 119 |
-
|
| 120 |
-
# ============ GROQ FALLBACK: Fix Data if clean_data failed ============
|
| 121 |
-
if wide_df is None or wide_df.empty:
|
| 122 |
-
logger.warning("Analyst: Initial cleaning failed, attempting Groq fix...")
|
| 123 |
-
update_job(job_id, "processing", 58, "Analyst: Fixing data with Groq AI...")
|
| 124 |
-
|
| 125 |
-
# Try Groq to fix the data
|
| 126 |
-
fixed_df = analyst.fix_with_groq(raw_csv_backup, enhanced_topic)
|
| 127 |
-
|
| 128 |
-
if fixed_df is not None and not fixed_df.empty:
|
| 129 |
-
logger.info("Analyst: Groq fixed the data, re-cleaning...")
|
| 130 |
-
wide_df = analyst.clean_data(fixed_df)
|
| 131 |
-
|
| 132 |
-
# Convert to long format for bar_chart_race
|
| 133 |
-
clean_df = analyst.convert_to_long_format(wide_df)
|
| 134 |
-
|
| 135 |
-
if clean_df is None or clean_df.empty:
|
| 136 |
-
raise Exception("Analyst failed to produce clean data")
|
| 137 |
-
|
| 138 |
-
# Save for debugging
|
| 139 |
-
clean_df.to_csv(os.path.join(temp_dir, "clean_data.csv"), index=False)
|
| 140 |
-
logger.info(f"Analyst: Clean data - {len(clean_df)} rows, {clean_df['name'].nunique()} entities")
|
| 141 |
-
|
| 142 |
-
# ============ DIRECTOR ============
|
| 143 |
-
update_job(job_id, "processing", 65, "Director: Generating video...")
|
| 144 |
-
|
| 145 |
-
from .services.director import Director
|
| 146 |
-
director = Director(temp_dir=temp_dir)
|
| 147 |
-
|
| 148 |
-
# Create simple video metadata (replaces old 'plan')
|
| 149 |
-
video_meta = {
|
| 150 |
-
"video_meta": {"title": request.topic},
|
| 151 |
-
"value_intent": {"unit": ""},
|
| 152 |
-
"visualization": {"top_n": 10}
|
| 153 |
-
}
|
| 154 |
-
|
| 155 |
-
video_path = director.generate_video(
|
| 156 |
-
df=clean_df,
|
| 157 |
-
plan=video_meta,
|
| 158 |
-
image_paths={},
|
| 159 |
-
duration_seconds=request.duration_seconds,
|
| 160 |
-
job_id=job_id
|
| 161 |
-
)
|
| 162 |
-
|
| 163 |
-
if not video_path or not os.path.exists(video_path):
|
| 164 |
-
raise Exception("Director failed to generate video")
|
| 165 |
-
|
| 166 |
-
logger.info(f"Director: Generated video at {video_path}")
|
| 167 |
-
|
| 168 |
-
# ============ UPLOAD TO HF ============
|
| 169 |
-
update_job(job_id, "processing", 85, "Uploading to cloud storage...")
|
| 170 |
-
|
| 171 |
-
video_url = None
|
| 172 |
-
try:
|
| 173 |
-
from modules.shared.services.hf_storage import get_hf_storage
|
| 174 |
-
hf_storage = get_hf_storage()
|
| 175 |
-
|
| 176 |
-
if hf_storage and hf_storage.enabled:
|
| 177 |
-
# Upload video
|
| 178 |
-
from pathlib import Path
|
| 179 |
-
uploaded_url = hf_storage.upload_video(
|
| 180 |
-
local_path=Path(video_path),
|
| 181 |
-
video_id=job_id,
|
| 182 |
-
folder="bar_race"
|
| 183 |
-
)
|
| 184 |
-
if uploaded_url:
|
| 185 |
-
video_url = uploaded_url
|
| 186 |
-
logger.info(f"Uploaded to HF: {video_url}")
|
| 187 |
-
except Exception as e:
|
| 188 |
-
logger.warning(f"HF upload failed, using local: {e}")
|
| 189 |
-
|
| 190 |
-
# Fallback to local URL
|
| 191 |
-
if not video_url:
|
| 192 |
-
video_url = f"/api/bar-race/video/{job_id}"
|
| 193 |
-
|
| 194 |
-
# ============ SUCCESS ============
|
| 195 |
-
update_job(job_id, "ready", 100, "Complete", video_url=video_url)
|
| 196 |
-
logger.info(f"Bar race video ready: {video_url}")
|
| 197 |
-
|
| 198 |
-
# Cleanup temp files (only on success)
|
| 199 |
-
try:
|
| 200 |
-
if os.path.exists(temp_dir):
|
| 201 |
-
shutil.rmtree(temp_dir)
|
| 202 |
-
logger.info(f"Cleaned up temp directory: {temp_dir}")
|
| 203 |
-
except Exception as e:
|
| 204 |
-
logger.warning(f"Cleanup failed: {e}")
|
| 205 |
-
|
| 206 |
-
except Exception as e:
|
| 207 |
-
logger.error(f"Bar race generation failed: {e}")
|
| 208 |
-
logger.error(traceback.format_exc())
|
| 209 |
-
update_job(job_id, "failed", error=str(e))
|
| 210 |
-
|
| 211 |
-
# Keep temp files for debugging on failure
|
| 212 |
-
logger.info(f"Keeping temp directory for debugging: {temp_dir}")
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
@router.post("/generate", response_model=JobResponse)
|
| 216 |
-
async def generate_bar_race(request: BarRaceRequest, background_tasks: BackgroundTasks):
|
| 217 |
-
"""
|
| 218 |
-
Generate a bar chart race video.
|
| 219 |
-
|
| 220 |
-
Takes a topic and duration, returns job_id to track progress.
|
| 221 |
-
"""
|
| 222 |
-
job_id = str(uuid.uuid4())[:8]
|
| 223 |
-
|
| 224 |
-
# Initialize job
|
| 225 |
-
jobs[job_id] = {
|
| 226 |
-
"job_id": job_id,
|
| 227 |
-
"status": "queued",
|
| 228 |
-
"progress": 0,
|
| 229 |
-
"current_step": "Initializing...",
|
| 230 |
-
"video_url": None,
|
| 231 |
-
"error": None
|
| 232 |
-
}
|
| 233 |
-
|
| 234 |
-
# Start background task
|
| 235 |
-
background_tasks.add_task(generate_bar_race_video, job_id, request)
|
| 236 |
-
|
| 237 |
-
return JobResponse(
|
| 238 |
-
job_id=job_id,
|
| 239 |
-
status="queued",
|
| 240 |
-
message=f"Bar race generation started for topic: {request.topic}"
|
| 241 |
-
)
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
@router.get("/{job_id}/status", response_model=JobStatus)
|
| 245 |
-
async def get_job_status(job_id: str):
|
| 246 |
-
"""Get status of a bar race generation job"""
|
| 247 |
-
if job_id not in jobs:
|
| 248 |
-
raise HTTPException(404, f"Job not found: {job_id}")
|
| 249 |
-
|
| 250 |
-
return JobStatus(**jobs[job_id])
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
@router.get("/video/{job_id}")
|
| 254 |
-
async def get_video(job_id: str):
|
| 255 |
-
"""Download the generated bar race video"""
|
| 256 |
-
video_path = f"videos/bar_race/bar_race_{job_id}.mp4"
|
| 257 |
-
|
| 258 |
-
if not os.path.exists(video_path):
|
| 259 |
-
raise HTTPException(404, "Video not found")
|
| 260 |
-
|
| 261 |
-
return FileResponse(
|
| 262 |
-
video_path,
|
| 263 |
-
media_type="video/mp4",
|
| 264 |
-
filename=f"bar_race_{job_id}.mp4"
|
| 265 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/bar_race/schemas.py
DELETED
|
@@ -1,59 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Bar Race Schemas
|
| 3 |
-
Pydantic models for bar chart race video generation.
|
| 4 |
-
"""
|
| 5 |
-
from pydantic import BaseModel, Field
|
| 6 |
-
from typing import Optional, List, Dict, Any
|
| 7 |
-
from enum import Enum
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class EntityType(str, Enum):
|
| 11 |
-
"""Type of entities in the bar chart"""
|
| 12 |
-
PERSON = "person"
|
| 13 |
-
COUNTRY = "country"
|
| 14 |
-
COMPANY = "company"
|
| 15 |
-
GENERAL = "general"
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
class BarRaceRequest(BaseModel):
|
| 19 |
-
"""Request to generate a bar chart race video"""
|
| 20 |
-
topic: str = Field(..., description="Topic/prompt for video (e.g., 'Top 10 richest cricketers')")
|
| 21 |
-
duration_seconds: int = Field(60, ge=30, le=120, description="Video duration in seconds")
|
| 22 |
-
|
| 23 |
-
class Config:
|
| 24 |
-
json_schema_extra = {
|
| 25 |
-
"example": {
|
| 26 |
-
"topic": "Top 10 richest countries by GDP 2000-2024",
|
| 27 |
-
"duration_seconds": 60
|
| 28 |
-
}
|
| 29 |
-
}
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
class BrainPlan(BaseModel):
|
| 33 |
-
"""JSON plan generated by Brain (LLM)"""
|
| 34 |
-
topic: str
|
| 35 |
-
entity_type: EntityType
|
| 36 |
-
time_config: Dict[str, Any]
|
| 37 |
-
value_intent: Dict[str, Any]
|
| 38 |
-
search_strategies: List[Dict[str, Any]]
|
| 39 |
-
source_priority: List[str]
|
| 40 |
-
data_expectation: Dict[str, Any]
|
| 41 |
-
visualization: Dict[str, Any]
|
| 42 |
-
video_meta: Dict[str, Any]
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
class JobResponse(BaseModel):
|
| 46 |
-
"""Response when job is created"""
|
| 47 |
-
job_id: str
|
| 48 |
-
status: str
|
| 49 |
-
message: str
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
class JobStatus(BaseModel):
|
| 53 |
-
"""Job status response"""
|
| 54 |
-
job_id: str
|
| 55 |
-
status: str # queued, brain, scout, surgeon, artist, director, uploading, ready, failed
|
| 56 |
-
progress: int = 0
|
| 57 |
-
current_step: Optional[str] = None
|
| 58 |
-
video_url: Optional[str] = None
|
| 59 |
-
error: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/bar_race/services/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
# Services package
|
|
|
|
|
|
modules/bar_race/services/analyst.py
DELETED
|
@@ -1,517 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Analyst Service - Extracts structured CSV data from research text.
|
| 3 |
-
|
| 4 |
-
This service uses Gemini to parse unstructured research results into
|
| 5 |
-
clean, structured CSV data suitable for bar chart race videos.
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
import os
|
| 9 |
-
import re
|
| 10 |
-
import logging
|
| 11 |
-
from io import StringIO
|
| 12 |
-
from typing import Optional, Dict, Any, List
|
| 13 |
-
|
| 14 |
-
import pandas as pd
|
| 15 |
-
|
| 16 |
-
from modules.bar_race.deep_researcher.rate_limiter import gemini_rate_limiter
|
| 17 |
-
|
| 18 |
-
logger = logging.getLogger(__name__)
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
class Analyst:
|
| 22 |
-
"""
|
| 23 |
-
Data Analyst for Bar Race video generation.
|
| 24 |
-
|
| 25 |
-
Responsibilities:
|
| 26 |
-
- Extract structured CSV from research text (LLM)
|
| 27 |
-
- Find gaps in data
|
| 28 |
-
- Clean and format final data
|
| 29 |
-
- Fix formatting issues using Groq fallback
|
| 30 |
-
"""
|
| 31 |
-
|
| 32 |
-
# Groq Model for data fixing (OpenAI GPT-OSS 120B via Groq)
|
| 33 |
-
GROQ_MODEL = "openai/gpt-oss-120b"
|
| 34 |
-
|
| 35 |
-
# Universal Data Fixer System Prompt for Groq
|
| 36 |
-
DATA_FIXER_PROMPT = '''You are an expert Data Engineer specialized in fixing and formatting
|
| 37 |
-
time-series data for Bar Chart Race animations.
|
| 38 |
-
|
| 39 |
-
INPUT: You will receive malformed CSV data that failed initial processing.
|
| 40 |
-
|
| 41 |
-
REQUIRED OUTPUT FORMAT (STRICT - ALWAYS THE SAME):
|
| 42 |
-
The CSV must be in WIDE FORMAT with exactly this structure:
|
| 43 |
-
- First column: "year" (integer years like 2010, 2011, 2012...)
|
| 44 |
-
- Remaining columns: Entity names (countries, companies, people, etc.)
|
| 45 |
-
- Each cell: Numeric value (float or integer, no symbols)
|
| 46 |
-
|
| 47 |
-
EXAMPLE OUTPUT CSV:
|
| 48 |
-
year,USA,China,India,Japan,Germany
|
| 49 |
-
2010,14992,6087,1708,5759,3417
|
| 50 |
-
2011,15543,7552,1823,6157,3757
|
| 51 |
-
2012,16197,8532,1827,6203,3543
|
| 52 |
-
2013,16785,9635,1857,5156,3752
|
| 53 |
-
|
| 54 |
-
CRITICAL RULES:
|
| 55 |
-
1. YEAR COLUMN: First column MUST be named "year" with valid years (1900-2100)
|
| 56 |
-
2. ENTITY COLUMNS: Column headers = entity names (NOT numbers or years)
|
| 57 |
-
3. VALUES: Pure numbers only (no commas, no symbols, no text)
|
| 58 |
-
4. WIDE FORMAT: Years in rows, Entities in columns
|
| 59 |
-
|
| 60 |
-
COMMON FIXES YOU MUST APPLY:
|
| 61 |
-
1. If year data is in the index/row labels β Move it to first column
|
| 62 |
-
2. If data is transposed (years as columns) β Transpose to correct format
|
| 63 |
-
3. If year column contains large numbers (population/GDP) β Find real years elsewhere
|
| 64 |
-
4. If columns are shifted β Realign correctly
|
| 65 |
-
5. If values have commas or symbols β Clean to pure numbers
|
| 66 |
-
|
| 67 |
-
OUTPUT FORMAT (STRICT JSON):
|
| 68 |
-
{
|
| 69 |
-
"status": "ok",
|
| 70 |
-
"summary": "Fixed: [describe what was wrong and how you fixed it]",
|
| 71 |
-
"csv": "year,Entity1,Entity2,...\\n2010,100,200,...\\n2011,110,210,..."
|
| 72 |
-
}
|
| 73 |
-
|
| 74 |
-
If unfixable, return:
|
| 75 |
-
{
|
| 76 |
-
"status": "abort",
|
| 77 |
-
"reason": "Cannot fix because: [specific reason]"
|
| 78 |
-
}'''
|
| 79 |
-
|
| 80 |
-
def __init__(self, gemini_api_key: str = None, groq_api_key: str = None):
|
| 81 |
-
self.gemini_api_key = gemini_api_key or os.getenv("GEMINI_API_KEY")
|
| 82 |
-
self.groq_api_key = groq_api_key or os.getenv("GROQ_API_KEY")
|
| 83 |
-
self.gemini_client = None
|
| 84 |
-
self.groq_client = None
|
| 85 |
-
|
| 86 |
-
# Initialize Gemini client
|
| 87 |
-
if self.gemini_api_key:
|
| 88 |
-
try:
|
| 89 |
-
from google import genai
|
| 90 |
-
self.gemini_client = genai.Client(api_key=self.gemini_api_key)
|
| 91 |
-
logger.info("Analyst: Gemini client initialized")
|
| 92 |
-
except Exception as e:
|
| 93 |
-
logger.warning(f"Analyst: Gemini init failed: {e}")
|
| 94 |
-
|
| 95 |
-
# Initialize Groq client (fallback for data fixing)
|
| 96 |
-
if self.groq_api_key:
|
| 97 |
-
try:
|
| 98 |
-
from groq import Groq
|
| 99 |
-
self.groq_client = Groq(api_key=self.groq_api_key)
|
| 100 |
-
logger.info("Analyst: Groq client initialized (data fixer fallback)")
|
| 101 |
-
except Exception as e:
|
| 102 |
-
logger.warning(f"Analyst: Groq init failed: {e}")
|
| 103 |
-
|
| 104 |
-
def extract_csv(self, research_text: str, topic: str) -> Optional[pd.DataFrame]:
|
| 105 |
-
"""
|
| 106 |
-
Extract structured CSV data from research text using Gemini.
|
| 107 |
-
|
| 108 |
-
Args:
|
| 109 |
-
research_text: Raw text from Deep Researcher
|
| 110 |
-
topic: Original research topic for context
|
| 111 |
-
|
| 112 |
-
Returns:
|
| 113 |
-
DataFrame in Wide Format (year as rows, entities as columns)
|
| 114 |
-
"""
|
| 115 |
-
if not self.gemini_client:
|
| 116 |
-
logger.error("Analyst: No Gemini client available")
|
| 117 |
-
return None
|
| 118 |
-
|
| 119 |
-
# Very explicit prompt for consistent Wide Format
|
| 120 |
-
prompt = f"""Act as a Data Extraction Expert for Bar Chart Race Animation.
|
| 121 |
-
|
| 122 |
-
TOPIC: {topic}
|
| 123 |
-
|
| 124 |
-
RESEARCH TEXT:
|
| 125 |
-
{research_text[:8000]}
|
| 126 |
-
|
| 127 |
-
---
|
| 128 |
-
|
| 129 |
-
## CRITICAL OUTPUT FORMAT (Wide Format CSV)
|
| 130 |
-
|
| 131 |
-
The CSV MUST follow this EXACT structure:
|
| 132 |
-
|
| 133 |
-
```
|
| 134 |
-
year,Entity1,Entity2,Entity3,...
|
| 135 |
-
2000,100,200,150,...
|
| 136 |
-
2001,120,210,160,...
|
| 137 |
-
2002,140,220,180,...
|
| 138 |
-
```
|
| 139 |
-
|
| 140 |
-
### RULES (STRICT):
|
| 141 |
-
|
| 142 |
-
1. FIRST COLUMN must be named 'year' (lowercase)
|
| 143 |
-
2. FIRST COLUMN contains TIME values (2000, 2001, 2002... OR Over 1, Over 2...)
|
| 144 |
-
3. OTHER COLUMNS are ENTITY NAMES (USA, China, India OR Virat Kohli, Sachin...)
|
| 145 |
-
4. CELLS contain NUMERIC VALUES only (no currency symbols, no commas)
|
| 146 |
-
5. Each ROW = One time period
|
| 147 |
-
6. Each COLUMN (after year) = One competitor/entity
|
| 148 |
-
|
| 149 |
-
### EXAMPLE OUTPUT:
|
| 150 |
-
|
| 151 |
-
year,United States,China,Japan,Germany
|
| 152 |
-
2010,14992,6087,5700,3396
|
| 153 |
-
2011,15543,7552,5893,3757
|
| 154 |
-
2012,16197,8532,5954,3543
|
| 155 |
-
|
| 156 |
-
---
|
| 157 |
-
|
| 158 |
-
OUTPUT ONLY THE CSV DATA (no markdown, no explanations, no backticks):
|
| 159 |
-
"""
|
| 160 |
-
|
| 161 |
-
try:
|
| 162 |
-
gemini_rate_limiter.acquire() # Rate limit before API call
|
| 163 |
-
response = self.gemini_client.models.generate_content(
|
| 164 |
-
model="gemma-3-27b-it",
|
| 165 |
-
contents=prompt
|
| 166 |
-
)
|
| 167 |
-
|
| 168 |
-
csv_content = response.text.strip()
|
| 169 |
-
|
| 170 |
-
# Clean markdown formatting
|
| 171 |
-
if csv_content.startswith("```csv"):
|
| 172 |
-
csv_content = csv_content[6:]
|
| 173 |
-
elif csv_content.startswith("```"):
|
| 174 |
-
csv_content = csv_content[3:]
|
| 175 |
-
if csv_content.endswith("```"):
|
| 176 |
-
csv_content = csv_content[:-3]
|
| 177 |
-
csv_content = csv_content.strip()
|
| 178 |
-
|
| 179 |
-
if not csv_content:
|
| 180 |
-
logger.warning("Analyst: Empty CSV response from Gemini")
|
| 181 |
-
return None
|
| 182 |
-
|
| 183 |
-
# Parse CSV
|
| 184 |
-
df = pd.read_csv(StringIO(csv_content))
|
| 185 |
-
logger.info(f"Analyst: Extracted CSV with {len(df)} rows, {len(df.columns)} columns")
|
| 186 |
-
|
| 187 |
-
# Validate and fix format if needed
|
| 188 |
-
df = self._validate_and_fix_format(df)
|
| 189 |
-
|
| 190 |
-
return df
|
| 191 |
-
|
| 192 |
-
except Exception as e:
|
| 193 |
-
logger.error(f"Analyst: CSV extraction failed: {e}")
|
| 194 |
-
return None
|
| 195 |
-
|
| 196 |
-
def _validate_and_fix_format(self, df: pd.DataFrame) -> pd.DataFrame:
|
| 197 |
-
"""
|
| 198 |
-
Validate CSV is in correct Wide Format and fix if needed.
|
| 199 |
-
|
| 200 |
-
Expected Wide Format:
|
| 201 |
-
- First column: year/time (2000, 2001, 2002...)
|
| 202 |
-
- Other columns: entity names (USA, China, India...)
|
| 203 |
-
- Values: numeric
|
| 204 |
-
|
| 205 |
-
Detects and fixes:
|
| 206 |
-
- Transposed format (entities in rows, years in columns)
|
| 207 |
-
- Wrong column order
|
| 208 |
-
"""
|
| 209 |
-
if df is None or df.empty:
|
| 210 |
-
return df
|
| 211 |
-
|
| 212 |
-
first_col = df.columns[0]
|
| 213 |
-
first_col_lower = str(first_col).lower()
|
| 214 |
-
|
| 215 |
-
# Check if first column looks like year/time
|
| 216 |
-
is_first_col_time = False
|
| 217 |
-
|
| 218 |
-
if first_col_lower in ['year', 'date', 'time', 'month', 'period', 'over']:
|
| 219 |
-
is_first_col_time = True
|
| 220 |
-
else:
|
| 221 |
-
# Check if values look like years (1900-2100) or sequential numbers
|
| 222 |
-
sample_values = df[first_col].dropna().head(10)
|
| 223 |
-
try:
|
| 224 |
-
numeric_values = pd.to_numeric(sample_values, errors='coerce')
|
| 225 |
-
valid_years = numeric_values.apply(lambda x: 1900 <= x <= 2100 if pd.notna(x) else False)
|
| 226 |
-
if valid_years.sum() / len(valid_years) > 0.8:
|
| 227 |
-
is_first_col_time = True
|
| 228 |
-
logger.info(f"Analyst: Detected first column '{first_col}' contains year values")
|
| 229 |
-
except:
|
| 230 |
-
pass
|
| 231 |
-
|
| 232 |
-
if is_first_col_time:
|
| 233 |
-
# Format is correct (Wide Format)
|
| 234 |
-
logger.info("Analyst: Data is in correct Wide Format")
|
| 235 |
-
return df
|
| 236 |
-
|
| 237 |
-
# Check if COLUMNS look like years (Transposed Format)
|
| 238 |
-
year_like_columns = []
|
| 239 |
-
for col in df.columns[1:]:
|
| 240 |
-
try:
|
| 241 |
-
year_val = int(col)
|
| 242 |
-
if 1900 <= year_val <= 2100:
|
| 243 |
-
year_like_columns.append(col)
|
| 244 |
-
except:
|
| 245 |
-
pass
|
| 246 |
-
|
| 247 |
-
if len(year_like_columns) > 5:
|
| 248 |
-
# Data is transposed! Entities in rows, Years in columns
|
| 249 |
-
logger.warning("Analyst: Detected Transposed Format, converting to Wide Format...")
|
| 250 |
-
|
| 251 |
-
# First column contains entity names
|
| 252 |
-
df = df.rename(columns={first_col: 'entity'})
|
| 253 |
-
|
| 254 |
-
# Melt to long format first
|
| 255 |
-
df_long = df.melt(
|
| 256 |
-
id_vars=['entity'],
|
| 257 |
-
var_name='year',
|
| 258 |
-
value_name='value'
|
| 259 |
-
)
|
| 260 |
-
|
| 261 |
-
# Pivot to correct Wide Format
|
| 262 |
-
df_wide = df_long.pivot(index='year', columns='entity', values='value').reset_index()
|
| 263 |
-
|
| 264 |
-
logger.info(f"Analyst: Converted from Transposed to Wide Format: {df_wide.shape}")
|
| 265 |
-
return df_wide
|
| 266 |
-
|
| 267 |
-
# If we can't determine format, assume it's correct
|
| 268 |
-
logger.warning("Analyst: Could not determine format, assuming Wide Format")
|
| 269 |
-
return df
|
| 270 |
-
|
| 271 |
-
def find_gaps(self, df: pd.DataFrame) -> List[Dict[str, Any]]:
|
| 272 |
-
"""
|
| 273 |
-
Find missing data points in the DataFrame.
|
| 274 |
-
|
| 275 |
-
Args:
|
| 276 |
-
df: DataFrame with potential gaps
|
| 277 |
-
|
| 278 |
-
Returns:
|
| 279 |
-
List of gap descriptions (year, entity, missing_type)
|
| 280 |
-
"""
|
| 281 |
-
gaps = []
|
| 282 |
-
|
| 283 |
-
if df is None or df.empty:
|
| 284 |
-
return gaps
|
| 285 |
-
|
| 286 |
-
# Assume first column is 'year'
|
| 287 |
-
year_col = df.columns[0]
|
| 288 |
-
entity_cols = df.columns[1:]
|
| 289 |
-
|
| 290 |
-
for col in entity_cols:
|
| 291 |
-
for idx, row in df.iterrows():
|
| 292 |
-
value = row[col]
|
| 293 |
-
if pd.isna(value) or value == "":
|
| 294 |
-
gaps.append({
|
| 295 |
-
"year": row[year_col],
|
| 296 |
-
"entity": col,
|
| 297 |
-
"type": "missing_value"
|
| 298 |
-
})
|
| 299 |
-
|
| 300 |
-
logger.info(f"Analyst: Found {len(gaps)} data gaps")
|
| 301 |
-
return gaps
|
| 302 |
-
|
| 303 |
-
def clean_data(self, df: pd.DataFrame) -> pd.DataFrame:
|
| 304 |
-
"""
|
| 305 |
-
Final cleaning and formatting of data for video generation.
|
| 306 |
-
|
| 307 |
-
Args:
|
| 308 |
-
df: DataFrame to clean
|
| 309 |
-
|
| 310 |
-
Returns:
|
| 311 |
-
Cleaned DataFrame ready for Director
|
| 312 |
-
"""
|
| 313 |
-
if df is None or df.empty:
|
| 314 |
-
return df
|
| 315 |
-
|
| 316 |
-
# Log first few rows for debugging
|
| 317 |
-
logger.info(f"Analyst: Raw data before cleaning:\n{df.head(3).to_string()}")
|
| 318 |
-
|
| 319 |
-
# Check if years are in the index instead of first column
|
| 320 |
-
# This happens when LLM outputs data with year as row index
|
| 321 |
-
first_col = df.columns[0]
|
| 322 |
-
first_col_values = df[first_col].head(3).tolist()
|
| 323 |
-
|
| 324 |
-
# Check if first column values are too large to be years (likely population data)
|
| 325 |
-
try:
|
| 326 |
-
first_values_numeric = [float(v) for v in first_col_values if v is not None and str(v).replace('.','').isdigit()]
|
| 327 |
-
if first_values_numeric and min(first_values_numeric) > 100000:
|
| 328 |
-
# First column contains large numbers (likely population/GDP), not years
|
| 329 |
-
# Check if index contains valid years
|
| 330 |
-
index_values = df.index.tolist()[:5]
|
| 331 |
-
try:
|
| 332 |
-
index_years = [int(float(v)) for v in index_values if str(v).replace('.','').isdigit()]
|
| 333 |
-
if index_years and all(1900 <= y <= 2100 for y in index_years):
|
| 334 |
-
logger.info(f"Analyst: Detected years in index ({index_years[:3]}...), fixing data structure")
|
| 335 |
-
# Reset index to make year a column
|
| 336 |
-
df = df.reset_index()
|
| 337 |
-
df = df.rename(columns={'index': 'year'})
|
| 338 |
-
except:
|
| 339 |
-
pass
|
| 340 |
-
except:
|
| 341 |
-
pass
|
| 342 |
-
|
| 343 |
-
# Ensure year column is named correctly
|
| 344 |
-
first_col = df.columns[0]
|
| 345 |
-
if first_col.lower() != 'year':
|
| 346 |
-
df = df.rename(columns={first_col: 'year'})
|
| 347 |
-
|
| 348 |
-
# Try to extract year from various formats
|
| 349 |
-
def extract_year(val):
|
| 350 |
-
"""Extract year from various formats like '2010', 2010, '2010-01-01', etc."""
|
| 351 |
-
try:
|
| 352 |
-
# If already a number
|
| 353 |
-
if isinstance(val, (int, float)):
|
| 354 |
-
return int(val)
|
| 355 |
-
|
| 356 |
-
# If string, try to extract 4-digit year
|
| 357 |
-
val_str = str(val).strip()
|
| 358 |
-
|
| 359 |
-
# Try direct conversion first
|
| 360 |
-
if val_str.isdigit() and len(val_str) == 4:
|
| 361 |
-
return int(val_str)
|
| 362 |
-
|
| 363 |
-
# Try to find 4-digit year pattern
|
| 364 |
-
import re
|
| 365 |
-
year_match = re.search(r'(19|20)\d{2}', val_str)
|
| 366 |
-
if year_match:
|
| 367 |
-
return int(year_match.group())
|
| 368 |
-
|
| 369 |
-
# Last resort: try float conversion
|
| 370 |
-
return int(float(val_str))
|
| 371 |
-
except:
|
| 372 |
-
return None
|
| 373 |
-
|
| 374 |
-
# Apply year extraction
|
| 375 |
-
df['year'] = df['year'].apply(extract_year)
|
| 376 |
-
|
| 377 |
-
# Log after extraction
|
| 378 |
-
logger.info(f"Analyst: Years extracted: {df['year'].tolist()[:5]}...")
|
| 379 |
-
|
| 380 |
-
# Drop rows where year extraction failed
|
| 381 |
-
df = df.dropna(subset=['year'])
|
| 382 |
-
df['year'] = df['year'].astype(int)
|
| 383 |
-
|
| 384 |
-
# β
CRITICAL: Filter out invalid years (only 1900-2100 allowed)
|
| 385 |
-
original_len = len(df)
|
| 386 |
-
df = df[(df['year'] >= 1900) & (df['year'] <= 2100)]
|
| 387 |
-
|
| 388 |
-
if len(df) < original_len:
|
| 389 |
-
logger.warning(f"Analyst: Removed {original_len - len(df)} rows with invalid years (outside 1900-2100)")
|
| 390 |
-
|
| 391 |
-
if df.empty:
|
| 392 |
-
logger.error("Analyst: No valid years found in data!")
|
| 393 |
-
return df
|
| 394 |
-
|
| 395 |
-
# Convert all value columns to numeric
|
| 396 |
-
for col in df.columns[1:]:
|
| 397 |
-
df[col] = pd.to_numeric(df[col], errors='coerce')
|
| 398 |
-
|
| 399 |
-
# Sort by year
|
| 400 |
-
df = df.sort_values('year').reset_index(drop=True)
|
| 401 |
-
|
| 402 |
-
# Interpolate missing values
|
| 403 |
-
for col in df.columns[1:]:
|
| 404 |
-
df[col] = df[col].interpolate(method='linear')
|
| 405 |
-
|
| 406 |
-
logger.info(f"Analyst: Cleaned data - {len(df)} rows, years {df['year'].min()}-{df['year'].max()}")
|
| 407 |
-
|
| 408 |
-
return df
|
| 409 |
-
|
| 410 |
-
def convert_to_long_format(self, df: pd.DataFrame) -> pd.DataFrame:
|
| 411 |
-
"""
|
| 412 |
-
Convert Wide Format to Long Format for bar_chart_race.
|
| 413 |
-
|
| 414 |
-
Wide: year, USA, China, India
|
| 415 |
-
Long: name, year, value
|
| 416 |
-
|
| 417 |
-
Args:
|
| 418 |
-
df: Wide format DataFrame
|
| 419 |
-
|
| 420 |
-
Returns:
|
| 421 |
-
Long format DataFrame
|
| 422 |
-
"""
|
| 423 |
-
if df is None or df.empty:
|
| 424 |
-
return df
|
| 425 |
-
|
| 426 |
-
df_long = df.melt(
|
| 427 |
-
id_vars=['year'],
|
| 428 |
-
var_name='name',
|
| 429 |
-
value_name='value'
|
| 430 |
-
)
|
| 431 |
-
|
| 432 |
-
# Ensure correct types
|
| 433 |
-
df_long['year'] = df_long['year'].astype(int)
|
| 434 |
-
df_long['value'] = pd.to_numeric(df_long['value'], errors='coerce')
|
| 435 |
-
df_long = df_long.dropna(subset=['value'])
|
| 436 |
-
|
| 437 |
-
logger.info(f"Analyst: Converted to long format - {len(df_long)} records")
|
| 438 |
-
|
| 439 |
-
return df_long
|
| 440 |
-
|
| 441 |
-
def fix_with_groq(self, raw_csv: str, topic: str) -> Optional[pd.DataFrame]:
|
| 442 |
-
"""
|
| 443 |
-
Use Groq GPT-OSS 120B to fix malformed CSV data.
|
| 444 |
-
|
| 445 |
-
This is a fallback when the initial clean_data fails due to:
|
| 446 |
-
- Year column issues
|
| 447 |
-
- Column misalignment
|
| 448 |
-
- Format confusion (wide/long)
|
| 449 |
-
|
| 450 |
-
Args:
|
| 451 |
-
raw_csv: The raw CSV string that failed cleaning
|
| 452 |
-
topic: Topic context for better understanding
|
| 453 |
-
|
| 454 |
-
Returns:
|
| 455 |
-
Fixed DataFrame or None if unfixable
|
| 456 |
-
"""
|
| 457 |
-
if not self.groq_client:
|
| 458 |
-
logger.warning("Analyst: Groq client not available for data fixing")
|
| 459 |
-
return None
|
| 460 |
-
|
| 461 |
-
logger.info("Analyst: Attempting to fix data with Groq...")
|
| 462 |
-
|
| 463 |
-
try:
|
| 464 |
-
user_message = f"""Topic: {topic}
|
| 465 |
-
|
| 466 |
-
The following CSV data has formatting issues. Please fix it:
|
| 467 |
-
|
| 468 |
-
```csv
|
| 469 |
-
{raw_csv}
|
| 470 |
-
```
|
| 471 |
-
|
| 472 |
-
Analyze the data structure and return a properly formatted CSV ready for bar_chart_race animation."""
|
| 473 |
-
|
| 474 |
-
completion = self.groq_client.chat.completions.create(
|
| 475 |
-
model=self.GROQ_MODEL,
|
| 476 |
-
messages=[
|
| 477 |
-
{"role": "system", "content": self.DATA_FIXER_PROMPT},
|
| 478 |
-
{"role": "user", "content": user_message}
|
| 479 |
-
],
|
| 480 |
-
temperature=0.1,
|
| 481 |
-
max_tokens=4000,
|
| 482 |
-
)
|
| 483 |
-
|
| 484 |
-
response_text = completion.choices[0].message.content.strip()
|
| 485 |
-
logger.info(f"Analyst: Groq response received ({len(response_text)} chars)")
|
| 486 |
-
|
| 487 |
-
# Parse JSON response
|
| 488 |
-
import json
|
| 489 |
-
|
| 490 |
-
# Try to extract JSON from response
|
| 491 |
-
json_match = re.search(r'\{[\s\S]*\}', response_text)
|
| 492 |
-
if not json_match:
|
| 493 |
-
logger.error("Analyst: Groq response is not valid JSON")
|
| 494 |
-
return None
|
| 495 |
-
|
| 496 |
-
result = json.loads(json_match.group())
|
| 497 |
-
|
| 498 |
-
if result.get("status") == "abort":
|
| 499 |
-
logger.warning(f"Analyst: Groq aborted - {result.get('reason', 'unknown')}")
|
| 500 |
-
return None
|
| 501 |
-
|
| 502 |
-
if result.get("status") == "ok":
|
| 503 |
-
logger.info(f"Analyst: Groq fixed data - {result.get('summary', 'no summary')}")
|
| 504 |
-
csv_content = result.get("csv", "")
|
| 505 |
-
|
| 506 |
-
if csv_content:
|
| 507 |
-
# Parse the fixed CSV
|
| 508 |
-
df = pd.read_csv(StringIO(csv_content))
|
| 509 |
-
logger.info(f"Analyst: Groq produced {len(df)} rows, {len(df.columns)} columns")
|
| 510 |
-
return df
|
| 511 |
-
|
| 512 |
-
logger.warning("Analyst: Groq response did not contain valid data")
|
| 513 |
-
return None
|
| 514 |
-
|
| 515 |
-
except Exception as e:
|
| 516 |
-
logger.error(f"Analyst: Groq fix failed - {e}")
|
| 517 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/bar_race/services/artist.py
DELETED
|
@@ -1,301 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Artist - Image Processor
|
| 3 |
-
Downloads and processes entity images for bar chart race.
|
| 4 |
-
"""
|
| 5 |
-
import logging
|
| 6 |
-
import requests
|
| 7 |
-
import os
|
| 8 |
-
from PIL import Image, ImageDraw
|
| 9 |
-
from typing import Dict, Any, List, Optional
|
| 10 |
-
from io import BytesIO
|
| 11 |
-
|
| 12 |
-
logger = logging.getLogger(__name__)
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
class Artist:
|
| 16 |
-
"""
|
| 17 |
-
Image Processor for Bar Race video generation.
|
| 18 |
-
|
| 19 |
-
Responsibilities:
|
| 20 |
-
- Search and download entity images
|
| 21 |
-
- Background removal (optional, if rembg available)
|
| 22 |
-
- Face detection for person entities
|
| 23 |
-
- Circular mask application
|
| 24 |
-
"""
|
| 25 |
-
|
| 26 |
-
HEADERS = {
|
| 27 |
-
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
|
| 28 |
-
}
|
| 29 |
-
|
| 30 |
-
# Image size for bar chart
|
| 31 |
-
IMAGE_SIZE = 80
|
| 32 |
-
|
| 33 |
-
def __init__(self, temp_dir: str):
|
| 34 |
-
self.temp_dir = temp_dir
|
| 35 |
-
self.images_dir = os.path.join(temp_dir, "images")
|
| 36 |
-
os.makedirs(self.images_dir, exist_ok=True)
|
| 37 |
-
|
| 38 |
-
# Check if rembg is available
|
| 39 |
-
self.rembg_available = False
|
| 40 |
-
try:
|
| 41 |
-
import rembg
|
| 42 |
-
self.rembg_available = True
|
| 43 |
-
logger.info("Artist: rembg available for background removal")
|
| 44 |
-
except ImportError:
|
| 45 |
-
logger.info("Artist: rembg not available, skipping background removal")
|
| 46 |
-
|
| 47 |
-
def process_entities(self, entities: List[str], entity_type: str) -> Dict[str, str]:
|
| 48 |
-
"""
|
| 49 |
-
Download and process images for all entities.
|
| 50 |
-
|
| 51 |
-
Args:
|
| 52 |
-
entities: List of entity names
|
| 53 |
-
entity_type: Type of entity (person, country, company, general)
|
| 54 |
-
|
| 55 |
-
Returns:
|
| 56 |
-
Dict mapping entity name to processed image path
|
| 57 |
-
"""
|
| 58 |
-
logger.info(f"Artist: Processing images for {len(entities)} entities (type: {entity_type})")
|
| 59 |
-
|
| 60 |
-
image_paths = {}
|
| 61 |
-
|
| 62 |
-
for entity in entities:
|
| 63 |
-
try:
|
| 64 |
-
image_path = self._process_entity(entity, entity_type)
|
| 65 |
-
if image_path:
|
| 66 |
-
image_paths[entity] = image_path
|
| 67 |
-
logger.debug(f"Artist: Processed image for {entity}")
|
| 68 |
-
else:
|
| 69 |
-
logger.warning(f"Artist: No image found for {entity}")
|
| 70 |
-
except Exception as e:
|
| 71 |
-
logger.warning(f"Artist: Failed to process {entity}: {e}")
|
| 72 |
-
|
| 73 |
-
logger.info(f"Artist: Processed {len(image_paths)}/{len(entities)} images")
|
| 74 |
-
return image_paths
|
| 75 |
-
|
| 76 |
-
def _process_entity(self, entity: str, entity_type: str) -> Optional[str]:
|
| 77 |
-
"""Process a single entity's image"""
|
| 78 |
-
# Try to get image
|
| 79 |
-
image = self._get_image(entity, entity_type)
|
| 80 |
-
|
| 81 |
-
if image is None:
|
| 82 |
-
return None
|
| 83 |
-
|
| 84 |
-
# Process image
|
| 85 |
-
try:
|
| 86 |
-
# Resize to square
|
| 87 |
-
image = image.convert("RGBA")
|
| 88 |
-
image = self._resize_to_square(image)
|
| 89 |
-
|
| 90 |
-
# Remove background if rembg available and it's a person
|
| 91 |
-
if self.rembg_available and entity_type == "person":
|
| 92 |
-
image = self._remove_background(image)
|
| 93 |
-
|
| 94 |
-
# Apply circular mask
|
| 95 |
-
image = self._apply_circular_mask(image)
|
| 96 |
-
|
| 97 |
-
# Save processed image
|
| 98 |
-
safe_name = "".join(c if c.isalnum() else "_" for c in entity)
|
| 99 |
-
output_path = os.path.join(self.images_dir, f"{safe_name}.png")
|
| 100 |
-
image.save(output_path, "PNG")
|
| 101 |
-
|
| 102 |
-
return output_path
|
| 103 |
-
|
| 104 |
-
except Exception as e:
|
| 105 |
-
logger.error(f"Artist: Error processing image for {entity}: {e}")
|
| 106 |
-
return None
|
| 107 |
-
|
| 108 |
-
def _get_image(self, entity: str, entity_type: str) -> Optional[Image.Image]:
|
| 109 |
-
"""Get image for an entity"""
|
| 110 |
-
|
| 111 |
-
# Priority 1: Wikipedia Commons
|
| 112 |
-
image = self._search_wikipedia_commons(entity, entity_type)
|
| 113 |
-
if image:
|
| 114 |
-
return image
|
| 115 |
-
|
| 116 |
-
# Priority 2: DuckDuckGo image search
|
| 117 |
-
image = self._search_duckduckgo(entity, entity_type)
|
| 118 |
-
if image:
|
| 119 |
-
return image
|
| 120 |
-
|
| 121 |
-
# Priority 3: Generate placeholder
|
| 122 |
-
return self._generate_placeholder(entity)
|
| 123 |
-
|
| 124 |
-
def _search_wikipedia_commons(self, entity: str, entity_type: str) -> Optional[Image.Image]:
|
| 125 |
-
"""Search Wikipedia Commons for entity image"""
|
| 126 |
-
try:
|
| 127 |
-
# For countries, search for flag
|
| 128 |
-
if entity_type == "country":
|
| 129 |
-
search_query = f"Flag of {entity}"
|
| 130 |
-
else:
|
| 131 |
-
search_query = entity
|
| 132 |
-
|
| 133 |
-
# Wikipedia API search
|
| 134 |
-
search_url = "https://en.wikipedia.org/w/api.php"
|
| 135 |
-
params = {
|
| 136 |
-
"action": "query",
|
| 137 |
-
"titles": search_query,
|
| 138 |
-
"prop": "pageimages",
|
| 139 |
-
"format": "json",
|
| 140 |
-
"pithumbsize": 200
|
| 141 |
-
}
|
| 142 |
-
|
| 143 |
-
response = requests.get(search_url, params=params, headers=self.HEADERS, timeout=10)
|
| 144 |
-
if response.status_code == 200:
|
| 145 |
-
data = response.json()
|
| 146 |
-
pages = data.get("query", {}).get("pages", {})
|
| 147 |
-
|
| 148 |
-
for page_id, page_data in pages.items():
|
| 149 |
-
if "thumbnail" in page_data:
|
| 150 |
-
image_url = page_data["thumbnail"]["source"]
|
| 151 |
-
return self._download_image(image_url)
|
| 152 |
-
|
| 153 |
-
except Exception as e:
|
| 154 |
-
logger.debug(f"Artist: Wikipedia Commons search failed for {entity}: {e}")
|
| 155 |
-
|
| 156 |
-
return None
|
| 157 |
-
|
| 158 |
-
def _search_duckduckgo(self, entity: str, entity_type: str) -> Optional[Image.Image]:
|
| 159 |
-
"""Search DuckDuckGo for entity image"""
|
| 160 |
-
try:
|
| 161 |
-
from ddgs import DDGS
|
| 162 |
-
|
| 163 |
-
# Build search query
|
| 164 |
-
if entity_type == "country":
|
| 165 |
-
query = f"{entity} flag icon"
|
| 166 |
-
elif entity_type == "person":
|
| 167 |
-
query = f"{entity} portrait photo"
|
| 168 |
-
else:
|
| 169 |
-
query = f"{entity} logo"
|
| 170 |
-
|
| 171 |
-
with DDGS() as ddgs:
|
| 172 |
-
results = list(ddgs.images(query, max_results=3))
|
| 173 |
-
|
| 174 |
-
for result in results:
|
| 175 |
-
image_url = result.get("image")
|
| 176 |
-
if image_url:
|
| 177 |
-
image = self._download_image(image_url)
|
| 178 |
-
if image:
|
| 179 |
-
return image
|
| 180 |
-
|
| 181 |
-
except ImportError:
|
| 182 |
-
logger.debug("Artist: duckduckgo-search not available")
|
| 183 |
-
except Exception as e:
|
| 184 |
-
logger.debug(f"Artist: DuckDuckGo search failed for {entity}: {e}")
|
| 185 |
-
|
| 186 |
-
return None
|
| 187 |
-
|
| 188 |
-
def _download_image(self, url: str) -> Optional[Image.Image]:
|
| 189 |
-
"""Download image from URL"""
|
| 190 |
-
try:
|
| 191 |
-
response = requests.get(url, headers=self.HEADERS, timeout=10)
|
| 192 |
-
if response.status_code == 200:
|
| 193 |
-
return Image.open(BytesIO(response.content))
|
| 194 |
-
except Exception as e:
|
| 195 |
-
logger.debug(f"Artist: Failed to download image: {e}")
|
| 196 |
-
|
| 197 |
-
return None
|
| 198 |
-
|
| 199 |
-
def _resize_to_square(self, image: Image.Image) -> Image.Image:
|
| 200 |
-
"""Resize image to square, center cropping if needed"""
|
| 201 |
-
width, height = image.size
|
| 202 |
-
|
| 203 |
-
# Determine crop box for square
|
| 204 |
-
if width > height:
|
| 205 |
-
left = (width - height) // 2
|
| 206 |
-
top = 0
|
| 207 |
-
right = left + height
|
| 208 |
-
bottom = height
|
| 209 |
-
else:
|
| 210 |
-
left = 0
|
| 211 |
-
top = (height - width) // 2
|
| 212 |
-
right = width
|
| 213 |
-
bottom = top + width
|
| 214 |
-
|
| 215 |
-
# Crop to square
|
| 216 |
-
image = image.crop((left, top, right, bottom))
|
| 217 |
-
|
| 218 |
-
# Resize to target size
|
| 219 |
-
image = image.resize((self.IMAGE_SIZE, self.IMAGE_SIZE), Image.Resampling.LANCZOS)
|
| 220 |
-
|
| 221 |
-
return image
|
| 222 |
-
|
| 223 |
-
def _remove_background(self, image: Image.Image) -> Image.Image:
|
| 224 |
-
"""Remove background using rembg"""
|
| 225 |
-
try:
|
| 226 |
-
import rembg
|
| 227 |
-
|
| 228 |
-
# Convert to bytes
|
| 229 |
-
img_bytes = BytesIO()
|
| 230 |
-
image.save(img_bytes, format="PNG")
|
| 231 |
-
img_bytes.seek(0)
|
| 232 |
-
|
| 233 |
-
# Remove background
|
| 234 |
-
output = rembg.remove(img_bytes.getvalue())
|
| 235 |
-
|
| 236 |
-
return Image.open(BytesIO(output))
|
| 237 |
-
|
| 238 |
-
except Exception as e:
|
| 239 |
-
logger.warning(f"Artist: Background removal failed: {e}")
|
| 240 |
-
return image
|
| 241 |
-
|
| 242 |
-
def _apply_circular_mask(self, image: Image.Image) -> Image.Image:
|
| 243 |
-
"""Apply circular mask to image"""
|
| 244 |
-
# Ensure RGBA
|
| 245 |
-
if image.mode != "RGBA":
|
| 246 |
-
image = image.convert("RGBA")
|
| 247 |
-
|
| 248 |
-
size = image.size[0]
|
| 249 |
-
|
| 250 |
-
# Create circular mask
|
| 251 |
-
mask = Image.new("L", (size, size), 0)
|
| 252 |
-
draw = ImageDraw.Draw(mask)
|
| 253 |
-
draw.ellipse((0, 0, size, size), fill=255)
|
| 254 |
-
|
| 255 |
-
# Apply mask
|
| 256 |
-
output = Image.new("RGBA", (size, size), (0, 0, 0, 0))
|
| 257 |
-
output.paste(image, (0, 0), mask)
|
| 258 |
-
|
| 259 |
-
return output
|
| 260 |
-
|
| 261 |
-
def _generate_placeholder(self, entity: str) -> Image.Image:
|
| 262 |
-
"""Generate a placeholder image with entity initial"""
|
| 263 |
-
size = self.IMAGE_SIZE
|
| 264 |
-
|
| 265 |
-
# Create colored background
|
| 266 |
-
colors = [
|
| 267 |
-
(74, 222, 128), # Green
|
| 268 |
-
(251, 191, 36), # Yellow
|
| 269 |
-
(239, 68, 68), # Red
|
| 270 |
-
(59, 130, 246), # Blue
|
| 271 |
-
(168, 85, 247), # Purple
|
| 272 |
-
(20, 184, 166), # Teal
|
| 273 |
-
]
|
| 274 |
-
|
| 275 |
-
# Pick color based on entity name hash
|
| 276 |
-
color = colors[hash(entity) % len(colors)]
|
| 277 |
-
|
| 278 |
-
# Create image
|
| 279 |
-
image = Image.new("RGBA", (size, size), color)
|
| 280 |
-
draw = ImageDraw.Draw(image)
|
| 281 |
-
|
| 282 |
-
# Draw initial
|
| 283 |
-
initial = entity[0].upper() if entity else "?"
|
| 284 |
-
|
| 285 |
-
# Use default font
|
| 286 |
-
try:
|
| 287 |
-
from PIL import ImageFont
|
| 288 |
-
font = ImageFont.truetype("arial.ttf", size // 2)
|
| 289 |
-
except:
|
| 290 |
-
font = ImageFont.load_default()
|
| 291 |
-
|
| 292 |
-
# Center text
|
| 293 |
-
bbox = draw.textbbox((0, 0), initial, font=font)
|
| 294 |
-
text_width = bbox[2] - bbox[0]
|
| 295 |
-
text_height = bbox[3] - bbox[1]
|
| 296 |
-
x = (size - text_width) // 2
|
| 297 |
-
y = (size - text_height) // 2 - bbox[1]
|
| 298 |
-
|
| 299 |
-
draw.text((x, y), initial, fill=(255, 255, 255), font=font)
|
| 300 |
-
|
| 301 |
-
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/bar_race/services/brain.py
DELETED
|
@@ -1,365 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Brain - Topic Enhancement Agent
|
| 3 |
-
|
| 4 |
-
Uses Gemini to convert vague user topics into research-ready prompts
|
| 5 |
-
for the Deep Researcher module.
|
| 6 |
-
|
| 7 |
-
Simplified pipeline:
|
| 8 |
-
User Topic β Brain.enhance_topic() β Deep Researcher β Analyst β Director
|
| 9 |
-
"""
|
| 10 |
-
import logging
|
| 11 |
-
import os
|
| 12 |
-
from typing import Optional
|
| 13 |
-
|
| 14 |
-
logger = logging.getLogger(__name__)
|
| 15 |
-
|
| 16 |
-
# Import rate limiter (lazy import to avoid circular dependency)
|
| 17 |
-
def get_rate_limiter():
|
| 18 |
-
from modules.bar_race.deep_researcher.rate_limiter import gemini_rate_limiter
|
| 19 |
-
return gemini_rate_limiter
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
class Brain:
|
| 23 |
-
"""
|
| 24 |
-
Topic Enhancement Agent for Bar Race video generation.
|
| 25 |
-
|
| 26 |
-
Responsibilities:
|
| 27 |
-
- Convert vague user topics into research-ready prompts
|
| 28 |
-
- Fill data gaps using Gemini knowledge (if needed)
|
| 29 |
-
"""
|
| 30 |
-
|
| 31 |
-
GEMINI_MODEL = "gemma-3-27b-it"
|
| 32 |
-
|
| 33 |
-
# Universal Topic Enhancer Prompt
|
| 34 |
-
TOPIC_ENHANCER_PROMPT = '''You are an intelligent AI agent named "Brain".
|
| 35 |
-
|
| 36 |
-
Your task is to convert a raw, unclear, or incomplete user topic
|
| 37 |
-
into a precise, research-ready prompt for a Deep Research AI agent
|
| 38 |
-
(using Tavily or similar web research tools).
|
| 39 |
-
|
| 40 |
-
The ultimate goal is to generate data suitable for a Bar Chart Race animation video.
|
| 41 |
-
|
| 42 |
-
---
|
| 43 |
-
|
| 44 |
-
### USER INPUT
|
| 45 |
-
- A short, vague, or poorly defined topic
|
| 46 |
-
- The user may not specify time range, metrics, or competitors
|
| 47 |
-
- The user may not understand data structure requirements
|
| 48 |
-
|
| 49 |
-
---
|
| 50 |
-
|
| 51 |
-
### CORE DATA REQUIREMENTS (MANDATORY)
|
| 52 |
-
|
| 53 |
-
The data MUST follow these rules:
|
| 54 |
-
|
| 55 |
-
1. Time Axis (X-axis):
|
| 56 |
-
- Must represent time progression
|
| 57 |
-
- Can be: year, date, month, or sequential order (e.g., Over 1, Over 2)
|
| 58 |
-
- Must be continuous and sortable
|
| 59 |
-
|
| 60 |
-
2. Categories / Competitors:
|
| 61 |
-
- Entities competing over time
|
| 62 |
-
- Examples: countries, companies, platforms, individuals, teams
|
| 63 |
-
|
| 64 |
-
3. Values:
|
| 65 |
-
- Numeric values that change over time
|
| 66 |
-
- Examples: GDP, population, revenue, users, runs, points, sales
|
| 67 |
-
|
| 68 |
-
4. Data Format:
|
| 69 |
-
- MUST be in WIDE FORMAT
|
| 70 |
-
- Time must be represented as rows (index)
|
| 71 |
-
- Each competitor/category must be a separate column
|
| 72 |
-
- Values must be numeric and suitable for animation
|
| 73 |
-
|
| 74 |
-
---
|
| 75 |
-
|
| 76 |
-
### YOUR RESPONSIBILITIES
|
| 77 |
-
|
| 78 |
-
1. Understand the user's intent and visualization goal
|
| 79 |
-
2. Enhance the topic into a clear analytical research objective
|
| 80 |
-
3. Infer missing details intelligently:
|
| 81 |
-
- Time range β assume long-term historical data
|
| 82 |
-
- Metric β use the most standard and authoritative metric
|
| 83 |
-
- Scope β assume global unless specified otherwise
|
| 84 |
-
4. Ensure the research target will produce:
|
| 85 |
-
- Time-series numerical data
|
| 86 |
-
- Comparable values across competitors
|
| 87 |
-
- Data convertible into Wide Format
|
| 88 |
-
|
| 89 |
-
---
|
| 90 |
-
|
| 91 |
-
### OUTPUT RULES (STRICT)
|
| 92 |
-
|
| 93 |
-
- Output ONLY ONE enhanced research prompt
|
| 94 |
-
- Write in clear, professional English
|
| 95 |
-
- The prompt must be directly usable by a Deep Research AI agent
|
| 96 |
-
- Do NOT include:
|
| 97 |
-
- Explanations
|
| 98 |
-
- Bullet points
|
| 99 |
-
- Multiple options
|
| 100 |
-
- CSV or table
|
| 101 |
-
- Commentary or assumptions list
|
| 102 |
-
|
| 103 |
-
---
|
| 104 |
-
|
| 105 |
-
### EXAMPLE
|
| 106 |
-
|
| 107 |
-
User Topic:
|
| 108 |
-
"Cricket players performance"
|
| 109 |
-
|
| 110 |
-
Enhanced Output:
|
| 111 |
-
"Research year-by-year total international runs scored by the top 10 cricket players globally from 2000 to the most recent year, ensuring consistent annual time-series data suitable for wide-format bar chart race visualization."
|
| 112 |
-
|
| 113 |
-
---
|
| 114 |
-
|
| 115 |
-
### FINAL RULE
|
| 116 |
-
Your output must guarantee that the resulting dataset
|
| 117 |
-
can be transformed into a Wide Format table
|
| 118 |
-
with Time as rows and Competitors as columns,
|
| 119 |
-
ready for Bar Chart Race animation.'''
|
| 120 |
-
|
| 121 |
-
def __init__(self, gemini_api_key: str = None):
|
| 122 |
-
self.gemini_api_key = gemini_api_key or os.getenv("GEMINI_API_KEY")
|
| 123 |
-
self.gemini_client = None
|
| 124 |
-
|
| 125 |
-
if self.gemini_api_key:
|
| 126 |
-
try:
|
| 127 |
-
from google import genai
|
| 128 |
-
self.gemini_client = genai.Client(api_key=self.gemini_api_key)
|
| 129 |
-
logger.info("Brain: Gemini client initialized")
|
| 130 |
-
except ImportError:
|
| 131 |
-
logger.warning("google-genai package not installed")
|
| 132 |
-
else:
|
| 133 |
-
logger.warning("Brain: No Gemini API key, will use basic enhancement")
|
| 134 |
-
|
| 135 |
-
def enhance_topic(self, raw_topic: str) -> str:
|
| 136 |
-
"""
|
| 137 |
-
Convert raw user topic into research-ready prompt.
|
| 138 |
-
|
| 139 |
-
Uses TOPIC_ENHANCER_PROMPT to transform vague topics into
|
| 140 |
-
precise, research-friendly prompts for Deep Researcher.
|
| 141 |
-
|
| 142 |
-
Args:
|
| 143 |
-
raw_topic: User's raw topic string (may be vague/incomplete)
|
| 144 |
-
|
| 145 |
-
Returns:
|
| 146 |
-
Enhanced research prompt ready for Deep Researcher
|
| 147 |
-
"""
|
| 148 |
-
logger.info(f"Brain: Enhancing topic: {raw_topic}")
|
| 149 |
-
|
| 150 |
-
if not self.gemini_client:
|
| 151 |
-
# Fallback: basic enhancement
|
| 152 |
-
return f"Research historical year-by-year data for {raw_topic} from 2000 to present, suitable for bar chart race visualization."
|
| 153 |
-
|
| 154 |
-
try:
|
| 155 |
-
get_rate_limiter().acquire() # Rate limit before API call
|
| 156 |
-
response = self.gemini_client.models.generate_content(
|
| 157 |
-
model=self.GEMINI_MODEL,
|
| 158 |
-
contents=[
|
| 159 |
-
{"role": "user", "parts": [{"text": self.TOPIC_ENHANCER_PROMPT}]},
|
| 160 |
-
{"role": "user", "parts": [{"text": f"User Topic:\n{raw_topic}"}]}
|
| 161 |
-
]
|
| 162 |
-
)
|
| 163 |
-
|
| 164 |
-
enhanced = response.text.strip()
|
| 165 |
-
|
| 166 |
-
# Clean up any markdown formatting
|
| 167 |
-
if enhanced.startswith('"') and enhanced.endswith('"'):
|
| 168 |
-
enhanced = enhanced[1:-1]
|
| 169 |
-
|
| 170 |
-
logger.info(f"Brain: Enhanced topic to: {enhanced[:100]}...")
|
| 171 |
-
return enhanced
|
| 172 |
-
|
| 173 |
-
except Exception as e:
|
| 174 |
-
logger.warning(f"Brain: Topic enhancement failed: {e}")
|
| 175 |
-
return f"Research historical year-by-year data for {raw_topic} from 2000 to present, suitable for bar chart race visualization."
|
| 176 |
-
|
| 177 |
-
def fill_data_gaps(self, df, gaps: list, topic: str = ""):
|
| 178 |
-
"""
|
| 179 |
-
Fill missing data points using Gemini's knowledge + Tavily fallback.
|
| 180 |
-
|
| 181 |
-
Strategy:
|
| 182 |
-
1. First ask Gemini for all gaps (batch)
|
| 183 |
-
2. For remaining unfilled gaps, use Tavily search
|
| 184 |
-
|
| 185 |
-
Args:
|
| 186 |
-
df: DataFrame with gaps
|
| 187 |
-
gaps: List of gap descriptions from Analyst
|
| 188 |
-
topic: The research topic for context
|
| 189 |
-
|
| 190 |
-
Returns:
|
| 191 |
-
DataFrame with gaps filled
|
| 192 |
-
"""
|
| 193 |
-
if not gaps:
|
| 194 |
-
return df
|
| 195 |
-
|
| 196 |
-
import pandas as pd
|
| 197 |
-
import os
|
| 198 |
-
|
| 199 |
-
filled_count = 0
|
| 200 |
-
unfilled_gaps = []
|
| 201 |
-
|
| 202 |
-
# ============ STEP 1: Try Gemini first (batch) ============
|
| 203 |
-
if self.gemini_client:
|
| 204 |
-
try:
|
| 205 |
-
# Format gaps for Gemini
|
| 206 |
-
gap_text = "\n".join([
|
| 207 |
-
f"- Year {g['year']}, Entity {g['entity']}"
|
| 208 |
-
for g in gaps[:20] # Limit to 20 gaps
|
| 209 |
-
])
|
| 210 |
-
|
| 211 |
-
prompt = f"""Act as a Data Expert with historical knowledge.
|
| 212 |
-
|
| 213 |
-
TOPIC: {topic}
|
| 214 |
-
|
| 215 |
-
MISSING DATA POINTS:
|
| 216 |
-
{gap_text}
|
| 217 |
-
|
| 218 |
-
For each missing data point above, provide the approximate value based on your knowledge.
|
| 219 |
-
Output format (one per line): year,entity,value
|
| 220 |
-
|
| 221 |
-
Rules:
|
| 222 |
-
1. Use realistic values based on trends
|
| 223 |
-
2. If completely unknown, write: year,entity,UNKNOWN
|
| 224 |
-
3. Output ONLY the data in CSV format (no headers, no explanations)
|
| 225 |
-
"""
|
| 226 |
-
|
| 227 |
-
get_rate_limiter().acquire() # Rate limit before API call
|
| 228 |
-
response = self.gemini_client.models.generate_content(
|
| 229 |
-
model=self.GEMINI_MODEL,
|
| 230 |
-
contents=prompt
|
| 231 |
-
)
|
| 232 |
-
|
| 233 |
-
fill_data = response.text.strip()
|
| 234 |
-
|
| 235 |
-
# Parse and apply fill values
|
| 236 |
-
for line in fill_data.split("\n"):
|
| 237 |
-
try:
|
| 238 |
-
parts = line.strip().split(",")
|
| 239 |
-
if len(parts) >= 3:
|
| 240 |
-
year = int(parts[0].strip())
|
| 241 |
-
entity = parts[1].strip()
|
| 242 |
-
value_str = parts[2].strip()
|
| 243 |
-
|
| 244 |
-
# Check if Gemini doesn't know
|
| 245 |
-
if value_str.upper() == "UNKNOWN" or not value_str:
|
| 246 |
-
unfilled_gaps.append({"year": year, "entity": entity})
|
| 247 |
-
continue
|
| 248 |
-
|
| 249 |
-
value = float(value_str.replace(",", ""))
|
| 250 |
-
|
| 251 |
-
# Find and fill in DataFrame
|
| 252 |
-
if entity in df.columns:
|
| 253 |
-
mask = df['year'] == year
|
| 254 |
-
if mask.any():
|
| 255 |
-
df.loc[mask, entity] = value
|
| 256 |
-
filled_count += 1
|
| 257 |
-
except:
|
| 258 |
-
pass
|
| 259 |
-
|
| 260 |
-
logger.info(f"Brain: Gemini filled {filled_count} gaps, {len(unfilled_gaps)} remaining")
|
| 261 |
-
|
| 262 |
-
except Exception as e:
|
| 263 |
-
logger.warning(f"Brain: Gemini gap filling failed: {e}")
|
| 264 |
-
unfilled_gaps = gaps[:20] # All gaps need Tavily
|
| 265 |
-
else:
|
| 266 |
-
unfilled_gaps = gaps[:20] # No Gemini, use Tavily for all
|
| 267 |
-
|
| 268 |
-
# ============ STEP 2: Tavily fallback for unfilled gaps ============
|
| 269 |
-
if unfilled_gaps:
|
| 270 |
-
tavily_filled = self._fill_gaps_with_tavily(df, unfilled_gaps, topic)
|
| 271 |
-
filled_count += tavily_filled
|
| 272 |
-
|
| 273 |
-
logger.info(f"Brain: Total filled {filled_count} out of {len(gaps)} gaps")
|
| 274 |
-
return df
|
| 275 |
-
|
| 276 |
-
def _fill_gaps_with_tavily(self, df, gaps: list, topic: str) -> int:
|
| 277 |
-
"""
|
| 278 |
-
Use Tavily API to search for specific missing data points.
|
| 279 |
-
|
| 280 |
-
Args:
|
| 281 |
-
df: DataFrame to fill
|
| 282 |
-
gaps: List of unfilled gaps
|
| 283 |
-
topic: Research topic for context
|
| 284 |
-
|
| 285 |
-
Returns:
|
| 286 |
-
Number of gaps filled
|
| 287 |
-
"""
|
| 288 |
-
import os
|
| 289 |
-
|
| 290 |
-
tavily_api_key = os.getenv("TAVILY_API_KEY")
|
| 291 |
-
if not tavily_api_key:
|
| 292 |
-
logger.warning("Brain: No Tavily API key for gap fallback")
|
| 293 |
-
return 0
|
| 294 |
-
|
| 295 |
-
try:
|
| 296 |
-
from tavily import TavilyClient
|
| 297 |
-
tavily = TavilyClient(api_key=tavily_api_key)
|
| 298 |
-
except ImportError:
|
| 299 |
-
logger.warning("Brain: Tavily client not available")
|
| 300 |
-
return 0
|
| 301 |
-
|
| 302 |
-
filled_count = 0
|
| 303 |
-
|
| 304 |
-
for gap in gaps[:5]: # Limit to 5 Tavily searches (to save API quota)
|
| 305 |
-
try:
|
| 306 |
-
year = gap['year']
|
| 307 |
-
entity = gap['entity']
|
| 308 |
-
|
| 309 |
-
# Specific search query for this data point
|
| 310 |
-
query = f"{topic} {entity} {year} data statistics value"
|
| 311 |
-
|
| 312 |
-
result = tavily.search(query=query, max_results=3)
|
| 313 |
-
|
| 314 |
-
if result and result.get('results'):
|
| 315 |
-
# Try to extract numeric value from search results
|
| 316 |
-
for r in result['results']:
|
| 317 |
-
content = r.get('content', '')
|
| 318 |
-
value = self._extract_numeric_from_text(content, entity, year)
|
| 319 |
-
|
| 320 |
-
if value is not None:
|
| 321 |
-
if entity in df.columns:
|
| 322 |
-
mask = df['year'] == year
|
| 323 |
-
if mask.any():
|
| 324 |
-
df.loc[mask, entity] = value
|
| 325 |
-
filled_count += 1
|
| 326 |
-
logger.info(f"Brain: Tavily filled {entity} {year} = {value}")
|
| 327 |
-
break
|
| 328 |
-
|
| 329 |
-
except Exception as e:
|
| 330 |
-
logger.warning(f"Brain: Tavily search failed for {gap}: {e}")
|
| 331 |
-
|
| 332 |
-
return filled_count
|
| 333 |
-
|
| 334 |
-
def _extract_numeric_from_text(self, text: str, entity: str, year: int) -> float:
|
| 335 |
-
"""Extract a numeric value from text that relates to entity and year."""
|
| 336 |
-
import re
|
| 337 |
-
|
| 338 |
-
# Look for patterns like "India 2015: 2.1 trillion" or "GDP was 2100 billion"
|
| 339 |
-
patterns = [
|
| 340 |
-
rf'{year}[:\s]+[\$β¬]?([\d,.]+)\s*(trillion|billion|million)?',
|
| 341 |
-
rf'([\d,.]+)\s*(trillion|billion|million)?\s*(?:in|for)?\s*{year}',
|
| 342 |
-
rf'{entity}[:\s]+([\d,.]+)',
|
| 343 |
-
]
|
| 344 |
-
|
| 345 |
-
for pattern in patterns:
|
| 346 |
-
match = re.search(pattern, text, re.IGNORECASE)
|
| 347 |
-
if match:
|
| 348 |
-
try:
|
| 349 |
-
value_str = match.group(1).replace(",", "")
|
| 350 |
-
value = float(value_str)
|
| 351 |
-
|
| 352 |
-
# Apply multiplier if present
|
| 353 |
-
if match.lastindex >= 2 and match.group(2):
|
| 354 |
-
unit = match.group(2).lower()
|
| 355 |
-
if unit == 'trillion':
|
| 356 |
-
value *= 1000 # Convert to billions for consistency
|
| 357 |
-
elif unit == 'million':
|
| 358 |
-
value /= 1000 # Convert to billions
|
| 359 |
-
|
| 360 |
-
return value
|
| 361 |
-
except:
|
| 362 |
-
pass
|
| 363 |
-
|
| 364 |
-
return None
|
| 365 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/bar_race/services/director.py
DELETED
|
@@ -1,438 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Director - Video Generator
|
| 3 |
-
Creates bar chart race animation and final video.
|
| 4 |
-
"""
|
| 5 |
-
import logging
|
| 6 |
-
import pandas as pd
|
| 7 |
-
import os
|
| 8 |
-
from typing import Dict, Any, Optional
|
| 9 |
-
import shutil
|
| 10 |
-
|
| 11 |
-
logger = logging.getLogger(__name__)
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
class Director:
|
| 15 |
-
"""
|
| 16 |
-
Video Generator for Bar Race.
|
| 17 |
-
|
| 18 |
-
Creates animated bar chart race video using:
|
| 19 |
-
- bar_chart_race library for animation
|
| 20 |
-
- Entity images overlay
|
| 21 |
-
- Background music
|
| 22 |
-
- 9:16 vertical format (1080x1920)
|
| 23 |
-
"""
|
| 24 |
-
|
| 25 |
-
# Video dimensions (9:16)
|
| 26 |
-
VIDEO_WIDTH = 1080
|
| 27 |
-
VIDEO_HEIGHT = 1920
|
| 28 |
-
FPS = 30
|
| 29 |
-
|
| 30 |
-
def __init__(self, temp_dir: str, output_dir: str = "videos/bar_race"):
|
| 31 |
-
self.temp_dir = temp_dir
|
| 32 |
-
self.output_dir = output_dir
|
| 33 |
-
os.makedirs(output_dir, exist_ok=True)
|
| 34 |
-
|
| 35 |
-
def generate_video(
|
| 36 |
-
self,
|
| 37 |
-
df: pd.DataFrame,
|
| 38 |
-
plan: Dict[str, Any],
|
| 39 |
-
image_paths: Dict[str, str],
|
| 40 |
-
duration_seconds: int = 60,
|
| 41 |
-
job_id: str = ""
|
| 42 |
-
) -> Optional[str]:
|
| 43 |
-
"""
|
| 44 |
-
Generate bar chart race video.
|
| 45 |
-
|
| 46 |
-
Args:
|
| 47 |
-
df: Cleaned data with columns: name, year, value
|
| 48 |
-
plan: Brain's plan with video_meta
|
| 49 |
-
image_paths: Dict mapping entity name to image path
|
| 50 |
-
duration_seconds: Video duration
|
| 51 |
-
job_id: Job ID for output filename
|
| 52 |
-
|
| 53 |
-
Returns:
|
| 54 |
-
Path to generated video, or None if failed
|
| 55 |
-
"""
|
| 56 |
-
logger.info(f"Director: Starting video generation for {duration_seconds}s video")
|
| 57 |
-
|
| 58 |
-
try:
|
| 59 |
-
# Prepare data for bar_chart_race
|
| 60 |
-
df_pivot = self._prepare_data(df)
|
| 61 |
-
|
| 62 |
-
if df_pivot is None or df_pivot.empty:
|
| 63 |
-
logger.error("Director: Failed to prepare data")
|
| 64 |
-
return None
|
| 65 |
-
|
| 66 |
-
# Generate animation
|
| 67 |
-
video_path = self._generate_bar_race(
|
| 68 |
-
df_pivot=df_pivot,
|
| 69 |
-
plan=plan,
|
| 70 |
-
duration_seconds=duration_seconds,
|
| 71 |
-
job_id=job_id
|
| 72 |
-
)
|
| 73 |
-
|
| 74 |
-
if video_path and os.path.exists(video_path):
|
| 75 |
-
# Try to add background music (optional)
|
| 76 |
-
video_with_music = self._add_background_music(video_path, duration_seconds)
|
| 77 |
-
if video_with_music:
|
| 78 |
-
return video_with_music
|
| 79 |
-
|
| 80 |
-
return video_path
|
| 81 |
-
|
| 82 |
-
except Exception as e:
|
| 83 |
-
logger.error(f"Director: Video generation failed: {e}")
|
| 84 |
-
import traceback
|
| 85 |
-
logger.error(traceback.format_exc())
|
| 86 |
-
return None
|
| 87 |
-
|
| 88 |
-
def _add_background_music(self, video_path: str, duration_seconds: int) -> Optional[str]:
|
| 89 |
-
"""Add background music if available in assets/music folder"""
|
| 90 |
-
music_dir = "modules/bar_race/assets/music"
|
| 91 |
-
|
| 92 |
-
# Check if music directory exists
|
| 93 |
-
if not os.path.exists(music_dir):
|
| 94 |
-
logger.info("Director: No music folder found, skipping background music")
|
| 95 |
-
return None
|
| 96 |
-
|
| 97 |
-
# Find music files
|
| 98 |
-
music_files = []
|
| 99 |
-
for ext in [".mp3", ".wav", ".m4a", ".ogg"]:
|
| 100 |
-
for f in os.listdir(music_dir):
|
| 101 |
-
if f.lower().endswith(ext):
|
| 102 |
-
music_files.append(os.path.join(music_dir, f))
|
| 103 |
-
|
| 104 |
-
if not music_files:
|
| 105 |
-
logger.info("Director: No music files found, skipping background music")
|
| 106 |
-
return None
|
| 107 |
-
|
| 108 |
-
try:
|
| 109 |
-
from moviepy.editor import VideoFileClip, AudioFileClip
|
| 110 |
-
import random
|
| 111 |
-
|
| 112 |
-
# Pick random music file
|
| 113 |
-
music_path = random.choice(music_files)
|
| 114 |
-
logger.info(f"Director: Adding background music: {music_path}")
|
| 115 |
-
|
| 116 |
-
# Load video and audio
|
| 117 |
-
video = VideoFileClip(video_path)
|
| 118 |
-
audio = AudioFileClip(music_path)
|
| 119 |
-
|
| 120 |
-
# Loop audio if shorter than video
|
| 121 |
-
if audio.duration < video.duration:
|
| 122 |
-
from moviepy.editor import concatenate_audioclips
|
| 123 |
-
loops_needed = int(video.duration / audio.duration) + 1
|
| 124 |
-
audio = concatenate_audioclips([audio] * loops_needed)
|
| 125 |
-
|
| 126 |
-
# Trim audio to video length and lower volume
|
| 127 |
-
audio = audio.subclip(0, video.duration).volumex(0.3)
|
| 128 |
-
|
| 129 |
-
# Add audio to video
|
| 130 |
-
video_with_audio = video.set_audio(audio)
|
| 131 |
-
|
| 132 |
-
# Save with music
|
| 133 |
-
output_path = video_path.replace(".mp4", "_music.mp4")
|
| 134 |
-
video_with_audio.write_videofile(
|
| 135 |
-
output_path,
|
| 136 |
-
codec="libx264",
|
| 137 |
-
audio_codec="aac",
|
| 138 |
-
fps=self.FPS,
|
| 139 |
-
logger=None
|
| 140 |
-
)
|
| 141 |
-
|
| 142 |
-
# Cleanup
|
| 143 |
-
video.close()
|
| 144 |
-
audio.close()
|
| 145 |
-
|
| 146 |
-
# Replace original with music version
|
| 147 |
-
os.remove(video_path)
|
| 148 |
-
os.rename(output_path, video_path)
|
| 149 |
-
|
| 150 |
-
logger.info(f"Director: Added background music to video")
|
| 151 |
-
return video_path
|
| 152 |
-
|
| 153 |
-
except Exception as e:
|
| 154 |
-
logger.warning(f"Director: Failed to add music: {e}")
|
| 155 |
-
return None
|
| 156 |
-
|
| 157 |
-
def _prepare_data(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
|
| 158 |
-
"""Prepare data for bar_chart_race (pivoted format)"""
|
| 159 |
-
try:
|
| 160 |
-
# Pivot: rows=year, columns=entity, values=value
|
| 161 |
-
df_pivot = df.pivot(index="year", columns="name", values="value")
|
| 162 |
-
|
| 163 |
-
# Sort by year
|
| 164 |
-
df_pivot = df_pivot.sort_index()
|
| 165 |
-
|
| 166 |
-
# Fill NaN with 0
|
| 167 |
-
df_pivot = df_pivot.fillna(0)
|
| 168 |
-
|
| 169 |
-
logger.info(f"Director: Prepared pivot table with shape {df_pivot.shape}")
|
| 170 |
-
return df_pivot
|
| 171 |
-
|
| 172 |
-
except Exception as e:
|
| 173 |
-
logger.error(f"Director: Data preparation failed: {e}")
|
| 174 |
-
return None
|
| 175 |
-
|
| 176 |
-
def _generate_bar_race(
|
| 177 |
-
self,
|
| 178 |
-
df_pivot: pd.DataFrame,
|
| 179 |
-
plan: Dict[str, Any],
|
| 180 |
-
duration_seconds: int,
|
| 181 |
-
job_id: str
|
| 182 |
-
) -> Optional[str]:
|
| 183 |
-
"""Generate bar chart race animation"""
|
| 184 |
-
|
| 185 |
-
# Get video metadata (handle both dict and string formats)
|
| 186 |
-
video_meta = plan.get("video_meta", {})
|
| 187 |
-
if isinstance(video_meta, str):
|
| 188 |
-
title = video_meta
|
| 189 |
-
else:
|
| 190 |
-
title = video_meta.get("title", "Bar Chart Race")
|
| 191 |
-
|
| 192 |
-
value_intent = plan.get("value_intent", {})
|
| 193 |
-
if isinstance(value_intent, str):
|
| 194 |
-
value_unit = value_intent
|
| 195 |
-
else:
|
| 196 |
-
value_unit = value_intent.get("unit", "")
|
| 197 |
-
|
| 198 |
-
visualization = plan.get("visualization", {})
|
| 199 |
-
if isinstance(visualization, str):
|
| 200 |
-
top_n = 10
|
| 201 |
-
else:
|
| 202 |
-
top_n = visualization.get("top_n", 10)
|
| 203 |
-
|
| 204 |
-
output_path = os.path.join(self.output_dir, f"bar_race_{job_id}.mp4")
|
| 205 |
-
|
| 206 |
-
try:
|
| 207 |
-
import bar_chart_race as bcr
|
| 208 |
-
|
| 209 |
-
# Calculate timing for exact user-requested duration
|
| 210 |
-
num_years = len(df_pivot)
|
| 211 |
-
|
| 212 |
-
# Total frames = duration * FPS (e.g., 30s * 30fps = 900 frames)
|
| 213 |
-
total_frames = duration_seconds * self.FPS
|
| 214 |
-
|
| 215 |
-
# Frames per period (year) = total_frames / num_years
|
| 216 |
-
# steps_per_period controls animation smoothness within each year
|
| 217 |
-
steps_per_period = max(10, total_frames // num_years)
|
| 218 |
-
|
| 219 |
-
# period_length (ms) = how long each year takes on screen
|
| 220 |
-
# To get exact duration: period_length = (duration_seconds * 1000) / num_years
|
| 221 |
-
period_length = int((duration_seconds * 1000) / num_years)
|
| 222 |
-
|
| 223 |
-
logger.info(f"Director: Duration={duration_seconds}s, Years={num_years}, "
|
| 224 |
-
f"period_length={period_length}ms, steps_per_period={steps_per_period}")
|
| 225 |
-
|
| 226 |
-
# Calculate figsize for 9:16 vertical format (1080x1920 at dpi=144)
|
| 227 |
-
# Resolution = figsize * dpi
|
| 228 |
-
# For 1080x1920: figsize = (1080/144, 1920/144) = (7.5, 13.33)
|
| 229 |
-
figsize_9x16 = (self.VIDEO_WIDTH / 144, self.VIDEO_HEIGHT / 144) # (7.5, 13.33)
|
| 230 |
-
|
| 231 |
-
logger.info(f"Director: Video resolution {self.VIDEO_WIDTH}x{self.VIDEO_HEIGHT} (9:16 YouTube Shorts)")
|
| 232 |
-
|
| 233 |
-
# Generate bar chart race with 9:16 vertical format
|
| 234 |
-
bcr.bar_chart_race(
|
| 235 |
-
df=df_pivot,
|
| 236 |
-
filename=output_path,
|
| 237 |
-
orientation='h', # Horizontal bars look better in vertical video
|
| 238 |
-
sort='desc',
|
| 239 |
-
n_bars=top_n,
|
| 240 |
-
fixed_order=False,
|
| 241 |
-
fixed_max=True,
|
| 242 |
-
steps_per_period=steps_per_period,
|
| 243 |
-
period_length=period_length,
|
| 244 |
-
interpolate_period=True,
|
| 245 |
-
title=title,
|
| 246 |
-
figsize=figsize_9x16, # 9:16 vertical format
|
| 247 |
-
cmap='dark24',
|
| 248 |
-
dpi=144 # Combined with figsize gives 1080x1920
|
| 249 |
-
)
|
| 250 |
-
|
| 251 |
-
logger.info(f"Director: Generated video at {output_path}")
|
| 252 |
-
return output_path
|
| 253 |
-
|
| 254 |
-
except ImportError:
|
| 255 |
-
logger.warning("Director: bar_chart_race not available, using fallback")
|
| 256 |
-
return self._generate_fallback_video(df_pivot, plan, duration_seconds, job_id)
|
| 257 |
-
except Exception as e:
|
| 258 |
-
logger.error(f"Director: bar_chart_race failed: {e}")
|
| 259 |
-
return self._generate_fallback_video(df_pivot, plan, duration_seconds, job_id)
|
| 260 |
-
|
| 261 |
-
def _generate_fallback_video(
|
| 262 |
-
self,
|
| 263 |
-
df_pivot: pd.DataFrame,
|
| 264 |
-
plan: Dict[str, Any],
|
| 265 |
-
duration_seconds: int,
|
| 266 |
-
job_id: str
|
| 267 |
-
) -> Optional[str]:
|
| 268 |
-
"""Fallback: Generate smooth bar race animation with racing positions"""
|
| 269 |
-
logger.info(f"Director: Using fallback matplotlib animation for {duration_seconds}s")
|
| 270 |
-
|
| 271 |
-
try:
|
| 272 |
-
import matplotlib
|
| 273 |
-
matplotlib.use('Agg')
|
| 274 |
-
import matplotlib.pyplot as plt
|
| 275 |
-
from matplotlib.animation import FuncAnimation
|
| 276 |
-
import numpy as np
|
| 277 |
-
|
| 278 |
-
video_meta = plan.get("video_meta", {})
|
| 279 |
-
if isinstance(video_meta, str):
|
| 280 |
-
title = video_meta
|
| 281 |
-
else:
|
| 282 |
-
title = video_meta.get("title", "Bar Chart Race")
|
| 283 |
-
|
| 284 |
-
visualization = plan.get("visualization", {})
|
| 285 |
-
if isinstance(visualization, str):
|
| 286 |
-
top_n = 10
|
| 287 |
-
else:
|
| 288 |
-
top_n = visualization.get("top_n", 10)
|
| 289 |
-
|
| 290 |
-
# 9:16 aspect ratio (portrait mode for TikTok/Reels)
|
| 291 |
-
fig, ax = plt.subplots(figsize=(5.625, 10), facecolor='#0f0f1a', dpi=144)
|
| 292 |
-
ax.set_facecolor('#0f0f1a')
|
| 293 |
-
|
| 294 |
-
years = df_pivot.index.tolist()
|
| 295 |
-
columns = df_pivot.columns.tolist()
|
| 296 |
-
|
| 297 |
-
# Create interpolated data for smooth animation
|
| 298 |
-
# Each year will have multiple intermediate frames
|
| 299 |
-
frames_per_year = max(20, (duration_seconds * self.FPS) // len(years))
|
| 300 |
-
total_frames = len(years) * frames_per_year
|
| 301 |
-
|
| 302 |
-
logger.info(f"Director: {len(years)} years, {frames_per_year} frames/year, {total_frames} total frames")
|
| 303 |
-
|
| 304 |
-
# Interpolate data between years
|
| 305 |
-
interpolated_data = []
|
| 306 |
-
for i in range(len(years) - 1):
|
| 307 |
-
year1, year2 = years[i], years[i + 1]
|
| 308 |
-
for j in range(frames_per_year):
|
| 309 |
-
t = j / frames_per_year # Interpolation factor 0->1
|
| 310 |
-
# Smooth easing
|
| 311 |
-
t = t * t * (3 - 2 * t) # Smoothstep
|
| 312 |
-
|
| 313 |
-
values = {}
|
| 314 |
-
for col in columns:
|
| 315 |
-
v1 = df_pivot.loc[year1, col] if not pd.isna(df_pivot.loc[year1, col]) else 0
|
| 316 |
-
v2 = df_pivot.loc[year2, col] if not pd.isna(df_pivot.loc[year2, col]) else 0
|
| 317 |
-
values[col] = v1 + (v2 - v1) * t
|
| 318 |
-
|
| 319 |
-
interpolated_data.append({
|
| 320 |
-
'year': year1 + (year2 - year1) * t,
|
| 321 |
-
'values': values
|
| 322 |
-
})
|
| 323 |
-
|
| 324 |
-
# Add last year
|
| 325 |
-
for j in range(frames_per_year):
|
| 326 |
-
interpolated_data.append({
|
| 327 |
-
'year': years[-1],
|
| 328 |
-
'values': {col: df_pivot.loc[years[-1], col] for col in columns}
|
| 329 |
-
})
|
| 330 |
-
|
| 331 |
-
total_frames = len(interpolated_data)
|
| 332 |
-
logger.info(f"Director: Created {total_frames} interpolated frames")
|
| 333 |
-
|
| 334 |
-
# Color palette - distinct colors for each entity
|
| 335 |
-
np.random.seed(42)
|
| 336 |
-
color_map = {}
|
| 337 |
-
vibrant_colors = plt.cm.tab20.colors + plt.cm.Set3.colors
|
| 338 |
-
for i, col in enumerate(columns):
|
| 339 |
-
color_map[col] = vibrant_colors[i % len(vibrant_colors)]
|
| 340 |
-
|
| 341 |
-
# Max value for consistent x-axis
|
| 342 |
-
max_val = df_pivot.max().max() * 1.15
|
| 343 |
-
|
| 344 |
-
def update(frame):
|
| 345 |
-
ax.clear()
|
| 346 |
-
ax.set_facecolor('#0f0f1a')
|
| 347 |
-
|
| 348 |
-
if frame >= len(interpolated_data):
|
| 349 |
-
frame = len(interpolated_data) - 1
|
| 350 |
-
|
| 351 |
-
data = interpolated_data[frame]
|
| 352 |
-
year = data['year']
|
| 353 |
-
values = data['values']
|
| 354 |
-
|
| 355 |
-
# Sort by value (descending) and take top N
|
| 356 |
-
sorted_items = sorted(values.items(), key=lambda x: x[1] if x[1] else 0, reverse=True)[:top_n]
|
| 357 |
-
|
| 358 |
-
# Reverse for bottom-to-top (biggest on top)
|
| 359 |
-
sorted_items = sorted_items[::-1]
|
| 360 |
-
|
| 361 |
-
names = [item[0] for item in sorted_items]
|
| 362 |
-
vals = [item[1] if item[1] else 0 for item in sorted_items]
|
| 363 |
-
colors = [color_map[name] for name in names]
|
| 364 |
-
|
| 365 |
-
# Draw horizontal bars with smooth animation
|
| 366 |
-
y_positions = np.arange(len(names))
|
| 367 |
-
bars = ax.barh(y_positions, vals, color=colors, height=0.75, alpha=0.9)
|
| 368 |
-
|
| 369 |
-
# Add value labels and entity names
|
| 370 |
-
for i, (name, val) in enumerate(zip(names, vals)):
|
| 371 |
-
# Value on bar end
|
| 372 |
-
if val > 0:
|
| 373 |
-
label = f'{val/1e12:.2f}T' if val >= 1e12 else f'{val/1e9:.1f}B'
|
| 374 |
-
ax.text(val + max_val * 0.01, i, label, va='center', ha='left',
|
| 375 |
-
fontsize=10, color='white', fontweight='bold')
|
| 376 |
-
|
| 377 |
-
# Entity name inside bar
|
| 378 |
-
display_name = name[:15] if len(name) > 15 else name
|
| 379 |
-
ax.text(max_val * 0.01, i, display_name, va='center', ha='left',
|
| 380 |
-
fontsize=11, color='white', fontweight='bold')
|
| 381 |
-
|
| 382 |
-
# Title
|
| 383 |
-
ax.set_title(title, fontsize=16, color='white', pad=20, fontweight='bold',
|
| 384 |
-
fontfamily='sans-serif')
|
| 385 |
-
|
| 386 |
-
# Large year display
|
| 387 |
-
ax.text(0.95, 0.12, f'{int(year)}', transform=ax.transAxes, fontsize=72,
|
| 388 |
-
ha='right', va='top', color='white', alpha=0.7, fontweight='bold',
|
| 389 |
-
fontfamily='sans-serif')
|
| 390 |
-
|
| 391 |
-
# Styling
|
| 392 |
-
ax.set_yticks([])
|
| 393 |
-
ax.set_xlim(0, max_val)
|
| 394 |
-
ax.spines['top'].set_visible(False)
|
| 395 |
-
ax.spines['right'].set_visible(False)
|
| 396 |
-
ax.spines['left'].set_visible(False)
|
| 397 |
-
ax.spines['bottom'].set_color('#333')
|
| 398 |
-
ax.tick_params(colors='#666', labelsize=9)
|
| 399 |
-
|
| 400 |
-
# X-axis formatting
|
| 401 |
-
ax.xaxis.set_major_formatter(plt.FuncFormatter(
|
| 402 |
-
lambda x, p: f'{x/1e12:.1f}T' if x >= 1e12 else f'{x/1e9:.0f}B'
|
| 403 |
-
))
|
| 404 |
-
|
| 405 |
-
plt.tight_layout()
|
| 406 |
-
|
| 407 |
-
# Create animation
|
| 408 |
-
anim = FuncAnimation(fig, update, frames=total_frames,
|
| 409 |
-
interval=1000/self.FPS, repeat=False)
|
| 410 |
-
|
| 411 |
-
# Save to temp file
|
| 412 |
-
temp_path = os.path.join(self.temp_dir, f"temp_animation_{job_id}.mp4")
|
| 413 |
-
logger.info(f"Director: Saving smooth racing animation with {total_frames} frames")
|
| 414 |
-
anim.save(temp_path, writer='ffmpeg', fps=self.FPS, dpi=144,
|
| 415 |
-
savefig_kwargs={'facecolor': '#0f0f1a'})
|
| 416 |
-
plt.close(fig)
|
| 417 |
-
|
| 418 |
-
# Move to output
|
| 419 |
-
output_path = os.path.join(self.output_dir, f"bar_race_{job_id}.mp4")
|
| 420 |
-
shutil.move(temp_path, output_path)
|
| 421 |
-
|
| 422 |
-
logger.info(f"Director: Generated smooth racing video at {output_path}")
|
| 423 |
-
return output_path
|
| 424 |
-
|
| 425 |
-
except Exception as e:
|
| 426 |
-
logger.error(f"Director: Fallback video generation failed: {e}")
|
| 427 |
-
import traceback
|
| 428 |
-
logger.error(traceback.format_exc())
|
| 429 |
-
return None
|
| 430 |
-
|
| 431 |
-
def cleanup(self):
|
| 432 |
-
"""Clean up temporary files"""
|
| 433 |
-
try:
|
| 434 |
-
if os.path.exists(self.temp_dir):
|
| 435 |
-
shutil.rmtree(self.temp_dir)
|
| 436 |
-
logger.info(f"Director: Cleaned up temp directory: {self.temp_dir}")
|
| 437 |
-
except Exception as e:
|
| 438 |
-
logger.warning(f"Director: Cleanup failed: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/text_story/__init__.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Text Story Module for NCAkit
|
| 3 |
+
Generates fake iMessage-style text conversation videos.
|
| 4 |
+
"""
|
| 5 |
+
from fastapi import FastAPI
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
# Module Metadata
|
| 9 |
+
MODULE_NAME = "text_story"
|
| 10 |
+
MODULE_PREFIX = "/api/text-story"
|
| 11 |
+
MODULE_DESCRIPTION = "Generate fake iMessage-style text story videos with TTS"
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def register(app: FastAPI, config):
|
| 17 |
+
"""
|
| 18 |
+
Register the text story module with FastAPI.
|
| 19 |
+
Initializes services and adds routes.
|
| 20 |
+
"""
|
| 21 |
+
from .router import router
|
| 22 |
+
|
| 23 |
+
logger.info("Registering text_story module...")
|
| 24 |
+
|
| 25 |
+
# Validate TTS config
|
| 26 |
+
if not config.hf_tts:
|
| 27 |
+
logger.warning("HF_TTS not configured! TTS generation will fail.")
|
| 28 |
+
|
| 29 |
+
# Create required directories
|
| 30 |
+
import os
|
| 31 |
+
os.makedirs("videos/text_story", exist_ok=True)
|
| 32 |
+
|
| 33 |
+
# Create gameplay backgrounds folder in persistent storage
|
| 34 |
+
if os.path.exists("/data"):
|
| 35 |
+
os.makedirs("/data/gameplay_backgrounds", exist_ok=True)
|
| 36 |
+
logger.info("Created /data/gameplay_backgrounds folder for gameplay videos")
|
| 37 |
+
|
| 38 |
+
# Check gameplay backgrounds
|
| 39 |
+
backgrounds_paths = [
|
| 40 |
+
"/data/gameplay_backgrounds",
|
| 41 |
+
"assets/gameplay_backgrounds"
|
| 42 |
+
]
|
| 43 |
+
|
| 44 |
+
bg_found = False
|
| 45 |
+
for path in backgrounds_paths:
|
| 46 |
+
if os.path.exists(path):
|
| 47 |
+
count = len([f for f in os.listdir(path) if f.endswith('.mp4')])
|
| 48 |
+
if count > 0:
|
| 49 |
+
logger.info(f"Found {count} gameplay backgrounds in {path}")
|
| 50 |
+
bg_found = True
|
| 51 |
+
break
|
| 52 |
+
|
| 53 |
+
if not bg_found:
|
| 54 |
+
logger.warning("No gameplay backgrounds found! Will use solid color background.")
|
| 55 |
+
logger.info("Add .mp4 files to /data/gameplay_backgrounds for video backgrounds")
|
| 56 |
+
|
| 57 |
+
# Register router
|
| 58 |
+
app.include_router(router)
|
| 59 |
+
|
| 60 |
+
logger.info("text_story module registered successfully")
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# Export router for direct import
|
| 64 |
+
from .router import router
|
| 65 |
+
|
| 66 |
+
__all__ = ["router", "register", "MODULE_NAME", "MODULE_PREFIX", "MODULE_DESCRIPTION"]
|
modules/text_story/router.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Text Story Router - FastAPI endpoints for fake iMessage chat video generation.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import uuid
|
| 7 |
+
import logging
|
| 8 |
+
import asyncio
|
| 9 |
+
from typing import Dict, Any
|
| 10 |
+
from fastapi import APIRouter, HTTPException, BackgroundTasks
|
| 11 |
+
from fastapi.responses import FileResponse
|
| 12 |
+
|
| 13 |
+
from .schemas import TextStoryRequest, TextStoryResponse, TextStoryStatus
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
router = APIRouter(prefix="/api/text-story", tags=["Text Story"])
|
| 18 |
+
|
| 19 |
+
# Job storage (in-memory for now)
|
| 20 |
+
jobs: Dict[str, Dict[str, Any]] = {}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def update_job(job_id: str, status: str, progress: int, step: str = None,
|
| 24 |
+
video_url: str = None, error: str = None):
|
| 25 |
+
"""Update job status."""
|
| 26 |
+
if job_id in jobs:
|
| 27 |
+
jobs[job_id].update({
|
| 28 |
+
"status": status,
|
| 29 |
+
"progress": progress,
|
| 30 |
+
"current_step": step,
|
| 31 |
+
"video_url": video_url,
|
| 32 |
+
"error": error
|
| 33 |
+
})
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
async def generate_text_story_video(job_id: str, request: TextStoryRequest):
|
| 37 |
+
"""
|
| 38 |
+
Main video generation pipeline.
|
| 39 |
+
|
| 40 |
+
Pipeline:
|
| 41 |
+
1. Setup temp directory
|
| 42 |
+
2. Generate TTS for each message
|
| 43 |
+
3. Create chat UI frames
|
| 44 |
+
4. Load gameplay background
|
| 45 |
+
5. Compose final video
|
| 46 |
+
"""
|
| 47 |
+
try:
|
| 48 |
+
import tempfile
|
| 49 |
+
import shutil
|
| 50 |
+
|
| 51 |
+
temp_dir = tempfile.mkdtemp(prefix="text_story_")
|
| 52 |
+
logger.info(f"TextStory: Starting job {job_id}")
|
| 53 |
+
|
| 54 |
+
# ============ STEP 1: TTS Generation ============
|
| 55 |
+
update_job(job_id, "processing", 10, "Generating voices...")
|
| 56 |
+
|
| 57 |
+
from .services.tts_handler import TTSHandler
|
| 58 |
+
tts = TTSHandler()
|
| 59 |
+
|
| 60 |
+
audio_files = []
|
| 61 |
+
for i, msg in enumerate(request.messages):
|
| 62 |
+
voice = request.voice_a if msg.sender == "A" else request.voice_b
|
| 63 |
+
audio_path = os.path.join(temp_dir, f"msg_{i:03d}.wav")
|
| 64 |
+
|
| 65 |
+
duration = await tts.generate_tts(msg.text, voice, audio_path)
|
| 66 |
+
audio_files.append({
|
| 67 |
+
"path": audio_path,
|
| 68 |
+
"duration": duration,
|
| 69 |
+
"sender": msg.sender,
|
| 70 |
+
"text": msg.text
|
| 71 |
+
})
|
| 72 |
+
|
| 73 |
+
progress = 10 + int((i + 1) / len(request.messages) * 30)
|
| 74 |
+
update_job(job_id, "processing", progress, f"Voice {i+1}/{len(request.messages)}")
|
| 75 |
+
|
| 76 |
+
# ============ STEP 2: Create Chat Frames ============
|
| 77 |
+
update_job(job_id, "processing", 45, "Rendering chat UI...")
|
| 78 |
+
|
| 79 |
+
from .services.renderer import ChatRenderer
|
| 80 |
+
renderer = ChatRenderer(
|
| 81 |
+
person_a_name=request.person_a_name,
|
| 82 |
+
person_b_name=request.person_b_name,
|
| 83 |
+
person_b_avatar=request.person_b_avatar
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# ============ STEP 3: Load Background ============
|
| 87 |
+
update_job(job_id, "processing", 55, "Loading background...")
|
| 88 |
+
|
| 89 |
+
from .services.background import BackgroundHandler
|
| 90 |
+
bg_handler = BackgroundHandler()
|
| 91 |
+
|
| 92 |
+
# ============ STEP 4: Compose Video ============
|
| 93 |
+
update_job(job_id, "processing", 65, "Composing video...")
|
| 94 |
+
|
| 95 |
+
from .services.video_composer import VideoComposer
|
| 96 |
+
composer = VideoComposer(
|
| 97 |
+
renderer=renderer,
|
| 98 |
+
bg_handler=bg_handler,
|
| 99 |
+
tts_handler=tts
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
output_path = await composer.compose(
|
| 103 |
+
messages=audio_files,
|
| 104 |
+
ending_text=request.ending_text,
|
| 105 |
+
output_dir=temp_dir
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# ============ STEP 5: Save Final Video ============
|
| 109 |
+
update_job(job_id, "processing", 90, "Saving video...")
|
| 110 |
+
|
| 111 |
+
# Create output directory
|
| 112 |
+
videos_dir = os.path.join("videos", "text_story")
|
| 113 |
+
os.makedirs(videos_dir, exist_ok=True)
|
| 114 |
+
|
| 115 |
+
final_path = os.path.join(videos_dir, f"text_story_{job_id}.mp4")
|
| 116 |
+
shutil.copy2(output_path, final_path)
|
| 117 |
+
|
| 118 |
+
# Also save to persistent storage if available
|
| 119 |
+
persistent_dir = "/data/videos/text_story"
|
| 120 |
+
if os.path.exists("/data"):
|
| 121 |
+
os.makedirs(persistent_dir, exist_ok=True)
|
| 122 |
+
shutil.copy2(output_path, os.path.join(persistent_dir, f"text_story_{job_id}.mp4"))
|
| 123 |
+
|
| 124 |
+
# Cleanup temp
|
| 125 |
+
shutil.rmtree(temp_dir, ignore_errors=True)
|
| 126 |
+
|
| 127 |
+
video_url = f"/api/text-story/{job_id}/video"
|
| 128 |
+
update_job(job_id, "ready", 100, "Complete!", video_url=video_url)
|
| 129 |
+
logger.info(f"TextStory: Job {job_id} completed successfully")
|
| 130 |
+
|
| 131 |
+
except Exception as e:
|
| 132 |
+
logger.error(f"TextStory: Job {job_id} failed - {e}")
|
| 133 |
+
update_job(job_id, "failed", 0, error=str(e))
|
| 134 |
+
|
| 135 |
+
# Cleanup on failure
|
| 136 |
+
if 'temp_dir' in locals():
|
| 137 |
+
shutil.rmtree(temp_dir, ignore_errors=True)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
@router.post("/generate", response_model=TextStoryResponse)
|
| 141 |
+
async def generate_text_story(
|
| 142 |
+
request: TextStoryRequest,
|
| 143 |
+
background_tasks: BackgroundTasks
|
| 144 |
+
):
|
| 145 |
+
"""
|
| 146 |
+
Start text story video generation.
|
| 147 |
+
|
| 148 |
+
Returns job_id for status polling.
|
| 149 |
+
"""
|
| 150 |
+
job_id = uuid.uuid4().hex[:12]
|
| 151 |
+
|
| 152 |
+
# Initialize job
|
| 153 |
+
jobs[job_id] = {
|
| 154 |
+
"status": "processing",
|
| 155 |
+
"progress": 0,
|
| 156 |
+
"current_step": "Starting...",
|
| 157 |
+
"video_url": None,
|
| 158 |
+
"error": None,
|
| 159 |
+
"request": request.model_dump()
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
# Start background generation
|
| 163 |
+
background_tasks.add_task(generate_text_story_video, job_id, request)
|
| 164 |
+
|
| 165 |
+
logger.info(f"TextStory: Started job {job_id} with {len(request.messages)} messages")
|
| 166 |
+
|
| 167 |
+
return TextStoryResponse(
|
| 168 |
+
job_id=job_id,
|
| 169 |
+
status="processing",
|
| 170 |
+
message=f"Started generating text story with {len(request.messages)} messages"
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
@router.get("/{job_id}/status", response_model=TextStoryStatus)
|
| 175 |
+
async def get_text_story_status(job_id: str):
|
| 176 |
+
"""Get status of a text story generation job."""
|
| 177 |
+
if job_id not in jobs:
|
| 178 |
+
raise HTTPException(status_code=404, detail="Job not found")
|
| 179 |
+
|
| 180 |
+
job = jobs[job_id]
|
| 181 |
+
return TextStoryStatus(
|
| 182 |
+
job_id=job_id,
|
| 183 |
+
status=job["status"],
|
| 184 |
+
progress=job["progress"],
|
| 185 |
+
current_step=job.get("current_step"),
|
| 186 |
+
video_url=job.get("video_url"),
|
| 187 |
+
error=job.get("error")
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
@router.get("/{job_id}/video")
|
| 192 |
+
async def download_text_story_video(job_id: str):
|
| 193 |
+
"""Download the generated text story video."""
|
| 194 |
+
if job_id not in jobs:
|
| 195 |
+
raise HTTPException(status_code=404, detail="Job not found")
|
| 196 |
+
|
| 197 |
+
job = jobs[job_id]
|
| 198 |
+
if job["status"] != "ready":
|
| 199 |
+
raise HTTPException(status_code=400, detail="Video not ready yet")
|
| 200 |
+
|
| 201 |
+
# Check persistent storage first
|
| 202 |
+
persistent_path = f"/data/videos/text_story/text_story_{job_id}.mp4"
|
| 203 |
+
local_path = f"videos/text_story/text_story_{job_id}.mp4"
|
| 204 |
+
|
| 205 |
+
if os.path.exists(persistent_path):
|
| 206 |
+
video_path = persistent_path
|
| 207 |
+
elif os.path.exists(local_path):
|
| 208 |
+
video_path = local_path
|
| 209 |
+
else:
|
| 210 |
+
raise HTTPException(status_code=404, detail="Video file not found")
|
| 211 |
+
|
| 212 |
+
return FileResponse(
|
| 213 |
+
video_path,
|
| 214 |
+
media_type="video/mp4",
|
| 215 |
+
filename=f"text_story_{job_id}.mp4"
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
# ============================================
|
| 220 |
+
# AI CONVERSATION GENERATION
|
| 221 |
+
# ============================================
|
| 222 |
+
|
| 223 |
+
from pydantic import BaseModel
|
| 224 |
+
from typing import List, Optional
|
| 225 |
+
|
| 226 |
+
class AiGenerateRequest(BaseModel):
|
| 227 |
+
"""Request for AI-generated conversation."""
|
| 228 |
+
prompt: str
|
| 229 |
+
person_a_name: str = "You"
|
| 230 |
+
person_b_name: str = "My Ex"
|
| 231 |
+
message_count: int = 7
|
| 232 |
+
tone: str = "emotional"
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class AiGenerateResponse(BaseModel):
|
| 236 |
+
"""Response with generated messages."""
|
| 237 |
+
messages: List[dict]
|
| 238 |
+
ending_text: Optional[str] = None
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
@router.post("/ai-generate", response_model=AiGenerateResponse)
|
| 242 |
+
async def ai_generate_conversation(request: AiGenerateRequest):
|
| 243 |
+
"""
|
| 244 |
+
Generate a fake conversation using Groq AI (openai/gpt-oss-120b).
|
| 245 |
+
|
| 246 |
+
Returns a list of messages for the text story.
|
| 247 |
+
"""
|
| 248 |
+
import aiohttp
|
| 249 |
+
import json
|
| 250 |
+
|
| 251 |
+
groq_api_key = os.getenv("GROQ_API_KEY", "")
|
| 252 |
+
if not groq_api_key:
|
| 253 |
+
raise HTTPException(status_code=500, detail="GROQ_API_KEY not configured")
|
| 254 |
+
|
| 255 |
+
# Tone descriptions
|
| 256 |
+
tone_prompts = {
|
| 257 |
+
"emotional": "Make it emotional and dramatic with deep feelings.",
|
| 258 |
+
"funny": "Make it funny and comedic with witty responses.",
|
| 259 |
+
"shocking": "Include a shocking plot twist at the end.",
|
| 260 |
+
"romantic": "Make it romantic with heartfelt messages.",
|
| 261 |
+
"angry": "Make it an angry argument with heated exchanges."
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
tone_instruction = tone_prompts.get(request.tone, tone_prompts["emotional"])
|
| 265 |
+
|
| 266 |
+
system_prompt = f"""You are a viral content script writer. Generate a fake text message conversation.
|
| 267 |
+
|
| 268 |
+
RULES:
|
| 269 |
+
1. Create exactly {request.message_count} messages
|
| 270 |
+
2. Alternate between Person A ({request.person_a_name}) and Person B ({request.person_b_name})
|
| 271 |
+
3. {tone_instruction}
|
| 272 |
+
4. Make it engaging and viral-worthy
|
| 273 |
+
5. Keep messages short (1-3 sentences each)
|
| 274 |
+
6. End with impact (twist or emotional ending)
|
| 275 |
+
|
| 276 |
+
OUTPUT FORMAT (strict JSON):
|
| 277 |
+
{{
|
| 278 |
+
"messages": [
|
| 279 |
+
{{"sender": "B", "text": "message from {request.person_b_name}"}},
|
| 280 |
+
{{"sender": "A", "text": "message from {request.person_a_name}"}},
|
| 281 |
+
...
|
| 282 |
+
],
|
| 283 |
+
"ending_text": "Optional ending text like 'To be continued...'"
|
| 284 |
+
}}
|
| 285 |
+
|
| 286 |
+
Only output valid JSON, nothing else."""
|
| 287 |
+
|
| 288 |
+
user_prompt = f"Create a conversation about: {request.prompt}"
|
| 289 |
+
|
| 290 |
+
try:
|
| 291 |
+
async with aiohttp.ClientSession() as session:
|
| 292 |
+
payload = {
|
| 293 |
+
"model": "meta-llama/llama-4-scout-17b-16e-instruct",
|
| 294 |
+
"messages": [
|
| 295 |
+
{"role": "system", "content": system_prompt},
|
| 296 |
+
{"role": "user", "content": user_prompt}
|
| 297 |
+
],
|
| 298 |
+
"temperature": 0.8,
|
| 299 |
+
"max_tokens": 2000
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
headers = {
|
| 303 |
+
"Authorization": f"Bearer {groq_api_key}",
|
| 304 |
+
"Content-Type": "application/json"
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
async with session.post(
|
| 308 |
+
"https://api.groq.com/openai/v1/chat/completions",
|
| 309 |
+
json=payload,
|
| 310 |
+
headers=headers,
|
| 311 |
+
timeout=aiohttp.ClientTimeout(total=30)
|
| 312 |
+
) as response:
|
| 313 |
+
if response.status != 200:
|
| 314 |
+
error_text = await response.text()
|
| 315 |
+
logger.error(f"Groq API error: {error_text}")
|
| 316 |
+
raise HTTPException(status_code=500, detail=f"Groq API error: {response.status}")
|
| 317 |
+
|
| 318 |
+
data = await response.json()
|
| 319 |
+
content = data["choices"][0]["message"]["content"]
|
| 320 |
+
|
| 321 |
+
# Parse JSON response
|
| 322 |
+
try:
|
| 323 |
+
# Clean up content (remove markdown code blocks if present)
|
| 324 |
+
content = content.strip()
|
| 325 |
+
if content.startswith("```"):
|
| 326 |
+
content = content.split("```")[1]
|
| 327 |
+
if content.startswith("json"):
|
| 328 |
+
content = content[4:]
|
| 329 |
+
content = content.strip()
|
| 330 |
+
|
| 331 |
+
result = json.loads(content)
|
| 332 |
+
|
| 333 |
+
return AiGenerateResponse(
|
| 334 |
+
messages=result.get("messages", []),
|
| 335 |
+
ending_text=result.get("ending_text")
|
| 336 |
+
)
|
| 337 |
+
except json.JSONDecodeError as e:
|
| 338 |
+
logger.error(f"Failed to parse AI response: {content}")
|
| 339 |
+
raise HTTPException(status_code=500, detail="AI returned invalid JSON")
|
| 340 |
+
|
| 341 |
+
except aiohttp.ClientError as e:
|
| 342 |
+
logger.error(f"Groq API request failed: {e}")
|
| 343 |
+
raise HTTPException(status_code=500, detail=f"AI request failed: {str(e)}")
|
| 344 |
+
|
modules/text_story/schemas.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pydantic schemas for Text Story module.
|
| 3 |
+
Defines input/output models for the API.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from pydantic import BaseModel, Field
|
| 7 |
+
from typing import List, Optional, Literal
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Message(BaseModel):
|
| 11 |
+
"""Single chat message."""
|
| 12 |
+
sender: Literal["A", "B"] = Field(
|
| 13 |
+
description="A = User (right, blue bubble), B = Other person (left, gray bubble)"
|
| 14 |
+
)
|
| 15 |
+
text: str = Field(
|
| 16 |
+
description="Message text content",
|
| 17 |
+
min_length=1,
|
| 18 |
+
max_length=500
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class TextStoryRequest(BaseModel):
|
| 23 |
+
"""Request to generate a text story video."""
|
| 24 |
+
person_a_name: str = Field(
|
| 25 |
+
default="You",
|
| 26 |
+
description="Name for Person A (user, right side)"
|
| 27 |
+
)
|
| 28 |
+
person_b_name: str = Field(
|
| 29 |
+
default="Unknown",
|
| 30 |
+
description="Name for Person B (other, left side)"
|
| 31 |
+
)
|
| 32 |
+
person_b_avatar: Optional[str] = Field(
|
| 33 |
+
default=None,
|
| 34 |
+
description="Avatar letter or emoji for Person B header (e.g., 'M' or 'π')"
|
| 35 |
+
)
|
| 36 |
+
messages: List[Message] = Field(
|
| 37 |
+
description="List of chat messages in order",
|
| 38 |
+
min_length=2,
|
| 39 |
+
max_length=50
|
| 40 |
+
)
|
| 41 |
+
ending_text: Optional[str] = Field(
|
| 42 |
+
default=None,
|
| 43 |
+
description="Emotional ending text (e.g., 'To be continued...')"
|
| 44 |
+
)
|
| 45 |
+
voice_a: str = Field(
|
| 46 |
+
default="af_heart",
|
| 47 |
+
description="Kokoro TTS voice for Person A"
|
| 48 |
+
)
|
| 49 |
+
voice_b: str = Field(
|
| 50 |
+
default="am_fenrir",
|
| 51 |
+
description="Kokoro TTS voice for Person B"
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class TextStoryResponse(BaseModel):
|
| 56 |
+
"""Response after starting video generation."""
|
| 57 |
+
job_id: str
|
| 58 |
+
status: str = "processing"
|
| 59 |
+
message: str = "Text story generation started"
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class TextStoryStatus(BaseModel):
|
| 63 |
+
"""Status of a text story generation job."""
|
| 64 |
+
job_id: str
|
| 65 |
+
status: Literal["processing", "ready", "failed"]
|
| 66 |
+
progress: int = Field(default=0, ge=0, le=100)
|
| 67 |
+
current_step: Optional[str] = None
|
| 68 |
+
video_url: Optional[str] = None
|
| 69 |
+
error: Optional[str] = None
|
modules/text_story/services/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Services init
|
modules/text_story/services/background.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Background Handler for Text Story module.
|
| 3 |
+
Handles gameplay video loading from HuggingFace Dataset storage.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
import logging
|
| 9 |
+
from moviepy.editor import VideoFileClip, vfx
|
| 10 |
+
from typing import Optional, List
|
| 11 |
+
from huggingface_hub import hf_hub_download, list_repo_files
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
# Folder name in HF Dataset for gameplay backgrounds
|
| 16 |
+
HF_BACKGROUNDS_FOLDER = "gameplay_backgrounds"
|
| 17 |
+
|
| 18 |
+
# Local cache path
|
| 19 |
+
LOCAL_CACHE_DIR = "cache/gameplay_backgrounds"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class BackgroundHandler:
|
| 23 |
+
"""
|
| 24 |
+
Handles gameplay background video processing.
|
| 25 |
+
Downloads from HuggingFace Dataset (HF_REPO env variable).
|
| 26 |
+
|
| 27 |
+
Features:
|
| 28 |
+
- Download videos from HF Dataset
|
| 29 |
+
- Random video selection
|
| 30 |
+
- Audio removal
|
| 31 |
+
- Slow motion (0.7x)
|
| 32 |
+
- Dark overlay
|
| 33 |
+
- Seamless looping
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self):
|
| 37 |
+
# Get repo from environment variable (e.g., robiul487/NCAkit)
|
| 38 |
+
self.repo_id = os.getenv("HF_REPO", "")
|
| 39 |
+
self.folder = HF_BACKGROUNDS_FOLDER
|
| 40 |
+
self.cache_dir = LOCAL_CACHE_DIR
|
| 41 |
+
|
| 42 |
+
if not self.repo_id:
|
| 43 |
+
logger.warning("BackgroundHandler: HF_REPO not set! Using solid background.")
|
| 44 |
+
self.available_videos = []
|
| 45 |
+
return
|
| 46 |
+
|
| 47 |
+
# Ensure cache directory exists
|
| 48 |
+
os.makedirs(self.cache_dir, exist_ok=True)
|
| 49 |
+
|
| 50 |
+
# Get list of available videos
|
| 51 |
+
self.available_videos = self._list_available_videos()
|
| 52 |
+
|
| 53 |
+
if self.available_videos:
|
| 54 |
+
logger.info(f"BackgroundHandler: Found {len(self.available_videos)} videos in {self.repo_id}/{self.folder}")
|
| 55 |
+
else:
|
| 56 |
+
logger.warning(f"BackgroundHandler: No videos found in {self.repo_id}/{self.folder}")
|
| 57 |
+
|
| 58 |
+
def _list_available_videos(self) -> List[str]:
|
| 59 |
+
"""List available video files in HF Dataset folder."""
|
| 60 |
+
if not self.repo_id:
|
| 61 |
+
return []
|
| 62 |
+
|
| 63 |
+
try:
|
| 64 |
+
all_files = list_repo_files(
|
| 65 |
+
repo_id=self.repo_id,
|
| 66 |
+
repo_type="dataset"
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# Filter for videos in gameplay_backgrounds folder
|
| 70 |
+
videos = [
|
| 71 |
+
f for f in all_files
|
| 72 |
+
if f.startswith(f"{self.folder}/")
|
| 73 |
+
and f.lower().endswith(('.mp4', '.mov', '.avi', '.webm'))
|
| 74 |
+
]
|
| 75 |
+
|
| 76 |
+
return videos
|
| 77 |
+
|
| 78 |
+
except Exception as e:
|
| 79 |
+
logger.error(f"BackgroundHandler: Failed to list files - {e}")
|
| 80 |
+
return []
|
| 81 |
+
|
| 82 |
+
def _download_video(self, filename: str) -> Optional[str]:
|
| 83 |
+
"""Download a video from HF Dataset to local cache."""
|
| 84 |
+
try:
|
| 85 |
+
# Check if already cached
|
| 86 |
+
local_name = os.path.basename(filename)
|
| 87 |
+
cached_path = os.path.join(self.cache_dir, local_name)
|
| 88 |
+
|
| 89 |
+
if os.path.exists(cached_path):
|
| 90 |
+
logger.info(f"BackgroundHandler: Using cached {local_name}")
|
| 91 |
+
return cached_path
|
| 92 |
+
|
| 93 |
+
# Download from HF
|
| 94 |
+
logger.info(f"BackgroundHandler: Downloading {filename}...")
|
| 95 |
+
|
| 96 |
+
downloaded_path = hf_hub_download(
|
| 97 |
+
repo_id=self.repo_id,
|
| 98 |
+
filename=filename,
|
| 99 |
+
repo_type="dataset",
|
| 100 |
+
local_dir=self.cache_dir,
|
| 101 |
+
local_dir_use_symlinks=False
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
logger.info(f"BackgroundHandler: Downloaded to {downloaded_path}")
|
| 105 |
+
return downloaded_path
|
| 106 |
+
|
| 107 |
+
except Exception as e:
|
| 108 |
+
logger.error(f"BackgroundHandler: Download failed - {e}")
|
| 109 |
+
return None
|
| 110 |
+
|
| 111 |
+
def get_random_video(self) -> Optional[str]:
|
| 112 |
+
"""Get random video from HF Dataset and download it."""
|
| 113 |
+
if not self.available_videos:
|
| 114 |
+
# Retry listing
|
| 115 |
+
self.available_videos = self._list_available_videos()
|
| 116 |
+
|
| 117 |
+
if not self.available_videos:
|
| 118 |
+
logger.warning("BackgroundHandler: No videos available")
|
| 119 |
+
return None
|
| 120 |
+
|
| 121 |
+
selected = random.choice(self.available_videos)
|
| 122 |
+
logger.info(f"BackgroundHandler: Selected {selected}")
|
| 123 |
+
|
| 124 |
+
return self._download_video(selected)
|
| 125 |
+
|
| 126 |
+
def load_and_process(self,
|
| 127 |
+
target_duration: float,
|
| 128 |
+
video_path: str = None) -> Optional[VideoFileClip]:
|
| 129 |
+
"""
|
| 130 |
+
Load and process a background video.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
target_duration: Required duration in seconds
|
| 134 |
+
video_path: Optional specific video path (or random if None)
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
Processed VideoFileClip or None
|
| 138 |
+
"""
|
| 139 |
+
# Get video path
|
| 140 |
+
if video_path is None:
|
| 141 |
+
video_path = self.get_random_video()
|
| 142 |
+
|
| 143 |
+
if not video_path or not os.path.exists(video_path):
|
| 144 |
+
logger.warning("BackgroundHandler: No video available, creating solid background")
|
| 145 |
+
return self._create_solid_background(target_duration)
|
| 146 |
+
|
| 147 |
+
try:
|
| 148 |
+
# Load video WITHOUT audio
|
| 149 |
+
clip = VideoFileClip(video_path).without_audio()
|
| 150 |
+
logger.info(f"BackgroundHandler: Loaded {video_path}, duration: {clip.duration:.1f}s")
|
| 151 |
+
|
| 152 |
+
# Apply slow motion (0.7x speed)
|
| 153 |
+
clip = clip.fx(vfx.speedx, 0.7)
|
| 154 |
+
|
| 155 |
+
# Loop if needed to match target duration
|
| 156 |
+
clip = self._loop_to_duration(clip, target_duration)
|
| 157 |
+
|
| 158 |
+
# Apply visual effects
|
| 159 |
+
clip = self._apply_visual_effects(clip)
|
| 160 |
+
|
| 161 |
+
return clip
|
| 162 |
+
|
| 163 |
+
except Exception as e:
|
| 164 |
+
logger.error(f"BackgroundHandler: Failed to process video - {e}")
|
| 165 |
+
return self._create_solid_background(target_duration)
|
| 166 |
+
|
| 167 |
+
def _loop_to_duration(self, clip: VideoFileClip, target_duration: float) -> VideoFileClip:
|
| 168 |
+
"""Loop video to match target duration."""
|
| 169 |
+
if clip.duration >= target_duration:
|
| 170 |
+
return clip.subclip(0, target_duration)
|
| 171 |
+
|
| 172 |
+
loops_needed = int(target_duration / clip.duration) + 1
|
| 173 |
+
looped = clip.loop(n=loops_needed)
|
| 174 |
+
return looped.subclip(0, target_duration)
|
| 175 |
+
|
| 176 |
+
def _apply_visual_effects(self, clip: VideoFileClip) -> VideoFileClip:
|
| 177 |
+
"""
|
| 178 |
+
Apply resize, crop (if needed), dark overlay, and saturation reduction.
|
| 179 |
+
|
| 180 |
+
- 9:16 videos: just resize (no crop needed)
|
| 181 |
+
- 16:9 videos: center crop to 9:16
|
| 182 |
+
"""
|
| 183 |
+
target_w, target_h = 1080, 1920
|
| 184 |
+
target_ratio = target_w / target_h # 0.5625 (9:16)
|
| 185 |
+
|
| 186 |
+
# Calculate source aspect ratio
|
| 187 |
+
clip_ratio = clip.w / clip.h
|
| 188 |
+
|
| 189 |
+
# Check if already 9:16 (or close to it)
|
| 190 |
+
# 9:16 ratio is ~0.5625, allow some tolerance
|
| 191 |
+
is_vertical = clip_ratio < 0.7 # Less than ~11:16 is considered vertical
|
| 192 |
+
|
| 193 |
+
if is_vertical:
|
| 194 |
+
# Already vertical (9:16)
|
| 195 |
+
if clip.w == target_w and clip.h == target_h:
|
| 196 |
+
# Perfect match, no resize needed
|
| 197 |
+
logger.info(f"BackgroundHandler: Video is already {target_w}x{target_h}, no resize")
|
| 198 |
+
else:
|
| 199 |
+
# Resize to target resolution
|
| 200 |
+
logger.info(f"BackgroundHandler: Video is vertical ({clip.w}x{clip.h}), resizing to {target_w}x{target_h}")
|
| 201 |
+
clip = clip.resize(newsize=(target_w, target_h))
|
| 202 |
+
else:
|
| 203 |
+
# Horizontal (16:9), need to crop
|
| 204 |
+
logger.info(f"BackgroundHandler: Video is horizontal ({clip.w}x{clip.h}), center cropping to 9:16")
|
| 205 |
+
|
| 206 |
+
# Scale to match height, then center crop width
|
| 207 |
+
new_h = target_h
|
| 208 |
+
new_w = int(clip_ratio * new_h)
|
| 209 |
+
clip = clip.resize(height=new_h)
|
| 210 |
+
|
| 211 |
+
# Center crop
|
| 212 |
+
x_center = new_w // 2
|
| 213 |
+
clip = clip.crop(x_center=x_center, width=target_w, height=target_h)
|
| 214 |
+
|
| 215 |
+
# Dark overlay (reduce brightness by 40%)
|
| 216 |
+
clip = clip.fx(vfx.colorx, 0.6)
|
| 217 |
+
|
| 218 |
+
# Saturation reduction
|
| 219 |
+
clip = clip.fx(vfx.lum_contrast, lum=-10, contrast=-0.1)
|
| 220 |
+
|
| 221 |
+
return clip
|
| 222 |
+
|
| 223 |
+
def _create_solid_background(self, duration: float) -> VideoFileClip:
|
| 224 |
+
"""Create solid dark background as fallback."""
|
| 225 |
+
from moviepy.editor import ColorClip
|
| 226 |
+
|
| 227 |
+
return ColorClip(
|
| 228 |
+
size=(1080, 1920),
|
| 229 |
+
color=(15, 15, 25),
|
| 230 |
+
duration=duration
|
| 231 |
+
)
|
modules/text_story/services/renderer.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chat UI Renderer for Text Story module.
|
| 3 |
+
Creates iMessage-style chat bubbles and UI.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import logging
|
| 8 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 9 |
+
from typing import List, Tuple, Optional
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
# Canvas dimensions (9:16 vertical)
|
| 14 |
+
CANVAS_WIDTH = 1080
|
| 15 |
+
CANVAS_HEIGHT = 1920
|
| 16 |
+
|
| 17 |
+
# Colors (iMessage style)
|
| 18 |
+
COLORS = {
|
| 19 |
+
"header_bg": (28, 28, 30), # #1C1C1E - Dark header
|
| 20 |
+
"bubble_user": (0, 122, 255), # #007AFF - Blue (Person A/right)
|
| 21 |
+
"bubble_other": (58, 58, 60), # #3A3A3C - Gray (Person B/left)
|
| 22 |
+
"text_white": (255, 255, 255), # White text
|
| 23 |
+
"text_gray": (142, 142, 147), # #8E8E93 - Secondary text
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
# UI Measurements
|
| 27 |
+
UI = {
|
| 28 |
+
"header_height": 100,
|
| 29 |
+
"margin_side": 30,
|
| 30 |
+
"bubble_max_width_ratio": 0.75, # 75% of screen
|
| 31 |
+
"bubble_padding_h": 16,
|
| 32 |
+
"bubble_padding_v": 12,
|
| 33 |
+
"bubble_radius": 20,
|
| 34 |
+
"bubble_gap": 10,
|
| 35 |
+
"font_size": 34,
|
| 36 |
+
"header_font_size": 22,
|
| 37 |
+
"avatar_size": 50,
|
| 38 |
+
"max_visible_messages": 7,
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class ChatRenderer:
|
| 43 |
+
"""
|
| 44 |
+
Renders iMessage-style chat UI frames.
|
| 45 |
+
Handles dynamic box sizing and message bubbles.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(self,
|
| 49 |
+
person_a_name: str = "You",
|
| 50 |
+
person_b_name: str = "Unknown",
|
| 51 |
+
person_b_avatar: str = None):
|
| 52 |
+
self.person_a_name = person_a_name
|
| 53 |
+
self.person_b_name = person_b_name
|
| 54 |
+
self.person_b_avatar = person_b_avatar or person_b_name[0].upper()
|
| 55 |
+
|
| 56 |
+
# Load fonts
|
| 57 |
+
self.font = self._load_font(UI["font_size"])
|
| 58 |
+
self.font_small = self._load_font(UI["header_font_size"])
|
| 59 |
+
self.font_avatar = self._load_font(28)
|
| 60 |
+
|
| 61 |
+
# Track visible messages for scroll behavior
|
| 62 |
+
self.visible_messages: List[dict] = []
|
| 63 |
+
|
| 64 |
+
def _load_font(self, size: int) -> ImageFont.FreeTypeFont:
|
| 65 |
+
"""Load font with fallback."""
|
| 66 |
+
font_paths = [
|
| 67 |
+
"/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
|
| 68 |
+
"/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf",
|
| 69 |
+
"C:/Windows/Fonts/arial.ttf",
|
| 70 |
+
"/System/Library/Fonts/SFNS.ttf",
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
for path in font_paths:
|
| 74 |
+
if os.path.exists(path):
|
| 75 |
+
try:
|
| 76 |
+
return ImageFont.truetype(path, size)
|
| 77 |
+
except:
|
| 78 |
+
continue
|
| 79 |
+
|
| 80 |
+
# Fallback to default
|
| 81 |
+
return ImageFont.load_default()
|
| 82 |
+
|
| 83 |
+
def _wrap_text(self, text: str, max_width: int) -> List[str]:
|
| 84 |
+
"""Wrap text to fit within max width."""
|
| 85 |
+
words = text.split()
|
| 86 |
+
lines = []
|
| 87 |
+
current_line = []
|
| 88 |
+
|
| 89 |
+
for word in words:
|
| 90 |
+
test_line = " ".join(current_line + [word])
|
| 91 |
+
bbox = self.font.getbbox(test_line)
|
| 92 |
+
width = bbox[2] - bbox[0]
|
| 93 |
+
|
| 94 |
+
if width <= max_width:
|
| 95 |
+
current_line.append(word)
|
| 96 |
+
else:
|
| 97 |
+
if current_line:
|
| 98 |
+
lines.append(" ".join(current_line))
|
| 99 |
+
current_line = [word]
|
| 100 |
+
|
| 101 |
+
if current_line:
|
| 102 |
+
lines.append(" ".join(current_line))
|
| 103 |
+
|
| 104 |
+
return lines if lines else [text]
|
| 105 |
+
|
| 106 |
+
def _calculate_bubble_size(self, text: str) -> Tuple[int, int, List[str]]:
|
| 107 |
+
"""Calculate bubble size based on text."""
|
| 108 |
+
max_text_width = int(CANVAS_WIDTH * UI["bubble_max_width_ratio"]) - UI["bubble_padding_h"] * 2
|
| 109 |
+
lines = self._wrap_text(text, max_text_width)
|
| 110 |
+
|
| 111 |
+
# Calculate text dimensions
|
| 112 |
+
line_height = self.font.getbbox("Ay")[3] + 4
|
| 113 |
+
text_height = line_height * len(lines)
|
| 114 |
+
|
| 115 |
+
max_line_width = 0
|
| 116 |
+
for line in lines:
|
| 117 |
+
bbox = self.font.getbbox(line)
|
| 118 |
+
max_line_width = max(max_line_width, bbox[2] - bbox[0])
|
| 119 |
+
|
| 120 |
+
# Add padding
|
| 121 |
+
bubble_width = max_line_width + UI["bubble_padding_h"] * 2
|
| 122 |
+
bubble_height = text_height + UI["bubble_padding_v"] * 2
|
| 123 |
+
|
| 124 |
+
return bubble_width, bubble_height, lines
|
| 125 |
+
|
| 126 |
+
def _draw_header(self, draw: ImageDraw.Draw, img: Image.Image):
|
| 127 |
+
"""Draw iMessage-style header."""
|
| 128 |
+
# Header background
|
| 129 |
+
draw.rectangle(
|
| 130 |
+
[0, 0, CANVAS_WIDTH, UI["header_height"]],
|
| 131 |
+
fill=COLORS["header_bg"]
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# Avatar circle
|
| 135 |
+
avatar_x = CANVAS_WIDTH // 2
|
| 136 |
+
avatar_y = 35
|
| 137 |
+
avatar_r = UI["avatar_size"] // 2
|
| 138 |
+
|
| 139 |
+
draw.ellipse(
|
| 140 |
+
[avatar_x - avatar_r, avatar_y - avatar_r,
|
| 141 |
+
avatar_x + avatar_r, avatar_y + avatar_r],
|
| 142 |
+
fill=(100, 100, 105) # Gray circle
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# Avatar letter
|
| 146 |
+
bbox = self.font_avatar.getbbox(self.person_b_avatar)
|
| 147 |
+
text_w = bbox[2] - bbox[0]
|
| 148 |
+
text_h = bbox[3] - bbox[1]
|
| 149 |
+
draw.text(
|
| 150 |
+
(avatar_x - text_w // 2, avatar_y - text_h // 2 - 2),
|
| 151 |
+
self.person_b_avatar,
|
| 152 |
+
fill=COLORS["text_white"],
|
| 153 |
+
font=self.font_avatar
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# Name below avatar
|
| 157 |
+
name_bbox = self.font_small.getbbox(self.person_b_name)
|
| 158 |
+
name_w = name_bbox[2] - name_bbox[0]
|
| 159 |
+
draw.text(
|
| 160 |
+
(avatar_x - name_w // 2, avatar_y + avatar_r + 8),
|
| 161 |
+
self.person_b_name,
|
| 162 |
+
fill=COLORS["text_white"],
|
| 163 |
+
font=self.font_small
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
# Left chevron (back button)
|
| 167 |
+
draw.text((20, 30), "βΉ", fill=(0, 122, 255), font=self.font)
|
| 168 |
+
|
| 169 |
+
# Right video icon
|
| 170 |
+
draw.text((CANVAS_WIDTH - 50, 30), "πΉ", fill=(0, 122, 255), font=self.font_small)
|
| 171 |
+
|
| 172 |
+
def _draw_bubble(self, draw: ImageDraw.Draw,
|
| 173 |
+
x: int, y: int,
|
| 174 |
+
width: int, height: int,
|
| 175 |
+
lines: List[str],
|
| 176 |
+
is_user: bool) -> int:
|
| 177 |
+
"""
|
| 178 |
+
Draw a chat bubble.
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
Bottom Y position of bubble
|
| 182 |
+
"""
|
| 183 |
+
# Bubble color
|
| 184 |
+
color = COLORS["bubble_user"] if is_user else COLORS["bubble_other"]
|
| 185 |
+
|
| 186 |
+
# Draw rounded rectangle
|
| 187 |
+
radius = UI["bubble_radius"]
|
| 188 |
+
draw.rounded_rectangle(
|
| 189 |
+
[x, y, x + width, y + height],
|
| 190 |
+
radius=radius,
|
| 191 |
+
fill=color
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# Draw text
|
| 195 |
+
text_x = x + UI["bubble_padding_h"]
|
| 196 |
+
text_y = y + UI["bubble_padding_v"]
|
| 197 |
+
line_height = self.font.getbbox("Ay")[3] + 4
|
| 198 |
+
|
| 199 |
+
for line in lines:
|
| 200 |
+
draw.text((text_x, text_y), line, fill=COLORS["text_white"], font=self.font)
|
| 201 |
+
text_y += line_height
|
| 202 |
+
|
| 203 |
+
return y + height
|
| 204 |
+
|
| 205 |
+
def render_frame(self, messages: List[dict], show_typing: bool = False) -> Image.Image:
|
| 206 |
+
"""
|
| 207 |
+
Render a single frame with current messages.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
messages: List of {"sender": "A"/"B", "text": "..."} dicts
|
| 211 |
+
show_typing: Whether to show typing indicator
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
PIL Image of the frame
|
| 215 |
+
"""
|
| 216 |
+
# Create transparent image (gameplay will be behind)
|
| 217 |
+
img = Image.new("RGBA", (CANVAS_WIDTH, CANVAS_HEIGHT), (0, 0, 0, 0))
|
| 218 |
+
draw = ImageDraw.Draw(img)
|
| 219 |
+
|
| 220 |
+
# Calculate total height needed for messages
|
| 221 |
+
message_heights = []
|
| 222 |
+
for msg in messages:
|
| 223 |
+
_, height, _ = self._calculate_bubble_size(msg["text"])
|
| 224 |
+
message_heights.append(height + UI["bubble_gap"])
|
| 225 |
+
|
| 226 |
+
total_msg_height = sum(message_heights)
|
| 227 |
+
|
| 228 |
+
# Calculate UI box height (dynamic)
|
| 229 |
+
ui_height = UI["header_height"] + total_msg_height + 20 # 20px bottom padding
|
| 230 |
+
|
| 231 |
+
# Draw semi-transparent black background for chat area
|
| 232 |
+
draw.rectangle(
|
| 233 |
+
[0, 0, CANVAS_WIDTH, ui_height],
|
| 234 |
+
fill=(0, 0, 0, 220) # Semi-transparent black
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# Draw header
|
| 238 |
+
self._draw_header(draw, img)
|
| 239 |
+
|
| 240 |
+
# Draw messages
|
| 241 |
+
current_y = UI["header_height"] + 15
|
| 242 |
+
|
| 243 |
+
# Only show last N messages if too many
|
| 244 |
+
visible_messages = messages[-UI["max_visible_messages"]:]
|
| 245 |
+
|
| 246 |
+
for msg in visible_messages:
|
| 247 |
+
width, height, lines = self._calculate_bubble_size(msg["text"])
|
| 248 |
+
|
| 249 |
+
# Position: A (user) = right, B (other) = left
|
| 250 |
+
if msg["sender"] == "A":
|
| 251 |
+
x = CANVAS_WIDTH - UI["margin_side"] - width
|
| 252 |
+
else:
|
| 253 |
+
x = UI["margin_side"]
|
| 254 |
+
|
| 255 |
+
current_y = self._draw_bubble(draw, x, current_y, width, height, lines, msg["sender"] == "A")
|
| 256 |
+
current_y += UI["bubble_gap"]
|
| 257 |
+
|
| 258 |
+
# Draw typing indicator if needed
|
| 259 |
+
if show_typing:
|
| 260 |
+
typing_y = current_y + 5
|
| 261 |
+
self._draw_typing_indicator(draw, typing_y)
|
| 262 |
+
|
| 263 |
+
return img
|
| 264 |
+
|
| 265 |
+
def _draw_typing_indicator(self, draw: ImageDraw.Draw, y: int):
|
| 266 |
+
"""Draw typing indicator (βββ)."""
|
| 267 |
+
x = UI["margin_side"]
|
| 268 |
+
|
| 269 |
+
# Background bubble
|
| 270 |
+
bubble_width = 70
|
| 271 |
+
bubble_height = 40
|
| 272 |
+
draw.rounded_rectangle(
|
| 273 |
+
[x, y, x + bubble_width, y + bubble_height],
|
| 274 |
+
radius=15,
|
| 275 |
+
fill=COLORS["bubble_other"]
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
# Three dots
|
| 279 |
+
dot_y = y + bubble_height // 2
|
| 280 |
+
for i, dx in enumerate([20, 35, 50]):
|
| 281 |
+
draw.ellipse(
|
| 282 |
+
[x + dx - 4, dot_y - 4, x + dx + 4, dot_y + 4],
|
| 283 |
+
fill=COLORS["text_gray"]
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
def get_ui_height(self, messages: List[dict]) -> int:
|
| 287 |
+
"""Calculate the height of the chat UI for given messages."""
|
| 288 |
+
message_heights = []
|
| 289 |
+
visible_messages = messages[-UI["max_visible_messages"]:]
|
| 290 |
+
|
| 291 |
+
for msg in visible_messages:
|
| 292 |
+
_, height, _ = self._calculate_bubble_size(msg["text"])
|
| 293 |
+
message_heights.append(height + UI["bubble_gap"])
|
| 294 |
+
|
| 295 |
+
return UI["header_height"] + sum(message_heights) + 20
|
modules/text_story/services/tts_handler.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TTS Handler for Text Story module.
|
| 3 |
+
Handles voice generation and audio processing.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import logging
|
| 8 |
+
import aiohttp
|
| 9 |
+
from pydub import AudioSegment
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class TTSHandler:
|
| 15 |
+
"""
|
| 16 |
+
Handles Text-to-Speech generation using Kokoro TTS.
|
| 17 |
+
Also handles silence trimming and duration detection.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self):
|
| 21 |
+
self.tts_url = os.getenv("HF_TTS", "")
|
| 22 |
+
if not self.tts_url:
|
| 23 |
+
logger.warning("TTSHandler: HF_TTS not configured, TTS will fail")
|
| 24 |
+
|
| 25 |
+
async def generate_tts(self, text: str, voice: str, output_path: str) -> float:
|
| 26 |
+
"""
|
| 27 |
+
Generate TTS audio for text.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
text: Text to speak
|
| 31 |
+
voice: Kokoro voice ID (e.g., 'af_heart', 'am_fenrir')
|
| 32 |
+
output_path: Path to save WAV file
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
Duration in seconds
|
| 36 |
+
"""
|
| 37 |
+
if not self.tts_url:
|
| 38 |
+
raise ValueError("HF_TTS environment variable not set")
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
async with aiohttp.ClientSession() as session:
|
| 42 |
+
payload = {
|
| 43 |
+
"text": text,
|
| 44 |
+
"voice": voice
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
async with session.post(
|
| 48 |
+
f"{self.tts_url}/tts",
|
| 49 |
+
json=payload,
|
| 50 |
+
timeout=aiohttp.ClientTimeout(total=60)
|
| 51 |
+
) as response:
|
| 52 |
+
if response.status != 200:
|
| 53 |
+
error_text = await response.text()
|
| 54 |
+
raise Exception(f"TTS failed: {error_text}")
|
| 55 |
+
|
| 56 |
+
audio_data = await response.read()
|
| 57 |
+
|
| 58 |
+
# Save raw audio
|
| 59 |
+
temp_path = output_path + ".temp.wav"
|
| 60 |
+
with open(temp_path, "wb") as f:
|
| 61 |
+
f.write(audio_data)
|
| 62 |
+
|
| 63 |
+
# Trim silence and get duration
|
| 64 |
+
duration = self.trim_silence(temp_path, output_path)
|
| 65 |
+
|
| 66 |
+
# Cleanup temp
|
| 67 |
+
if os.path.exists(temp_path):
|
| 68 |
+
os.remove(temp_path)
|
| 69 |
+
|
| 70 |
+
logger.info(f"TTS: Generated {len(text)} chars, {duration:.2f}s")
|
| 71 |
+
return duration
|
| 72 |
+
|
| 73 |
+
except Exception as e:
|
| 74 |
+
logger.error(f"TTS generation failed: {e}")
|
| 75 |
+
raise
|
| 76 |
+
|
| 77 |
+
def trim_silence(self, input_path: str, output_path: str,
|
| 78 |
+
silence_thresh: int = -40, min_silence_len: int = 100) -> float:
|
| 79 |
+
"""
|
| 80 |
+
Trim leading and trailing silence from audio.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
input_path: Input audio file
|
| 84 |
+
output_path: Output audio file
|
| 85 |
+
silence_thresh: Silence threshold in dB
|
| 86 |
+
min_silence_len: Minimum silence length in ms
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
Duration of trimmed audio in seconds
|
| 90 |
+
"""
|
| 91 |
+
try:
|
| 92 |
+
audio = AudioSegment.from_file(input_path)
|
| 93 |
+
|
| 94 |
+
# Detect non-silent parts
|
| 95 |
+
from pydub.silence import detect_nonsilent
|
| 96 |
+
|
| 97 |
+
nonsilent_ranges = detect_nonsilent(
|
| 98 |
+
audio,
|
| 99 |
+
min_silence_len=min_silence_len,
|
| 100 |
+
silence_thresh=silence_thresh
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
if nonsilent_ranges:
|
| 104 |
+
# Get start and end of non-silent audio
|
| 105 |
+
start_ms = max(0, nonsilent_ranges[0][0] - 50) # Add 50ms padding
|
| 106 |
+
end_ms = min(len(audio), nonsilent_ranges[-1][1] + 100) # Add 100ms padding
|
| 107 |
+
|
| 108 |
+
trimmed = audio[start_ms:end_ms]
|
| 109 |
+
else:
|
| 110 |
+
# No speech detected, use original
|
| 111 |
+
trimmed = audio
|
| 112 |
+
|
| 113 |
+
# Export trimmed audio
|
| 114 |
+
trimmed.export(output_path, format="wav")
|
| 115 |
+
|
| 116 |
+
duration = len(trimmed) / 1000.0 # Convert ms to seconds
|
| 117 |
+
return duration
|
| 118 |
+
|
| 119 |
+
except Exception as e:
|
| 120 |
+
logger.error(f"Silence trim failed: {e}")
|
| 121 |
+
# Fallback: just copy the file
|
| 122 |
+
import shutil
|
| 123 |
+
shutil.copy2(input_path, output_path)
|
| 124 |
+
audio = AudioSegment.from_file(output_path)
|
| 125 |
+
return len(audio) / 1000.0
|
| 126 |
+
|
| 127 |
+
def get_duration(self, audio_path: str) -> float:
|
| 128 |
+
"""Get duration of audio file in seconds."""
|
| 129 |
+
try:
|
| 130 |
+
audio = AudioSegment.from_file(audio_path)
|
| 131 |
+
return len(audio) / 1000.0
|
| 132 |
+
except Exception as e:
|
| 133 |
+
logger.error(f"Failed to get audio duration: {e}")
|
| 134 |
+
return 2.0 # Default fallback
|
modules/text_story/services/video_composer.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Video Composer for Text Story module.
|
| 3 |
+
Assembles final video with realistic timing and effects.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import logging
|
| 8 |
+
from typing import List, Dict, Optional
|
| 9 |
+
from moviepy.editor import (
|
| 10 |
+
VideoFileClip, ImageClip, AudioFileClip,
|
| 11 |
+
CompositeVideoClip, concatenate_videoclips,
|
| 12 |
+
ColorClip, TextClip
|
| 13 |
+
)
|
| 14 |
+
from moviepy.video.fx.all import fadein, fadeout
|
| 15 |
+
import numpy as np
|
| 16 |
+
from PIL import Image
|
| 17 |
+
|
| 18 |
+
from .renderer import ChatRenderer, CANVAS_WIDTH, CANVAS_HEIGHT
|
| 19 |
+
from .background import BackgroundHandler
|
| 20 |
+
from .tts_handler import TTSHandler
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
# Timing configurations (realistic chat behavior)
|
| 25 |
+
TIMING = {
|
| 26 |
+
"typing_base": 0.5, # Base typing indicator duration
|
| 27 |
+
"typing_per_char": 0.008, # Additional time per character
|
| 28 |
+
"typing_max": 1.2, # Max typing duration
|
| 29 |
+
"human_pause_min": 0.3, # Minimum pause before message
|
| 30 |
+
"human_pause_max": 0.8, # Maximum pause
|
| 31 |
+
"voice_delay": 0.15, # Gap between text appear and voice
|
| 32 |
+
"micro_silence": 0.3, # Silence between messages
|
| 33 |
+
"last_msg_pause": 1.5, # Pause after last message
|
| 34 |
+
"ending_duration": 3.0, # Ending text duration
|
| 35 |
+
"ending_fade": 0.5, # Ending fade in/out
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class VideoComposer:
|
| 40 |
+
"""
|
| 41 |
+
Composes the final text story video.
|
| 42 |
+
|
| 43 |
+
Pipeline:
|
| 44 |
+
1. Load gameplay background
|
| 45 |
+
2. For each message:
|
| 46 |
+
- Show typing indicator
|
| 47 |
+
- Pop in message bubble
|
| 48 |
+
- Play TTS audio
|
| 49 |
+
- Add micro silence
|
| 50 |
+
3. Add ending text
|
| 51 |
+
4. Export final video
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(self,
|
| 55 |
+
renderer: ChatRenderer,
|
| 56 |
+
bg_handler: BackgroundHandler,
|
| 57 |
+
tts_handler: TTSHandler):
|
| 58 |
+
self.renderer = renderer
|
| 59 |
+
self.bg_handler = bg_handler
|
| 60 |
+
self.tts_handler = tts_handler
|
| 61 |
+
|
| 62 |
+
def _calculate_typing_duration(self, text: str) -> float:
|
| 63 |
+
"""Calculate typing indicator duration based on message length."""
|
| 64 |
+
duration = TIMING["typing_base"] + len(text) * TIMING["typing_per_char"]
|
| 65 |
+
return min(duration, TIMING["typing_max"])
|
| 66 |
+
|
| 67 |
+
def _calculate_human_pause(self, msg_index: int, total: int) -> float:
|
| 68 |
+
"""Calculate human-like pause before message."""
|
| 69 |
+
# First message: shorter pause
|
| 70 |
+
if msg_index == 0:
|
| 71 |
+
return TIMING["human_pause_min"]
|
| 72 |
+
|
| 73 |
+
# Vary pause based on position
|
| 74 |
+
import random
|
| 75 |
+
return random.uniform(TIMING["human_pause_min"], TIMING["human_pause_max"])
|
| 76 |
+
|
| 77 |
+
def _pil_to_moviepy(self, pil_image: Image.Image, duration: float) -> ImageClip:
|
| 78 |
+
"""Convert PIL Image to MoviePy ImageClip."""
|
| 79 |
+
# Convert to numpy array
|
| 80 |
+
np_array = np.array(pil_image)
|
| 81 |
+
|
| 82 |
+
# Create ImageClip
|
| 83 |
+
clip = ImageClip(np_array, duration=duration)
|
| 84 |
+
|
| 85 |
+
return clip
|
| 86 |
+
|
| 87 |
+
async def compose(self,
|
| 88 |
+
messages: List[Dict],
|
| 89 |
+
ending_text: Optional[str],
|
| 90 |
+
output_dir: str) -> str:
|
| 91 |
+
"""
|
| 92 |
+
Compose the full text story video.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
messages: List of {path, duration, sender, text} dicts
|
| 96 |
+
ending_text: Optional ending text
|
| 97 |
+
output_dir: Directory for output files
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
Path to final video file
|
| 101 |
+
"""
|
| 102 |
+
logger.info(f"VideoComposer: Starting composition with {len(messages)} messages")
|
| 103 |
+
|
| 104 |
+
# Calculate total duration
|
| 105 |
+
total_duration = 0
|
| 106 |
+
for i, msg in enumerate(messages):
|
| 107 |
+
total_duration += self._calculate_typing_duration(msg["text"])
|
| 108 |
+
total_duration += TIMING["voice_delay"]
|
| 109 |
+
total_duration += msg["duration"]
|
| 110 |
+
total_duration += TIMING["micro_silence"]
|
| 111 |
+
|
| 112 |
+
total_duration += TIMING["last_msg_pause"]
|
| 113 |
+
|
| 114 |
+
if ending_text:
|
| 115 |
+
total_duration += TIMING["ending_duration"]
|
| 116 |
+
|
| 117 |
+
logger.info(f"VideoComposer: Total duration: {total_duration:.1f}s")
|
| 118 |
+
|
| 119 |
+
# Load background
|
| 120 |
+
background = self.bg_handler.load_and_process(total_duration)
|
| 121 |
+
|
| 122 |
+
# Create message sequence
|
| 123 |
+
clips = []
|
| 124 |
+
current_time = 0
|
| 125 |
+
displayed_messages = []
|
| 126 |
+
|
| 127 |
+
for i, msg in enumerate(messages):
|
| 128 |
+
msg_dict = {"sender": msg["sender"], "text": msg["text"]}
|
| 129 |
+
|
| 130 |
+
# 1. Typing indicator phase
|
| 131 |
+
typing_duration = self._calculate_typing_duration(msg["text"])
|
| 132 |
+
typing_frame = self.renderer.render_frame(displayed_messages, show_typing=True)
|
| 133 |
+
typing_clip = self._pil_to_moviepy(typing_frame, typing_duration)
|
| 134 |
+
typing_clip = typing_clip.set_start(current_time)
|
| 135 |
+
clips.append(typing_clip)
|
| 136 |
+
current_time += typing_duration
|
| 137 |
+
|
| 138 |
+
# 2. Add message (with voice delay)
|
| 139 |
+
displayed_messages.append(msg_dict)
|
| 140 |
+
|
| 141 |
+
# Message appears (voice delay gap)
|
| 142 |
+
msg_frame = self.renderer.render_frame(displayed_messages)
|
| 143 |
+
voice_delay_clip = self._pil_to_moviepy(msg_frame, TIMING["voice_delay"])
|
| 144 |
+
voice_delay_clip = voice_delay_clip.set_start(current_time)
|
| 145 |
+
clips.append(voice_delay_clip)
|
| 146 |
+
current_time += TIMING["voice_delay"]
|
| 147 |
+
|
| 148 |
+
# 3. Message with audio
|
| 149 |
+
msg_clip = self._pil_to_moviepy(msg_frame, msg["duration"])
|
| 150 |
+
msg_clip = msg_clip.set_start(current_time)
|
| 151 |
+
|
| 152 |
+
# Add audio
|
| 153 |
+
audio = AudioFileClip(msg["path"])
|
| 154 |
+
msg_clip = msg_clip.set_audio(audio.set_start(0))
|
| 155 |
+
clips.append(msg_clip)
|
| 156 |
+
current_time += msg["duration"]
|
| 157 |
+
|
| 158 |
+
# 4. Micro silence
|
| 159 |
+
silence_clip = self._pil_to_moviepy(msg_frame, TIMING["micro_silence"])
|
| 160 |
+
silence_clip = silence_clip.set_start(current_time)
|
| 161 |
+
clips.append(silence_clip)
|
| 162 |
+
current_time += TIMING["micro_silence"]
|
| 163 |
+
|
| 164 |
+
logger.info(f"VideoComposer: Message {i+1}/{len(messages)} at {current_time:.1f}s")
|
| 165 |
+
|
| 166 |
+
# Last message pause
|
| 167 |
+
final_frame = self.renderer.render_frame(displayed_messages)
|
| 168 |
+
pause_clip = self._pil_to_moviepy(final_frame, TIMING["last_msg_pause"])
|
| 169 |
+
pause_clip = pause_clip.set_start(current_time)
|
| 170 |
+
clips.append(pause_clip)
|
| 171 |
+
current_time += TIMING["last_msg_pause"]
|
| 172 |
+
|
| 173 |
+
# Ending text
|
| 174 |
+
if ending_text:
|
| 175 |
+
ending_clip = self._create_ending_clip(ending_text, displayed_messages)
|
| 176 |
+
ending_clip = ending_clip.set_start(current_time)
|
| 177 |
+
clips.append(ending_clip)
|
| 178 |
+
|
| 179 |
+
# Composite all clips over background
|
| 180 |
+
final = CompositeVideoClip([background] + clips, size=(CANVAS_WIDTH, CANVAS_HEIGHT))
|
| 181 |
+
final = final.set_duration(total_duration)
|
| 182 |
+
|
| 183 |
+
# Export
|
| 184 |
+
output_path = os.path.join(output_dir, "text_story_output.mp4")
|
| 185 |
+
|
| 186 |
+
final.write_videofile(
|
| 187 |
+
output_path,
|
| 188 |
+
fps=30,
|
| 189 |
+
codec="libx264",
|
| 190 |
+
audio_codec="aac",
|
| 191 |
+
preset="medium",
|
| 192 |
+
threads=4,
|
| 193 |
+
logger=None # Suppress MoviePy logs
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# Cleanup
|
| 197 |
+
final.close()
|
| 198 |
+
background.close()
|
| 199 |
+
for clip in clips:
|
| 200 |
+
clip.close()
|
| 201 |
+
|
| 202 |
+
logger.info(f"VideoComposer: Saved to {output_path}")
|
| 203 |
+
return output_path
|
| 204 |
+
|
| 205 |
+
def _create_ending_clip(self, text: str, messages: List[dict]) -> ImageClip:
|
| 206 |
+
"""Create ending text overlay."""
|
| 207 |
+
# Render current state
|
| 208 |
+
frame = self.renderer.render_frame(messages)
|
| 209 |
+
|
| 210 |
+
# Add ending text overlay
|
| 211 |
+
from PIL import ImageDraw, ImageFont
|
| 212 |
+
draw = ImageDraw.Draw(frame)
|
| 213 |
+
|
| 214 |
+
# Load font
|
| 215 |
+
try:
|
| 216 |
+
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 48)
|
| 217 |
+
except:
|
| 218 |
+
font = ImageFont.load_default()
|
| 219 |
+
|
| 220 |
+
# Calculate text position (center bottom area)
|
| 221 |
+
bbox = font.getbbox(text)
|
| 222 |
+
text_w = bbox[2] - bbox[0]
|
| 223 |
+
text_h = bbox[3] - bbox[1]
|
| 224 |
+
|
| 225 |
+
x = (CANVAS_WIDTH - text_w) // 2
|
| 226 |
+
y = CANVAS_HEIGHT - 300
|
| 227 |
+
|
| 228 |
+
# Draw text with shadow
|
| 229 |
+
draw.text((x + 2, y + 2), text, fill=(0, 0, 0, 200), font=font)
|
| 230 |
+
draw.text((x, y), text, fill=(255, 255, 255, 255), font=font)
|
| 231 |
+
|
| 232 |
+
# Convert to clip with fade
|
| 233 |
+
clip = self._pil_to_moviepy(frame, TIMING["ending_duration"])
|
| 234 |
+
clip = clip.fx(fadein, TIMING["ending_fade"])
|
| 235 |
+
|
| 236 |
+
return clip
|
requirements.txt
CHANGED
|
@@ -29,15 +29,3 @@ imageio-ffmpeg>=0.4.9
|
|
| 29 |
# Trends Analysis
|
| 30 |
pytrends
|
| 31 |
pandas
|
| 32 |
-
|
| 33 |
-
# Bar Race Module
|
| 34 |
-
bar_chart_race
|
| 35 |
-
ddgs
|
| 36 |
-
tavily-python
|
| 37 |
-
|
| 38 |
-
# Deep Researcher (LangGraph)
|
| 39 |
-
langchain-core
|
| 40 |
-
langchain-google-genai>=2.0.0,<2.1.0
|
| 41 |
-
langgraph
|
| 42 |
-
markdownify
|
| 43 |
-
duckduckgo-search
|
|
|
|
| 29 |
# Trends Analysis
|
| 30 |
pytrends
|
| 31 |
pandas
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static/index.html
CHANGED
|
@@ -279,8 +279,8 @@
|
|
| 279 |
<button class="tab-btn" data-tab="quiz">
|
| 280 |
π― Quiz Reel
|
| 281 |
</button>
|
| 282 |
-
<button class="tab-btn" data-tab="
|
| 283 |
-
|
| 284 |
</button>
|
| 285 |
</div>
|
| 286 |
|
|
@@ -646,38 +646,141 @@
|
|
| 646 |
</div>
|
| 647 |
</div>
|
| 648 |
|
| 649 |
-
<!--
|
| 650 |
-
<div id="
|
| 651 |
<div class="card">
|
| 652 |
-
<h2>
|
| 653 |
<p style="color: var(--text-secondary); margin-bottom: 1.5rem;">
|
| 654 |
-
Create
|
| 655 |
</p>
|
| 656 |
|
| 657 |
-
<
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 658 |
<div class="form-group">
|
| 659 |
-
<label>
|
| 660 |
-
<input type="text" id="
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 665 |
</div>
|
| 666 |
|
| 667 |
<div class="form-group">
|
| 668 |
-
<label>
|
| 669 |
-
<
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 675 |
</div>
|
| 676 |
|
| 677 |
-
<button type="submit" class="btn btn-primary" style="width: 100%;">
|
| 678 |
</form>
|
| 679 |
|
| 680 |
-
<div id="
|
| 681 |
</div>
|
| 682 |
</div>
|
| 683 |
|
|
@@ -1205,59 +1308,172 @@
|
|
| 1205 |
}
|
| 1206 |
}, 100); // End of setTimeout
|
| 1207 |
|
| 1208 |
-
//
|
| 1209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1210 |
e.preventDefault();
|
| 1211 |
-
const status = document.getElementById('
|
| 1212 |
status.className = 'status processing';
|
| 1213 |
-
status.
|
| 1214 |
|
| 1215 |
-
const
|
| 1216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1217 |
|
| 1218 |
try {
|
| 1219 |
-
const
|
| 1220 |
method: 'POST',
|
| 1221 |
headers: { 'Content-Type': 'application/json' },
|
| 1222 |
-
body: JSON.stringify(
|
| 1223 |
-
topic: topic,
|
| 1224 |
-
duration_seconds: duration
|
| 1225 |
-
})
|
| 1226 |
});
|
|
|
|
| 1227 |
|
| 1228 |
-
|
| 1229 |
-
|
| 1230 |
-
|
| 1231 |
-
|
| 1232 |
-
|
| 1233 |
-
|
|
|
|
| 1234 |
} catch (err) {
|
| 1235 |
status.className = 'status error';
|
| 1236 |
-
status.innerHTML = 'β ' + err.message;
|
| 1237 |
}
|
| 1238 |
});
|
| 1239 |
|
| 1240 |
-
|
| 1241 |
-
|
|
|
|
| 1242 |
const poll = async () => {
|
| 1243 |
try {
|
| 1244 |
-
const res = await fetch(`/api/
|
| 1245 |
const data = await res.json();
|
| 1246 |
|
| 1247 |
if (data.status === 'ready') {
|
| 1248 |
status.className = 'status success';
|
| 1249 |
-
status.innerHTML = `β
Video ready! <a href="${data.video_url}" target="_blank" style="
|
| 1250 |
} else if (data.status === 'failed') {
|
| 1251 |
status.className = 'status error';
|
| 1252 |
status.innerHTML = 'β Failed: ' + (data.error || 'Unknown error');
|
| 1253 |
} else {
|
| 1254 |
-
const step = data.current_step ||
|
| 1255 |
status.innerHTML = `β³ ${step} (${data.progress}%)`;
|
| 1256 |
setTimeout(poll, 2000);
|
| 1257 |
}
|
| 1258 |
} catch (err) {
|
| 1259 |
-
|
| 1260 |
-
status.innerHTML = 'β Status check failed: ' + err.message;
|
| 1261 |
}
|
| 1262 |
};
|
| 1263 |
poll();
|
|
|
|
| 279 |
<button class="tab-btn" data-tab="quiz">
|
| 280 |
π― Quiz Reel
|
| 281 |
</button>
|
| 282 |
+
<button class="tab-btn" data-tab="textstory">
|
| 283 |
+
π± Text Story
|
| 284 |
</button>
|
| 285 |
</div>
|
| 286 |
|
|
|
|
| 646 |
</div>
|
| 647 |
</div>
|
| 648 |
|
| 649 |
+
<!-- Text Story Tab -->
|
| 650 |
+
<div id="textstory-tab" class="tab-content">
|
| 651 |
<div class="card">
|
| 652 |
+
<h2>π± Fake Text Story Generator</h2>
|
| 653 |
<p style="color: var(--text-secondary); margin-bottom: 1.5rem;">
|
| 654 |
+
Create viral iMessage-style fake conversation videos with AI voice
|
| 655 |
</p>
|
| 656 |
|
| 657 |
+
<!-- Mode Toggle -->
|
| 658 |
+
<div class="form-group" style="margin-bottom: 1.5rem;">
|
| 659 |
+
<label>Mode</label>
|
| 660 |
+
<div style="display: flex; gap: 1rem;">
|
| 661 |
+
<label style="display: flex; align-items: center; gap: 0.5rem; cursor: pointer;">
|
| 662 |
+
<input type="radio" name="tsMode" value="manual" checked onchange="toggleTsMode()">
|
| 663 |
+
βοΈ Manual
|
| 664 |
+
</label>
|
| 665 |
+
<label style="display: flex; align-items: center; gap: 0.5rem; cursor: pointer;">
|
| 666 |
+
<input type="radio" name="tsMode" value="ai" onchange="toggleTsMode()">
|
| 667 |
+
π€ AI Generate
|
| 668 |
+
</label>
|
| 669 |
+
</div>
|
| 670 |
+
</div>
|
| 671 |
+
|
| 672 |
+
<form id="textStoryForm">
|
| 673 |
+
<!-- AI Mode Section (Hidden by default) -->
|
| 674 |
+
<div id="tsAiSection" style="display: none;">
|
| 675 |
+
<div class="form-group"
|
| 676 |
+
style="background: linear-gradient(135deg, rgba(99,102,241,0.1), rgba(168,85,247,0.1)); padding: 1.5rem; border-radius: 12px; margin-bottom: 1rem;">
|
| 677 |
+
<label>π€ AI Prompt - Describe the conversation</label>
|
| 678 |
+
<textarea id="tsAiPrompt" rows="3"
|
| 679 |
+
placeholder="e.g., A breakup conversation where the ex wants to get back together but gets rejected. Emotional and dramatic. End with a plot twist."></textarea>
|
| 680 |
+
<small style="color: var(--text-secondary); display: block; margin-top: 0.5rem;">
|
| 681 |
+
AI will generate a realistic conversation based on your prompt
|
| 682 |
+
</small>
|
| 683 |
+
</div>
|
| 684 |
+
|
| 685 |
+
<div class="form-row">
|
| 686 |
+
<div class="form-group">
|
| 687 |
+
<label>Number of Messages</label>
|
| 688 |
+
<select id="tsAiMsgCount">
|
| 689 |
+
<option value="5">5 messages</option>
|
| 690 |
+
<option value="7" selected>7 messages</option>
|
| 691 |
+
<option value="10">10 messages</option>
|
| 692 |
+
<option value="15">15 messages</option>
|
| 693 |
+
</select>
|
| 694 |
+
</div>
|
| 695 |
+
<div class="form-group">
|
| 696 |
+
<label>Conversation Tone</label>
|
| 697 |
+
<select id="tsAiTone">
|
| 698 |
+
<option value="emotional">Emotional / Dramatic</option>
|
| 699 |
+
<option value="funny">Funny / Comedy</option>
|
| 700 |
+
<option value="shocking">Shocking / Twist</option>
|
| 701 |
+
<option value="romantic">Romantic</option>
|
| 702 |
+
<option value="angry">Angry / Fight</option>
|
| 703 |
+
</select>
|
| 704 |
+
</div>
|
| 705 |
+
</div>
|
| 706 |
+
</div>
|
| 707 |
+
|
| 708 |
+
<!-- Common Fields -->
|
| 709 |
+
<div class="form-row">
|
| 710 |
+
<div class="form-group">
|
| 711 |
+
<label>Person A Name (You - Right/Blue)</label>
|
| 712 |
+
<input type="text" id="tsPersonA" value="You" placeholder="Your name">
|
| 713 |
+
</div>
|
| 714 |
+
<div class="form-group">
|
| 715 |
+
<label>Person B Name (Other - Left/Gray)</label>
|
| 716 |
+
<input type="text" id="tsPersonB" value="My Ex" placeholder="Other person name">
|
| 717 |
+
</div>
|
| 718 |
+
</div>
|
| 719 |
+
|
| 720 |
<div class="form-group">
|
| 721 |
+
<label>Person B Avatar (1 letter or emoji)</label>
|
| 722 |
+
<input type="text" id="tsAvatar" maxlength="2" placeholder="M" style="width: 80px;">
|
| 723 |
+
</div>
|
| 724 |
+
|
| 725 |
+
<!-- Manual Mode Section -->
|
| 726 |
+
<div id="tsManualSection">
|
| 727 |
+
<div id="tsMessagesContainer">
|
| 728 |
+
<div class="ts-message-item"
|
| 729 |
+
style="background: var(--bg-secondary); padding: 1rem; border-radius: 8px; margin-bottom: 0.5rem;">
|
| 730 |
+
<div class="form-row" style="align-items: flex-end;">
|
| 731 |
+
<div class="form-group" style="flex: 0 0 100px;">
|
| 732 |
+
<label>Sender</label>
|
| 733 |
+
<select class="ts-sender">
|
| 734 |
+
<option value="B">B (Other)</option>
|
| 735 |
+
<option value="A">A (You)</option>
|
| 736 |
+
</select>
|
| 737 |
+
</div>
|
| 738 |
+
<div class="form-group" style="flex: 1;">
|
| 739 |
+
<label>Message</label>
|
| 740 |
+
<input type="text" class="ts-text" placeholder="Type message...">
|
| 741 |
+
</div>
|
| 742 |
+
<button type="button" class="btn btn-secondary ts-remove"
|
| 743 |
+
style="height: 42px;">β</button>
|
| 744 |
+
</div>
|
| 745 |
+
</div>
|
| 746 |
+
</div>
|
| 747 |
+
|
| 748 |
+
<button type="button" id="tsAddMessage" class="btn btn-secondary"
|
| 749 |
+
style="width: 100%; margin-bottom: 1rem;">
|
| 750 |
+
β Add Message
|
| 751 |
+
</button>
|
| 752 |
</div>
|
| 753 |
|
| 754 |
<div class="form-group">
|
| 755 |
+
<label>Ending Text (Optional)</label>
|
| 756 |
+
<input type="text" id="tsEnding" placeholder="e.g., To be continued...">
|
| 757 |
+
</div>
|
| 758 |
+
|
| 759 |
+
<div class="form-row">
|
| 760 |
+
<div class="form-group">
|
| 761 |
+
<label>Voice A (You)</label>
|
| 762 |
+
<select id="tsVoiceA">
|
| 763 |
+
<option value="af_heart">Female - Heart</option>
|
| 764 |
+
<option value="af_bella">Female - Bella</option>
|
| 765 |
+
<option value="am_fenrir">Male - Fenrir</option>
|
| 766 |
+
<option value="am_michael">Male - Michael</option>
|
| 767 |
+
</select>
|
| 768 |
+
</div>
|
| 769 |
+
<div class="form-group">
|
| 770 |
+
<label>Voice B (Other)</label>
|
| 771 |
+
<select id="tsVoiceB">
|
| 772 |
+
<option value="am_fenrir">Male - Fenrir</option>
|
| 773 |
+
<option value="am_michael">Male - Michael</option>
|
| 774 |
+
<option value="af_heart">Female - Heart</option>
|
| 775 |
+
<option value="af_bella">Female - Bella</option>
|
| 776 |
+
</select>
|
| 777 |
+
</div>
|
| 778 |
</div>
|
| 779 |
|
| 780 |
+
<button type="submit" class="btn btn-primary" style="width: 100%;">π± Generate Text Story Video</button>
|
| 781 |
</form>
|
| 782 |
|
| 783 |
+
<div id="textStoryStatus" class="status hidden"></div>
|
| 784 |
</div>
|
| 785 |
</div>
|
| 786 |
|
|
|
|
| 1308 |
}
|
| 1309 |
}, 100); // End of setTimeout
|
| 1310 |
|
| 1311 |
+
// ==========================================
|
| 1312 |
+
// TEXT STORY MODULE
|
| 1313 |
+
// ==========================================
|
| 1314 |
+
|
| 1315 |
+
// Add message row
|
| 1316 |
+
document.getElementById('tsAddMessage').addEventListener('click', () => {
|
| 1317 |
+
const container = document.getElementById('tsMessagesContainer');
|
| 1318 |
+
const count = container.querySelectorAll('.ts-message-item').length + 1;
|
| 1319 |
+
const html = `
|
| 1320 |
+
<div class="ts-message-item" style="background: var(--bg-secondary); padding: 1rem; border-radius: 8px; margin-bottom: 0.5rem;">
|
| 1321 |
+
<div class="form-row" style="align-items: flex-end;">
|
| 1322 |
+
<div class="form-group" style="flex: 0 0 100px;">
|
| 1323 |
+
<label>Sender</label>
|
| 1324 |
+
<select class="ts-sender">
|
| 1325 |
+
<option value="B">B (Other)</option>
|
| 1326 |
+
<option value="A">A (You)</option>
|
| 1327 |
+
</select>
|
| 1328 |
+
</div>
|
| 1329 |
+
<div class="form-group" style="flex: 1;">
|
| 1330 |
+
<label>Message ${count}</label>
|
| 1331 |
+
<input type="text" class="ts-text" placeholder="Type message..." required>
|
| 1332 |
+
</div>
|
| 1333 |
+
<button type="button" class="btn btn-secondary ts-remove" style="height: 42px;">β</button>
|
| 1334 |
+
</div>
|
| 1335 |
+
</div>
|
| 1336 |
+
`;
|
| 1337 |
+
container.insertAdjacentHTML('beforeend', html);
|
| 1338 |
+
});
|
| 1339 |
+
|
| 1340 |
+
// Remove message row
|
| 1341 |
+
document.getElementById('tsMessagesContainer').addEventListener('click', (e) => {
|
| 1342 |
+
if (e.target.classList.contains('ts-remove')) {
|
| 1343 |
+
const items = document.querySelectorAll('.ts-message-item');
|
| 1344 |
+
if (items.length > 1) {
|
| 1345 |
+
e.target.closest('.ts-message-item').remove();
|
| 1346 |
+
}
|
| 1347 |
+
}
|
| 1348 |
+
});
|
| 1349 |
+
|
| 1350 |
+
// Toggle Manual/AI mode
|
| 1351 |
+
function toggleTsMode() {
|
| 1352 |
+
const mode = document.querySelector('input[name="tsMode"]:checked').value;
|
| 1353 |
+
document.getElementById('tsAiSection').style.display = mode === 'ai' ? 'block' : 'none';
|
| 1354 |
+
document.getElementById('tsManualSection').style.display = mode === 'manual' ? 'block' : 'none';
|
| 1355 |
+
}
|
| 1356 |
+
|
| 1357 |
+
// Form submit
|
| 1358 |
+
document.getElementById('textStoryForm').addEventListener('submit', async (e) => {
|
| 1359 |
e.preventDefault();
|
| 1360 |
+
const status = document.getElementById('textStoryStatus');
|
| 1361 |
status.className = 'status processing';
|
| 1362 |
+
status.classList.remove('hidden');
|
| 1363 |
|
| 1364 |
+
const mode = document.querySelector('input[name="tsMode"]:checked').value;
|
| 1365 |
+
let messages = [];
|
| 1366 |
+
|
| 1367 |
+
if (mode === 'ai') {
|
| 1368 |
+
// AI Mode - Generate conversation first
|
| 1369 |
+
const prompt = document.getElementById('tsAiPrompt').value.trim();
|
| 1370 |
+
if (!prompt) {
|
| 1371 |
+
status.className = 'status error';
|
| 1372 |
+
status.innerHTML = 'β Please enter a prompt for AI!';
|
| 1373 |
+
return;
|
| 1374 |
+
}
|
| 1375 |
+
|
| 1376 |
+
status.innerHTML = 'π€ AI generating conversation...';
|
| 1377 |
+
|
| 1378 |
+
try {
|
| 1379 |
+
const aiRes = await fetch('/api/text-story/ai-generate', {
|
| 1380 |
+
method: 'POST',
|
| 1381 |
+
headers: { 'Content-Type': 'application/json' },
|
| 1382 |
+
body: JSON.stringify({
|
| 1383 |
+
prompt: prompt,
|
| 1384 |
+
person_a_name: document.getElementById('tsPersonA').value || 'You',
|
| 1385 |
+
person_b_name: document.getElementById('tsPersonB').value || 'My Ex',
|
| 1386 |
+
message_count: parseInt(document.getElementById('tsAiMsgCount').value),
|
| 1387 |
+
tone: document.getElementById('tsAiTone').value
|
| 1388 |
+
})
|
| 1389 |
+
});
|
| 1390 |
+
const aiData = await aiRes.json();
|
| 1391 |
+
|
| 1392 |
+
if (aiData.messages) {
|
| 1393 |
+
messages = aiData.messages;
|
| 1394 |
+
status.innerHTML = `π€ Generated ${messages.length} messages. Now creating video...`;
|
| 1395 |
+
} else {
|
| 1396 |
+
status.className = 'status error';
|
| 1397 |
+
status.innerHTML = 'β AI failed: ' + (aiData.detail || 'Unknown error');
|
| 1398 |
+
return;
|
| 1399 |
+
}
|
| 1400 |
+
} catch (err) {
|
| 1401 |
+
status.className = 'status error';
|
| 1402 |
+
status.innerHTML = 'β AI Error: ' + err.message;
|
| 1403 |
+
return;
|
| 1404 |
+
}
|
| 1405 |
+
} else {
|
| 1406 |
+
// Manual Mode - Collect messages from form
|
| 1407 |
+
const messageItems = document.querySelectorAll('.ts-message-item');
|
| 1408 |
+
messageItems.forEach(item => {
|
| 1409 |
+
const sender = item.querySelector('.ts-sender').value;
|
| 1410 |
+
const text = item.querySelector('.ts-text').value.trim();
|
| 1411 |
+
if (text) {
|
| 1412 |
+
messages.push({ sender, text });
|
| 1413 |
+
}
|
| 1414 |
+
});
|
| 1415 |
+
|
| 1416 |
+
if (messages.length < 2) {
|
| 1417 |
+
status.className = 'status error';
|
| 1418 |
+
status.innerHTML = 'β Need at least 2 messages!';
|
| 1419 |
+
return;
|
| 1420 |
+
}
|
| 1421 |
+
}
|
| 1422 |
+
|
| 1423 |
+
status.innerHTML = 'β³ Starting video generation...';
|
| 1424 |
+
|
| 1425 |
+
const data = {
|
| 1426 |
+
person_a_name: document.getElementById('tsPersonA').value || 'You',
|
| 1427 |
+
person_b_name: document.getElementById('tsPersonB').value || 'My Ex',
|
| 1428 |
+
person_b_avatar: document.getElementById('tsAvatar').value || null,
|
| 1429 |
+
messages: messages,
|
| 1430 |
+
ending_text: document.getElementById('tsEnding').value || null,
|
| 1431 |
+
voice_a: document.getElementById('tsVoiceA').value,
|
| 1432 |
+
voice_b: document.getElementById('tsVoiceB').value
|
| 1433 |
+
};
|
| 1434 |
|
| 1435 |
try {
|
| 1436 |
+
const res = await fetch('/api/text-story/generate', {
|
| 1437 |
method: 'POST',
|
| 1438 |
headers: { 'Content-Type': 'application/json' },
|
| 1439 |
+
body: JSON.stringify(data)
|
|
|
|
|
|
|
|
|
|
| 1440 |
});
|
| 1441 |
+
const result = await res.json();
|
| 1442 |
|
| 1443 |
+
if (result.job_id) {
|
| 1444 |
+
status.innerHTML = `β³ Job started: ${result.job_id}`;
|
| 1445 |
+
pollTextStoryStatus(result.job_id);
|
| 1446 |
+
} else {
|
| 1447 |
+
status.className = 'status error';
|
| 1448 |
+
status.innerHTML = `β Error: ${result.detail || 'Failed to start'}`;
|
| 1449 |
+
}
|
| 1450 |
} catch (err) {
|
| 1451 |
status.className = 'status error';
|
| 1452 |
+
status.innerHTML = 'β Error: ' + err.message;
|
| 1453 |
}
|
| 1454 |
});
|
| 1455 |
|
| 1456 |
+
// Poll status
|
| 1457 |
+
async function pollTextStoryStatus(jobId) {
|
| 1458 |
+
const status = document.getElementById('textStoryStatus');
|
| 1459 |
const poll = async () => {
|
| 1460 |
try {
|
| 1461 |
+
const res = await fetch(`/api/text-story/${jobId}/status`);
|
| 1462 |
const data = await res.json();
|
| 1463 |
|
| 1464 |
if (data.status === 'ready') {
|
| 1465 |
status.className = 'status success';
|
| 1466 |
+
status.innerHTML = `β
Video ready! <a href="${data.video_url}" target="_blank" class="btn btn-primary" style="margin-left: 1rem;">π₯ Download</a>`;
|
| 1467 |
} else if (data.status === 'failed') {
|
| 1468 |
status.className = 'status error';
|
| 1469 |
status.innerHTML = 'β Failed: ' + (data.error || 'Unknown error');
|
| 1470 |
} else {
|
| 1471 |
+
const step = data.current_step || 'Processing';
|
| 1472 |
status.innerHTML = `β³ ${step} (${data.progress}%)`;
|
| 1473 |
setTimeout(poll, 2000);
|
| 1474 |
}
|
| 1475 |
} catch (err) {
|
| 1476 |
+
setTimeout(poll, 3000);
|
|
|
|
| 1477 |
}
|
| 1478 |
};
|
| 1479 |
poll();
|