diff --git "a/htmlcov/z_357ee38f49d3e320_grounding_py.html" "b/htmlcov/z_357ee38f49d3e320_grounding_py.html" new file mode 100644--- /dev/null +++ "b/htmlcov/z_357ee38f49d3e320_grounding_py.html" @@ -0,0 +1,495 @@ + + + + + Coverage for tinytroupe/agent/grounding.py: 0% + + + + + +
+
+

+ Coverage for tinytroupe / agent / grounding.py: + 0% +

+ +

+ 200 statements   + + + +

+

+ « prev     + ^ index     + » next +       + coverage.py v7.13.4, + created at 2026-02-28 17:48 +0000 +

+ +
+
+
+

1from tinytroupe.utils import JsonSerializableRegistry 

+

2import tinytroupe.utils as utils 

+

3 

+

4from tinytroupe.agent import logger 

+

5from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Document, StorageContext, load_index_from_storage 

+

6from llama_index.core.vector_stores import SimpleVectorStore 

+

7from llama_index.readers.web import SimpleWebPageReader 

+

8import json 

+

9import tempfile 

+

10import os 

+

11import shutil 

+

12 

+

13 

+

14####################################################################################################################### 

+

15# Grounding connectors 

+

16####################################################################################################################### 

+

17 

+

18class GroundingConnector(JsonSerializableRegistry): 

+

19 """ 

+

20 An abstract class representing a grounding connector. A grounding connector is a component that allows an agent to ground 

+

21 its knowledge in external sources, such as files, web pages, databases, etc. 

+

22 """ 

+

23 

+

24 serializable_attributes = ["name"] 

+

25 

+

26 def __init__(self, name:str) -> None: 

+

27 self.name = name 

+

28 

+

29 def retrieve_relevant(self, relevance_target:str, source:str, top_k=20) -> list: 

+

30 raise NotImplementedError("Subclasses must implement this method.") 

+

31 

+

32 def retrieve_by_name(self, name:str) -> str: 

+

33 raise NotImplementedError("Subclasses must implement this method.") 

+

34 

+

35 def list_sources(self) -> list: 

+

36 raise NotImplementedError("Subclasses must implement this method.") 

+

37 

+

38 

+

39@utils.post_init 

+

40class BaseSemanticGroundingConnector(GroundingConnector): 

+

41 """ 

+

42 A base class for semantic grounding connectors. A semantic grounding connector is a component that indexes and retrieves 

+

43 documents based on so-called "semantic search" (i.e, embeddings-based search). This specific implementation 

+

44 is based on the VectorStoreIndex class from the LLaMa-Index library. Here, "documents" refer to the llama-index's 

+

45 data structure that stores a unit of content, not necessarily a file. 

+

46 """ 

+

47 

+

48 serializable_attributes = ["documents", "index"] 

+

49 

+

50 # needs custom deserialization to handle Pydantic models (Document is a Pydantic model) 

+

51 custom_deserializers = {"documents": lambda docs_json: [Document.from_json(doc_json) for doc_json in docs_json], 

+

52 "index": lambda index_json: BaseSemanticGroundingConnector._deserialize_index(index_json)} 

+

53 

+

54 custom_serializers = {"documents": lambda docs: [doc.to_json() for doc in docs] if docs is not None else None, 

+

55 "index": lambda index: BaseSemanticGroundingConnector._serialize_index(index)} 

+

56 

+

57 def __init__(self, name:str="Semantic Grounding") -> None: 

+

58 super().__init__(name) 

+

59 

+

60 self.documents = None 

+

61 self.name_to_document = None 

+

62 self.index = None 

+

63 

+

64 # @post_init ensures that _post_init is called after the __init__ method 

+

65 

+

66 def _post_init(self): 

+

67 """ 

+

68 This will run after __init__, since the class has the @post_init decorator. 

+

69 It is convenient to separate some of the initialization processes to make deserialize easier. 

+

70 """ 

+

71 self.index = None 

+

72 

+

73 if not hasattr(self, 'documents') or self.documents is None: 

+

74 self.documents = [] 

+

75 

+

76 if not hasattr(self, 'name_to_document') or self.name_to_document is None: 

+

77 self.name_to_document = {} 

+

78 

+

79 if hasattr(self, 'documents') and self.documents is not None: 

+

80 for document in self.documents: 

+

81 # if the document has a semantic memory ID, we use it as the identifier 

+

82 name = document.metadata.get("semantic_memory_id", document.id_) 

+

83 

+

