| """ |
| Module that implements containers for specific LLM bindings. |
| |
| This module provides container implementations for various Large Language Model |
| bindings and integrations. |
| """ |
|
|
| from argparse import ArgumentParser, Namespace |
| import argparse |
| import json |
| from dataclasses import asdict, dataclass, field |
| from typing import Any, ClassVar, List |
|
|
| from lightrag.utils import get_env_value |
| from lightrag.constants import DEFAULT_TEMPERATURE |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| @dataclass |
| class BindingOptions: |
| """Base class for binding options.""" |
|
|
| |
| _binding_name: ClassVar[str] |
|
|
| |
| _help: ClassVar[dict[str, str]] |
|
|
| @staticmethod |
| def _all_class_vars(klass: type, include_inherited=True) -> dict[str, Any]: |
| """Print class variables, optionally including inherited ones""" |
| if include_inherited: |
| |
| vars_dict = {} |
| for base in reversed(klass.__mro__[:-1]): |
| vars_dict.update( |
| { |
| k: v |
| for k, v in base.__dict__.items() |
| if ( |
| not k.startswith("_") |
| and not callable(v) |
| and not isinstance(v, classmethod) |
| ) |
| } |
| ) |
| else: |
| |
| vars_dict = { |
| k: v |
| for k, v in klass.__dict__.items() |
| if ( |
| not k.startswith("_") |
| and not callable(v) |
| and not isinstance(v, classmethod) |
| ) |
| } |
|
|
| return vars_dict |
|
|
| @classmethod |
| def add_args(cls, parser: ArgumentParser): |
| group = parser.add_argument_group(f"{cls._binding_name} binding options") |
| for arg_item in cls.args_env_name_type_value(): |
| |
| if arg_item["type"] is List[str]: |
|
|
| def json_list_parser(value): |
| try: |
| parsed = json.loads(value) |
| if not isinstance(parsed, list): |
| raise argparse.ArgumentTypeError( |
| f"Expected JSON array, got {type(parsed).__name__}" |
| ) |
| return parsed |
| except json.JSONDecodeError as e: |
| raise argparse.ArgumentTypeError(f"Invalid JSON: {e}") |
|
|
| |
| env_value = get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS) |
| if env_value is not argparse.SUPPRESS: |
| try: |
| env_value = json_list_parser(env_value) |
| except argparse.ArgumentTypeError: |
| env_value = argparse.SUPPRESS |
|
|
| group.add_argument( |
| f"--{arg_item['argname']}", |
| type=json_list_parser, |
| default=env_value, |
| help=arg_item["help"], |
| ) |
| |
| elif arg_item["type"] is dict: |
|
|
| def json_dict_parser(value): |
| try: |
| parsed = json.loads(value) |
| if not isinstance(parsed, dict): |
| raise argparse.ArgumentTypeError( |
| f"Expected JSON object, got {type(parsed).__name__}" |
| ) |
| return parsed |
| except json.JSONDecodeError as e: |
| raise argparse.ArgumentTypeError(f"Invalid JSON: {e}") |
|
|
| |
| env_value = get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS) |
| if env_value is not argparse.SUPPRESS: |
| try: |
| env_value = json_dict_parser(env_value) |
| except argparse.ArgumentTypeError: |
| env_value = argparse.SUPPRESS |
|
|
| group.add_argument( |
| f"--{arg_item['argname']}", |
| type=json_dict_parser, |
| default=env_value, |
| help=arg_item["help"], |
| ) |
| |
| elif arg_item["type"] is bool: |
|
|
| def bool_parser(value): |
| """Custom boolean parser that handles string representations correctly""" |
| if isinstance(value, bool): |
| return value |
| if isinstance(value, str): |
| return value.lower() in ("true", "1", "yes", "t", "on") |
| return bool(value) |
|
|
| |
| env_value = get_env_value( |
| f"{arg_item['env_name']}", argparse.SUPPRESS, bool |
| ) |
|
|
| group.add_argument( |
| f"--{arg_item['argname']}", |
| type=bool_parser, |
| default=env_value, |
| help=arg_item["help"], |
| ) |
| else: |
| group.add_argument( |
| f"--{arg_item['argname']}", |
| type=arg_item["type"], |
| default=get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS), |
| help=arg_item["help"], |
| ) |
|
|
| @classmethod |
| def args_env_name_type_value(cls): |
| import dataclasses |
|
|
| args_prefix = f"{cls._binding_name}".replace("_", "-") |
| env_var_prefix = f"{cls._binding_name}_".upper() |
| help = cls._help |
|
|
| |
| if dataclasses.is_dataclass(cls): |
| for field in dataclasses.fields(cls): |
| |
| if field.name.startswith("_"): |
| continue |
|
|
| |
| if field.default is not dataclasses.MISSING: |
| default_value = field.default |
| elif field.default_factory is not dataclasses.MISSING: |
| default_value = field.default_factory() |
| else: |
| default_value = None |
|
|
| argdef = { |
| "argname": f"{args_prefix}-{field.name}", |
| "env_name": f"{env_var_prefix}{field.name.upper()}", |
| "type": field.type, |
| "default": default_value, |
| "help": f"{cls._binding_name} -- " + help.get(field.name, ""), |
| } |
|
|
| yield argdef |
| else: |
| |
| class_vars = { |
| key: value |
| for key, value in cls._all_class_vars(cls).items() |
| if not callable(value) and not key.startswith("_") |
| } |
|
|
| |
| type_hints = {} |
| for base in cls.__mro__: |
| if hasattr(base, "__annotations__"): |
| type_hints.update(base.__annotations__) |
|
|
| for class_var in class_vars: |
| |
| var_type = type_hints.get(class_var, type(class_vars[class_var])) |
|
|
| argdef = { |
| "argname": f"{args_prefix}-{class_var}", |
| "env_name": f"{env_var_prefix}{class_var.upper()}", |
| "type": var_type, |
| "default": class_vars[class_var], |
| "help": f"{cls._binding_name} -- " + help.get(class_var, ""), |
| } |
|
|
| yield argdef |
|
|
| @classmethod |
| def generate_dot_env_sample(cls): |
| """ |
| Generate a sample .env file for all LightRAG binding options. |
| |
| This method creates a .env file that includes all the binding options |
| defined by the subclasses of BindingOptions. It uses the args_env_name_type_value() |
| method to get the list of all options and their default values. |
| |
| Returns: |
| str: A string containing the contents of the sample .env file. |
| """ |
| from io import StringIO |
|
|
| sample_top = ( |
| "#" * 80 |
| + "\n" |
| + ( |
| "# Autogenerated .env entries list for LightRAG binding options\n" |
| "#\n" |
| "# To generate run:\n" |
| "# $ python -m lightrag.llm.binding_options\n" |
| ) |
| + "#" * 80 |
| + "\n" |
| ) |
|
|
| sample_bottom = ( |
| ("#\n# End of .env entries for LightRAG binding options\n") |
| + "#" * 80 |
| + "\n" |
| ) |
|
|
| sample_stream = StringIO() |
| sample_stream.write(sample_top) |
| for klass in cls.__subclasses__(): |
| for arg_item in klass.args_env_name_type_value(): |
| if arg_item["help"]: |
| sample_stream.write(f"# {arg_item['help']}\n") |
|
|
| |
| if arg_item["type"] is List[str] or arg_item["type"] is dict: |
| default_value = json.dumps(arg_item["default"]) |
| else: |
| default_value = arg_item["default"] |
|
|
| sample_stream.write(f"# {arg_item['env_name']}={default_value}\n\n") |
|
|
| sample_stream.write(sample_bottom) |
| return sample_stream.getvalue() |
|
|
| @classmethod |
| def options_dict(cls, args: Namespace) -> dict[str, Any]: |
| """ |
| Extract options dictionary for a specific binding from parsed arguments. |
| |
| This method filters the parsed command-line arguments to return only those |
| that belong to the specific binding class. It removes the binding prefix |
| from argument names to create a clean options dictionary. |
| |
| Args: |
| args (Namespace): Parsed command-line arguments containing all binding options |
| |
| Returns: |
| dict[str, Any]: Dictionary mapping option names (without prefix) to their values |
| |
| Example: |
| If args contains {'ollama_num_ctx': 512, 'other_option': 'value'} |
| and this is called on OllamaOptions, it returns {'num_ctx': 512} |
| """ |
| prefix = cls._binding_name + "_" |
| skipchars = len(prefix) |
| options = { |
| key[skipchars:]: value |
| for key, value in vars(args).items() |
| if key.startswith(prefix) |
| } |
|
|
| return options |
|
|
| def asdict(self) -> dict[str, Any]: |
| """ |
| Convert an instance of binding options to a dictionary. |
| |
| This method uses dataclasses.asdict() to convert the dataclass instance |
| into a dictionary representation, including all its fields and values. |
| |
| Returns: |
| dict[str, Any]: Dictionary representation of the binding options instance |
| """ |
| return asdict(self) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| @dataclass |
| class _OllamaOptionsMixin: |
| """Options for Ollama bindings.""" |
|
|
| |
| num_ctx: int = 32768 |
| num_predict: int = 128 |
| num_keep: int = 0 |
| seed: int = -1 |
|
|
| |
| temperature: float = DEFAULT_TEMPERATURE |
| top_k: int = 40 |
| top_p: float = 0.9 |
| tfs_z: float = 1.0 |
| typical_p: float = 1.0 |
| min_p: float = 0.0 |
|
|
| |
| repeat_last_n: int = 64 |
| repeat_penalty: float = 1.1 |
| presence_penalty: float = 0.0 |
| frequency_penalty: float = 0.0 |
|
|
| |
| mirostat: int = ( |
| |
| 0 |
| ) |
| mirostat_tau: float = 5.0 |
| mirostat_eta: float = 0.1 |
|
|
| |
| numa: bool = False |
| num_batch: int = 512 |
| num_gpu: int = -1 |
| main_gpu: int = 0 |
| low_vram: bool = False |
| num_thread: int = 0 |
|
|
| |
| f16_kv: bool = True |
| logits_all: bool = False |
| vocab_only: bool = False |
| use_mmap: bool = True |
| use_mlock: bool = False |
| embedding_only: bool = False |
|
|
| |
| penalize_newline: bool = True |
| stop: List[str] = field(default_factory=list) |
|
|
| |
| _help: ClassVar[dict[str, str]] = { |
| "num_ctx": "Context window size (number of tokens)", |
| "num_predict": "Maximum number of tokens to predict", |
| "num_keep": "Number of tokens to keep from the initial prompt", |
| "seed": "Random seed for generation (-1 for random)", |
| "temperature": "Controls randomness (0.0-2.0, higher = more creative)", |
| "top_k": "Top-k sampling parameter (0 = disabled)", |
| "top_p": "Top-p (nucleus) sampling parameter (0.0-1.0)", |
| "tfs_z": "Tail free sampling parameter (1.0 = disabled)", |
| "typical_p": "Typical probability mass (1.0 = disabled)", |
| "min_p": "Minimum probability threshold (0.0 = disabled)", |
| "repeat_last_n": "Number of tokens to consider for repetition penalty", |
| "repeat_penalty": "Penalty for repetition (1.0 = no penalty)", |
| "presence_penalty": "Penalty for token presence (-2.0 to 2.0)", |
| "frequency_penalty": "Penalty for token frequency (-2.0 to 2.0)", |
| "mirostat": "Mirostat sampling algorithm (0=disabled, 1=Mirostat 1.0, 2=Mirostat 2.0)", |
| "mirostat_tau": "Mirostat target entropy", |
| "mirostat_eta": "Mirostat learning rate", |
| "numa": "Enable NUMA optimization", |
| "num_batch": "Batch size for processing", |
| "num_gpu": "Number of GPUs to use (-1 for auto)", |
| "main_gpu": "Main GPU index", |
| "low_vram": "Optimize for low VRAM", |
| "num_thread": "Number of CPU threads (0 for auto)", |
| "f16_kv": "Use half-precision for key/value cache", |
| "logits_all": "Return logits for all tokens", |
| "vocab_only": "Only load vocabulary", |
| "use_mmap": "Use memory mapping for model files", |
| "use_mlock": "Lock model in memory", |
| "embedding_only": "Only use for embeddings", |
| "penalize_newline": "Penalize newline tokens", |
| "stop": 'Stop sequences (JSON array of strings, e.g., \'["</s>", "\\n\\n"]\')', |
| } |
|
|
|
|
| @dataclass |
| class OllamaEmbeddingOptions(_OllamaOptionsMixin, BindingOptions): |
| """Options for Ollama embeddings with specialized configuration for embedding tasks.""" |
|
|
| |
| _binding_name: ClassVar[str] = "ollama_embedding" |
|
|
|
|
| @dataclass |
| class OllamaLLMOptions(_OllamaOptionsMixin, BindingOptions): |
| """Options for Ollama LLM with specialized configuration for LLM tasks.""" |
|
|
| |
| _binding_name: ClassVar[str] = "ollama_llm" |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| @dataclass |
| class OpenAILLMOptions(BindingOptions): |
| """Options for OpenAI LLM with configuration for OpenAI and Azure OpenAI API calls.""" |
|
|
| |
| _binding_name: ClassVar[str] = "openai_llm" |
|
|
| |
| frequency_penalty: float = 0.0 |
| max_completion_tokens: int = None |
| presence_penalty: float = 0.0 |
| reasoning_effort: str = "medium" |
| safety_identifier: str = "" |
| service_tier: str = "" |
| stop: List[str] = field(default_factory=list) |
| temperature: float = DEFAULT_TEMPERATURE |
| top_p: float = 1.0 |
| max_tokens: int = None |
| extra_body: dict = None |
|
|
| |
| _help: ClassVar[dict[str, str]] = { |
| "frequency_penalty": "Penalty for token frequency (-2.0 to 2.0, positive values discourage repetition)", |
| "max_completion_tokens": "Maximum number of tokens to generate (optional, leave empty for model default)", |
| "presence_penalty": "Penalty for token presence (-2.0 to 2.0, positive values encourage new topics)", |
| "reasoning_effort": "Reasoning effort level for o1 models (low, medium, high)", |
| "safety_identifier": "Safety identifier for content filtering (optional)", |
| "service_tier": "Service tier for API usage (optional)", |
| "stop": 'Stop sequences (JSON array of strings, e.g., \'["</s>", "\\n\\n"]\')', |
| "temperature": "Controls randomness (0.0-2.0, higher = more creative)", |
| "top_p": "Nucleus sampling parameter (0.0-1.0, lower = more focused)", |
| "max_tokens": "Maximum number of tokens to generate (deprecated, use max_completion_tokens instead)", |
| "extra_body": 'Extra body parameters for OpenRouter of vLLM (JSON dict, e.g., \'"reasoning": {"reasoning": {"enabled": false}}\')', |
| } |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| if __name__ == "__main__": |
| import sys |
| import dotenv |
| |
|
|
| dotenv.load_dotenv(dotenv_path=".env", override=False) |
|
|
| |
| |
| |
| |
| |
|
|
| if len(sys.argv) > 1 and sys.argv[1] == "test": |
| |
| parser = ArgumentParser(description="Test binding options") |
| OllamaEmbeddingOptions.add_args(parser) |
| OllamaLLMOptions.add_args(parser) |
| OpenAILLMOptions.add_args(parser) |
|
|
| |
| args = parser.parse_args( |
| [ |
| "--ollama-embedding-num_ctx", |
| "1024", |
| "--ollama-llm-num_ctx", |
| "2048", |
| "--openai-llm-temperature", |
| "0.7", |
| "--openai-llm-max_completion_tokens", |
| "1000", |
| "--openai-llm-stop", |
| '["</s>", "\\n\\n"]', |
| "--openai-llm-reasoning", |
| '{"effort": "high", "max_tokens": 2000, "exclude": false, "enabled": true}', |
| ] |
| ) |
| print("Final args for LLM and Embedding:") |
| print(f"{args}\n") |
|
|
| print("Ollama LLM options:") |
| print(OllamaLLMOptions.options_dict(args)) |
|
|
| print("\nOllama Embedding options:") |
| print(OllamaEmbeddingOptions.options_dict(args)) |
|
|
| print("\nOpenAI LLM options:") |
| print(OpenAILLMOptions.options_dict(args)) |
|
|
| |
| openai_options = OpenAILLMOptions( |
| temperature=0.8, |
| max_completion_tokens=1500, |
| frequency_penalty=0.1, |
| presence_penalty=0.2, |
| stop=["<|end|>", "\n\n"], |
| ) |
| print("\nOpenAI LLM options instance:") |
| print(openai_options.asdict()) |
|
|
| |
| openai_options_with_reasoning = OpenAILLMOptions( |
| temperature=0.9, |
| max_completion_tokens=2000, |
| reasoning={ |
| "effort": "medium", |
| "max_tokens": 1500, |
| "exclude": True, |
| "enabled": True, |
| }, |
| ) |
| print("\nOpenAI LLM options instance with reasoning:") |
| print(openai_options_with_reasoning.asdict()) |
|
|
| |
| print("\n" + "=" * 50) |
| print("TESTING DICT PARSING FUNCTIONALITY") |
| print("=" * 50) |
|
|
| |
| test_parser = ArgumentParser(description="Test dict parsing") |
| OpenAILLMOptions.add_args(test_parser) |
|
|
| try: |
| test_args = test_parser.parse_args( |
| ["--openai-llm-reasoning", '{"effort": "low", "max_tokens": 1000}'] |
| ) |
| print("✓ Valid JSON dict parsing successful:") |
| print( |
| f" Parsed reasoning: {OpenAILLMOptions.options_dict(test_args)['reasoning']}" |
| ) |
| except Exception as e: |
| print(f"✗ Valid JSON dict parsing failed: {e}") |
|
|
| |
| try: |
| test_args = test_parser.parse_args( |
| [ |
| "--openai-llm-reasoning", |
| '{"effort": "low", "max_tokens": 1000', |
| ] |
| ) |
| print("✗ Invalid JSON should have failed but didn't") |
| except SystemExit: |
| print("✓ Invalid JSON dict parsing correctly rejected") |
| except Exception as e: |
| print(f"✓ Invalid JSON dict parsing correctly rejected: {e}") |
|
|
| |
| try: |
| test_args = test_parser.parse_args( |
| [ |
| "--openai-llm-reasoning", |
| '["not", "a", "dict"]', |
| ] |
| ) |
| print("✗ Non-dict JSON should have failed but didn't") |
| except SystemExit: |
| print("✓ Non-dict JSON parsing correctly rejected") |
| except Exception as e: |
| print(f"✓ Non-dict JSON parsing correctly rejected: {e}") |
|
|
| print("\n" + "=" * 50) |
| print("TESTING ENVIRONMENT VARIABLE SUPPORT") |
| print("=" * 50) |
|
|
| |
| import os |
|
|
| os.environ["OPENAI_LLM_REASONING"] = ( |
| '{"effort": "high", "max_tokens": 3000, "exclude": false}' |
| ) |
|
|
| env_parser = ArgumentParser(description="Test env var dict parsing") |
| OpenAILLMOptions.add_args(env_parser) |
|
|
| try: |
| env_args = env_parser.parse_args( |
| [] |
| ) |
| reasoning_from_env = OpenAILLMOptions.options_dict(env_args).get( |
| "reasoning" |
| ) |
| if reasoning_from_env: |
| print("✓ Environment variable dict parsing successful:") |
| print(f" Parsed reasoning from env: {reasoning_from_env}") |
| else: |
| print("✗ Environment variable dict parsing failed: No reasoning found") |
| except Exception as e: |
| print(f"✗ Environment variable dict parsing failed: {e}") |
| finally: |
| |
| if "OPENAI_LLM_REASONING" in os.environ: |
| del os.environ["OPENAI_LLM_REASONING"] |
|
|
| else: |
| print(BindingOptions.generate_dot_env_sample()) |
|
|