NSF-RAG-Codex / vector_store.py
Alexander Sanchez
mt5-base added
8e14964
"""
vector_store.py
───────────────
Indexa pares HTR/GT en ChromaDB usando embeddings configurables:
- "openai" : text-embedding-3-small (API)
- "mpnet" : paraphrase-multilingual-mpnet-base-v2 (local)
- "e5" : multilingual-e5-small (local, por defecto)
- "mt5" : mt5-base fine-tuneado en corpus HTR/GT s.XVI (local)
Uso:
from vector_store import VectorStore
vs = VectorStore(embedding_model="mt5")
vs.index(pairs)
results = vs.retrieve("texto htr...", k=5)
"""
import json
import os
from typing import List, Dict
import torch
import numpy as np
from tqdm import tqdm
import chromadb
from chromadb.utils import embedding_functions
from chromadb import EmbeddingFunction, Documents, Embeddings
from transformers import MT5EncoderModel, AutoTokenizer
from dotenv import load_dotenv
load_dotenv()
CHROMA_PATH = os.getenv("CHROMA_PATH", "./chroma_db")
OPENAI_KEY = os.getenv("OPENAI_API_KEY", "")
EMBED_MODEL = "text-embedding-3-small"
COLLECTION = "scriptorium_corpus"
# ── Embedding personalizado para mt5 ─────────────────────────────────────────
class MT5EmbeddingFunction(EmbeddingFunction):
"""
Usa el encoder de mt5 fine-tuneado con pares HTR/GT para generar embeddings.
El encoder aprendiΓ³ que grafΓ­as del s.XVI y sus formas corregidas son
semΓ‘nticamente cercanas en el espacio vectorial β€” ideal para recuperar
ejemplos similares en el RAG.
Estrategia: mean-pooling sobre los hidden states del ΓΊltimo layer,
enmascarando los tokens de padding.
"""
def __init__(self, model_name: str):
print(f" Cargando MT5EncoderModel desde '{model_name}'...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = MT5EncoderModel.from_pretrained(model_name)
self.model.eval()
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
print(f" MT5 cargado en {self.device} βœ“")
def __call__(self, input: Documents) -> Embeddings:
encoded = self.tokenizer(
list(input),
padding=True,
truncation=True,
max_length=512,
return_tensors="pt",
).to(self.device)
with torch.no_grad():
outputs = self.model(**encoded)
# Mean-pooling con mΓ‘scara de atenciΓ³n
token_embs = outputs.last_hidden_state # (B, T, H)
attention = encoded["attention_mask"].unsqueeze(-1).float()
sum_embs = (token_embs * attention).sum(dim=1)
count = attention.sum(dim=1).clamp(min=1e-9)
embeddings = (sum_embs / count).cpu().numpy() # (B, H)
return embeddings.tolist()
# ── VectorStore ───────────────────────────────────────────────────────────────
class VectorStore:
def __init__(self, embedding_model: str = "e5"):
self.client = chromadb.PersistentClient(path=CHROMA_PATH)
if embedding_model == "mpnet":
self.ef = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
)
collection_name = "scriptorium_mpnet"
elif embedding_model == "mt5":
# Apunta a tu modelo fine-tuneado via variable de entorno.
# Puede ser un repo de HF ("alezsd/mt5-htr-xvi") o una ruta local.
mt5_path = os.getenv("MT5_MODEL_PATH", "google/mt5-base")
self.ef = MT5EmbeddingFunction(mt5_path)
collection_name = "scriptorium_mt5"
else: # e5 por defecto
self.ef = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name="intfloat/multilingual-e5-small"
)
collection_name = "scriptorium_e5"
self.collection = self.client.get_or_create_collection(
name=collection_name,
embedding_function=self.ef,
metadata={"hnsw:space": "cosine"},
)
# ── IndexaciΓ³n ────────────────────────────────────────────────────────────
def index(self, pairs: List[Dict], batch_size: int = 50) -> int:
"""
Indexa los pares HTR/GT. Cada fragmento se almacena con:
- document : texto que se embebe (htr + ' [SEP] ' + gt)
- metadata : tipo, regiΓ³n, fecha, htr, gt originales
- id : identificador ΓΊnico del par
Retorna el nΓΊmero de documentos nuevos aΓ±adidos.
"""
existing_ids = set(self.collection.get(include=[])["ids"])
to_add = [p for p in pairs if p["id"] not in existing_ids]
if not to_add:
print(f"β„Ή Vector store ya actualizado ({len(existing_ids)} documentos).")
return 0
print(f"πŸ”„ Indexando {len(to_add)} documentos nuevos...")
for i in tqdm(range(0, len(to_add), batch_size), desc="Indexando"):
batch = to_add[i : i + batch_size]
documents = [
f"HTR: {p['htr']} [SEP] GT: {p['gt']}"
for p in batch
]
metadatas = [
{
"htr": p["htr"],
"gt": p["gt"],
"type": p.get("type", ""),
"region": p.get("region", ""),
"date": p.get("date", ""),
"caligrafia": p.get("caligrafia", "desconocida"),
"corrections": json.dumps(
p.get("corrections", []), ensure_ascii=False
),
}
for p in batch
]
ids = [p["id"] for p in batch]
self.collection.add(
documents=documents,
metadatas=metadatas,
ids=ids,
)
print(f"βœ… IndexaciΓ³n completa. Total en store: {self.collection.count()}")
return len(to_add)
# ── RecuperaciΓ³n ──────────────────────────────────────────────────────────
def retrieve(self, query: str, k: int = 5) -> List[Dict]:
"""
Recupera los k pares mΓ‘s similares al texto HTR de consulta.
Retorna lista de dicts con htr, gt, type, region, date, corrections, score.
"""
if self.collection.count() == 0:
return []
results = self.collection.query(
query_texts=[query],
n_results=min(k, self.collection.count()),
include=["metadatas", "distances"],
)
retrieved = []
for meta, dist in zip(
results["metadatas"][0], results["distances"][0]
):
retrieved.append({
"htr": meta["htr"],
"gt": meta["gt"],
"type": meta.get("type", ""),
"region": meta.get("region", ""),
"date": meta.get("date", ""),
"caligrafia": meta.get("caligrafia", ""),
"corrections": json.loads(meta.get("corrections", "[]")),
"score": round(1 - dist, 4),
})
return retrieved
# ── Utilidades ────────────────────────────────────────────────────────────
def count(self) -> int:
return self.collection.count()
def reset(self):
"""Elimina y recrea la colecciΓ³n (ΓΊtil para re-indexar desde cero)."""
self.client.delete_collection(self.collection.name)
self.collection = self.client.get_or_create_collection(
name=self.collection.name,
embedding_function=self.ef,
metadata={"hnsw:space": "cosine"},
)
print("πŸ—‘ Vector store reseteado.")