84 # self.name_to_document[name] contains a list, since each source file could be split into multiple pages 

+

85 if name in self.name_to_document: 

+

86 self.name_to_document[name].append(document) 

+

87 else: 

+

88 self.name_to_document[name] = [document] 

+

89 

+

90 # Rebuild index from documents if it's None or invalid 

+

91 if self.index is None and self.documents: 

+

92 logger.warning("No index found. Rebuilding index from documents.") 

+

93 vector_store = SimpleVectorStore() 

+

94 self.index = VectorStoreIndex.from_documents( 

+

95 self.documents, 

+

96 vector_store=vector_store, 

+

97 store_nodes_override=True 

+

98 ) 

+

99 

+

100 # TODO remove? 

+

101 #self.add_documents(self.documents)  

+

102 

+

103 @staticmethod 

+

104 def _serialize_index(index): 

+

105 """Helper function to serialize index with proper storage context""" 

+

106 if index is None: 

+

107 return None 

+

108 

+

109 try: 

+

110 # Create a temporary directory to store the index 

+

111 with tempfile.TemporaryDirectory() as temp_dir: 

+

112 # Persist the index to the temporary directory 

+

113 index.storage_context.persist(persist_dir=temp_dir) 

+

114 

+

115 # Read all the persisted files and store them in a dictionary 

+

116 persisted_data = {} 

+

117 for filename in os.listdir(temp_dir): 

+

118 filepath = os.path.join(temp_dir, filename) 

+

119 if os.path.isfile(filepath): 

+

120 with open(filepath, 'r', encoding="utf-8", errors="replace") as f: 

+

121 persisted_data[filename] = f.read() 

+

122 

+

123 return persisted_data 

+

124 except Exception as e: 

+

125 logger.warning(f"Failed to serialize index: {e}") 

+

126 return None 

+

127 

+

128 @staticmethod 

+

129 def _deserialize_index(index_data): 

+

130 """Helper function to deserialize index with proper error handling""" 

+

131 if not index_data: 

+

132 return None 

+

133 

+

134 try: 

+

135 # Create a temporary directory to restore the index 

+

136 with tempfile.TemporaryDirectory() as temp_dir: 

+

137 # Write all the persisted files to the temporary directory 

+

138 for filename, content in index_data.items(): 

+

139 filepath = os.path.join(temp_dir, filename) 

+

140 with open(filepath, 'w', encoding="utf-8", errors="replace") as f: 

+

141 f.write(content) 

+

142 

+

143 # Load the index from the temporary directory 

+

144 storage_context = StorageContext.from_defaults(persist_dir=temp_dir) 

+

145 index = load_index_from_storage(storage_context) 

+

146 

+

147 return index 

+

148 except Exception as e: 

+

149 # If deserialization fails, return None 

+

150 # The index will be rebuilt from documents in _post_init 

+

151 logger.warning(f"Failed to deserialize index: {e}. Index will be rebuilt.") 

+

152 return None 

+

153 

+

154 def retrieve_relevant(self, relevance_target:str, top_k=20) -> list: 

+

155 """ 

+

156 Retrieves all values from memory that are relevant to a given target. 

+

157 """ 

+

158 # Handle empty or None query 

+

159 if not relevance_target or not relevance_target.strip(): 

+

160 return [] 

+

161 

+

162 if self.index is not None: 

+

163 retriever = self.index.as_retriever(similarity_top_k=top_k) 

+

164 nodes = retriever.retrieve(relevance_target) 

+

165 else: 

+

166 nodes = [] 

+

167 

+

168 retrieved = [] 

+

169 for node in nodes: 

+

170 content = "SOURCE: " + node.metadata.get('file_name', '(unknown)') 

+

171 content += "\n" + "SIMILARITY SCORE:" + str(node.score) 

+

172 content += "\n" + "RELEVANT CONTENT:" + node.text 

+

173 retrieved.append(content) 

+

174 

+

175 logger.debug(f"Content retrieved: {content[:200]}") 

+

176 

+

177 return retrieved 

+

178 

+

179 def retrieve_by_name(self, name:str) -> list: 

+

180 """ 

+

181 Retrieves a content source by its name. 

+

182 """ 

+

183 # TODO also optionally provide a relevance target? 

+

184 results = [] 

+

185 if self.name_to_document is not None and name in self.name_to_document: 

+

186 docs = self.name_to_document[name] 

+

187 for i, doc in enumerate(docs): 

+

188 if doc is not None: 

+

