Spaces:
Running
Running
| """ | |
| vector_store.py | |
| βββββββββββββββ | |
| Indexa pares HTR/GT en ChromaDB usando embeddings configurables: | |
| - "openai" : text-embedding-3-small (API) | |
| - "mpnet" : paraphrase-multilingual-mpnet-base-v2 (local) | |
| - "e5" : multilingual-e5-small (local, por defecto) | |
| - "mt5" : mt5-base fine-tuneado en corpus HTR/GT s.XVI (local) | |
| Uso: | |
| from vector_store import VectorStore | |
| vs = VectorStore(embedding_model="mt5") | |
| vs.index(pairs) | |
| results = vs.retrieve("texto htr...", k=5) | |
| """ | |
| import json | |
| import os | |
| from typing import List, Dict | |
| import torch | |
| import numpy as np | |
| from tqdm import tqdm | |
| import chromadb | |
| from chromadb.utils import embedding_functions | |
| from chromadb import EmbeddingFunction, Documents, Embeddings | |
| from transformers import MT5EncoderModel, AutoTokenizer | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| CHROMA_PATH = os.getenv("CHROMA_PATH", "./chroma_db") | |
| OPENAI_KEY = os.getenv("OPENAI_API_KEY", "") | |
| EMBED_MODEL = "text-embedding-3-small" | |
| COLLECTION = "scriptorium_corpus" | |
| # ββ Embedding personalizado para mt5 βββββββββββββββββββββββββββββββββββββββββ | |
| class MT5EmbeddingFunction(EmbeddingFunction): | |
| """ | |
| Usa el encoder de mt5 fine-tuneado con pares HTR/GT para generar embeddings. | |
| El encoder aprendiΓ³ que grafΓas del s.XVI y sus formas corregidas son | |
| semΓ‘nticamente cercanas en el espacio vectorial β ideal para recuperar | |
| ejemplos similares en el RAG. | |
| Estrategia: mean-pooling sobre los hidden states del ΓΊltimo layer, | |
| enmascarando los tokens de padding. | |
| """ | |
| def __init__(self, model_name: str): | |
| print(f" Cargando MT5EncoderModel desde '{model_name}'...") | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = MT5EncoderModel.from_pretrained(model_name) | |
| self.model.eval() | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model.to(self.device) | |
| print(f" MT5 cargado en {self.device} β") | |
| def __call__(self, input: Documents) -> Embeddings: | |
| encoded = self.tokenizer( | |
| list(input), | |
| padding=True, | |
| truncation=True, | |
| max_length=512, | |
| return_tensors="pt", | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(**encoded) | |
| # Mean-pooling con mΓ‘scara de atenciΓ³n | |
| token_embs = outputs.last_hidden_state # (B, T, H) | |
| attention = encoded["attention_mask"].unsqueeze(-1).float() | |
| sum_embs = (token_embs * attention).sum(dim=1) | |
| count = attention.sum(dim=1).clamp(min=1e-9) | |
| embeddings = (sum_embs / count).cpu().numpy() # (B, H) | |
| return embeddings.tolist() | |
| # ββ VectorStore βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class VectorStore: | |
| def __init__(self, embedding_model: str = "e5"): | |
| self.client = chromadb.PersistentClient(path=CHROMA_PATH) | |
| if embedding_model == "mpnet": | |
| self.ef = embedding_functions.SentenceTransformerEmbeddingFunction( | |
| model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" | |
| ) | |
| collection_name = "scriptorium_mpnet" | |
| elif embedding_model == "mt5": | |
| # Apunta a tu modelo fine-tuneado via variable de entorno. | |
| # Puede ser un repo de HF ("alezsd/mt5-htr-xvi") o una ruta local. | |
| mt5_path = os.getenv("MT5_MODEL_PATH", "google/mt5-base") | |
| self.ef = MT5EmbeddingFunction(mt5_path) | |
| collection_name = "scriptorium_mt5" | |
| else: # e5 por defecto | |
| self.ef = embedding_functions.SentenceTransformerEmbeddingFunction( | |
| model_name="intfloat/multilingual-e5-small" | |
| ) | |
| collection_name = "scriptorium_e5" | |
| self.collection = self.client.get_or_create_collection( | |
| name=collection_name, | |
| embedding_function=self.ef, | |
| metadata={"hnsw:space": "cosine"}, | |
| ) | |
| # ββ IndexaciΓ³n ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def index(self, pairs: List[Dict], batch_size: int = 50) -> int: | |
| """ | |
| Indexa los pares HTR/GT. Cada fragmento se almacena con: | |
| - document : texto que se embebe (htr + ' [SEP] ' + gt) | |
| - metadata : tipo, regiΓ³n, fecha, htr, gt originales | |
| - id : identificador ΓΊnico del par | |
| Retorna el nΓΊmero de documentos nuevos aΓ±adidos. | |
| """ | |
| existing_ids = set(self.collection.get(include=[])["ids"]) | |
| to_add = [p for p in pairs if p["id"] not in existing_ids] | |
| if not to_add: | |
| print(f"βΉ Vector store ya actualizado ({len(existing_ids)} documentos).") | |
| return 0 | |
| print(f"π Indexando {len(to_add)} documentos nuevos...") | |
| for i in tqdm(range(0, len(to_add), batch_size), desc="Indexando"): | |
| batch = to_add[i : i + batch_size] | |
| documents = [ | |
| f"HTR: {p['htr']} [SEP] GT: {p['gt']}" | |
| for p in batch | |
| ] | |
| metadatas = [ | |
| { | |
| "htr": p["htr"], | |
| "gt": p["gt"], | |
| "type": p.get("type", ""), | |
| "region": p.get("region", ""), | |
| "date": p.get("date", ""), | |
| "caligrafia": p.get("caligrafia", "desconocida"), | |
| "corrections": json.dumps( | |
| p.get("corrections", []), ensure_ascii=False | |
| ), | |
| } | |
| for p in batch | |
| ] | |
| ids = [p["id"] for p in batch] | |
| self.collection.add( | |
| documents=documents, | |
| metadatas=metadatas, | |
| ids=ids, | |
| ) | |
| print(f"β IndexaciΓ³n completa. Total en store: {self.collection.count()}") | |
| return len(to_add) | |
| # ββ RecuperaciΓ³n ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def retrieve(self, query: str, k: int = 5) -> List[Dict]: | |
| """ | |
| Recupera los k pares mΓ‘s similares al texto HTR de consulta. | |
| Retorna lista de dicts con htr, gt, type, region, date, corrections, score. | |
| """ | |
| if self.collection.count() == 0: | |
| return [] | |
| results = self.collection.query( | |
| query_texts=[query], | |
| n_results=min(k, self.collection.count()), | |
| include=["metadatas", "distances"], | |
| ) | |
| retrieved = [] | |
| for meta, dist in zip( | |
| results["metadatas"][0], results["distances"][0] | |
| ): | |
| retrieved.append({ | |
| "htr": meta["htr"], | |
| "gt": meta["gt"], | |
| "type": meta.get("type", ""), | |
| "region": meta.get("region", ""), | |
| "date": meta.get("date", ""), | |
| "caligrafia": meta.get("caligrafia", ""), | |
| "corrections": json.loads(meta.get("corrections", "[]")), | |
| "score": round(1 - dist, 4), | |
| }) | |
| return retrieved | |
| # ββ Utilidades ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def count(self) -> int: | |
| return self.collection.count() | |
| def reset(self): | |
| """Elimina y recrea la colecciΓ³n (ΓΊtil para re-indexar desde cero).""" | |
| self.client.delete_collection(self.collection.name) | |
| self.collection = self.client.get_or_create_collection( | |
| name=self.collection.name, | |
| embedding_function=self.ef, | |
| metadata={"hnsw:space": "cosine"}, | |
| ) | |
| print("π Vector store reseteado.") |