drahnreb commited on
Commit ·
6c0bb1c
1
Parent(s): f47a464
fix linting
Browse files- examples/lightrag_gemini_demo_no_tiktoken.py +13 -6
- lightrag/api/routers/ollama_api.py +1 -1
- lightrag/lightrag.py +13 -6
- lightrag/operate.py +16 -10
- lightrag/utils.py +10 -5
examples/lightrag_gemini_demo_no_tiktoken.py
CHANGED
|
@@ -51,10 +51,12 @@ class GemmaTokenizer(Tokenizer):
|
|
| 51 |
"google/gemma3": _TokenizerConfig(
|
| 52 |
tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/cb7c0152a369e43908e769eb09e1ce6043afe084/tokenizer/gemma3_cleaned_262144_v2.spiece.model",
|
| 53 |
tokenizer_model_hash="1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c",
|
| 54 |
-
)
|
| 55 |
-
}
|
| 56 |
|
| 57 |
-
def __init__(
|
|
|
|
|
|
|
| 58 |
# https://github.com/google/gemma_pytorch/tree/main/tokenizer
|
| 59 |
if "1.5" in model_name or "1.0" in model_name:
|
| 60 |
# up to gemini 1.5 gemma2 is a comparable local tokenizer
|
|
@@ -77,7 +79,9 @@ class GemmaTokenizer(Tokenizer):
|
|
| 77 |
else:
|
| 78 |
model_data = None
|
| 79 |
if not model_data:
|
| 80 |
-
model_data = self._load_from_url(
|
|
|
|
|
|
|
| 81 |
self.save_tokenizer_to_cache(cache_path=file_path, model_data=model_data)
|
| 82 |
|
| 83 |
tokenizer = spm.SentencePieceProcessor()
|
|
@@ -140,7 +144,7 @@ class GemmaTokenizer(Tokenizer):
|
|
| 140 |
|
| 141 |
# def encode(self, content: str) -> list[int]:
|
| 142 |
# return self.tokenizer.encode(content)
|
| 143 |
-
|
| 144 |
# def decode(self, tokens: list[int]) -> str:
|
| 145 |
# return self.tokenizer.decode(tokens)
|
| 146 |
|
|
@@ -187,7 +191,10 @@ async def initialize_rag():
|
|
| 187 |
rag = LightRAG(
|
| 188 |
working_dir=WORKING_DIR,
|
| 189 |
# tiktoken_model_name="gpt-4o-mini",
|
| 190 |
-
tokenizer=GemmaTokenizer(
|
|
|
|
|
|
|
|
|
|
| 191 |
llm_model_func=llm_model_func,
|
| 192 |
embedding_func=EmbeddingFunc(
|
| 193 |
embedding_dim=384,
|
|
|
|
| 51 |
"google/gemma3": _TokenizerConfig(
|
| 52 |
tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/cb7c0152a369e43908e769eb09e1ce6043afe084/tokenizer/gemma3_cleaned_262144_v2.spiece.model",
|
| 53 |
tokenizer_model_hash="1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c",
|
| 54 |
+
),
|
| 55 |
+
}
|
| 56 |
|
| 57 |
+
def __init__(
|
| 58 |
+
self, model_name: str = "gemini-2.0-flash", tokenizer_dir: Optional[str] = None
|
| 59 |
+
):
|
| 60 |
# https://github.com/google/gemma_pytorch/tree/main/tokenizer
|
| 61 |
if "1.5" in model_name or "1.0" in model_name:
|
| 62 |
# up to gemini 1.5 gemma2 is a comparable local tokenizer
|
|
|
|
| 79 |
else:
|
| 80 |
model_data = None
|
| 81 |
if not model_data:
|
| 82 |
+
model_data = self._load_from_url(
|
| 83 |
+
file_url=file_url, expected_hash=expected_hash
|
| 84 |
+
)
|
| 85 |
self.save_tokenizer_to_cache(cache_path=file_path, model_data=model_data)
|
| 86 |
|
| 87 |
tokenizer = spm.SentencePieceProcessor()
|
|
|
|
| 144 |
|
| 145 |
# def encode(self, content: str) -> list[int]:
|
| 146 |
# return self.tokenizer.encode(content)
|
| 147 |
+
|
| 148 |
# def decode(self, tokens: list[int]) -> str:
|
| 149 |
# return self.tokenizer.decode(tokens)
|
| 150 |
|
|
|
|
| 191 |
rag = LightRAG(
|
| 192 |
working_dir=WORKING_DIR,
|
| 193 |
# tiktoken_model_name="gpt-4o-mini",
|
| 194 |
+
tokenizer=GemmaTokenizer(
|
| 195 |
+
tokenizer_dir=(Path(WORKING_DIR) / "vertexai_tokenizer_model"),
|
| 196 |
+
model_name="gemini-2.0-flash",
|
| 197 |
+
),
|
| 198 |
llm_model_func=llm_model_func,
|
| 199 |
embedding_func=EmbeddingFunc(
|
| 200 |
embedding_dim=384,
|
lightrag/api/routers/ollama_api.py
CHANGED
|
@@ -10,7 +10,7 @@ from fastapi.responses import StreamingResponse
|
|
| 10 |
import asyncio
|
| 11 |
from ascii_colors import trace_exception
|
| 12 |
from lightrag import LightRAG, QueryParam
|
| 13 |
-
from lightrag.utils import TiktokenTokenizer
|
| 14 |
from lightrag.api.utils_api import ollama_server_infos, get_combined_auth_dependency
|
| 15 |
from fastapi import Depends
|
| 16 |
|
|
|
|
| 10 |
import asyncio
|
| 11 |
from ascii_colors import trace_exception
|
| 12 |
from lightrag import LightRAG, QueryParam
|
| 13 |
+
from lightrag.utils import TiktokenTokenizer
|
| 14 |
from lightrag.api.utils_api import ollama_server_infos, get_combined_auth_dependency
|
| 15 |
from fastapi import Depends
|
| 16 |
|
lightrag/lightrag.py
CHANGED
|
@@ -7,7 +7,18 @@ import warnings
|
|
| 7 |
from dataclasses import asdict, dataclass, field
|
| 8 |
from datetime import datetime
|
| 9 |
from functools import partial
|
| 10 |
-
from typing import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
from lightrag.kg import (
|
| 13 |
STORAGES,
|
|
@@ -1147,11 +1158,7 @@ class LightRAG:
|
|
| 1147 |
for chunk_data in custom_kg.get("chunks", []):
|
| 1148 |
chunk_content = clean_text(chunk_data["content"])
|
| 1149 |
source_id = chunk_data["source_id"]
|
| 1150 |
-
tokens = len(
|
| 1151 |
-
self.tokenizer.encode(
|
| 1152 |
-
chunk_content
|
| 1153 |
-
)
|
| 1154 |
-
)
|
| 1155 |
chunk_order_index = (
|
| 1156 |
0
|
| 1157 |
if "chunk_order_index" not in chunk_data.keys()
|
|
|
|
| 7 |
from dataclasses import asdict, dataclass, field
|
| 8 |
from datetime import datetime
|
| 9 |
from functools import partial
|
| 10 |
+
from typing import (
|
| 11 |
+
Any,
|
| 12 |
+
AsyncIterator,
|
| 13 |
+
Callable,
|
| 14 |
+
Iterator,
|
| 15 |
+
cast,
|
| 16 |
+
final,
|
| 17 |
+
Literal,
|
| 18 |
+
Optional,
|
| 19 |
+
List,
|
| 20 |
+
Dict,
|
| 21 |
+
)
|
| 22 |
|
| 23 |
from lightrag.kg import (
|
| 24 |
STORAGES,
|
|
|
|
| 1158 |
for chunk_data in custom_kg.get("chunks", []):
|
| 1159 |
chunk_content = clean_text(chunk_data["content"])
|
| 1160 |
source_id = chunk_data["source_id"]
|
| 1161 |
+
tokens = len(self.tokenizer.encode(chunk_content))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1162 |
chunk_order_index = (
|
| 1163 |
0
|
| 1164 |
if "chunk_order_index" not in chunk_data.keys()
|
lightrag/operate.py
CHANGED
|
@@ -88,9 +88,7 @@ def chunking_by_token_size(
|
|
| 88 |
for index, start in enumerate(
|
| 89 |
range(0, len(tokens), max_token_size - overlap_token_size)
|
| 90 |
):
|
| 91 |
-
chunk_content = tokenizer.decode(
|
| 92 |
-
tokens[start : start + max_token_size]
|
| 93 |
-
)
|
| 94 |
results.append(
|
| 95 |
{
|
| 96 |
"tokens": min(max_token_size, len(tokens) - start),
|
|
@@ -126,9 +124,7 @@ async def _handle_entity_relation_summary(
|
|
| 126 |
if len(tokens) < summary_max_tokens: # No need for summary
|
| 127 |
return description
|
| 128 |
prompt_template = PROMPTS["summarize_entity_descriptions"]
|
| 129 |
-
use_description = tokenizer.decode(
|
| 130 |
-
tokens[:llm_max_tokens]
|
| 131 |
-
)
|
| 132 |
context_base = dict(
|
| 133 |
entity_name=entity_or_relation_name,
|
| 134 |
description_list=use_description.split(GRAPH_FIELD_SEP),
|
|
@@ -1378,10 +1374,15 @@ async def _get_node_data(
|
|
| 1378 |
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
|
| 1379 |
# get entitytext chunk
|
| 1380 |
use_text_units = await _find_most_related_text_unit_from_entities(
|
| 1381 |
-
node_datas,
|
|
|
|
|
|
|
|
|
|
| 1382 |
)
|
| 1383 |
use_relations = await _find_most_related_edges_from_entities(
|
| 1384 |
-
node_datas,
|
|
|
|
|
|
|
| 1385 |
)
|
| 1386 |
|
| 1387 |
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
|
|
@@ -1703,10 +1704,15 @@ async def _get_edge_data(
|
|
| 1703 |
)
|
| 1704 |
use_entities, use_text_units = await asyncio.gather(
|
| 1705 |
_find_most_related_entities_from_relationships(
|
| 1706 |
-
edge_datas,
|
|
|
|
|
|
|
| 1707 |
),
|
| 1708 |
_find_related_text_unit_from_relationships(
|
| 1709 |
-
edge_datas,
|
|
|
|
|
|
|
|
|
|
| 1710 |
),
|
| 1711 |
)
|
| 1712 |
logger.info(
|
|
|
|
| 88 |
for index, start in enumerate(
|
| 89 |
range(0, len(tokens), max_token_size - overlap_token_size)
|
| 90 |
):
|
| 91 |
+
chunk_content = tokenizer.decode(tokens[start : start + max_token_size])
|
|
|
|
|
|
|
| 92 |
results.append(
|
| 93 |
{
|
| 94 |
"tokens": min(max_token_size, len(tokens) - start),
|
|
|
|
| 124 |
if len(tokens) < summary_max_tokens: # No need for summary
|
| 125 |
return description
|
| 126 |
prompt_template = PROMPTS["summarize_entity_descriptions"]
|
| 127 |
+
use_description = tokenizer.decode(tokens[:llm_max_tokens])
|
|
|
|
|
|
|
| 128 |
context_base = dict(
|
| 129 |
entity_name=entity_or_relation_name,
|
| 130 |
description_list=use_description.split(GRAPH_FIELD_SEP),
|
|
|
|
| 1374 |
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
|
| 1375 |
# get entitytext chunk
|
| 1376 |
use_text_units = await _find_most_related_text_unit_from_entities(
|
| 1377 |
+
node_datas,
|
| 1378 |
+
query_param,
|
| 1379 |
+
text_chunks_db,
|
| 1380 |
+
knowledge_graph_inst,
|
| 1381 |
)
|
| 1382 |
use_relations = await _find_most_related_edges_from_entities(
|
| 1383 |
+
node_datas,
|
| 1384 |
+
query_param,
|
| 1385 |
+
knowledge_graph_inst,
|
| 1386 |
)
|
| 1387 |
|
| 1388 |
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
|
|
|
|
| 1704 |
)
|
| 1705 |
use_entities, use_text_units = await asyncio.gather(
|
| 1706 |
_find_most_related_entities_from_relationships(
|
| 1707 |
+
edge_datas,
|
| 1708 |
+
query_param,
|
| 1709 |
+
knowledge_graph_inst,
|
| 1710 |
),
|
| 1711 |
_find_related_text_unit_from_relationships(
|
| 1712 |
+
edge_datas,
|
| 1713 |
+
query_param,
|
| 1714 |
+
text_chunks_db,
|
| 1715 |
+
knowledge_graph_inst,
|
| 1716 |
),
|
| 1717 |
)
|
| 1718 |
logger.info(
|
lightrag/utils.py
CHANGED
|
@@ -12,7 +12,7 @@ import re
|
|
| 12 |
from dataclasses import dataclass
|
| 13 |
from functools import wraps
|
| 14 |
from hashlib import md5
|
| 15 |
-
from typing import Any, Protocol, Callable, TYPE_CHECKING, List
|
| 16 |
import xml.etree.ElementTree as ET
|
| 17 |
import numpy as np
|
| 18 |
from lightrag.prompt import PROMPTS
|
|
@@ -311,6 +311,7 @@ class TokenizerInterface(Protocol):
|
|
| 311 |
"""
|
| 312 |
Defines the interface for a tokenizer, requiring encode and decode methods.
|
| 313 |
"""
|
|
|
|
| 314 |
def encode(self, content: str) -> List[int]:
|
| 315 |
"""Encodes a string into a list of tokens."""
|
| 316 |
...
|
|
@@ -319,10 +320,12 @@ class TokenizerInterface(Protocol):
|
|
| 319 |
"""Decodes a list of tokens into a string."""
|
| 320 |
...
|
| 321 |
|
|
|
|
| 322 |
class Tokenizer:
|
| 323 |
"""
|
| 324 |
A wrapper around a tokenizer to provide a consistent interface for encoding and decoding.
|
| 325 |
"""
|
|
|
|
| 326 |
def __init__(self, model_name: str, tokenizer: TokenizerInterface):
|
| 327 |
"""
|
| 328 |
Initializes the Tokenizer with a tokenizer model name and a tokenizer instance.
|
|
@@ -363,6 +366,7 @@ class TiktokenTokenizer(Tokenizer):
|
|
| 363 |
"""
|
| 364 |
A Tokenizer implementation using the tiktoken library.
|
| 365 |
"""
|
|
|
|
| 366 |
def __init__(self, model_name: str = "gpt-4o-mini"):
|
| 367 |
"""
|
| 368 |
Initializes the TiktokenTokenizer with a specified model name.
|
|
@@ -385,9 +389,7 @@ class TiktokenTokenizer(Tokenizer):
|
|
| 385 |
tokenizer = tiktoken.encoding_for_model(model_name)
|
| 386 |
super().__init__(model_name=model_name, tokenizer=tokenizer)
|
| 387 |
except KeyError:
|
| 388 |
-
raise ValueError(
|
| 389 |
-
f"Invalid model_name: {model_name}."
|
| 390 |
-
)
|
| 391 |
|
| 392 |
|
| 393 |
def pack_user_ass_to_openai_messages(*args: str):
|
|
@@ -424,7 +426,10 @@ def is_float_regex(value: str) -> bool:
|
|
| 424 |
|
| 425 |
|
| 426 |
def truncate_list_by_token_size(
|
| 427 |
-
list_data: list[Any],
|
|
|
|
|
|
|
|
|
|
| 428 |
) -> list[int]:
|
| 429 |
"""Truncate a list of data by token size"""
|
| 430 |
if max_token_size <= 0:
|
|
|
|
| 12 |
from dataclasses import dataclass
|
| 13 |
from functools import wraps
|
| 14 |
from hashlib import md5
|
| 15 |
+
from typing import Any, Protocol, Callable, TYPE_CHECKING, List
|
| 16 |
import xml.etree.ElementTree as ET
|
| 17 |
import numpy as np
|
| 18 |
from lightrag.prompt import PROMPTS
|
|
|
|
| 311 |
"""
|
| 312 |
Defines the interface for a tokenizer, requiring encode and decode methods.
|
| 313 |
"""
|
| 314 |
+
|
| 315 |
def encode(self, content: str) -> List[int]:
|
| 316 |
"""Encodes a string into a list of tokens."""
|
| 317 |
...
|
|
|
|
| 320 |
"""Decodes a list of tokens into a string."""
|
| 321 |
...
|
| 322 |
|
| 323 |
+
|
| 324 |
class Tokenizer:
|
| 325 |
"""
|
| 326 |
A wrapper around a tokenizer to provide a consistent interface for encoding and decoding.
|
| 327 |
"""
|
| 328 |
+
|
| 329 |
def __init__(self, model_name: str, tokenizer: TokenizerInterface):
|
| 330 |
"""
|
| 331 |
Initializes the Tokenizer with a tokenizer model name and a tokenizer instance.
|
|
|
|
| 366 |
"""
|
| 367 |
A Tokenizer implementation using the tiktoken library.
|
| 368 |
"""
|
| 369 |
+
|
| 370 |
def __init__(self, model_name: str = "gpt-4o-mini"):
|
| 371 |
"""
|
| 372 |
Initializes the TiktokenTokenizer with a specified model name.
|
|
|
|
| 389 |
tokenizer = tiktoken.encoding_for_model(model_name)
|
| 390 |
super().__init__(model_name=model_name, tokenizer=tokenizer)
|
| 391 |
except KeyError:
|
| 392 |
+
raise ValueError(f"Invalid model_name: {model_name}.")
|
|
|
|
|
|
|
| 393 |
|
| 394 |
|
| 395 |
def pack_user_ass_to_openai_messages(*args: str):
|
|
|
|
| 426 |
|
| 427 |
|
| 428 |
def truncate_list_by_token_size(
|
| 429 |
+
list_data: list[Any],
|
| 430 |
+
key: Callable[[Any], str],
|
| 431 |
+
max_token_size: int,
|
| 432 |
+
tokenizer: Tokenizer,
|
| 433 |
) -> list[int]:
|
| 434 |
"""Truncate a list of data by token size"""
|
| 435 |
if max_token_size <= 0:
|