189 content = f"SOURCE: {name}\n" 

+

190 content += f"PAGE: {i}\n" 

+

191 content += "CONTENT: \n" + doc.text[:10000] # TODO a more intelligent way to limit the content 

+

192 results.append(content) 

+

193 

+

194 return results 

+

195 

+

196 

+

197 def list_sources(self) -> list: 

+

198 """ 

+

199 Lists the names of the available content sources. 

+

200 """ 

+

201 if self.name_to_document is not None: 

+

202 return list(self.name_to_document.keys()) 

+

203 else: 

+

204 return [] 

+

205 

+

206 def add_document(self, document) -> None: 

+

207 """ 

+

208 Indexes a document for semantic retrieval. 

+

209 

+

210 Assumes the document has a metadata field called "semantic_memory_id" that is used to identify the document within Semantic Memory. 

+

211 """ 

+

212 self.add_documents([document]) 

+

213 

+

214 def add_documents(self, new_documents) -> list: 

+

215 """ 

+

216 Indexes documents for semantic retrieval. 

+

217 """ 

+

218 # index documents by name 

+

219 if len(new_documents) > 0: 

+

220 

+

221 # process documents individually too 

+

222 for document in new_documents: 

+

223 logger.debug(f"Adding document {document} to index, text is: {document.text}") 

+

224 

+

225 # out of an abundance of caution, we sanitize the text 

+

226 document.text = utils.sanitize_raw_string(document.text) 

+

227 

+

228 logger.debug(f"Document text after sanitization: {document.text}") 

+

229 

+

230 # add the new document to the list of documents after all sanitization and checks 

+

231 self.documents.append(document) 

+

232 

+

233 if document.metadata.get("semantic_memory_id") is not None: 

+

234 # if the document has a semantic memory ID, we use it as the identifier 

+

235 name = document.metadata["semantic_memory_id"] 

+

236 

+

237 # Ensure name_to_document is initialized 

+

238 if not hasattr(self, 'name_to_document') or self.name_to_document is None: 

+

239 self.name_to_document = {} 

+

240 

+

241 # self.name_to_document[name] contains a list, since each source file could be split into multiple pages 

+

242 if name in self.name_to_document: 

+

243 self.name_to_document[name].append(document) 

+

244 else: 

+

245 self.name_to_document[name] = [document] 

+

246 

+

247 

+

248 # index documents for semantic retrieval 

+

249 if self.index is None: 

+

250 # Create storage context with vector store 

+

251 vector_store = SimpleVectorStore() 

+

252 storage_context = StorageContext.from_defaults(vector_store=vector_store) 

+

253 

+

254 self.index = VectorStoreIndex.from_documents( 

+

255 self.documents, 

+

256 storage_context=storage_context, 

+

257 store_nodes_override=True # This ensures nodes (with text) are stored 

+

258 ) 

+

259 else: 

+

260 self.index.refresh(self.documents) 

+

261 

+

262 @staticmethod 

+

263 def _set_internal_id_to_documents(documents:list, external_attribute_name:str ="file_name") -> None: 

+

264 """ 

+

265 Sets the internal ID for each document in the list of documents. 

+

266 This is useful to ensure that each document has a unique identifier. 

+

267 """ 

+

268 for doc in documents: 

+

269 if not hasattr(doc, 'metadata'): 

+

270 doc.metadata = {} 

+

271 doc.metadata["semantic_memory_id"] = doc.metadata.get(external_attribute_name, doc.id_) 

+

272 

+

273 return documents 

+

274 

+

275 

+

276@utils.post_init 

+

277class LocalFilesGroundingConnector(BaseSemanticGroundingConnector): 

+

278 

+

279 serializable_attributes = ["folders_paths"] 

+

280 

+

281 def __init__(self, name:str="Local Files", folders_paths: list=None) -> None: 

+

282 super().__init__(name) 

+

283 

+

284 self.folders_paths = folders_paths 

+

285 

+

286 # @post_init ensures that _post_init is called after the __init__ method 

+

287 

+

288 def _post_init(self): 

+

289 """ 

+

290 This will run after __init__, since the class has the @post_init decorator. 

+

291 It is convenient to separate some of the initialization processes to make deserialize easier. 

+

292 """ 

+

293 self.loaded_folders_paths = [] 

+

294 

+

295 if not hasattr(self, 'folders_paths') or self.folders_paths is None: 

+

296 self.folders_paths = [] 

+

297 

+

298 self.add_folders(self.folders_paths) 

+

299 

+

