TinyModel1Space / app.py
staindart's picture
Deploy TinyModel1Space from GitHub Actions
736707b verified
import os
import gradio as gr
from transformers import pipeline
MODEL_ID = "HyperlinksSpace/TinyModel1"
PUBLIC_APP_URL = "https://hyperlinksspace-tinymodel1space.hf.space"
MODEL_HUB_URL = "https://huggingface.co/HyperlinksSpace/TinyModel1"
GITHUB_REPO_URL = "https://github.com/HyperlinksSpace/TinyModel"
_clf = None
def get_pipeline():
global _clf
if _clf is not None:
return _clf
token = os.getenv("HF_TOKEN")
kwargs = {}
if token:
kwargs["token"] = token
_clf = pipeline(
"text-classification",
model=MODEL_ID,
tokenizer=MODEL_ID,
top_k=None,
**kwargs,
)
return _clf
def _prediction_list(batch_output):
# One batch item: either a single {label, score} dict or a list of them.
if not batch_output:
return []
first = batch_output[0]
if isinstance(first, dict):
return [first]
if isinstance(first, list):
return first
return []
def predict(text):
text = (text or "").strip()
if not text:
return {}, "Please enter some text first."
try:
clf = get_pipeline()
except Exception as exc:
return {}, f"Model load failed for {MODEL_ID}: {exc}"
raw = clf(text, truncation=True, max_length=128)
preds = _prediction_list(raw)
if not preds:
return {}, "Empty model output (unexpected pipeline shape)."
preds = sorted(preds, key=lambda x: float(x["score"]), reverse=True)
return {item["label"]: float(item["score"]) for item in preds}, "OK"
EXAMPLES = [
["Apple reported strong quarterly revenue growth and raised guidance."],
["The team won the championship after a dramatic overtime finish."],
["Scientists announced a new breakthrough in battery technology."],
["Leaders met to discuss tensions and trade policy in the region."],
]
with gr.Blocks(title="TinyModel1Space") as demo:
gr.Markdown("# TinyModel1Space")
gr.Markdown("Model: `HyperlinksSpace/TinyModel1`")
gr.Markdown(
"- **Public URL (direct app):** ["
+ PUBLIC_APP_URL
+ "]("
+ PUBLIC_APP_URL
+ ")\n- **Model on Hugging Face:** ["
+ MODEL_HUB_URL
+ "]("
+ MODEL_HUB_URL
+ ")\n- **Source code (GitHub):** ["
+ GITHUB_REPO_URL
+ "]("
+ GITHUB_REPO_URL
+ ")"
)
inp = gr.Textbox(lines=4, label="Input text", placeholder="Paste a news sentence here...")
out = gr.Label(num_top_classes=4, label="Predicted class probabilities")
status = gr.Textbox(label="Status", interactive=False)
run_btn = gr.Button("Run Inference", variant="primary")
run_btn.click(fn=predict, inputs=inp, outputs=[out, status])
inp.submit(fn=predict, inputs=inp, outputs=[out, status])
# Do not pre-run examples at startup (loads model N times; can hang, hit Hub rate limits, or break the Space).
gr.Examples(examples=EXAMPLES, inputs=inp, cache_examples=False)
if __name__ == "__main__":
print(f"Public URL (direct): {PUBLIC_APP_URL}")
demo.queue(default_concurrency_limit=4)
demo.launch(ssr_mode=False)