""" vector_store.py ─────────────── Indexa pares HTR/GT en ChromaDB usando embeddings configurables: - "openai" : text-embedding-3-small (API) - "mpnet" : paraphrase-multilingual-mpnet-base-v2 (local) - "e5" : multilingual-e5-small (local, por defecto) - "mt5" : mt5-base fine-tuneado en corpus HTR/GT s.XVI (local) Uso: from vector_store import VectorStore vs = VectorStore(embedding_model="mt5") vs.index(pairs) results = vs.retrieve("texto htr...", k=5) """ import json import os from typing import List, Dict import torch import numpy as np from tqdm import tqdm import chromadb from chromadb.utils import embedding_functions from chromadb import EmbeddingFunction, Documents, Embeddings from transformers import MT5EncoderModel, AutoTokenizer from dotenv import load_dotenv load_dotenv() CHROMA_PATH = os.getenv("CHROMA_PATH", "./chroma_db") OPENAI_KEY = os.getenv("OPENAI_API_KEY", "") EMBED_MODEL = "text-embedding-3-small" COLLECTION = "scriptorium_corpus" # ── Embedding personalizado para mt5 ───────────────────────────────────────── class MT5EmbeddingFunction(EmbeddingFunction): """ Usa el encoder de mt5 fine-tuneado con pares HTR/GT para generar embeddings. El encoder aprendió que grafías del s.XVI y sus formas corregidas son semánticamente cercanas en el espacio vectorial — ideal para recuperar ejemplos similares en el RAG. Estrategia: mean-pooling sobre los hidden states del último layer, enmascarando los tokens de padding. """ def __init__(self, model_name: str): print(f" Cargando MT5EncoderModel desde '{model_name}'...") self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = MT5EncoderModel.from_pretrained(model_name) self.model.eval() self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model.to(self.device) print(f" MT5 cargado en {self.device} ✓") def __call__(self, input: Documents) -> Embeddings: encoded = self.tokenizer( list(input), padding=True, truncation=True, max_length=512, return_tensors="pt", ).to(self.device) with torch.no_grad(): outputs = self.model(**encoded) # Mean-pooling con máscara de atención token_embs = outputs.last_hidden_state # (B, T, H) attention = encoded["attention_mask"].unsqueeze(-1).float() sum_embs = (token_embs * attention).sum(dim=1) count = attention.sum(dim=1).clamp(min=1e-9) embeddings = (sum_embs / count).cpu().numpy() # (B, H) return embeddings.tolist() # ── VectorStore ─────────────────────────────────────────────────────────────── class VectorStore: def __init__(self, embedding_model: str = "e5"): self.client = chromadb.PersistentClient(path=CHROMA_PATH) if embedding_model == "mpnet": self.ef = embedding_functions.SentenceTransformerEmbeddingFunction( model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" ) collection_name = "scriptorium_mpnet" elif embedding_model == "mt5": # Apunta a tu modelo fine-tuneado via variable de entorno. # Puede ser un repo de HF ("alezsd/mt5-htr-xvi") o una ruta local. mt5_path = os.getenv("MT5_MODEL_PATH", "google/mt5-base") self.ef = MT5EmbeddingFunction(mt5_path) collection_name = "scriptorium_mt5" else: # e5 por defecto self.ef = embedding_functions.SentenceTransformerEmbeddingFunction( model_name="intfloat/multilingual-e5-small" ) collection_name = "scriptorium_e5" self.collection = self.client.get_or_create_collection( name=collection_name, embedding_function=self.ef, metadata={"hnsw:space": "cosine"}, ) # ── Indexación ──────────────────────────────────────────────────────────── def index(self, pairs: List[Dict], batch_size: int = 50) -> int: """ Indexa los pares HTR/GT. Cada fragmento se almacena con: - document : texto que se embebe (htr + ' [SEP] ' + gt) - metadata : tipo, región, fecha, htr, gt originales - id : identificador único del par Retorna el número de documentos nuevos añadidos. """ existing_ids = set(self.collection.get(include=[])["ids"]) to_add = [p for p in pairs if p["id"] not in existing_ids] if not to_add: print(f"ℹ Vector store ya actualizado ({len(existing_ids)} documentos).") return 0 print(f"🔄 Indexando {len(to_add)} documentos nuevos...") for i in tqdm(range(0, len(to_add), batch_size), desc="Indexando"): batch = to_add[i : i + batch_size] documents = [ f"HTR: {p['htr']} [SEP] GT: {p['gt']}" for p in batch ] metadatas = [ { "htr": p["htr"], "gt": p["gt"], "type": p.get("type", ""), "region": p.get("region", ""), "date": p.get("date", ""), "caligrafia": p.get("caligrafia", "desconocida"), "corrections": json.dumps( p.get("corrections", []), ensure_ascii=False ), } for p in batch ] ids = [p["id"] for p in batch] self.collection.add( documents=documents, metadatas=metadatas, ids=ids, ) print(f"✅ Indexación completa. Total en store: {self.collection.count()}") return len(to_add) # ── Recuperación ────────────────────────────────────────────────────────── def retrieve(self, query: str, k: int = 5) -> List[Dict]: """ Recupera los k pares más similares al texto HTR de consulta. Retorna lista de dicts con htr, gt, type, region, date, corrections, score. """ if self.collection.count() == 0: return [] results = self.collection.query( query_texts=[query], n_results=min(k, self.collection.count()), include=["metadatas", "distances"], ) retrieved = [] for meta, dist in zip( results["metadatas"][0], results["distances"][0] ): retrieved.append({ "htr": meta["htr"], "gt": meta["gt"], "type": meta.get("type", ""), "region": meta.get("region", ""), "date": meta.get("date", ""), "caligrafia": meta.get("caligrafia", ""), "corrections": json.loads(meta.get("corrections", "[]")), "score": round(1 - dist, 4), }) return retrieved # ── Utilidades ──────────────────────────────────────────────────────────── def count(self) -> int: return self.collection.count() def reset(self): """Elimina y recrea la colección (útil para re-indexar desde cero).""" self.client.delete_collection(self.collection.name) self.collection = self.client.get_or_create_collection( name=self.collection.name, embedding_function=self.ef, metadata={"hnsw:space": "cosine"}, ) print("🗑 Vector store reseteado.")