300 def add_folders(self, folders_paths:list) -> None: 

+

301 """ 

+

302 Adds a path to a folder with files used for grounding. 

+

303 """ 

+

304 

+

305 if folders_paths is not None: 

+

306 for folder_path in folders_paths: 

+

307 try: 

+

308 logger.debug(f"Adding the following folder to grounding index: {folder_path}") 

+

309 self.add_folder(folder_path) 

+

310 except (FileNotFoundError, ValueError) as e: 

+

311 print(f"Error: {e}") 

+

312 print(f"Current working directory: {os.getcwd()}") 

+

313 print(f"Provided path: {folder_path}") 

+

314 print("Please check if the path exists and is accessible.") 

+

315 

+

316 def add_folder(self, folder_path:str) -> None: 

+

317 """ 

+

318 Adds a path to a folder with files used for grounding. 

+

319 """ 

+

320 

+

321 if folder_path not in self.loaded_folders_paths: 

+

322 self._mark_folder_as_loaded(folder_path) 

+

323 

+

324 # for PDF files, please note that the document will be split into pages: https://github.com/run-llama/llama_index/issues/15903 

+

325 new_files = SimpleDirectoryReader(folder_path).load_data() 

+

326 BaseSemanticGroundingConnector._set_internal_id_to_documents(new_files, "file_name") 

+

327 

+

328 self.add_documents(new_files) 

+

329 

+

330 def add_file_path(self, file_path:str) -> None: 

+

331 """ 

+

332 Adds a path to a file used for grounding. 

+

333 """ 

+

334 # a trick to make SimpleDirectoryReader work with a single file 

+

335 new_files = SimpleDirectoryReader(input_files=[file_path]).load_data() 

+

336 

+

337 logger.debug(f"Adding the following file to grounding index: {new_files}") 

+

338 BaseSemanticGroundingConnector._set_internal_id_to_documents(new_files, "file_name") 

+

339 

+

340 def _mark_folder_as_loaded(self, folder_path:str) -> None: 

+

341 if folder_path not in self.loaded_folders_paths: 

+

342 self.loaded_folders_paths.append(folder_path) 

+

343 

+

344 if folder_path not in self.folders_paths: 

+

345 self.folders_paths.append(folder_path) 

+

346 

+

347 

+

348 

+

349 

+

350@utils.post_init 

+

351class WebPagesGroundingConnector(BaseSemanticGroundingConnector): 

+

352 

+

353 serializable_attributes = ["web_urls"] 

+

354 

+

355 def __init__(self, name:str="Web Pages", web_urls: list=None) -> None: 

+

356 super().__init__(name) 

+

357 

+

358 self.web_urls = web_urls 

+

359 

+

360 # @post_init ensures that _post_init is called after the __init__ method 

+

361 

+

362 def _post_init(self): 

+

363 self.loaded_web_urls = [] 

+

364 

+

365 if not hasattr(self, 'web_urls') or self.web_urls is None: 

+

366 self.web_urls = [] 

+

367 

+

368 # load web urls 

+

369 self.add_web_urls(self.web_urls) 

+

370 

+

371 def add_web_urls(self, web_urls:list) -> None: 

+

372 """  

+

373 Adds the data retrieved from the specified URLs to grounding. 

+

374 """ 

+

375 filtered_web_urls = [url for url in web_urls if url not in self.loaded_web_urls] 

+

376 for url in filtered_web_urls: 

+

377 self._mark_web_url_as_loaded(url) 

+

378 

+

379 if len(filtered_web_urls) > 0: 

+

380 new_documents = SimpleWebPageReader(html_to_text=True).load_data(filtered_web_urls) 

+

381 BaseSemanticGroundingConnector._set_internal_id_to_documents(new_documents, "url") 

+

382 self.add_documents(new_documents) 

+

383 

+

384 def add_web_url(self, web_url:str) -> None: 

+

385 """ 

+

386 Adds the data retrieved from the specified URL to grounding. 

+

387 """ 

+

388 # we do it like this because the add_web_urls could run scrapes in parallel, so it is better 

+

389 # to implement this one in terms of the other 

+

390 self.add_web_urls([web_url]) 

+

391 

+

392 def _mark_web_url_as_loaded(self, web_url:str) -> None: 

+

393 if web_url not in self.loaded_web_urls: 

+

394 self.loaded_web_urls.append(web_url) 

+

395 

+

396 if web_url not in self.web_urls: 

+

397 self.web_urls.append(web_url) 

+

398 

+
+ + +