Spaces:
Running
Running
| from __future__ import annotations | |
| import os | |
| from functools import lru_cache | |
| from typing import Dict, List, Tuple | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # ============================================================================= | |
| # Model catalog | |
| # ============================================================================= | |
| MODEL_CHOICES: List[Tuple[str, str]] = [ | |
| ("PicoWord-5k", "Harley-ml/PicoWord-5k"), | |
| ("MicroWord-23k", "Harley-ml/MicroWord-23k"), | |
| ("TinyWord-134k", "Harley-ml/TinyWord-134k"), | |
| ("TinyWord2-128k", "Harley-ml/TinyWord2-128k"), | |
| ("MediumWord-559k", "Harley-ml/MediumWord-559k"), | |
| ("LargeWord-1.5M", "Harley-ml/LargeWord-1.5M"), | |
| ("Tenete-8M", "Harley-ml/Tenete-8M"), | |
| ("MCOD-4.7M", "Harley-ml/MCOD-4.7M"), | |
| ("LWTMoe", "Harley-ml/LWTMoE-10M-A6M"), | |
| ("LWTDense", "Harley-ml/LWTDense-6M"), | |
| ("MiniMD-28M", "Harley-ml/MiniMD-28M"), | |
| ("StopAskingQuestionsMini-656k", "Harley-ml/StopAskingQuestionsMini-656k"), | |
| ("Dillion-1.2M", "Harley-ml/Dillion-1.2M") | |
| ] | |
| LABEL_TO_REPO: Dict[str, str] = {label: repo for label, repo in MODEL_CHOICES} | |
| # ============================================================================= | |
| # Device helpers | |
| # ============================================================================= | |
| def get_device() -> str: | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| if torch.backends.mps.is_available(): | |
| return "mps" | |
| return "cpu" | |
| DEVICE = get_device() | |
| DEFAULT_TORCH_DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 | |
| # ============================================================================= | |
| # Loading | |
| # ============================================================================= | |
| def _tokenizer_kwargs() -> dict: | |
| # Some repos ship custom tokenizer code/config; trust_remote_code helps | |
| # with any special setup, while still working for standard tokenizers. | |
| return { | |
| "use_fast": True, | |
| "trust_remote_code": True, | |
| } | |
| def load_tokenizer(model_id: str): | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, **_tokenizer_kwargs()) | |
| if tokenizer.pad_token is None: | |
| if tokenizer.eos_token is not None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| else: | |
| tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) | |
| # Left padding is usually a safer default for batched generation. | |
| tokenizer.padding_side = "left" | |
| return tokenizer | |
| def load_model(model_id: str): | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype=DEFAULT_TORCH_DTYPE, | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True, | |
| ) | |
| # Some MoE/Qwen3 variants can expose this flag; disabling it avoids | |
| # a common generation-time mismatch in certain setups. | |
| if hasattr(model.config, "output_router_logits"): | |
| try: | |
| model.config.output_router_logits = False | |
| except Exception: | |
| pass | |
| model.eval() | |
| model.to(DEVICE) | |
| if hasattr(model, "generation_config"): | |
| try: | |
| tok = load_tokenizer(model_id) | |
| model.generation_config.pad_token_id = tok.pad_token_id | |
| model.generation_config.eos_token_id = tok.eos_token_id | |
| except Exception: | |
| pass | |
| return model | |
| # ============================================================================= | |
| # Generation | |
| # ============================================================================= | |
| def generate_text( | |
| model_label: str, | |
| prompt: str, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| top_k: int, | |
| repetition_penalty: float, | |
| do_sample: bool, | |
| return_full_text: bool, | |
| ) -> Tuple[str, str]: | |
| model_id = LABEL_TO_REPO[model_label] | |
| tokenizer = load_tokenizer(model_id) | |
| model = load_model(model_id) | |
| prompt = prompt or "" | |
| bos = tokenizer.bos_token or "" | |
| full_prompt = bos + prompt | |
| inputs = tokenizer( | |
| full_prompt, | |
| return_tensors="pt", | |
| add_special_tokens=False, | |
| ) | |
| inputs = { | |
| key: value.to(DEVICE) if hasattr(value, "to") else value | |
| for key, value in inputs.items() | |
| if key != "token_type_ids" | |
| } | |
| gen_kwargs = { | |
| "max_new_tokens": int(max_new_tokens), | |
| "do_sample": bool(do_sample), | |
| "repetition_penalty": float(repetition_penalty), | |
| "eos_token_id": tokenizer.eos_token_id, | |
| "pad_token_id": tokenizer.pad_token_id, | |
| } | |
| if do_sample: | |
| gen_kwargs["temperature"] = float(temperature) | |
| gen_kwargs["top_p"] = float(top_p) | |
| gen_kwargs["top_k"] = int(top_k) | |
| with torch.inference_mode(): | |
| output_ids = model.generate(**inputs, **gen_kwargs) | |
| input_len = inputs["input_ids"].shape[-1] | |
| generated_ids = output_ids[0][input_len:] | |
| generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True) | |
| if return_full_text: | |
| full_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| else: | |
| full_text = generated_text | |
| status = ( | |
| f"Loaded model: `{model_id}` \n" | |
| f"Device: `{DEVICE}` \n" | |
| f"Prompt tokens: `{input_len}` \n" | |
| f"Generated tokens: `{generated_ids.shape[-1]}`" | |
| ) | |
| return full_text, status | |
| # ============================================================================= | |
| # UI | |
| # ============================================================================= | |
| with gr.Blocks(title="Harley-ml Catalogue", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| "# Harley-ml Catalogue\n" | |
| "Pick a model, type a prompt, and change the sampling settings." | |
| ) | |
| with gr.Row(): | |
| model_box = gr.Dropdown( | |
| choices=[label for label, _ in MODEL_CHOICES], | |
| value="Tenete-8M", | |
| label="Model", | |
| ) | |
| prompt_box = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Type whatever you want here", | |
| lines=5, | |
| value="The", | |
| ) | |
| with gr.Row(): | |
| max_new_tokens_box = gr.Slider( | |
| minimum=1, | |
| maximum=1024, | |
| value=128, | |
| step=1, | |
| label="Max new tokens", | |
| ) | |
| temperature_box = gr.Slider( | |
| minimum=0.0, | |
| maximum=3.0, | |
| value=1.0, | |
| step=0.05, | |
| label="Temperature", | |
| ) | |
| top_p_box = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.01, | |
| label="Top-p", | |
| ) | |
| with gr.Row(): | |
| top_k_box = gr.Slider( | |
| minimum=0, | |
| maximum=500, | |
| value=40, | |
| step=1, | |
| label="Top-k", | |
| ) | |
| repetition_penalty_box = gr.Slider( | |
| minimum=0.0, | |
| maximum=3.0, | |
| value=1.2, | |
| step=0.01, | |
| label="Repetition penalty", | |
| ) | |
| do_sample_box = gr.Checkbox(value=True, label="Do sample") | |
| return_full_text_box = gr.Checkbox(value=False, label="Return full text") | |
| generate_btn = gr.Button("Generate", variant="primary") | |
| output_box = gr.Textbox(label="Output", lines=12) | |
| status_box = gr.Markdown() | |
| generate_btn.click( | |
| fn=generate_text, | |
| inputs=[ | |
| model_box, | |
| prompt_box, | |
| max_new_tokens_box, | |
| temperature_box, | |
| top_p_box, | |
| top_k_box, | |
| repetition_penalty_box, | |
| do_sample_box, | |
| return_full_text_box, | |
| ], | |
| outputs=[output_box, status_box], | |
| ) | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", "7860")) | |
| demo.launch(server_name="0.0.0.0", server_port=port) | |