| """ |
| LangChain Ollama Client for SPARKNET |
| Integrates Ollama with LangChain for multi-model complexity routing |
| Provides unified interface for chat, embeddings, and GPU monitoring |
| """ |
|
|
| from typing import Optional, Dict, Any, List, Literal |
| from loguru import logger |
| from langchain_ollama import ChatOllama, OllamaEmbeddings |
| from langchain_core.callbacks import BaseCallbackHandler |
| from langchain_core.messages import BaseMessage |
| from langchain_core.outputs import LLMResult |
|
|
| from ..utils.gpu_manager import get_gpu_manager |
|
|
|
|
| |
| ComplexityLevel = Literal["simple", "standard", "complex", "analysis"] |
|
|
|
|
| class SparknetCallbackHandler(BaseCallbackHandler): |
| """ |
| Custom callback handler for SPARKNET. |
| Monitors GPU usage, token counts, and latency. |
| """ |
|
|
| def __init__(self): |
| super().__init__() |
| self.gpu_manager = get_gpu_manager() |
| self.token_count = 0 |
| self.llm_calls = 0 |
|
|
| def on_llm_start( |
| self, |
| serialized: Dict[str, Any], |
| prompts: List[str], |
| **kwargs: Any |
| ) -> None: |
| """Called when LLM starts processing.""" |
| self.llm_calls += 1 |
| gpu_status = self.gpu_manager.monitor() |
| logger.debug(f"LLM call #{self.llm_calls} started") |
| logger.debug(f"GPU Status: {gpu_status['gpus'][0]['memory_used']:.2f} GB used") |
|
|
| def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: |
| """Called when LLM finishes processing.""" |
| |
| if hasattr(response, 'llm_output') and response.llm_output: |
| token_usage = response.llm_output.get('token_usage', {}) |
| if token_usage: |
| self.token_count += token_usage.get('total_tokens', 0) |
| logger.debug(f"Tokens used: {token_usage.get('total_tokens', 0)}") |
|
|
| def on_llm_error(self, error: Exception, **kwargs: Any) -> None: |
| """Called when LLM encounters an error.""" |
| logger.error(f"LLM error: {error}") |
|
|
| def get_stats(self) -> Dict[str, Any]: |
| """Get accumulated statistics.""" |
| return { |
| 'llm_calls': self.llm_calls, |
| 'total_tokens': self.token_count, |
| 'gpu_status': self.gpu_manager.monitor(), |
| } |
|
|
|
|
| class LangChainOllamaClient: |
| """ |
| LangChain-powered Ollama client with intelligent model routing. |
| |
| Manages multiple Ollama models for different complexity levels: |
| - simple: Fast, lightweight tasks (gemma2:2b) |
| - standard: General-purpose tasks (llama3.1:8b) |
| - complex: Advanced reasoning and planning (qwen2.5:14b) |
| - analysis: Critical analysis and validation (mistral:latest) |
| |
| Features: |
| - Automatic model selection based on task complexity |
| - GPU monitoring via custom callbacks |
| - Embedding generation for vector search |
| - Streaming and non-streaming support |
| """ |
|
|
| |
| MODEL_CONFIG: Dict[ComplexityLevel, Dict[str, Any]] = { |
| "simple": { |
| "model": "gemma2:2b", |
| "temperature": 0.3, |
| "max_tokens": 512, |
| "description": "Fast classification, routing, simple Q&A", |
| "size_gb": 1.6, |
| }, |
| "standard": { |
| "model": "llama3.1:8b", |
| "temperature": 0.7, |
| "max_tokens": 1024, |
| "description": "General tasks, code generation, summarization", |
| "size_gb": 4.9, |
| }, |
| "complex": { |
| "model": "qwen2.5:14b", |
| "temperature": 0.7, |
| "max_tokens": 2048, |
| "description": "Complex reasoning, planning, multi-step tasks", |
| "size_gb": 9.0, |
| }, |
| "analysis": { |
| "model": "mistral:latest", |
| "temperature": 0.6, |
| "max_tokens": 1024, |
| "description": "Critical analysis, validation, quality assessment", |
| "size_gb": 4.4, |
| }, |
| } |
|
|
| def __init__( |
| self, |
| base_url: str = "http://localhost:11434", |
| default_complexity: ComplexityLevel = "standard", |
| enable_monitoring: bool = True, |
| ): |
| """ |
| Initialize LangChain Ollama client. |
| |
| Args: |
| base_url: Ollama server URL |
| default_complexity: Default model complexity level |
| enable_monitoring: Enable GPU monitoring callbacks |
| """ |
| self.base_url = base_url |
| self.default_complexity = default_complexity |
| self.enable_monitoring = enable_monitoring |
|
|
| |
| self.callback_handler = SparknetCallbackHandler() if enable_monitoring else None |
| self.callbacks = [self.callback_handler] if self.callback_handler else [] |
|
|
| |
| self.llms: Dict[ComplexityLevel, ChatOllama] = {} |
| self._initialize_models() |
|
|
| |
| self.embeddings = OllamaEmbeddings( |
| base_url=base_url, |
| model="nomic-embed-text:latest", |
| ) |
|
|
| logger.info(f"Initialized LangChainOllamaClient with {len(self.llms)} models") |
| logger.info(f"Default complexity: {default_complexity}") |
|
|
| def _initialize_models(self) -> None: |
| """Initialize ChatOllama instances for each complexity level.""" |
| for complexity, config in self.MODEL_CONFIG.items(): |
| try: |
| self.llms[complexity] = ChatOllama( |
| base_url=self.base_url, |
| model=config["model"], |
| temperature=config["temperature"], |
| num_predict=config["max_tokens"], |
| callbacks=self.callbacks, |
| ) |
| logger.debug(f"Initialized {complexity} model: {config['model']}") |
| except Exception as e: |
| logger.error(f"Failed to initialize {complexity} model: {e}") |
|
|
| def get_llm( |
| self, |
| complexity: Optional[ComplexityLevel] = None, |
| temperature: Optional[float] = None, |
| max_tokens: Optional[int] = None, |
| ) -> ChatOllama: |
| """ |
| Get LLM for specified complexity level. |
| |
| Args: |
| complexity: Complexity level (simple, standard, complex, analysis) |
| temperature: Override default temperature |
| max_tokens: Override default max tokens |
| |
| Returns: |
| ChatOllama instance |
| """ |
| complexity = complexity or self.default_complexity |
|
|
| if complexity not in self.llms: |
| logger.warning(f"Unknown complexity '{complexity}', using default") |
| complexity = self.default_complexity |
|
|
| |
| if temperature is None and max_tokens is None: |
| return self.llms[complexity] |
|
|
| |
| config = self.MODEL_CONFIG[complexity] |
| return ChatOllama( |
| base_url=self.base_url, |
| model=config["model"], |
| temperature=temperature if temperature is not None else config["temperature"], |
| num_predict=max_tokens if max_tokens is not None else config["max_tokens"], |
| callbacks=self.callbacks, |
| ) |
|
|
| def get_embeddings(self) -> OllamaEmbeddings: |
| """ |
| Get embedding model for vector operations. |
| |
| Returns: |
| OllamaEmbeddings instance |
| """ |
| return self.embeddings |
|
|
| async def ainvoke( |
| self, |
| messages: List[BaseMessage], |
| complexity: Optional[ComplexityLevel] = None, |
| **kwargs: Any, |
| ) -> BaseMessage: |
| """ |
| Async invoke LLM with messages. |
| |
| Args: |
| messages: List of messages for the conversation |
| complexity: Model complexity level |
| **kwargs: Additional arguments for the LLM |
| |
| Returns: |
| AI response message |
| """ |
| llm = self.get_llm(complexity) |
| response = await llm.ainvoke(messages, **kwargs) |
| return response |
|
|
| def invoke( |
| self, |
| messages: List[BaseMessage], |
| complexity: Optional[ComplexityLevel] = None, |
| **kwargs: Any, |
| ) -> BaseMessage: |
| """ |
| Synchronous invoke LLM with messages. |
| |
| Args: |
| messages: List of messages for the conversation |
| complexity: Model complexity level |
| **kwargs: Additional arguments for the LLM |
| |
| Returns: |
| AI response message |
| """ |
| llm = self.get_llm(complexity) |
| response = llm.invoke(messages, **kwargs) |
| return response |
|
|
| async def astream( |
| self, |
| messages: List[BaseMessage], |
| complexity: Optional[ComplexityLevel] = None, |
| **kwargs: Any, |
| ): |
| """ |
| Async stream LLM responses. |
| |
| Args: |
| messages: List of messages for the conversation |
| complexity: Model complexity level |
| **kwargs: Additional arguments for the LLM |
| |
| Yields: |
| Chunks of AI response |
| """ |
| llm = self.get_llm(complexity) |
| async for chunk in llm.astream(messages, **kwargs): |
| yield chunk |
|
|
| async def embed_text(self, text: str) -> List[float]: |
| """ |
| Generate embedding for text. |
| |
| Args: |
| text: Text to embed |
| |
| Returns: |
| Embedding vector |
| """ |
| embedding = await self.embeddings.aembed_query(text) |
| return embedding |
|
|
| async def embed_documents(self, documents: List[str]) -> List[List[float]]: |
| """ |
| Generate embeddings for multiple documents. |
| |
| Args: |
| documents: List of documents to embed |
| |
| Returns: |
| List of embedding vectors |
| """ |
| embeddings = await self.embeddings.aembed_documents(documents) |
| return embeddings |
|
|
| def get_model_info(self, complexity: Optional[ComplexityLevel] = None) -> Dict[str, Any]: |
| """ |
| Get information about a model. |
| |
| Args: |
| complexity: Complexity level (defaults to current default) |
| |
| Returns: |
| Model configuration dictionary |
| """ |
| complexity = complexity or self.default_complexity |
| return self.MODEL_CONFIG.get(complexity, {}) |
|
|
| def list_models(self) -> Dict[ComplexityLevel, Dict[str, Any]]: |
| """ |
| List all available models and their configurations. |
| |
| Returns: |
| Dictionary mapping complexity levels to model configs |
| """ |
| return self.MODEL_CONFIG.copy() |
|
|
| def get_stats(self) -> Dict[str, Any]: |
| """ |
| Get client statistics. |
| |
| Returns: |
| Statistics dictionary |
| """ |
| if self.callback_handler: |
| return self.callback_handler.get_stats() |
| return {} |
|
|
| def recommend_complexity(self, task_description: str) -> ComplexityLevel: |
| """ |
| Recommend complexity level based on task description. |
| |
| Uses simple heuristics to suggest appropriate model: |
| - Keywords like "plan", "analyze", "complex" → complex |
| - Keywords like "validate", "critique", "assess" → analysis |
| - Keywords like "classify", "route", "simple" → simple |
| - Default → standard |
| |
| Args: |
| task_description: Natural language task description |
| |
| Returns: |
| Recommended complexity level |
| """ |
| task_lower = task_description.lower() |
|
|
| |
| if any(kw in task_lower for kw in ["plan", "strategy", "decompose", "workflow", "multi-step"]): |
| return "complex" |
|
|
| |
| if any(kw in task_lower for kw in ["validate", "critique", "assess", "review", "quality"]): |
| return "analysis" |
|
|
| |
| if any(kw in task_lower for kw in ["classify", "route", "yes/no", "binary", "simple"]): |
| return "simple" |
|
|
| |
| return "standard" |
|
|
|
|
| |
| def get_langchain_client( |
| base_url: str = "http://localhost:11434", |
| default_complexity: ComplexityLevel = "standard", |
| enable_monitoring: bool = True, |
| ) -> LangChainOllamaClient: |
| """ |
| Get a LangChain Ollama client instance. |
| |
| Args: |
| base_url: Ollama server URL |
| default_complexity: Default model complexity |
| enable_monitoring: Enable GPU monitoring |
| |
| Returns: |
| LangChainOllamaClient instance |
| """ |
| return LangChainOllamaClient( |
| base_url=base_url, |
| default_complexity=default_complexity, |
| enable_monitoring=enable_monitoring, |
| ) |
|
|