| from __future__ import annotations
|
|
|
| import json
|
| import logging
|
| from pathlib import Path
|
| from typing import TypedDict, Dict, Union, List
|
|
|
| from langgraph.graph import StateGraph, END
|
| from langchain_openai import ChatOpenAI
|
| from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
|
|
| from .reflection import (
|
| DEMO_TEMP_DIR,
|
| DEMO_DATA_DIR,
|
| TEMP_DIR,
|
| _load_audiodescription_from_db,
|
| _write_casting_csv_from_db,
|
| _write_scenarios_csv_from_db,
|
| )
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
| class MultiReflectionState(TypedDict):
|
| iteration: int
|
| current_srt_path: str
|
| critic_report: Dict[str, Union[float, str]]
|
| history: List[SystemMessage]
|
|
|
|
|
|
|
| _llm_ma = ChatOpenAI(model="gpt-4o-mini", temperature=0.2)
|
|
|
|
|
| def _read_text(path: Path) -> str:
|
| try:
|
| return path.read_text(encoding="utf-8")
|
| except Exception:
|
| return ""
|
|
|
|
|
| def _load_casting_for_sha1(sha1sum: str) -> str:
|
| db_path = DEMO_DATA_DIR / "casting.db"
|
| if not db_path.exists():
|
| return ""
|
| import sqlite3
|
|
|
| conn = sqlite3.connect(str(db_path))
|
| conn.row_factory = sqlite3.Row
|
| try:
|
| cur = conn.cursor()
|
| cur.execute("SELECT name, description FROM casting WHERE sha1sum=?", (sha1sum,))
|
| rows = cur.fetchall()
|
| if not rows:
|
| return ""
|
| data = [dict(r) for r in rows]
|
| return json.dumps(data, ensure_ascii=False, indent=2)
|
| finally:
|
| conn.close()
|
|
|
|
|
| def _load_scenarios_for_sha1(sha1sum: str) -> str:
|
| db_path = DEMO_DATA_DIR / "scenarios.db"
|
| if not db_path.exists():
|
| return ""
|
| import sqlite3
|
|
|
| conn = sqlite3.connect(str(db_path))
|
| conn.row_factory = sqlite3.Row
|
| try:
|
| cur = conn.cursor()
|
| cur.execute("SELECT name, description FROM scenarios WHERE sha1sum=?", (sha1sum,))
|
| rows = cur.fetchall()
|
| if not rows:
|
| return ""
|
| data = [dict(r) for r in rows]
|
| return json.dumps(data, ensure_ascii=False, indent=2)
|
| finally:
|
| conn.close()
|
|
|
|
|
| def narrator_initial(state: MultiReflectionState) -> MultiReflectionState:
|
| """Primer pas del narrador: pren l'SRT inicial tal qual.
|
|
|
| En aquest pipeline assumim que l'entrada ja és un SRT UNE inicial.
|
| """
|
|
|
| current_path = Path(state["current_srt_path"])
|
| if not current_path.exists():
|
| logger.warning("[reflection_ma] SRT inicial no trobat a %s", current_path)
|
| content = ""
|
| else:
|
| content = _read_text(current_path)
|
|
|
| history = state["history"] + [AIMessage(content="Narrador inicial: SRT de partida carregat.")]
|
| return {
|
| "iteration": state["iteration"],
|
| "current_srt_path": str(current_path),
|
| "critic_report": state.get("critic_report", {}),
|
| "history": history,
|
| }
|
|
|
|
|
| def identity_manager_agent(state: MultiReflectionState, *, sha1sum: str, info_ad: str) -> MultiReflectionState:
|
| """Agent que revisa identitats/personatges a partir del casting i info_ad."""
|
|
|
| srt_path = Path(state["current_srt_path"])
|
| srt_content = _read_text(srt_path)
|
| casting_json = _load_casting_for_sha1(sha1sum)
|
|
|
| prompt = (
|
| "Ets un gestor d'identitats per audiodescripcions. Se't proporciona un SRT "
|
| "i informació de casting (personatges) i un JSON de context (info_ad). "
|
| "La teva tasca és revisar si els noms i rols dels personatges al SRT són "
|
| "coherents amb el casting i el context. Si cal, corregeix els noms/rols "
|
| "perquè siguin consistents. Mantén el format SRT i retorna únicament el SRT modificat."
|
| )
|
|
|
| content = {
|
| "srt": srt_content,
|
| "casting": json.loads(casting_json) if casting_json else [],
|
| "info_ad": json.loads(info_ad) if info_ad else {},
|
| }
|
|
|
| resp = _llm_ma.invoke(
|
| [
|
| SystemMessage(content=prompt),
|
| HumanMessage(content=json.dumps(content, ensure_ascii=False)),
|
| ]
|
| )
|
|
|
| new_srt = resp.content if isinstance(resp.content, str) else str(resp.content)
|
| new_path = TEMP_DIR / "une_ad_ma_identity.srt"
|
| new_path.write_text(new_srt, encoding="utf-8")
|
|
|
| history = state["history"] + [AIMessage(content="Identity manager: SRT actualitzat amb identitats coherents.")]
|
| return {
|
| "iteration": state["iteration"],
|
| "current_srt_path": str(new_path),
|
| "critic_report": state.get("critic_report", {}),
|
| "history": history,
|
| }
|
|
|
|
|
| def background_descriptor_agent(state: MultiReflectionState, *, sha1sum: str) -> MultiReflectionState:
|
| """Agent que revisa la descripció d'escenaris a partir de scenarios.db."""
|
|
|
| srt_path = Path(state["current_srt_path"])
|
| srt_content = _read_text(srt_path)
|
| scenarios_json = _load_scenarios_for_sha1(sha1sum)
|
|
|
| prompt = (
|
| "Ets un expert en escenaris per audiodescripcions. Se't proporciona un SRT "
|
| "i una llista d'escenaris amb noms oficials. La teva tasca és revisar les "
|
| "descripcions de llocs al SRT i substituir referències genèriques per aquests "
|
| "noms quan millorin la claredat, sense afegir informació inventada. Mantén el "
|
| "format SRT i retorna únicament el SRT actualitzat."
|
| )
|
|
|
| content = {
|
| "srt": srt_content,
|
| "scenarios": json.loads(scenarios_json) if scenarios_json else [],
|
| }
|
|
|
| resp = _llm_ma.invoke(
|
| [
|
| SystemMessage(content=prompt),
|
| HumanMessage(content=json.dumps(content, ensure_ascii=False)),
|
| ]
|
| )
|
|
|
| new_srt = resp.content if isinstance(resp.content, str) else str(resp.content)
|
| new_path = TEMP_DIR / "une_ad_ma_background.srt"
|
| new_path.write_text(new_srt, encoding="utf-8")
|
|
|
| history = state["history"] + [AIMessage(content="Background descriptor: SRT actualitzat amb escenaris contextualitzats.")]
|
| return {
|
| "iteration": state["iteration"],
|
| "current_srt_path": str(new_path),
|
| "critic_report": state.get("critic_report", {}),
|
| "history": history,
|
| }
|
|
|
|
|
| def narrator_refine_agent(state: MultiReflectionState, *, info_ad: str) -> MultiReflectionState:
|
| """Segon pas del narrador: reescriu el SRT tenint en compte identitats i escenaris."""
|
|
|
| srt_path = Path(state["current_srt_path"])
|
| srt_content = _read_text(srt_path)
|
|
|
| prompt = (
|
| "Ets un Narrador d'audiodescripció UNE-153010. Has rebut un SRT on ja s'han "
|
| "revisat les identitats dels personatges i els escenaris. La teva tasca és "
|
| "refinar el text d'audiodescripció perquè sigui clar, coherent i ajustat al "
|
| "temps disponible, mantenint el format SRT i sense alterar els diàlegs. "
|
| "Retorna únicament el SRT final."
|
| )
|
|
|
| content = {
|
| "srt": srt_content,
|
| "info_ad": json.loads(info_ad) if info_ad else {},
|
| }
|
|
|
| resp = _llm_ma.invoke(
|
| [
|
| SystemMessage(content=prompt),
|
| HumanMessage(content=json.dumps(content, ensure_ascii=False)),
|
| ]
|
| )
|
|
|
| new_srt = resp.content if isinstance(resp.content, str) else str(resp.content)
|
| new_path = TEMP_DIR / "une_ad_ma_final.srt"
|
| new_path.write_text(new_srt, encoding="utf-8")
|
|
|
| history = state["history"] + [AIMessage(content="Narrador: SRT refinat després de gestió d'identitats i escenaris.")]
|
| return {
|
| "iteration": state["iteration"] + 1,
|
| "current_srt_path": str(new_path),
|
| "critic_report": state.get("critic_report", {}),
|
| "history": history,
|
| }
|
|
|
|
|
| def critic_agent(state: MultiReflectionState) -> MultiReflectionState:
|
| """Agent que avalua qualitativament el SRT final.
|
|
|
| Per simplicitat, aquí no generem CSV ni mitjanes ponderades; només un resum.
|
| """
|
|
|
| srt_path = Path(state["current_srt_path"])
|
| srt_content = _read_text(srt_path)
|
|
|
| prompt = (
|
| "Ets un crític d'audiodescripcions UNE-153010. Avalua breument la qualitat "
|
| "del SRT proporcionat en termes de precisió descriptiva, sincronització "
|
| "temporal, claredat i adequació dels noms de personatges i escenaris. "
|
| "Retorna un text breu en català amb la teva valoració general."
|
| )
|
|
|
| resp = _llm_ma.invoke(
|
| [
|
| SystemMessage(content=prompt),
|
| HumanMessage(content=srt_content),
|
| ]
|
| )
|
|
|
| critique = resp.content if isinstance(resp.content, str) else str(resp.content)
|
| report: Dict[str, Union[float, str]] = {
|
| "qualitative_critique": critique,
|
| }
|
|
|
| history = state["history"] + [AIMessage(content="Crític: valoració final generada.")]
|
| return {
|
| "iteration": state["iteration"],
|
| "current_srt_path": state["current_srt_path"],
|
| "critic_report": report,
|
| "history": history,
|
| }
|
|
|
|
|
|
|
| _graph = StateGraph(MultiReflectionState)
|
| _graph.add_node("narrator_initial", narrator_initial)
|
| _graph.add_node("identity_manager", lambda s: identity_manager_agent(s, sha1sum=_graph.sha1sum, info_ad=_graph.info_ad))
|
| _graph.add_node("background_descriptor", lambda s: background_descriptor_agent(s, sha1sum=_graph.sha1sum))
|
| _graph.add_node("narrator_refine", lambda s: narrator_refine_agent(s, info_ad=_graph.info_ad))
|
| _graph.add_node("critic", critic_agent)
|
|
|
| _graph.set_entry_point("narrator_initial")
|
| _graph.add_edge("narrator_initial", "identity_manager")
|
| _graph.add_edge("identity_manager", "background_descriptor")
|
| _graph.add_edge("background_descriptor", "narrator_refine")
|
| _graph.add_edge("narrator_refine", "critic")
|
| _graph.add_edge("critic", END)
|
|
|
|
|
| def _compile_app(sha1sum: str, info_ad: str):
|
| """Compila una instància de l'app de LangGraph amb paràmetres de vídeo."""
|
|
|
|
|
| _graph.sha1sum = sha1sum
|
| _graph.info_ad = info_ad
|
| return _graph.compile()
|
|
|
|
|
| def refine_video_with_reflection_ma(sha1sum: str, version: str) -> str:
|
| """Refina un vídeo (sha1sum, version) amb el pipeline multiagent de 4 agents.
|
|
|
| - Llegeix une_ad i info_ad de audiodescriptions.db (demo/temp).
|
| - Llegeix casting/scenarios per al mateix sha1sum.
|
| - Executa el pipeline narrator -> identity_manager -> background_descriptor -> narrator -> critic.
|
| - Retorna el SRT final generat.
|
| """
|
|
|
| une_ad, info_ad = _load_audiodescription_from_db(sha1sum, version)
|
|
|
|
|
| TEMP_DIR.mkdir(exist_ok=True, parents=True)
|
| initial_path = TEMP_DIR / "une_ad_ma_0.srt"
|
| initial_path.write_text(une_ad or "", encoding="utf-8")
|
|
|
| app = _compile_app(sha1sum, info_ad or "")
|
| initial_state: MultiReflectionState = {
|
| "iteration": 0,
|
| "current_srt_path": str(initial_path),
|
| "critic_report": {},
|
| "history": [],
|
| }
|
|
|
| final_state = app.invoke(initial_state)
|
| final_path = Path(final_state["current_srt_path"])
|
| return _read_text(final_path)
|
|
|
|
|
| def refine_srt_with_reflection_ma(srt_content: str) -> str:
|
| """Variant simplificada que només rep un SRT (sense info de BD).
|
|
|
| Es limita a fer passar el SRT pel pipeline d'identitat/escenaris sense mirar casting/scenarios/info_ad.
|
| Útil per a proves unitàries.
|
| """
|
|
|
| TEMP_DIR.mkdir(exist_ok=True, parents=True)
|
| initial_path = TEMP_DIR / "une_ad_ma_0.srt"
|
| initial_path.write_text(srt_content or "", encoding="utf-8")
|
|
|
|
|
| app = _compile_app(sha1sum="", info_ad="{}")
|
| initial_state: MultiReflectionState = {
|
| "iteration": 0,
|
| "current_srt_path": str(initial_path),
|
| "critic_report": {},
|
| "history": [],
|
| }
|
|
|
| final_state = app.invoke(initial_state)
|
| final_path = Path(final_state["current_srt_path"])
|
| return _read_text(final_path)
|
|
|