Fix: rename rerank parameter from top_k to top_n
Browse filesThe change aligns with the API parameter naming used by Jina and Cohere rerank services, ensuring consistency and clarity.
- examples/rerank_example.py +3 -3
- lightrag/api/lightrag_server.py +2 -2
- lightrag/operate.py +7 -7
- lightrag/rerank.py +19 -19
examples/rerank_example.py
CHANGED
|
@@ -57,7 +57,7 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
|
|
| 57 |
)
|
| 58 |
|
| 59 |
|
| 60 |
-
async def my_rerank_func(query: str, documents: list,
|
| 61 |
"""Custom rerank function with all settings included"""
|
| 62 |
return await custom_rerank(
|
| 63 |
query=query,
|
|
@@ -65,7 +65,7 @@ async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwarg
|
|
| 65 |
model="BAAI/bge-reranker-v2-m3",
|
| 66 |
base_url="https://api.your-rerank-provider.com/v1/rerank",
|
| 67 |
api_key="your_rerank_api_key_here",
|
| 68 |
-
|
| 69 |
**kwargs,
|
| 70 |
)
|
| 71 |
|
|
@@ -217,7 +217,7 @@ async def test_direct_rerank():
|
|
| 217 |
model="BAAI/bge-reranker-v2-m3",
|
| 218 |
base_url="https://api.your-rerank-provider.com/v1/rerank",
|
| 219 |
api_key="your_rerank_api_key_here",
|
| 220 |
-
|
| 221 |
)
|
| 222 |
|
| 223 |
print("\n✅ Rerank Results:")
|
|
|
|
| 57 |
)
|
| 58 |
|
| 59 |
|
| 60 |
+
async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs):
|
| 61 |
"""Custom rerank function with all settings included"""
|
| 62 |
return await custom_rerank(
|
| 63 |
query=query,
|
|
|
|
| 65 |
model="BAAI/bge-reranker-v2-m3",
|
| 66 |
base_url="https://api.your-rerank-provider.com/v1/rerank",
|
| 67 |
api_key="your_rerank_api_key_here",
|
| 68 |
+
top_n=top_n or 10,
|
| 69 |
**kwargs,
|
| 70 |
)
|
| 71 |
|
|
|
|
| 217 |
model="BAAI/bge-reranker-v2-m3",
|
| 218 |
base_url="https://api.your-rerank-provider.com/v1/rerank",
|
| 219 |
api_key="your_rerank_api_key_here",
|
| 220 |
+
top_n=3,
|
| 221 |
)
|
| 222 |
|
| 223 |
print("\n✅ Rerank Results:")
|
lightrag/api/lightrag_server.py
CHANGED
|
@@ -298,7 +298,7 @@ def create_app(args):
|
|
| 298 |
from lightrag.rerank import custom_rerank
|
| 299 |
|
| 300 |
async def server_rerank_func(
|
| 301 |
-
query: str, documents: list,
|
| 302 |
):
|
| 303 |
"""Server rerank function with configuration from environment variables"""
|
| 304 |
return await custom_rerank(
|
|
@@ -307,7 +307,7 @@ def create_app(args):
|
|
| 307 |
model=args.rerank_model,
|
| 308 |
base_url=args.rerank_binding_host,
|
| 309 |
api_key=args.rerank_binding_api_key,
|
| 310 |
-
|
| 311 |
**kwargs,
|
| 312 |
)
|
| 313 |
|
|
|
|
| 298 |
from lightrag.rerank import custom_rerank
|
| 299 |
|
| 300 |
async def server_rerank_func(
|
| 301 |
+
query: str, documents: list, top_n: int = None, **kwargs
|
| 302 |
):
|
| 303 |
"""Server rerank function with configuration from environment variables"""
|
| 304 |
return await custom_rerank(
|
|
|
|
| 307 |
model=args.rerank_model,
|
| 308 |
base_url=args.rerank_binding_host,
|
| 309 |
api_key=args.rerank_binding_api_key,
|
| 310 |
+
top_n=top_n,
|
| 311 |
**kwargs,
|
| 312 |
)
|
| 313 |
|
lightrag/operate.py
CHANGED
|
@@ -3165,7 +3165,7 @@ async def apply_rerank_if_enabled(
|
|
| 3165 |
retrieved_docs: list[dict],
|
| 3166 |
global_config: dict,
|
| 3167 |
enable_rerank: bool = True,
|
| 3168 |
-
|
| 3169 |
) -> list[dict]:
|
| 3170 |
"""
|
| 3171 |
Apply reranking to retrieved documents if rerank is enabled.
|
|
@@ -3175,7 +3175,7 @@ async def apply_rerank_if_enabled(
|
|
| 3175 |
retrieved_docs: List of retrieved documents
|
| 3176 |
global_config: Global configuration containing rerank settings
|
| 3177 |
enable_rerank: Whether to enable reranking from query parameter
|
| 3178 |
-
|
| 3179 |
|
| 3180 |
Returns:
|
| 3181 |
Reranked documents if rerank is enabled, otherwise original documents
|
|
@@ -3192,18 +3192,18 @@ async def apply_rerank_if_enabled(
|
|
| 3192 |
|
| 3193 |
try:
|
| 3194 |
logger.debug(
|
| 3195 |
-
f"Applying rerank to {len(retrieved_docs)} documents, returning top {
|
| 3196 |
)
|
| 3197 |
|
| 3198 |
# Apply reranking - let rerank_model_func handle top_k internally
|
| 3199 |
reranked_docs = await rerank_func(
|
| 3200 |
query=query,
|
| 3201 |
documents=retrieved_docs,
|
| 3202 |
-
|
| 3203 |
)
|
| 3204 |
if reranked_docs and len(reranked_docs) > 0:
|
| 3205 |
-
if len(reranked_docs) >
|
| 3206 |
-
reranked_docs = reranked_docs[:
|
| 3207 |
logger.info(
|
| 3208 |
f"Successfully reranked {len(retrieved_docs)} documents to {len(reranked_docs)}"
|
| 3209 |
)
|
|
@@ -3263,7 +3263,7 @@ async def process_chunks_unified(
|
|
| 3263 |
retrieved_docs=unique_chunks,
|
| 3264 |
global_config=global_config,
|
| 3265 |
enable_rerank=query_param.enable_rerank,
|
| 3266 |
-
|
| 3267 |
)
|
| 3268 |
logger.debug(f"Rerank: {len(unique_chunks)} chunks (source: {source_type})")
|
| 3269 |
|
|
|
|
| 3165 |
retrieved_docs: list[dict],
|
| 3166 |
global_config: dict,
|
| 3167 |
enable_rerank: bool = True,
|
| 3168 |
+
top_n: int = None,
|
| 3169 |
) -> list[dict]:
|
| 3170 |
"""
|
| 3171 |
Apply reranking to retrieved documents if rerank is enabled.
|
|
|
|
| 3175 |
retrieved_docs: List of retrieved documents
|
| 3176 |
global_config: Global configuration containing rerank settings
|
| 3177 |
enable_rerank: Whether to enable reranking from query parameter
|
| 3178 |
+
top_n: Number of top documents to return after reranking
|
| 3179 |
|
| 3180 |
Returns:
|
| 3181 |
Reranked documents if rerank is enabled, otherwise original documents
|
|
|
|
| 3192 |
|
| 3193 |
try:
|
| 3194 |
logger.debug(
|
| 3195 |
+
f"Applying rerank to {len(retrieved_docs)} documents, returning top {top_n}"
|
| 3196 |
)
|
| 3197 |
|
| 3198 |
# Apply reranking - let rerank_model_func handle top_k internally
|
| 3199 |
reranked_docs = await rerank_func(
|
| 3200 |
query=query,
|
| 3201 |
documents=retrieved_docs,
|
| 3202 |
+
top_n=top_n,
|
| 3203 |
)
|
| 3204 |
if reranked_docs and len(reranked_docs) > 0:
|
| 3205 |
+
if len(reranked_docs) > top_n:
|
| 3206 |
+
reranked_docs = reranked_docs[:top_n]
|
| 3207 |
logger.info(
|
| 3208 |
f"Successfully reranked {len(retrieved_docs)} documents to {len(reranked_docs)}"
|
| 3209 |
)
|
|
|
|
| 3263 |
retrieved_docs=unique_chunks,
|
| 3264 |
global_config=global_config,
|
| 3265 |
enable_rerank=query_param.enable_rerank,
|
| 3266 |
+
top_n=rerank_top_k,
|
| 3267 |
)
|
| 3268 |
logger.debug(f"Rerank: {len(unique_chunks)} chunks (source: {source_type})")
|
| 3269 |
|
lightrag/rerank.py
CHANGED
|
@@ -41,13 +41,13 @@ class RerankModel(BaseModel):
|
|
| 41 |
|
| 42 |
Or define a custom function directly:
|
| 43 |
```python
|
| 44 |
-
async def my_rerank_func(query: str, documents: list,
|
| 45 |
return await jina_rerank(
|
| 46 |
query=query,
|
| 47 |
documents=documents,
|
| 48 |
model="BAAI/bge-reranker-v2-m3",
|
| 49 |
api_key="your_api_key_here",
|
| 50 |
-
|
| 51 |
**kwargs
|
| 52 |
)
|
| 53 |
|
|
@@ -71,14 +71,14 @@ class RerankModel(BaseModel):
|
|
| 71 |
self,
|
| 72 |
query: str,
|
| 73 |
documents: List[Dict[str, Any]],
|
| 74 |
-
|
| 75 |
**extra_kwargs,
|
| 76 |
) -> List[Dict[str, Any]]:
|
| 77 |
"""Rerank documents using the configured model function."""
|
| 78 |
# Merge extra kwargs with model kwargs
|
| 79 |
kwargs = {**self.kwargs, **extra_kwargs}
|
| 80 |
return await self.rerank_func(
|
| 81 |
-
query=query, documents=documents,
|
| 82 |
)
|
| 83 |
|
| 84 |
|
|
@@ -98,7 +98,7 @@ class MultiRerankModel(BaseModel):
|
|
| 98 |
query: str,
|
| 99 |
documents: List[Dict[str, Any]],
|
| 100 |
mode: str = "default",
|
| 101 |
-
|
| 102 |
**kwargs,
|
| 103 |
) -> List[Dict[str, Any]]:
|
| 104 |
"""Rerank using the appropriate model based on mode."""
|
|
@@ -116,7 +116,7 @@ class MultiRerankModel(BaseModel):
|
|
| 116 |
logger.warning(f"No rerank model available for mode: {mode}")
|
| 117 |
return documents
|
| 118 |
|
| 119 |
-
return await model.rerank(query, documents,
|
| 120 |
|
| 121 |
|
| 122 |
async def generic_rerank_api(
|
|
@@ -125,7 +125,7 @@ async def generic_rerank_api(
|
|
| 125 |
model: str,
|
| 126 |
base_url: str,
|
| 127 |
api_key: str,
|
| 128 |
-
|
| 129 |
**kwargs,
|
| 130 |
) -> List[Dict[str, Any]]:
|
| 131 |
"""
|
|
@@ -137,7 +137,7 @@ async def generic_rerank_api(
|
|
| 137 |
model: Model identifier
|
| 138 |
base_url: API endpoint URL
|
| 139 |
api_key: API authentication key
|
| 140 |
-
|
| 141 |
**kwargs: Additional API-specific parameters
|
| 142 |
|
| 143 |
Returns:
|
|
@@ -165,8 +165,8 @@ async def generic_rerank_api(
|
|
| 165 |
|
| 166 |
data = {"model": model, "query": query, "documents": prepared_docs, **kwargs}
|
| 167 |
|
| 168 |
-
if
|
| 169 |
-
data["
|
| 170 |
|
| 171 |
try:
|
| 172 |
async with aiohttp.ClientSession() as session:
|
|
@@ -206,7 +206,7 @@ async def jina_rerank(
|
|
| 206 |
query: str,
|
| 207 |
documents: List[Dict[str, Any]],
|
| 208 |
model: str = "BAAI/bge-reranker-v2-m3",
|
| 209 |
-
|
| 210 |
base_url: str = "https://api.jina.ai/v1/rerank",
|
| 211 |
api_key: Optional[str] = None,
|
| 212 |
**kwargs,
|
|
@@ -218,7 +218,7 @@ async def jina_rerank(
|
|
| 218 |
query: The search query
|
| 219 |
documents: List of documents to rerank
|
| 220 |
model: Jina rerank model name
|
| 221 |
-
|
| 222 |
base_url: Jina API endpoint
|
| 223 |
api_key: Jina API key
|
| 224 |
**kwargs: Additional parameters
|
|
@@ -235,7 +235,7 @@ async def jina_rerank(
|
|
| 235 |
model=model,
|
| 236 |
base_url=base_url,
|
| 237 |
api_key=api_key,
|
| 238 |
-
|
| 239 |
**kwargs,
|
| 240 |
)
|
| 241 |
|
|
@@ -244,7 +244,7 @@ async def cohere_rerank(
|
|
| 244 |
query: str,
|
| 245 |
documents: List[Dict[str, Any]],
|
| 246 |
model: str = "rerank-english-v2.0",
|
| 247 |
-
|
| 248 |
base_url: str = "https://api.cohere.ai/v1/rerank",
|
| 249 |
api_key: Optional[str] = None,
|
| 250 |
**kwargs,
|
|
@@ -256,7 +256,7 @@ async def cohere_rerank(
|
|
| 256 |
query: The search query
|
| 257 |
documents: List of documents to rerank
|
| 258 |
model: Cohere rerank model name
|
| 259 |
-
|
| 260 |
base_url: Cohere API endpoint
|
| 261 |
api_key: Cohere API key
|
| 262 |
**kwargs: Additional parameters
|
|
@@ -273,7 +273,7 @@ async def cohere_rerank(
|
|
| 273 |
model=model,
|
| 274 |
base_url=base_url,
|
| 275 |
api_key=api_key,
|
| 276 |
-
|
| 277 |
**kwargs,
|
| 278 |
)
|
| 279 |
|
|
@@ -285,7 +285,7 @@ async def custom_rerank(
|
|
| 285 |
model: str,
|
| 286 |
base_url: str,
|
| 287 |
api_key: str,
|
| 288 |
-
|
| 289 |
**kwargs,
|
| 290 |
) -> List[Dict[str, Any]]:
|
| 291 |
"""
|
|
@@ -298,7 +298,7 @@ async def custom_rerank(
|
|
| 298 |
model=model,
|
| 299 |
base_url=base_url,
|
| 300 |
api_key=api_key,
|
| 301 |
-
|
| 302 |
**kwargs,
|
| 303 |
)
|
| 304 |
|
|
@@ -317,7 +317,7 @@ if __name__ == "__main__":
|
|
| 317 |
query = "What is the capital of France?"
|
| 318 |
|
| 319 |
result = await jina_rerank(
|
| 320 |
-
query=query, documents=docs,
|
| 321 |
)
|
| 322 |
print(result)
|
| 323 |
|
|
|
|
| 41 |
|
| 42 |
Or define a custom function directly:
|
| 43 |
```python
|
| 44 |
+
async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs):
|
| 45 |
return await jina_rerank(
|
| 46 |
query=query,
|
| 47 |
documents=documents,
|
| 48 |
model="BAAI/bge-reranker-v2-m3",
|
| 49 |
api_key="your_api_key_here",
|
| 50 |
+
top_n=top_n or 10,
|
| 51 |
**kwargs
|
| 52 |
)
|
| 53 |
|
|
|
|
| 71 |
self,
|
| 72 |
query: str,
|
| 73 |
documents: List[Dict[str, Any]],
|
| 74 |
+
top_n: Optional[int] = None,
|
| 75 |
**extra_kwargs,
|
| 76 |
) -> List[Dict[str, Any]]:
|
| 77 |
"""Rerank documents using the configured model function."""
|
| 78 |
# Merge extra kwargs with model kwargs
|
| 79 |
kwargs = {**self.kwargs, **extra_kwargs}
|
| 80 |
return await self.rerank_func(
|
| 81 |
+
query=query, documents=documents, top_n=top_n, **kwargs
|
| 82 |
)
|
| 83 |
|
| 84 |
|
|
|
|
| 98 |
query: str,
|
| 99 |
documents: List[Dict[str, Any]],
|
| 100 |
mode: str = "default",
|
| 101 |
+
top_n: Optional[int] = None,
|
| 102 |
**kwargs,
|
| 103 |
) -> List[Dict[str, Any]]:
|
| 104 |
"""Rerank using the appropriate model based on mode."""
|
|
|
|
| 116 |
logger.warning(f"No rerank model available for mode: {mode}")
|
| 117 |
return documents
|
| 118 |
|
| 119 |
+
return await model.rerank(query, documents, top_n, **kwargs)
|
| 120 |
|
| 121 |
|
| 122 |
async def generic_rerank_api(
|
|
|
|
| 125 |
model: str,
|
| 126 |
base_url: str,
|
| 127 |
api_key: str,
|
| 128 |
+
top_n: Optional[int] = None,
|
| 129 |
**kwargs,
|
| 130 |
) -> List[Dict[str, Any]]:
|
| 131 |
"""
|
|
|
|
| 137 |
model: Model identifier
|
| 138 |
base_url: API endpoint URL
|
| 139 |
api_key: API authentication key
|
| 140 |
+
top_n: Number of top results to return
|
| 141 |
**kwargs: Additional API-specific parameters
|
| 142 |
|
| 143 |
Returns:
|
|
|
|
| 165 |
|
| 166 |
data = {"model": model, "query": query, "documents": prepared_docs, **kwargs}
|
| 167 |
|
| 168 |
+
if top_n is not None:
|
| 169 |
+
data["top_n"] = min(top_n, len(prepared_docs))
|
| 170 |
|
| 171 |
try:
|
| 172 |
async with aiohttp.ClientSession() as session:
|
|
|
|
| 206 |
query: str,
|
| 207 |
documents: List[Dict[str, Any]],
|
| 208 |
model: str = "BAAI/bge-reranker-v2-m3",
|
| 209 |
+
top_n: Optional[int] = None,
|
| 210 |
base_url: str = "https://api.jina.ai/v1/rerank",
|
| 211 |
api_key: Optional[str] = None,
|
| 212 |
**kwargs,
|
|
|
|
| 218 |
query: The search query
|
| 219 |
documents: List of documents to rerank
|
| 220 |
model: Jina rerank model name
|
| 221 |
+
top_n: Number of top results to return
|
| 222 |
base_url: Jina API endpoint
|
| 223 |
api_key: Jina API key
|
| 224 |
**kwargs: Additional parameters
|
|
|
|
| 235 |
model=model,
|
| 236 |
base_url=base_url,
|
| 237 |
api_key=api_key,
|
| 238 |
+
top_n=top_n,
|
| 239 |
**kwargs,
|
| 240 |
)
|
| 241 |
|
|
|
|
| 244 |
query: str,
|
| 245 |
documents: List[Dict[str, Any]],
|
| 246 |
model: str = "rerank-english-v2.0",
|
| 247 |
+
top_n: Optional[int] = None,
|
| 248 |
base_url: str = "https://api.cohere.ai/v1/rerank",
|
| 249 |
api_key: Optional[str] = None,
|
| 250 |
**kwargs,
|
|
|
|
| 256 |
query: The search query
|
| 257 |
documents: List of documents to rerank
|
| 258 |
model: Cohere rerank model name
|
| 259 |
+
top_n: Number of top results to return
|
| 260 |
base_url: Cohere API endpoint
|
| 261 |
api_key: Cohere API key
|
| 262 |
**kwargs: Additional parameters
|
|
|
|
| 273 |
model=model,
|
| 274 |
base_url=base_url,
|
| 275 |
api_key=api_key,
|
| 276 |
+
top_n=top_n,
|
| 277 |
**kwargs,
|
| 278 |
)
|
| 279 |
|
|
|
|
| 285 |
model: str,
|
| 286 |
base_url: str,
|
| 287 |
api_key: str,
|
| 288 |
+
top_n: Optional[int] = None,
|
| 289 |
**kwargs,
|
| 290 |
) -> List[Dict[str, Any]]:
|
| 291 |
"""
|
|
|
|
| 298 |
model=model,
|
| 299 |
base_url=base_url,
|
| 300 |
api_key=api_key,
|
| 301 |
+
top_n=top_n,
|
| 302 |
**kwargs,
|
| 303 |
)
|
| 304 |
|
|
|
|
| 317 |
query = "What is the capital of France?"
|
| 318 |
|
| 319 |
result = await jina_rerank(
|
| 320 |
+
query=query, documents=docs, top_n=2, api_key="your-api-key-here"
|
| 321 |
)
|
| 322 |
print(result)
|
| 323 |
|