| import json |
| from huggingface_hub import HfApi, ModelFilter, DatasetFilter, ModelSearchArguments |
| from pprint import pprint |
| from hf_search import HFSearch |
| import streamlit as st |
| import itertools |
|
|
| from pbr.version import VersionInfo |
| print("hf_search version:", VersionInfo('hf_search').version_string()) |
|
|
| hf_search = HFSearch(top_k=200) |
|
|
| @st.cache |
| def hf_api(query, limit=5, sort=None, filters={}): |
| print("query", query) |
| print("filters", filters) |
| print("limit", limit) |
| print("sort", sort) |
|
|
| api = HfApi() |
| filt = ModelFilter( |
| task=filters["task"], |
| library=filters["library"], |
| ) |
| models = api.list_models(search=query, filter=filt, limit=limit, sort=sort, full=True) |
| hits = [] |
| for model in models: |
| model = model.__dict__ |
| hits.append( |
| { |
| "modelId": model.get("modelId"), |
| "tags": model.get("tags"), |
| "downloads": model.get("downloads"), |
| "likes": model.get("likes"), |
| } |
| ) |
| count = len(hits) |
| if len(hits) > limit: |
| hits = hits[:limit] |
| return {"hits": hits, "count": count} |
|
|
|
|
| @st.cache |
| def semantic_search(query, limit=5, sort=None, filters={}): |
| print("query", query) |
| print("filters", filters) |
| print("limit", limit) |
| print("sort", sort) |
|
|
| hits = hf_search.search(query=query, method="retrieve & rerank", limit=limit, sort=sort, filters=filters) |
| hits = [ |
| { |
| "modelId": hit["modelId"], |
| "tags": hit["tags"], |
| "downloads": hit["downloads"], |
| "likes": hit["likes"], |
| "readme": hit.get("readme", None), |
| } |
| for hit in hits |
| ] |
| return {"hits": hits, "count": len(hits)} |
|
|
|
|
| @st.cache |
| def bm25_search(query, limit=5, sort=None, filters={}): |
| print("query", query) |
| print("filters", filters) |
| print("limit", limit) |
| print("sort", sort) |
|
|
| |
| hits = hf_search.search(query=query, method="bm25", limit=limit, sort=sort, filters=filters) |
| hits = [ |
| { |
| "modelId": hit["modelId"], |
| "tags": hit["tags"], |
| "downloads": hit["downloads"], |
| "likes": hit["likes"], |
| "readme": hit.get("readme", None), |
| } |
| for hit in hits |
| ] |
| hits = [ |
| hits[i] for i in range(len(hits)) if hits[i]["modelId"] not in [h["modelId"] for h in hits[:i]] |
| ] |
| return {"hits": hits, "count": len(hits)} |
|
|
|
|
| def paginator(label, articles, articles_per_page=10, on_sidebar=True): |
| |
| """Lets the user paginate a set of article. |
| Parameters |
| ---------- |
| label : str |
| The label to display over the pagination widget. |
| article : Iterator[Any] |
| The articles to display in the paginator. |
| articles_per_page: int |
| The number of articles to display per page. |
| on_sidebar: bool |
| Whether to display the paginator widget on the sidebar. |
| |
| Returns |
| ------- |
| Iterator[Tuple[int, Any]] |
| An iterator over *only the article on that page*, including |
| the item's index. |
| """ |
|
|
| |
| if on_sidebar: |
| location = st.sidebar.empty() |
| else: |
| location = st.empty() |
|
|
| |
| articles = list(articles) |
| n_pages = (len(articles) - 1) // articles_per_page + 1 |
| page_format_func = lambda i: f"Results {i*10} to {i*10 +10 -1}" |
| page_number = location.selectbox(label, range(n_pages), format_func=page_format_func) |
|
|
| |
| min_index = page_number * articles_per_page |
| max_index = min_index + articles_per_page |
|
|
| return itertools.islice(enumerate(articles), min_index, max_index) |
|
|