contextforge-demo / contextforge /dedup /dedup_engine.py
Pablo
ContextForge v0.1.0 - shared context compiler for multi-agent LLM systems
6d9c72b
raw
history blame
2.4 kB
"""Semantic deduplication using SBERT embeddings."""
import asyncio
import logging
from typing import Literal
from contextforge.dedup.embedder import Embedder
logger = logging.getLogger(__name__)
class SemanticDedupEngine:
"""Semantic similarity + cosine deduplication using SBERT."""
def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
self._embedder = Embedder(model_name)
self._lock = asyncio.Lock()
async def embed(self, text: str) -> list[float]:
"""Generate embedding for text."""
return await self._embedder.encode(text)
async def similarity(self, emb1: list[float], emb2: list[float]) -> float:
"""Compute cosine similarity between two embeddings."""
dot = sum(a * b for a, b in zip(emb1, emb2))
norm1 = sum(a * a for a in emb1) ** 0.5
norm2 = sum(b * b for b in emb2) ** 0.5
if norm1 == 0 or norm2 == 0:
return 0.0
return dot / (norm1 * norm2)
async def find_shared_prefix(self, context_a: str, context_b: str) -> str:
"""Find overlapping text between two contexts."""
words_a = context_a.split()
words_b = context_b.split()
shared = []
min_len = min(len(words_a), len(words_b))
for i in range(min_len):
if words_a[i] == words_b[i]:
shared.append(words_a[i])
else:
break
return " ".join(shared)
async def batch_deduplicate(
self, contexts: list[str]
) -> dict[str, list[dict]]:
"""Deduplicate a batch of contexts."""
if not contexts:
return {}
embeddings = await self._embedder.encode_batch(contexts)
results: dict[str, list[dict]] = {}
for i, (ctx, emb) in enumerate(zip(contexts, embeddings)):
matches = []
for j, (other_ctx, other_emb) in enumerate(zip(contexts, embeddings)):
if i == j:
continue
sim = await self.similarity(emb, other_emb)
if sim >= 0.85:
shared = await self.find_shared_prefix(ctx, other_ctx)
matches.append({
"index": j,
"similarity": sim,
"shared_prefix": shared,
})
results[f"context_{i}"] = matches
return results