ClauseGuard-AI / clauseguard /services /model_service.py
muhammadbinmurtza
Restructure: clauseguard as package subfolder, app_file: clauseguard/app.py
913a064
"""Model service layer β€” unified Qwen/vLLM inference via OpenAI-compatible API.
Provides a single shared client and reusable inference functions for all
ClauseGuard agents and the copilot. Handles retries, timeouts, JSON cleaning,
and graceful error recovery.
"""
from __future__ import annotations
import asyncio
import json
import logging
from typing import Any, Dict, List
from openai import AsyncOpenAI, OpenAI
from clauseguard.config.settings import (
API_KEY,
BASE_URL,
MAX_TOKENS,
MODEL_NAME,
TEMPERATURE,
TIMEOUT_SECONDS,
)
logger = logging.getLogger(__name__)
_async_client: AsyncOpenAI | None = None
_sync_client: OpenAI | None = None
def get_client() -> AsyncOpenAI:
"""Return the shared AsyncOpenAI client (lazy singleton)."""
global _async_client
if _async_client is None:
_async_client = AsyncOpenAI(api_key=API_KEY, base_url=BASE_URL)
return _async_client
def get_sync_client() -> OpenAI:
"""Return the shared synchronous OpenAI client (lazy singleton)."""
global _sync_client
if _sync_client is None:
_sync_client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
return _sync_client
def reset_client() -> None:
"""Reset the shared clients β€” useful for testing or config changes."""
global _async_client, _sync_client
_async_client = None
_sync_client = None
def clean_json_response(content: str) -> str:
"""Strip markdown fences and leading/trailing non-JSON text from LLM output."""
content = content.strip()
if content.startswith("```json"):
content = content[7:]
elif content.startswith("```"):
content = content[3:]
if content.endswith("```"):
content = content[:-3]
return content.strip()
async def call_model(
system_prompt: str,
user_prompt: str,
*,
agent_name: str = "Agent",
temperature: float | None = None,
max_tokens: int | None = None,
timeout: int | None = None,
max_retries: int = 1,
validate_json: bool = True,
) -> str | None:
"""Call the Qwen model with retry, timeout, and JSON validation.
Args:
system_prompt: The system-level instruction.
user_prompt: The user-level query.
agent_name: Label used in log messages.
temperature: Sampling temperature (defaults to config TEMPERATURE).
max_tokens: Max tokens for the response (defaults to config MAX_TOKENS).
timeout: Per-call timeout in seconds (defaults to config TIMEOUT_SECONDS).
max_retries: Number of additional retries on JSON parse failure.
validate_json: Whether to validate the response as valid JSON.
Returns:
The model's raw text response, or None if all attempts fail.
"""
client = get_client()
temp = temperature if temperature is not None else TEMPERATURE
mt = max_tokens if max_tokens is not None else MAX_TOKENS
tout = timeout if timeout is not None else TIMEOUT_SECONDS
last_error: str | None = None
for attempt in range(max_retries + 1):
try:
response = await asyncio.wait_for(
client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
temperature=temp,
max_tokens=mt,
),
timeout=tout,
)
content = response.choices[0].message.content or ""
logger.info("%s received %d chars in %d attempt(s)", agent_name, len(content), attempt + 1)
if validate_json:
cleaned = clean_json_response(content)
if not cleaned or not cleaned.strip():
raise ValueError("Empty response")
json.loads(cleaned)
logger.info("%s produced valid JSON", agent_name)
return content
except json.JSONDecodeError as e:
last_error = str(e)
preview = content[:200] if 'content' in dir() else "(no content)"
logger.warning("%s returned malformed JSON (attempt %d): %s | preview: %s", agent_name, attempt + 1, e, preview)
if attempt < max_retries:
logger.warning("%s returned malformed JSON, retrying...", agent_name)
user_prompt += "\n\nIMPORTANT: Output ONLY raw JSON. No markdown, no explanation."
except ValueError as e:
last_error = str(e)
if attempt < max_retries:
logger.warning("%s returned empty response, retrying...", agent_name)
except asyncio.TimeoutError:
logger.error("%s agent timed out after %ds", agent_name, tout)
return None
except Exception as e:
logger.error("%s agent failed: %s", agent_name, e)
return None
logger.error("%s failed to produce valid JSON: %s", agent_name, last_error)
return None
async def call_model_chat(
messages: List[Dict[str, str]],
*,
temperature: float | None = None,
max_tokens: int | None = None,
timeout: int = 60,
) -> str:
"""Call the Qwen model for chat (multi-turn conversation).
Args:
messages: Full message list (system + history + user).
temperature: Sampling temperature.
max_tokens: Max tokens for the response.
timeout: Per-call timeout in seconds.
Returns:
The assistant's text response, or a friendly error message.
"""
client = get_client()
temp = temperature if temperature is not None else TEMPERATURE
mt = max_tokens if max_tokens is not None else MAX_TOKENS
try:
response = await asyncio.wait_for(
client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
temperature=temp,
max_tokens=mt,
),
timeout=timeout,
)
content = response.choices[0].message.content
return content or "I'm sorry, I couldn't generate a response. Please try again."
except asyncio.TimeoutError:
logger.error("Chat call timed out after %ds", timeout)
return "I'm sorry, the request timed out. Please try a shorter question or try again."
except Exception as e:
logger.error("Chat call failed: %s", e)
return f"I'm sorry, something went wrong: {e}"
# ── Synchronous wrappers for use in Streamlit callbacks ──
def call_model_chat_sync(
messages: List[Dict[str, str]],
*,
temperature: float | None = None,
max_tokens: int | None = None,
timeout: int = 60,
) -> str:
"""Synchronous wrapper around call_model_chat for Streamlit callbacks."""
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(
call_model_chat(messages, temperature=temperature, max_tokens=max_tokens, timeout=timeout)
)
finally:
loop.close()
return result
except Exception as e:
logger.error("call_model_chat_sync failed: %s", e)
return f"Sorry, an unexpected error occurred: {e}"
# ── Higher-level domain functions ──
async def analyze_clause(
clause_text: str,
clause_type: str = "",
additional_context: str = "",
system_prompt: str = "",
user_prompt_template: str = "",
agent_name: str = "Analyzer",
) -> str | None:
"""Analyze a single clause β€” used by pipeline agents.
Args:
clause_text: The clause raw text to analyze.
clause_type: Optional pre-classified clause type.
additional_context: Additional context to append.
system_prompt: The agent-specific system prompt.
user_prompt_template: A template string for the user prompt.
agent_name: Label for logging.
Returns:
Raw response string or None.
"""
user_prompt = user_prompt_template.format(
clause_text=clause_text,
clause_type=clause_type,
context=additional_context,
) if user_prompt_template else clause_text
return await call_model(
system_prompt=system_prompt,
user_prompt=user_prompt,
agent_name=agent_name,
)
async def generate_negotiation_message(
clause_text: str,
risk_reason: str,
safer_version: str = "",
) -> str:
"""Generate a professional negotiation message for a risky clause."""
system = (
"You are a professional contract negotiator. Write a short, polite email "
"message requesting a change to a contract clause. Keep it professional, "
"concise, and non-confrontational. Maximum 4-5 sentences."
)
user = (
f"The risky clause is:\n\"{clause_text}\"\n\n"
f"Why it's risky:\n{risk_reason}\n\n"
)
if safer_version:
user += f"Suggested safer version:\n\"{safer_version}\"\n\n"
user += "Write a single email-style negotiation message requesting a fair revision."
result = await call_model(
system_prompt=system,
user_prompt=user,
agent_name="NegotiationGenerator",
validate_json=False,
)
return result or ""
async def contract_chat(
contract_context: str,
chat_history: List[Dict[str, str]],
user_message: str,
system_prompt: str,
timeout: int = 60,
) -> str:
"""Handle a contract chat conversation with full contract context.
Args:
contract_context: The formatted contract + analysis context.
chat_history: Previous messages (role/content dicts).
user_message: The user's new question.
system_prompt: The copilot system prompt.
timeout: Per-call timeout.
Returns:
Assistant response string.
"""
full_system = f"{system_prompt}\n\n---\n\n## CONTRACT CONTEXT\n\n{contract_context}"
messages: List[Dict[str, str]] = [{"role": "system", "content": full_system}]
for msg in chat_history:
messages.append({"role": msg["role"], "content": msg["content"]})
messages.append({"role": "user", "content": user_message})
return await call_model_chat(messages, timeout=timeout)