from __future__ import annotations import os from functools import lru_cache from typing import Dict, List, Tuple import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer # ============================================================================= # Model catalog # ============================================================================= MODEL_CHOICES: List[Tuple[str, str]] = [ ("PicoWord-5k", "Harley-ml/PicoWord-5k"), ("MicroWord-23k", "Harley-ml/MicroWord-23k"), ("TinyWord-134k", "Harley-ml/TinyWord-134k"), ("TinyWord2-128k", "Harley-ml/TinyWord2-128k"), ("MediumWord-559k", "Harley-ml/MediumWord-559k"), ("LargeWord-1.5M", "Harley-ml/LargeWord-1.5M"), ("Tenete-8M", "Harley-ml/Tenete-8M"), ("MCOD-4.7M", "Harley-ml/MCOD-4.7M"), ("LWTMoe", "Harley-ml/LWTMoE-10M-A6M"), ("LWTDense", "Harley-ml/LWTDense-6M"), ("MiniMD-28M", "Harley-ml/MiniMD-28M"), ("StopAskingQuestionsMini-656k", "Harley-ml/StopAskingQuestionsMini-656k"), ("Dillion-1.2M", "Harley-ml/Dillion-1.2M") ] LABEL_TO_REPO: Dict[str, str] = {label: repo for label, repo in MODEL_CHOICES} # ============================================================================= # Device helpers # ============================================================================= def get_device() -> str: if torch.cuda.is_available(): return "cuda" if torch.backends.mps.is_available(): return "mps" return "cpu" DEVICE = get_device() DEFAULT_TORCH_DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 # ============================================================================= # Loading # ============================================================================= def _tokenizer_kwargs() -> dict: # Some repos ship custom tokenizer code/config; trust_remote_code helps # with any special setup, while still working for standard tokenizers. return { "use_fast": True, "trust_remote_code": True, } @lru_cache(maxsize=8) def load_tokenizer(model_id: str): tokenizer = AutoTokenizer.from_pretrained(model_id, **_tokenizer_kwargs()) if tokenizer.pad_token is None: if tokenizer.eos_token is not None: tokenizer.pad_token = tokenizer.eos_token else: tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) # Left padding is usually a safer default for batched generation. tokenizer.padding_side = "left" return tokenizer @lru_cache(maxsize=8) def load_model(model_id: str): model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=DEFAULT_TORCH_DTYPE, low_cpu_mem_usage=True, trust_remote_code=True, ) # Some MoE/Qwen3 variants can expose this flag; disabling it avoids # a common generation-time mismatch in certain setups. if hasattr(model.config, "output_router_logits"): try: model.config.output_router_logits = False except Exception: pass model.eval() model.to(DEVICE) if hasattr(model, "generation_config"): try: tok = load_tokenizer(model_id) model.generation_config.pad_token_id = tok.pad_token_id model.generation_config.eos_token_id = tok.eos_token_id except Exception: pass return model # ============================================================================= # Generation # ============================================================================= def generate_text( model_label: str, prompt: str, max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float, do_sample: bool, return_full_text: bool, ) -> Tuple[str, str]: model_id = LABEL_TO_REPO[model_label] tokenizer = load_tokenizer(model_id) model = load_model(model_id) prompt = prompt or "" bos = tokenizer.bos_token or "" full_prompt = bos + prompt inputs = tokenizer( full_prompt, return_tensors="pt", add_special_tokens=False, ) inputs = { key: value.to(DEVICE) if hasattr(value, "to") else value for key, value in inputs.items() if key != "token_type_ids" } gen_kwargs = { "max_new_tokens": int(max_new_tokens), "do_sample": bool(do_sample), "repetition_penalty": float(repetition_penalty), "eos_token_id": tokenizer.eos_token_id, "pad_token_id": tokenizer.pad_token_id, } if do_sample: gen_kwargs["temperature"] = float(temperature) gen_kwargs["top_p"] = float(top_p) gen_kwargs["top_k"] = int(top_k) with torch.inference_mode(): output_ids = model.generate(**inputs, **gen_kwargs) input_len = inputs["input_ids"].shape[-1] generated_ids = output_ids[0][input_len:] generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True) if return_full_text: full_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) else: full_text = generated_text status = ( f"Loaded model: `{model_id}` \n" f"Device: `{DEVICE}` \n" f"Prompt tokens: `{input_len}` \n" f"Generated tokens: `{generated_ids.shape[-1]}`" ) return full_text, status # ============================================================================= # UI # ============================================================================= with gr.Blocks(title="Harley-ml Catalogue", theme=gr.themes.Soft()) as demo: gr.Markdown( "# Harley-ml Catalogue\n" "Pick a model, type a prompt, and change the sampling settings." ) with gr.Row(): model_box = gr.Dropdown( choices=[label for label, _ in MODEL_CHOICES], value="Tenete-8M", label="Model", ) prompt_box = gr.Textbox( label="Prompt", placeholder="Type whatever you want here", lines=5, value="The", ) with gr.Row(): max_new_tokens_box = gr.Slider( minimum=1, maximum=1024, value=128, step=1, label="Max new tokens", ) temperature_box = gr.Slider( minimum=0.0, maximum=3.0, value=1.0, step=0.05, label="Temperature", ) top_p_box = gr.Slider( minimum=0.0, maximum=1.0, value=0.95, step=0.01, label="Top-p", ) with gr.Row(): top_k_box = gr.Slider( minimum=0, maximum=500, value=40, step=1, label="Top-k", ) repetition_penalty_box = gr.Slider( minimum=0.0, maximum=3.0, value=1.2, step=0.01, label="Repetition penalty", ) do_sample_box = gr.Checkbox(value=True, label="Do sample") return_full_text_box = gr.Checkbox(value=False, label="Return full text") generate_btn = gr.Button("Generate", variant="primary") output_box = gr.Textbox(label="Output", lines=12) status_box = gr.Markdown() generate_btn.click( fn=generate_text, inputs=[ model_box, prompt_box, max_new_tokens_box, temperature_box, top_p_box, top_k_box, repetition_penalty_box, do_sample_box, return_full_text_box, ], outputs=[output_box, status_box], ) if __name__ == "__main__": port = int(os.environ.get("PORT", "7860")) demo.launch(server_name="0.0.0.0", server_port=port)