drahnreb commited on
Commit
6c0bb1c
·
1 Parent(s): f47a464

fix linting

Browse files
examples/lightrag_gemini_demo_no_tiktoken.py CHANGED
@@ -51,10 +51,12 @@ class GemmaTokenizer(Tokenizer):
51
  "google/gemma3": _TokenizerConfig(
52
  tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/cb7c0152a369e43908e769eb09e1ce6043afe084/tokenizer/gemma3_cleaned_262144_v2.spiece.model",
53
  tokenizer_model_hash="1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c",
54
- )
55
- }
56
 
57
- def __init__(self, model_name: str = "gemini-2.0-flash", tokenizer_dir: Optional[str] = None):
 
 
58
  # https://github.com/google/gemma_pytorch/tree/main/tokenizer
59
  if "1.5" in model_name or "1.0" in model_name:
60
  # up to gemini 1.5 gemma2 is a comparable local tokenizer
@@ -77,7 +79,9 @@ class GemmaTokenizer(Tokenizer):
77
  else:
78
  model_data = None
79
  if not model_data:
80
- model_data = self._load_from_url(file_url=file_url, expected_hash=expected_hash)
 
 
81
  self.save_tokenizer_to_cache(cache_path=file_path, model_data=model_data)
82
 
83
  tokenizer = spm.SentencePieceProcessor()
@@ -140,7 +144,7 @@ class GemmaTokenizer(Tokenizer):
140
 
141
  # def encode(self, content: str) -> list[int]:
142
  # return self.tokenizer.encode(content)
143
-
144
  # def decode(self, tokens: list[int]) -> str:
145
  # return self.tokenizer.decode(tokens)
146
 
