| """ |
| Service LLM pour l'intégration avec Groq et autres fournisseurs. |
| Gère les appels aux modèles de langage pour le résumé et l'analyse. |
| """ |
|
|
| import asyncio |
| import aiohttp |
| import json |
| from typing import List, Dict, Any, Optional, Union |
| from datetime import datetime |
| import time |
|
|
| from config.settings import api_config |
| from src.core.logging import setup_logger |
| import traceback |
|
|
|
|
| class LLMError(Exception): |
| """Exception pour les erreurs LLM.""" |
| pass |
|
|
|
|
| class LLMRateLimitError(LLMError): |
| """Exception pour les erreurs de limite de taux.""" |
| pass |
|
|
|
|
| class LLMService: |
| """ |
| Service pour les appels aux modèles de langage. |
| |
| Fonctionnalités: |
| - Support de Groq API |
| - Gestion des limites de taux |
| - Retry automatique avec backoff |
| - Streaming optionnel |
| - Validation des réponses |
| """ |
| |
| def __init__(self): |
| self.config = api_config |
| self.logger = setup_logger("llm_service") |
| |
| |
| self.groq_api_key = self.config.GROQ_API_KEY |
| self.groq_base_url = "https://api.groq.com/openai/v1" |
| self.default_model = getattr(self.config, 'GROQ_MODEL', "llama-3.1-8b-instant") |
| |
| |
| self.rate_limit_requests = 30 |
| self.rate_limit_tokens = 6000 |
| self.request_timestamps = [] |
| |
| |
| self.default_params = { |
| "temperature": 0.3, |
| "max_tokens": 2000, |
| "top_p": 0.9, |
| "frequency_penalty": 0.1, |
| "presence_penalty": 0.1 |
| } |
| |
| |
| self.headers = { |
| "Authorization": f"Bearer {self.groq_api_key}", |
| "Content-Type": "application/json" |
| } |
| |
| async def generate_completion( |
| self, |
| prompt: str, |
| system_prompt: Optional[str] = None, |
| model: Optional[str] = None, |
| **kwargs |
| ) -> str: |
| """ |
| Génère une complétion de texte. |
| |
| Args: |
| prompt: Prompt utilisateur |
| system_prompt: Prompt système optionnel |
| model: Modèle à utiliser (défaut: config) |
| **kwargs: Paramètres supplémentaires pour l'API |
| |
| Returns: |
| Réponse générée par le modèle |
| |
| Raises: |
| LLMError: En cas d'erreur API |
| LLMRateLimitError: En cas de dépassement de limite |
| """ |
| |
| messages = [] |
| if system_prompt: |
| messages.append({"role": "system", "content": system_prompt}) |
| messages.append({"role": "user", "content": prompt}) |
| |
| |
| params = {**self.default_params, **kwargs} |
| payload = { |
| "model": model or self.default_model, |
| "messages": messages, |
| **params |
| } |
| |
| |
| await self._check_rate_limits() |
| |
| |
| return await self._make_api_call(payload) |
| |
| async def generate_batch_completions( |
| self, |
| prompts: List[str], |
| system_prompt: Optional[str] = None, |
| model: Optional[str] = None, |
| max_concurrent: int = 3, |
| **kwargs |
| ) -> List[str]: |
| """ |
| Génère plusieurs complétions en parallèle. |
| |
| Args: |
| prompts: Liste des prompts |
| system_prompt: Prompt système optionnel |
| model: Modèle à utiliser |
| max_concurrent: Nombre maximum de requêtes simultanées |
| **kwargs: Paramètres supplémentaires |
| |
| Returns: |
| Liste des réponses dans le même ordre que les prompts |
| """ |
| self.logger.info(f"Génération batch de {len(prompts)} complétions") |
| |
| |
| semaphore = asyncio.Semaphore(max_concurrent) |
| |
| async def generate_single(prompt: str, index: int) -> tuple: |
| async with semaphore: |
| try: |
| |
| await asyncio.sleep(index * 0.5) |
| |
| result = await self.generate_completion( |
| prompt, system_prompt, model, **kwargs |
| ) |
| return index, result |
| except Exception as e: |
| self.logger.error(f"Erreur completion {index}: {e}") |
| return index, f"ERREUR: {str(e)}" |
| |
| |
| tasks = [generate_single(prompt, i) for i, prompt in enumerate(prompts)] |
| results = await asyncio.gather(*tasks, return_exceptions=True) |
| |
| |
| ordered_results = [""] * len(prompts) |
| for result in results: |
| if isinstance(result, tuple): |
| index, content = result |
| ordered_results[index] = content |
| else: |
| |
| ordered_results.append(f"EXCEPTION: {str(result)}") |
| |
| success_count = sum(1 for r in ordered_results if not r.startswith("ERREUR")) |
| self.logger.info(f"Batch terminé: {success_count}/{len(prompts)} succès") |
| |
| return ordered_results |
| |
| async def _make_api_call(self, payload: Dict[str, Any], max_retries: int = 3) -> str: |
| """Effectue l'appel API avec retry automatique.""" |
| url = f"{self.groq_base_url}/chat/completions" |
| |
| for attempt in range(max_retries + 1): |
| try: |
| async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=60)) as session: |
| async with session.post(url, json=payload, headers=self.headers) as response: |
| |
| |
| self.request_timestamps.append(time.time()) |
| |
| if response.status == 200: |
| data = await response.json() |
| content = data["choices"][0]["message"]["content"] |
| |
| |
| if not content or content.strip() == "": |
| raise LLMError("Réponse vide du modèle") |
| |
| return content.strip() |
| |
| elif response.status == 429: |
| |
| retry_after = int(response.headers.get("retry-after", 60)) |
| self.logger.warning(f"Rate limit atteint, attente {retry_after}s") |
| |
| if attempt < max_retries: |
| await asyncio.sleep(retry_after) |
| continue |
| else: |
| raise LLMRateLimitError("Limite de taux API dépassée") |
| |
| else: |
| |
| error_text = await response.text() |
| error_msg = f"Erreur API {response.status}: {error_text}" |
| |
| if attempt < max_retries: |
| self.logger.warning(f"{error_msg} - Tentative {attempt + 1}/{max_retries}") |
| await asyncio.sleep(2 ** attempt) |
| continue |
| else: |
| raise LLMError(error_msg) |
| |
| except asyncio.TimeoutError: |
| if attempt < max_retries: |
| self.logger.warning(f"Timeout API - Tentative {attempt + 1}/{max_retries}") |
| await asyncio.sleep(2 ** attempt) |
| continue |
| else: |
| raise LLMError("Timeout API après plusieurs tentatives") |
| |
| except Exception as e: |
| if attempt < max_retries: |
| self.logger.warning(f"Erreur réseau: {e} - Tentative {attempt + 1}/{max_retries}") |
| await asyncio.sleep(2 ** attempt) |
| continue |
| else: |
| raise LLMError(f"Erreur de connexion: {str(e)}") |
| |
| raise LLMError("Toutes les tentatives ont échoué") |
| |
| async def _check_rate_limits(self): |
| """Vérifie et applique les limites de taux.""" |
| current_time = time.time() |
| |
| |
| self.request_timestamps = [ |
| ts for ts in self.request_timestamps |
| if current_time - ts < 60 |
| ] |
| |
| |
| if len(self.request_timestamps) >= self.rate_limit_requests: |
| oldest_request = min(self.request_timestamps) |
| wait_time = 60 - (current_time - oldest_request) |
| |
| if wait_time > 0: |
| self.logger.info(f"Rate limit: attente {wait_time:.1f}s") |
| await asyncio.sleep(wait_time) |
| |
| def estimate_tokens(self, text: str) -> int: |
| """Estime le nombre de tokens dans un texte.""" |
| |
| return len(text) // 4 |
| |
| def validate_input_length(self, text: str, max_tokens: int = 6000) -> bool: |
| """Valide que le texte ne dépasse pas la limite de tokens.""" |
| estimated_tokens = self.estimate_tokens(text) |
| return estimated_tokens <= max_tokens |
| |
| def truncate_text(self, text: str, max_tokens: int = 6000) -> str: |
| """Tronque un texte pour respecter la limite de tokens.""" |
| estimated_tokens = self.estimate_tokens(text) |
| |
| if estimated_tokens <= max_tokens: |
| return text |
| |
| |
| ratio = max_tokens / estimated_tokens |
| target_length = int(len(text) * ratio * 0.9) |
| |
| |
| sentences = text.split('. ') |
| truncated = "" |
| |
| for sentence in sentences: |
| if len(truncated) + len(sentence) + 2 <= target_length: |
| truncated += sentence + ". " |
| else: |
| break |
| |
| self.logger.info(f"Texte tronqué: {len(text)} → {len(truncated)} caractères") |
| return truncated.strip() |
| |
| async def test_connection(self) -> bool: |
| """Teste la connexion à l'API.""" |
| try: |
| result = await self.generate_completion( |
| "Test de connexion. Réponds juste 'OK'.", |
| system_prompt="Tu es un assistant de test." |
| ) |
| |
| if "ok" in result.lower(): |
| self.logger.info("Test de connexion LLM réussi") |
| return True |
| else: |
| self.logger.warning(f"Test de connexion étrange: {result}") |
| return False |
| |
| except Exception as e: |
| self.logger.error(f"Test de connexion LLM échoué: {e}") |
| return False |
|
|
|
|
| class LLMManager: |
| """ |
| Gestionnaire de services LLM avec stratégies multiples. |
| """ |
| |
| def __init__(self): |
| self.logger = setup_logger("llm_manager") |
| self.primary_service = LLMService() |
| self.services = { |
| "groq": self.primary_service |
| } |
| |
| async def get_completion( |
| self, |
| prompt: str, |
| system_prompt: Optional[str] = None, |
| service: str = "groq", |
| **kwargs |
| ) -> str: |
| """ |
| Obtient une complétion en utilisant le service spécifié. |
| |
| Args: |
| prompt: Prompt utilisateur |
| system_prompt: Prompt système |
| service: Service LLM à utiliser |
| **kwargs: Paramètres supplémentaires |
| |
| Returns: |
| Réponse du modèle |
| """ |
| if service not in self.services: |
| raise ValueError(f"Service LLM inconnu: {service}") |
| |
| llm_service = self.services[service] |
| return await llm_service.generate_completion(prompt, system_prompt, **kwargs) |
| |
| async def get_batch_completions( |
| self, |
| prompts: List[str], |
| system_prompt: Optional[str] = None, |
| service: str = "groq", |
| **kwargs |
| ) -> List[str]: |
| """Obtient des complétions en batch.""" |
| if service not in self.services: |
| raise ValueError(f"Service LLM inconnu: {service}") |
| |
| llm_service = self.services[service] |
| return await llm_service.generate_batch_completions( |
| prompts, system_prompt, **kwargs |
| ) |
| |
| async def test_all_services(self) -> Dict[str, bool]: |
| """Teste tous les services LLM disponibles.""" |
| results = {} |
| |
| for name, service in self.services.items(): |
| try: |
| results[name] = await service.test_connection() |
| except Exception as e: |
| self.logger.error(f"Test service {name} échoué: {e}") |
| results[name] = False |
| |
| return results |
| |
| |
|
|
| async def example_usage(): |
| """Exemple d'utilisation du service LLM.""" |
| |
| |
| print("=== Test de connexion ===") |
| llm_service = LLMService() |
| |
| connection_ok = await llm_service.test_connection() |
| print(f"Connexion LLM: {'✓ OK' if connection_ok else '✗ Échec'}") |
| |
| if not connection_ok: |
| print("Impossible de continuer sans connexion") |
| return |
| |
| |
| print("\n=== Génération simple ===") |
| try: |
| response = await llm_service.generate_completion( |
| prompt="Explique-moi en 2 phrases ce qu'est l'intelligence artificielle.", |
| system_prompt="Tu es un expert en IA qui explique simplement." |
| ) |
| print(f"Réponse: {response}") |
| except Exception as e: |
| print(f"Erreur: {e}") |
| |
| |
| print("\n=== Génération avec paramètres ===") |
| try: |
| response = await llm_service.generate_completion( |
| prompt="Écris un haiku sur la technologie.", |
| system_prompt="Tu es un poète spécialisé dans les haikus.", |
| temperature=0.8, |
| max_tokens=100 |
| ) |
| print(f"Haiku: {response}") |
| except Exception as e: |
| print(f"Erreur: {e}") |
| |
| |
| print("\n=== Génération en batch ===") |
| prompts = [ |
| "Qu'est-ce que Python?", |
| "Qu'est-ce que JavaScript?", |
| "Qu'est-ce que Rust?" |
| ] |
| |
| try: |
| responses = await llm_service.generate_batch_completions( |
| prompts=prompts, |
| system_prompt="Réponds en une phrase courte.", |
| max_concurrent=2 |
| ) |
| |
| for i, (prompt, response) in enumerate(zip(prompts, responses)): |
| print(f"{i+1}. {prompt}") |
| print(f" → {response}\n") |
| except Exception as e: |
| print(f"Erreur batch: {e}") |
| |
| |
| print("\n=== Test des utilitaires ===") |
| long_text = "Ceci est un texte très long. " * 1000 |
| print(f"Texte original: {len(long_text)} caractères") |
| print(f"Tokens estimés: {llm_service.estimate_tokens(long_text)}") |
|
|
| is_valid = llm_service.validate_input_length(long_text, max_tokens=7000) |
| print(f"Texte valide (7000 tokens max): {is_valid}") |
| |
| if not is_valid: |
| truncated = llm_service.truncate_text(long_text, max_tokens=7000) |
| print(f"Texte tronqué: {len(truncated)} caractères") |
| print(f"Contenu: {truncated[:200]}...") |
|
|
| |
| async def example_manager_usage(): |
| """Exemple d'utilisation du gestionnaire LLM.""" |
| |
| print("\n=== Test du gestionnaire LLM ===") |
| |
| manager = LLMManager() |
| |
| |
| service_status = await manager.test_all_services() |
| print("État des services:") |
| for service, status in service_status.items(): |
| print(f" {service}: {'✓' if status else '✗'}") |
| |
| |
| try: |
| response = await manager.get_completion( |
| prompt="Salut! Comment ça va?", |
| system_prompt="Tu es un assistant amical.", |
| service="groq" |
| ) |
| print(f"\nRéponse du gestionnaire: {response}") |
| except Exception as e: |
| print(f"Erreur gestionnaire: {e}") |
|
|
| |
| async def main(): |
| """Fonction principale de test.""" |
| try: |
| await example_usage() |
| await example_manager_usage() |
| except KeyboardInterrupt: |
| print("\n\nTest interrompu par l'utilisateur") |
| except Exception as e: |
| print(f"\nErreur inattendue: {e}") |
| traceback.print_exc() |
|
|
| |
| if __name__ == "__main__": |
| print("🚀 Démarrage du test du service LLM...") |
| asyncio.run(main()) |