Harley-ml's picture
Update app.py
ab91e93 verified
from __future__ import annotations
import os
from functools import lru_cache
from typing import Dict, List, Tuple
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# =============================================================================
# Model catalog
# =============================================================================
MODEL_CHOICES: List[Tuple[str, str]] = [
("PicoWord-5k", "Harley-ml/PicoWord-5k"),
("MicroWord-23k", "Harley-ml/MicroWord-23k"),
("TinyWord-134k", "Harley-ml/TinyWord-134k"),
("TinyWord2-128k", "Harley-ml/TinyWord2-128k"),
("MediumWord-559k", "Harley-ml/MediumWord-559k"),
("LargeWord-1.5M", "Harley-ml/LargeWord-1.5M"),
("Tenete-8M", "Harley-ml/Tenete-8M"),
("MCOD-4.7M", "Harley-ml/MCOD-4.7M"),
("LWTMoe", "Harley-ml/LWTMoE-10M-A6M"),
("LWTDense", "Harley-ml/LWTDense-6M"),
("MiniMD-28M", "Harley-ml/MiniMD-28M"),
("StopAskingQuestionsMini-656k", "Harley-ml/StopAskingQuestionsMini-656k"),
("Dillion-1.2M", "Harley-ml/Dillion-1.2M")
]
LABEL_TO_REPO: Dict[str, str] = {label: repo for label, repo in MODEL_CHOICES}
# =============================================================================
# Device helpers
# =============================================================================
def get_device() -> str:
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
return "cpu"
DEVICE = get_device()
DEFAULT_TORCH_DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
# =============================================================================
# Loading
# =============================================================================
def _tokenizer_kwargs() -> dict:
# Some repos ship custom tokenizer code/config; trust_remote_code helps
# with any special setup, while still working for standard tokenizers.
return {
"use_fast": True,
"trust_remote_code": True,
}
@lru_cache(maxsize=8)
def load_tokenizer(model_id: str):
tokenizer = AutoTokenizer.from_pretrained(model_id, **_tokenizer_kwargs())
if tokenizer.pad_token is None:
if tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
else:
tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
# Left padding is usually a safer default for batched generation.
tokenizer.padding_side = "left"
return tokenizer
@lru_cache(maxsize=8)
def load_model(model_id: str):
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=DEFAULT_TORCH_DTYPE,
low_cpu_mem_usage=True,
trust_remote_code=True,
)
# Some MoE/Qwen3 variants can expose this flag; disabling it avoids
# a common generation-time mismatch in certain setups.
if hasattr(model.config, "output_router_logits"):
try:
model.config.output_router_logits = False
except Exception:
pass
model.eval()
model.to(DEVICE)
if hasattr(model, "generation_config"):
try:
tok = load_tokenizer(model_id)
model.generation_config.pad_token_id = tok.pad_token_id
model.generation_config.eos_token_id = tok.eos_token_id
except Exception:
pass
return model
# =============================================================================
# Generation
# =============================================================================
def generate_text(
model_label: str,
prompt: str,
max_new_tokens: int,
temperature: float,
top_p: float,
top_k: int,
repetition_penalty: float,
do_sample: bool,
return_full_text: bool,
) -> Tuple[str, str]:
model_id = LABEL_TO_REPO[model_label]
tokenizer = load_tokenizer(model_id)
model = load_model(model_id)
prompt = prompt or ""
bos = tokenizer.bos_token or ""
full_prompt = bos + prompt
inputs = tokenizer(
full_prompt,
return_tensors="pt",
add_special_tokens=False,
)
inputs = {
key: value.to(DEVICE) if hasattr(value, "to") else value
for key, value in inputs.items()
if key != "token_type_ids"
}
gen_kwargs = {
"max_new_tokens": int(max_new_tokens),
"do_sample": bool(do_sample),
"repetition_penalty": float(repetition_penalty),
"eos_token_id": tokenizer.eos_token_id,
"pad_token_id": tokenizer.pad_token_id,
}
if do_sample:
gen_kwargs["temperature"] = float(temperature)
gen_kwargs["top_p"] = float(top_p)
gen_kwargs["top_k"] = int(top_k)
with torch.inference_mode():
output_ids = model.generate(**inputs, **gen_kwargs)
input_len = inputs["input_ids"].shape[-1]
generated_ids = output_ids[0][input_len:]
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
if return_full_text:
full_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
else:
full_text = generated_text
status = (
f"Loaded model: `{model_id}` \n"
f"Device: `{DEVICE}` \n"
f"Prompt tokens: `{input_len}` \n"
f"Generated tokens: `{generated_ids.shape[-1]}`"
)
return full_text, status
# =============================================================================
# UI
# =============================================================================
with gr.Blocks(title="Harley-ml Catalogue", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"# Harley-ml Catalogue\n"
"Pick a model, type a prompt, and change the sampling settings."
)
with gr.Row():
model_box = gr.Dropdown(
choices=[label for label, _ in MODEL_CHOICES],
value="Tenete-8M",
label="Model",
)
prompt_box = gr.Textbox(
label="Prompt",
placeholder="Type whatever you want here",
lines=5,
value="The",
)
with gr.Row():
max_new_tokens_box = gr.Slider(
minimum=1,
maximum=1024,
value=128,
step=1,
label="Max new tokens",
)
temperature_box = gr.Slider(
minimum=0.0,
maximum=3.0,
value=1.0,
step=0.05,
label="Temperature",
)
top_p_box = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.95,
step=0.01,
label="Top-p",
)
with gr.Row():
top_k_box = gr.Slider(
minimum=0,
maximum=500,
value=40,
step=1,
label="Top-k",
)
repetition_penalty_box = gr.Slider(
minimum=0.0,
maximum=3.0,
value=1.2,
step=0.01,
label="Repetition penalty",
)
do_sample_box = gr.Checkbox(value=True, label="Do sample")
return_full_text_box = gr.Checkbox(value=False, label="Return full text")
generate_btn = gr.Button("Generate", variant="primary")
output_box = gr.Textbox(label="Output", lines=12)
status_box = gr.Markdown()
generate_btn.click(
fn=generate_text,
inputs=[
model_box,
prompt_box,
max_new_tokens_box,
temperature_box,
top_p_box,
top_k_box,
repetition_penalty_box,
do_sample_box,
return_full_text_box,
],
outputs=[output_box, status_box],
)
if __name__ == "__main__":
port = int(os.environ.get("PORT", "7860"))
demo.launch(server_name="0.0.0.0", server_port=port)