trohrbaugh commited on
Commit
4b77797
·
verified ·
1 Parent(s): ace7f2c

Wire modelatlas_similar to HF parquet dataset (2,435 models)

Browse files
Files changed (2) hide show
  1. requirements.txt +3 -0
  2. scan.py +76 -17
requirements.txt CHANGED
@@ -1,2 +1,5 @@
1
  requests>=2.31.0
2
  psycopg2-binary>=2.9.9
 
 
 
 
1
  requests>=2.31.0
2
  psycopg2-binary>=2.9.9
3
+ pandas>=2.0.0
4
+ pyarrow>=14.0.0
5
+ huggingface_hub>=0.23.0
scan.py CHANGED
@@ -12,10 +12,51 @@ from datetime import datetime, timezone
12
  from pathlib import Path
13
  from typing import Optional
14
  import requests
15
- import psycopg2, psycopg2.extras
16
 
17
- DB = "postgresql:///modelatlas?host=/var/run/postgresql&port=5433&user=tim"
18
  HF_API = "https://huggingface.co"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  # Known base model reference configs (canonical identifiers)
21
  KNOWN_BASES = {
@@ -196,26 +237,44 @@ def stage1_screen(model_id: str, config: dict) -> dict:
196
  "evidence": reasons,
197
  })
198
 
199
- # Check ModelAtlas DB for exact signature
200
  db_matches = []
201
  try:
202
- conn = psycopg2.connect(DB)
203
- cur = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
204
- cur.execute("""
205
- SELECT m.model_id, o.name AS lab, m.hf_downloads, m.release_date,
206
- a.technique_signature, a.total_params, a.num_layers, a.hidden_size, a.vocab_size
207
- FROM analyses a JOIN models m ON m.id=a.model_id
208
- JOIN organizations o ON m.org_id=o.id
209
- WHERE a.is_current=true AND a.vocab_size=%s AND a.hidden_size=%s
210
- AND m.model_id NOT ILIKE '%%tiny%%' AND m.model_id NOT ILIKE '/%%'
211
- ORDER BY m.hf_downloads DESC NULLS LAST
212
- LIMIT 5
213
- """, (vocab, hidden))
214
- db_matches = [dict(r) for r in cur.fetchall()]
215
- cur.close(); conn.close()
216
  except Exception:
217
  pass
218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  return {
220
  "arch_signature": arch_sig,
221
  "config_signals": {
 
12
  from pathlib import Path
13
  from typing import Optional
14
  import requests
 
15
 
 
16
  HF_API = "https://huggingface.co"
17
+ HF_DATASET = "RadicalNotionAI/modelatlas-reference"
18
+ DB = "postgresql:///modelatlas?host=/var/run/postgresql&port=5433&user=tim"
19
+
20
+ # In-process cache — loaded once per worker, refreshes when the file changes
21
+ _REF_DF = None
22
+ _REF_LOADED_AT: float = 0.0
23
+ _REF_TTL = 3600 # reload at most once per hour
24
+
25
+
26
+ def _load_reference_df():
27
+ """Load ModelAtlas reference parquet. Tries local snapshot first, then HF dataset."""
28
+ global _REF_DF, _REF_LOADED_AT
29
+ now = time.time()
30
+ if _REF_DF is not None and (now - _REF_LOADED_AT) < _REF_TTL:
31
+ return _REF_DF
32
+
33
+ import pandas as pd
34
+
35
+ # 1. Local snapshot (fast, used in dev / on local server)
36
+ local_path = Path(__file__).parent.parent / "snapshots" / "modeldna_reference.parquet"
37
+ if local_path.exists():
38
+ try:
39
+ _REF_DF = pd.read_parquet(local_path)
40
+ _REF_LOADED_AT = now
41
+ return _REF_DF
42
+ except Exception:
43
+ pass
44
+
45
+ # 2. HF dataset (used on HF Space — downloaded and cached by huggingface_hub)
46
+ try:
47
+ from huggingface_hub import hf_hub_download
48
+ path = hf_hub_download(
49
+ repo_id=HF_DATASET,
50
+ filename="modeldna_reference.parquet",
51
+ repo_type="dataset",
52
+ )
53
+ _REF_DF = pd.read_parquet(path)
54
+ _REF_LOADED_AT = now
55
+ return _REF_DF
56
+ except Exception:
57
+ pass
58
+
59
+ return None
60
 
61
  # Known base model reference configs (canonical identifiers)
62
  KNOWN_BASES = {
 
237
  "evidence": reasons,
238
  })
239
 
240
+ # Query ModelAtlas reference parquet for architecturally similar models
241
  db_matches = []
242
  try:
243
+ ref = _load_reference_df()
244
+ if ref is not None and vocab and hidden:
245
+ hit = ref[
246
+ (ref["vocab_size"] == vocab) &
247
+ (ref["hidden_size"] == hidden) &
248
+ (~ref["model_id"].str.contains("tiny|/", case=False, na=False))
249
+ ].sort_values("hf_downloads", ascending=False).head(5)
250
+ db_matches = hit[
251
+ ["model_id", "org_display", "hf_downloads", "total_params",
252
+ "technique_signature", "num_layers", "hidden_size", "vocab_size"]
253
+ ].rename(columns={"org_display": "lab"}).to_dict("records")
 
 
 
254
  except Exception:
255
  pass
256
 
257
+ # Also try local DB if available (dev / local server)
258
+ if not db_matches:
259
+ try:
260
+ import psycopg2, psycopg2.extras
261
+ conn = psycopg2.connect(DB)
262
+ cur = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
263
+ cur.execute("""
264
+ SELECT m.model_id, o.name AS lab, m.hf_downloads, m.release_date,
265
+ a.technique_signature, a.total_params, a.num_layers, a.hidden_size, a.vocab_size
266
+ FROM analyses a JOIN models m ON m.id=a.model_id
267
+ JOIN organizations o ON m.org_id=o.id
268
+ WHERE a.is_current=true AND a.vocab_size=%s AND a.hidden_size=%s
269
+ AND m.model_id NOT ILIKE '%%tiny%%' AND m.model_id NOT ILIKE '/%%'
270
+ ORDER BY m.hf_downloads DESC NULLS LAST
271
+ LIMIT 5
272
+ """, (vocab, hidden))
273
+ db_matches = [dict(r) for r in cur.fetchall()]
274
+ cur.close(); conn.close()
275
+ except Exception:
276
+ pass
277
+
278
  return {
279
  "arch_signature": arch_sig,
280
  "config_signals": {