Spaces:
Running
Running
AmrYassinIsFree commited on
Commit ·
d6ca6d1
1
Parent(s): 2daebaf
add custom models
Browse files
.gitignore
CHANGED
|
@@ -205,3 +205,6 @@ cython_debug/
|
|
| 205 |
marimo/_static/
|
| 206 |
marimo/_lsp/
|
| 207 |
__marimo__/
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
marimo/_static/
|
| 206 |
marimo/_lsp/
|
| 207 |
__marimo__/
|
| 208 |
+
|
| 209 |
+
# Embedding Bench custom models
|
| 210 |
+
custom_models.json
|
app.py
CHANGED
|
@@ -14,9 +14,18 @@ from corpus import build_corpus
|
|
| 14 |
from dataset_config import DATASET_PRESETS, DatasetConfig
|
| 15 |
from evals.quality import evaluate_quality
|
| 16 |
from evals.speed import evaluate_speed
|
| 17 |
-
from models import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
from wrapper import load_model
|
| 19 |
|
|
|
|
|
|
|
| 20 |
# ---------------------------------------------------------------------------
|
| 21 |
# Page config & custom CSS
|
| 22 |
# ---------------------------------------------------------------------------
|
|
@@ -114,6 +123,40 @@ selected_models = st.sidebar.multiselect(
|
|
| 114 |
label_visibility="collapsed",
|
| 115 |
)
|
| 116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
st.sidebar.markdown("**Datasets**")
|
| 118 |
available_datasets = list(DATASET_PRESETS.keys())
|
| 119 |
selected_datasets = st.sidebar.multiselect(
|
|
|
|
| 14 |
from dataset_config import DATASET_PRESETS, DatasetConfig
|
| 15 |
from evals.quality import evaluate_quality
|
| 16 |
from evals.speed import evaluate_speed
|
| 17 |
+
from models import (
|
| 18 |
+
REGISTRY,
|
| 19 |
+
VALID_BACKENDS,
|
| 20 |
+
ModelConfig,
|
| 21 |
+
load_custom_models_from_file,
|
| 22 |
+
register_model,
|
| 23 |
+
save_custom_model_to_file,
|
| 24 |
+
)
|
| 25 |
from wrapper import load_model
|
| 26 |
|
| 27 |
+
load_custom_models_from_file()
|
| 28 |
+
|
| 29 |
# ---------------------------------------------------------------------------
|
| 30 |
# Page config & custom CSS
|
| 31 |
# ---------------------------------------------------------------------------
|
|
|
|
| 123 |
label_visibility="collapsed",
|
| 124 |
)
|
| 125 |
|
| 126 |
+
with st.sidebar.expander("➕ Add Custom Model"):
|
| 127 |
+
with st.form("add_model_form", clear_on_submit=True):
|
| 128 |
+
new_key = st.text_input("Registry key", placeholder="my-model")
|
| 129 |
+
new_name = st.text_input("Display name", placeholder="My Custom Model")
|
| 130 |
+
new_model_id = st.text_input("HuggingFace model ID", placeholder="org/model-name")
|
| 131 |
+
new_backend = st.selectbox("Backend", sorted(VALID_BACKENDS))
|
| 132 |
+
new_gguf_file = st.text_input(
|
| 133 |
+
"GGUF filename (gguf backend only)", value="", placeholder="model.gguf"
|
| 134 |
+
)
|
| 135 |
+
new_is_baseline = st.checkbox("Mark as baseline", value=False)
|
| 136 |
+
new_persist = st.checkbox("Save to disk", value=False,
|
| 137 |
+
help="Persist to custom_models.json so it loads next session")
|
| 138 |
+
submitted = st.form_submit_button("Add Model", use_container_width=True)
|
| 139 |
+
if submitted:
|
| 140 |
+
if not new_key or not new_name or not new_model_id:
|
| 141 |
+
st.sidebar.error("Key, name, and model ID are required.")
|
| 142 |
+
elif new_backend == "gguf" and not new_gguf_file:
|
| 143 |
+
st.sidebar.error("GGUF filename is required for gguf backend.")
|
| 144 |
+
else:
|
| 145 |
+
cfg = ModelConfig(
|
| 146 |
+
name=new_name,
|
| 147 |
+
model_id=new_model_id,
|
| 148 |
+
is_baseline=new_is_baseline,
|
| 149 |
+
backend=new_backend,
|
| 150 |
+
gguf_file=new_gguf_file or None,
|
| 151 |
+
)
|
| 152 |
+
try:
|
| 153 |
+
register_model(new_key, cfg)
|
| 154 |
+
if new_persist:
|
| 155 |
+
save_custom_model_to_file(new_key, cfg)
|
| 156 |
+
st.rerun()
|
| 157 |
+
except ValueError as e:
|
| 158 |
+
st.sidebar.error(str(e))
|
| 159 |
+
|
| 160 |
st.sidebar.markdown("**Datasets**")
|
| 161 |
available_datasets = list(DATASET_PRESETS.keys())
|
| 162 |
selected_datasets = st.sidebar.multiselect(
|
bench.py
CHANGED
|
@@ -5,7 +5,7 @@ import argparse
|
|
| 5 |
from corpus import build_corpus
|
| 6 |
from dataset_config import DATASET_PRESETS, DatasetConfig
|
| 7 |
from evals import evaluate_memory, evaluate_quality, evaluate_speed
|
| 8 |
-
from models import REGISTRY
|
| 9 |
from report import print_report
|
| 10 |
from wrapper import load_model
|
| 11 |
|
|
@@ -18,9 +18,15 @@ def main(argv: list[str] | None = None) -> None:
|
|
| 18 |
parser.add_argument(
|
| 19 |
"--models",
|
| 20 |
nargs="+",
|
| 21 |
-
default=
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
)
|
| 25 |
parser.add_argument("--corpus-size", type=int, default=1000)
|
| 26 |
parser.add_argument("--batch-size", type=int, default=64)
|
|
@@ -61,6 +67,28 @@ def main(argv: list[str] | None = None) -> None:
|
|
| 61 |
|
| 62 |
args = parser.parse_args(argv)
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
# Build list of dataset configs
|
| 65 |
if args.dataset:
|
| 66 |
# Custom dataset overrides presets
|
|
|
|
| 5 |
from corpus import build_corpus
|
| 6 |
from dataset_config import DATASET_PRESETS, DatasetConfig
|
| 7 |
from evals import evaluate_memory, evaluate_quality, evaluate_speed
|
| 8 |
+
from models import REGISTRY, ModelConfig, load_custom_models_from_file, register_model
|
| 9 |
from report import print_report
|
| 10 |
from wrapper import load_model
|
| 11 |
|
|
|
|
| 18 |
parser.add_argument(
|
| 19 |
"--models",
|
| 20 |
nargs="+",
|
| 21 |
+
default=None,
|
| 22 |
+
help="Models to benchmark (default: all registered)",
|
| 23 |
+
)
|
| 24 |
+
parser.add_argument(
|
| 25 |
+
"--add-model",
|
| 26 |
+
action="append",
|
| 27 |
+
default=[],
|
| 28 |
+
metavar="KEY:NAME:MODEL_ID:BACKEND[:GGUF_FILE]",
|
| 29 |
+
help="Register a custom model. Can be repeated.",
|
| 30 |
)
|
| 31 |
parser.add_argument("--corpus-size", type=int, default=1000)
|
| 32 |
parser.add_argument("--batch-size", type=int, default=64)
|
|
|
|
| 67 |
|
| 68 |
args = parser.parse_args(argv)
|
| 69 |
|
| 70 |
+
# Load persisted custom models and register any --add-model entries
|
| 71 |
+
load_custom_models_from_file()
|
| 72 |
+
for spec in args.add_model:
|
| 73 |
+
parts = spec.split(":")
|
| 74 |
+
if len(parts) < 4:
|
| 75 |
+
parser.error(f"--add-model requires KEY:NAME:MODEL_ID:BACKEND, got: {spec}")
|
| 76 |
+
key, name, model_id, backend = parts[0], parts[1], parts[2], parts[3]
|
| 77 |
+
gguf_file = parts[4] if len(parts) > 4 else None
|
| 78 |
+
try:
|
| 79 |
+
register_model(key, ModelConfig(
|
| 80 |
+
name=name, model_id=model_id, backend=backend, gguf_file=gguf_file,
|
| 81 |
+
))
|
| 82 |
+
except ValueError as e:
|
| 83 |
+
parser.error(str(e))
|
| 84 |
+
|
| 85 |
+
if args.models is None:
|
| 86 |
+
args.models = list(REGISTRY.keys())
|
| 87 |
+
else:
|
| 88 |
+
for k in args.models:
|
| 89 |
+
if k not in REGISTRY:
|
| 90 |
+
parser.error(f"Unknown model key: '{k}'. Available: {list(REGISTRY.keys())}")
|
| 91 |
+
|
| 92 |
# Build list of dataset configs
|
| 93 |
if args.dataset:
|
| 94 |
# Custom dataset overrides presets
|
models.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
-
|
|
|
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
@dataclass
|
|
@@ -43,3 +45,34 @@ REGISTRY: dict[str, ModelConfig] = {
|
|
| 43 |
# backend="libembedding",
|
| 44 |
# ),
|
| 45 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
+
import json
|
| 4 |
+
from dataclasses import asdict, dataclass
|
| 5 |
+
from pathlib import Path
|
| 6 |
|
| 7 |
|
| 8 |
@dataclass
|
|
|
|
| 45 |
# backend="libembedding",
|
| 46 |
# ),
|
| 47 |
}
|
| 48 |
+
|
| 49 |
+
VALID_BACKENDS = {"sbert", "fastembed", "libembedding", "gguf"}
|
| 50 |
+
CUSTOM_MODELS_PATH = Path(__file__).parent / "custom_models.json"
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def register_model(key: str, config: ModelConfig) -> None:
|
| 54 |
+
if key in REGISTRY:
|
| 55 |
+
raise ValueError(f"Model key '{key}' already exists in registry")
|
| 56 |
+
if config.backend not in VALID_BACKENDS:
|
| 57 |
+
raise ValueError(f"Invalid backend '{config.backend}'. Must be one of: {VALID_BACKENDS}")
|
| 58 |
+
REGISTRY[key] = config
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def load_custom_models_from_file(path: Path = CUSTOM_MODELS_PATH) -> None:
|
| 62 |
+
if not path.exists():
|
| 63 |
+
return
|
| 64 |
+
with open(path) as f:
|
| 65 |
+
entries = json.load(f)
|
| 66 |
+
for key, fields in entries.items():
|
| 67 |
+
if key not in REGISTRY:
|
| 68 |
+
REGISTRY[key] = ModelConfig(**fields)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def save_custom_model_to_file(key: str, config: ModelConfig, path: Path = CUSTOM_MODELS_PATH) -> None:
|
| 72 |
+
existing: dict = {}
|
| 73 |
+
if path.exists():
|
| 74 |
+
with open(path) as f:
|
| 75 |
+
existing = json.load(f)
|
| 76 |
+
existing[key] = asdict(config)
|
| 77 |
+
with open(path, "w") as f:
|
| 78 |
+
json.dump(existing, f, indent=2)
|