| import os |
| import json |
| import torch |
| from sentence_transformers import SentenceTransformer |
| from .utils import get_md5 |
|
|
|
|
| class ToolRAGModel: |
| def __init__(self, rag_model_name): |
| self.rag_model_name = rag_model_name |
| self.rag_model = None |
| self.tool_desc_embedding = None |
| self.tool_name = None |
| self.tool_embedding_path = None |
| self.load_rag_model() |
|
|
| def load_rag_model(self): |
| self.rag_model = SentenceTransformer(self.rag_model_name) |
| self.rag_model.max_seq_length = 4096 |
| self.rag_model.tokenizer.padding_side = "right" |
|
|
| def load_tool_desc_embedding(self, toolbox): |
| self.tool_name, _ = toolbox.refresh_tool_name_desc(enable_full_desc=True) |
| all_tools_str = [json.dumps(each) for each in toolbox.prepare_tool_prompts(toolbox.all_tools)] |
| md5_value = get_md5(str(all_tools_str)) |
| print("Computed MD5 for tool embedding:", md5_value) |
|
|
| self.tool_embedding_path = os.path.join( |
| os.path.dirname(__file__), |
| self.rag_model_name.split("/")[-1] + f"_tool_embedding_{md5_value}.pt" |
| ) |
|
|
| if os.path.exists(self.tool_embedding_path): |
| try: |
| self.tool_desc_embedding = torch.load(self.tool_embedding_path, map_location="cpu") |
| assert len(self.tool_desc_embedding) == len(toolbox.all_tools), \ |
| "Tool count mismatch with loaded embeddings." |
| print("\033[92mLoaded cached tool_desc_embedding.\033[0m") |
| return |
| except Exception as e: |
| print(f"⚠️ Failed loading cached embeddings: {e}") |
| self.tool_desc_embedding = None |
|
|
| print("\033[93mGenerating new tool_desc_embedding...\033[0m") |
| self.tool_desc_embedding = self.rag_model.encode( |
| all_tools_str, prompt="", normalize_embeddings=True |
| ) |
|
|
| torch.save(self.tool_desc_embedding, self.tool_embedding_path) |
| print(f"\033[92mSaved new tool_desc_embedding to {self.tool_embedding_path}\033[0m") |
|
|
| def rag_infer(self, query, top_k=5): |
| torch.cuda.empty_cache() |
| queries = [query] |
| query_embeddings = self.rag_model.encode( |
| queries, prompt="", normalize_embeddings=True |
| ) |
| if self.tool_desc_embedding is None: |
| raise RuntimeError("❌ tool_desc_embedding is not initialized. Did you forget to call load_tool_desc_embedding()?") |
|
|
| scores = self.rag_model.similarity( |
| query_embeddings, self.tool_desc_embedding |
| ) |
| top_k = min(top_k, len(self.tool_name)) |
| top_k_indices = torch.topk(scores, top_k).indices.tolist()[0] |
| top_k_tool_names = [self.tool_name[i] for i in top_k_indices] |
| return top_k_tool_names |
|
|