@@ -187,7 +191,10 @@ async def initialize_rag():
187
  rag = LightRAG(
188
  working_dir=WORKING_DIR,
189
  # tiktoken_model_name="gpt-4o-mini",
190
- tokenizer=GemmaTokenizer(tokenizer_dir=(Path(WORKING_DIR) / "vertexai_tokenizer_model"), model_name="gemini-2.0-flash"),
 
 
 
191
  llm_model_func=llm_model_func,
192
  embedding_func=EmbeddingFunc(
193
  embedding_dim=384,
 
51
  "google/gemma3": _TokenizerConfig(
52
  tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/cb7c0152a369e43908e769eb09e1ce6043afe084/tokenizer/gemma3_cleaned_262144_v2.spiece.model",
53
  tokenizer_model_hash="1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c",
54
+ ),
55
+ }
56
 
57
+ def __init__(
58
+ self, model_name: str = "gemini-2.0-flash", tokenizer_dir: Optional[str] = None
59
+ ):
60
  # https://github.com/google/gemma_pytorch/tree/main/tokenizer
61
  if "1.5" in model_name or "1.0" in model_name:
62
  # up to gemini 1.5 gemma2 is a comparable local tokenizer
 
79
  else:
80
  model_data = None
81
  if not model_data:
82
+ model_data = self._load_from_url(
83
+ file_url=file_url, expected_hash=expected_hash
84
+ )
85
  self.save_tokenizer_to_cache(cache_path=file_path, model_data=model_data)
86
 
87
  tokenizer = spm.SentencePieceProcessor()
 
144
 
145
  # def encode(self, content: str) -> list[int]:
146
  # return self.tokenizer.encode(content)
147
+
148
  # def decode(self, tokens: list[int]) -> str:
149
  # return self.tokenizer.decode(tokens)
150
 
 
191
  rag = LightRAG(
192
  working_dir=WORKING_DIR,
193
  # tiktoken_model_name="gpt-4o-mini",
194
+ tokenizer=GemmaTokenizer(
195
+ tokenizer_dir=(Path(WORKING_DIR) / "vertexai_tokenizer_model"),
196
+ model_name="gemini-2.0-flash",
197
+ ),
198
  llm_model_func=llm_model_func,
199
  embedding_func=EmbeddingFunc(
200
  embedding_dim=384,
lightrag/api/routers/ollama_api.py CHANGED
@@ -10,7 +10,7 @@ from fastapi.responses import StreamingResponse
10
  import asyncio
11
  from ascii_colors import trace_exception
12
  from lightrag import LightRAG, QueryParam
13
- from lightrag.utils import TiktokenTokenizer
14
  from lightrag.api.utils_api import ollama_server_infos, get_combined_auth_dependency
15
  from fastapi import Depends
16
 
 
10
  import asyncio
11
  from ascii_colors import trace_exception
12
  from lightrag import LightRAG, QueryParam
13
+ from lightrag.utils import TiktokenTokenizer
14
  from lightrag.api.utils_api import ollama_server_infos, get_combined_auth_dependency
15
  from fastapi import Depends
16
 
lightrag/lightrag.py CHANGED
@@ -7,7 +7,18 @@ import warnings
7
  from dataclasses import asdict, dataclass, field
8
  from datetime import datetime
9
  from functools import partial
10
- from typing import Any, AsyncIterator, Callable, Iterator, cast, final, Literal, Optional, List, Dict
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  from lightrag.kg import (
13
  STORAGES,
@@ -1147,11 +1158,7 @@ class LightRAG:
1147
  for chunk_data in custom_kg.get("chunks", []):
1148
  chunk_content = clean_text(chunk_data["content"])
1149
  source_id = chunk_data["source_id"]
1150
- tokens = len(
1151
- self.tokenizer.encode(
1152
- chunk_content
1153
- )
1154
- )
1155
  chunk_order_index = (
1156
  0
1157
  if "chunk_order_index" not in chunk_data.keys()
 
7
  from dataclasses import asdict, dataclass, field
8
  from datetime import datetime
9
  from functools import partial
10
+ from typing import (
11
+ Any,
12
+ AsyncIterator,
13
+ Callable,
14
+ Iterator,
15
+ cast,
16
+ final,
17
+ Literal,
18
+ Optional,
19
+ List,
20
+ Dict,
21
+ )
22
 
23
  from lightrag.kg import (
24
  STORAGES,
 
1158
  for chunk_data in custom_kg.get("chunks", []):
1159
  chunk_content = clean_text(chunk_data["content"])
1160
  source_id = chunk_data["source_id"]
1161
+ tokens = len(self.tokenizer.encode(chunk_content))
 
 
 
 
1162
  chunk_order_index = (
1163
  0
1164
  if "chunk_order_index" not in chunk_data.keys()
lightrag/operate.py CHANGED
@@ -88,9 +88,7 @@ def chunking_by_token_size(
88
  for index, start in enumerate(
89
  range(0, len(tokens), max_token_size - overlap_token_size)
90
  ):
91
- chunk_content = tokenizer.decode(
92
- tokens[start : start + max_token_size]
93
- )
94
  results.append(
95
  {
96
  "tokens": min(max_token_size, len(tokens) - start),
@@ -126,9 +124,7 @@ async def _handle_entity_relation_summary(
126
  if len(tokens) < summary_max_tokens: # No need for summary
127
  return description
128
  prompt_template = PROMPTS["summarize_entity_descriptions"]
129
- use_description = tokenizer.decode(
130
- tokens[:llm_max_tokens]
131
- )
132
  context_base = dict(
133
  entity_name=entity_or_relation_name,
134
  description_list=use_description.split(GRAPH_FIELD_SEP),
@@ -1378,10 +1374,15 @@ async def _get_node_data(
1378
  ] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
1379
  # get entitytext chunk
1380
  use_text_units = await _find_most_related_text_unit_from_entities(
1381
- node_datas, query_param, text_chunks_db, knowledge_graph_inst,
 
 
 
1382
  )
1383
  use_relations = await _find_most_related_edges_from_entities(
1384
- node_datas, query_param, knowledge_graph_inst,
 
 
1385
  )
1386
 
1387
  tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
@@ -1703,10 +1704,15 @@ async def _get_edge_data(
1703
  )
1704
  use_entities, use_text_units = await asyncio.gather(
1705
  _find_most_related_entities_from_relationships(
1706
- edge_datas, query_param, knowledge_graph_inst,
 
 
1707
  ),
1708
  _find_related_text_unit_from_relationships(
1709
- edge_datas, query_param, text_chunks_db, knowledge_graph_inst,
 
 
 
1710
  ),
1711
  )
1712
  logger.info(
 
88
  for index, start in enumerate(
89
  range(0, len(tokens), max_token_size - overlap_token_size)
90
  ):
91
+ chunk_content = tokenizer.decode(tokens[start : start + max_token_size])
 
 
92
  results.append(
93
  {
94
  "tokens": min(max_token_size, len(tokens) - start),
 
124
  if len(tokens) < summary_max_tokens: # No need for summary
125
  return description
126
  prompt_template = PROMPTS["summarize_entity_descriptions"]
127
+ use_description = tokenizer.decode(tokens[:llm_max_tokens])
 
 
128
  context_base = dict(
129
  entity_name=entity_or_relation_name,
130
  description_list=use_description.split(GRAPH_FIELD_SEP),
 
1374
  ] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
1375
  # get entitytext chunk
1376
  use_text_units = await _find_most_related_text_unit_from_entities(
1377
+ node_datas,
1378
+ query_param,
1379
+ text_chunks_db,
1380
+ knowledge_graph_inst,
1381
  )
1382
  use_relations = await _find_most_related_edges_from_entities(
1383
+ node_datas,
1384
+ query_param,
1385
+ knowledge_graph_inst,
1386
  )
1387
 
1388
  tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
 
1704
  )
1705
  use_entities, use_text_units = await asyncio.gather(
1706
  _find_most_related_entities_from_relationships(
1707
+ edge_datas,
1708
+ query_param,
1709
+ knowledge_graph_inst,
1710
  ),
1711
  _find_related_text_unit_from_relationships(
1712
+ edge_datas,
1713
+ query_param,
1714
+ text_chunks_db,
1715
+ knowledge_graph_inst,
1716
  ),
1717
  )
1718
  logger.info(
lightrag/utils.py CHANGED
@@ -12,7 +12,7 @@ import re
12
  from dataclasses import dataclass
13
  from functools import wraps
14
  from hashlib import md5
15
- from typing import Any, Protocol, Callable, TYPE_CHECKING, List, Optional, Union
16
  import xml.etree.ElementTree as ET
17
  import numpy as np
18
  from lightrag.prompt import PROMPTS
@@ -311,6 +311,7 @@ class TokenizerInterface(Protocol):
311
  """
312
  Defines the interface for a tokenizer, requiring encode and decode methods.
313
  """
 
314
  def encode(self, content: str) -> List[int]:
315
  """Encodes a string into a list of tokens."""
316
  ...
@@ -319,10 +320,12 @@ class TokenizerInterface(Protocol):
319
  """Decodes a list of tokens into a string."""
320
  ...
321
 
 
322
  class Tokenizer:
323
  """
324
  A wrapper around a tokenizer to provide a consistent interface for encoding and decoding.
325
  """
 
326
  def __init__(self, model_name: str, tokenizer: TokenizerInterface):
327
  """
328
  Initializes the Tokenizer with a tokenizer model name and a tokenizer instance.
@@ -363,6 +366,7 @@ class TiktokenTokenizer(Tokenizer):
363
  """
364
  A Tokenizer implementation using the tiktoken library.
365
  """
 
366
  def __init__(self, model_name: str = "gpt-4o-mini"):
367
  """
368
  Initializes the TiktokenTokenizer with a specified model name.
@@ -385,9 +389,7 @@ class TiktokenTokenizer(Tokenizer):
385
  tokenizer = tiktoken.encoding_for_model(model_name)
386
  super().__init__(model_name=model_name, tokenizer=tokenizer)
387
  except KeyError:
388
- raise ValueError(
389
- f"Invalid model_name: {model_name}."
390
- )
391
 
392
 
393
  def pack_user_ass_to_openai_messages(*args: str):
@@ -424,7 +426,10 @@ def is_float_regex(value: str) -> bool:
424
 
425
 
426
  def truncate_list_by_token_size(
427
- list_data: list[Any], key: Callable[[Any], str], max_token_size: int, tokenizer: Tokenizer
 
 
 
428
  ) -> list[int]:
429
  """Truncate a list of data by token size"""
430
  if max_token_size <= 0:
 
12
  from dataclasses import dataclass
13
  from functools import wraps
14
  from hashlib import md5
15
+ from typing import Any, Protocol, Callable, TYPE_CHECKING, List
16
  import xml.etree.ElementTree as ET
17
  import numpy as np
18
  from lightrag.prompt import PROMPTS
 
311
  """
312
  Defines the interface for a tokenizer, requiring encode and decode methods.
313
  """
314
+
315
  def encode(self, content: str) -> List[int]:
316
  """Encodes a string into a list of tokens."""
317
  ...
 
320
  """Decodes a list of tokens into a string."""
321
  ...
322
 
323
+
324
  class Tokenizer:
325
  """
326
  A wrapper around a tokenizer to provide a consistent interface for encoding and decoding.
327
  """
328
+
329
  def __init__(self, model_name: str, tokenizer: TokenizerInterface):
330
  """
331
  Initializes the Tokenizer with a tokenizer model name and a tokenizer instance.
 
366
  """
367
  A Tokenizer implementation using the tiktoken library.
368
  """
369
+
370
  def __init__(self, model_name: str = "gpt-4o-mini"):
371
  """
372
  Initializes the TiktokenTokenizer with a specified model name.
 
389
  tokenizer = tiktoken.encoding_for_model(model_name)
390
  super().__init__(model_name=model_name, tokenizer=tokenizer)
391
  except KeyError:
392
+ raise ValueError(f"Invalid model_name: {model_name}.")
 
 
393
 
394
 
395
  def pack_user_ass_to_openai_messages(*args: str):
 
426
 
427
 
428
  def truncate_list_by_token_size(
429
+ list_data: list[Any],
430
+ key: Callable[[Any], str],
431
+ max_token_size: int,
432
+ tokenizer: Tokenizer,
433
  ) -> list[int]:
434
  """Truncate a list of data by token size"""
435
  if max_token_size <= 0: