| import os |
| import time |
|
|
| import gradio as gr |
| from huggingface_hub import hf_hub_download |
| from llama_cpp import Llama |
|
|
| from config import ( |
| FLASH_ATTN, |
| KV_CACHE_TYPE, |
| MAX_TOKENS, |
| MIN_P, |
| N_CTX, |
| PRESENCE_PENALTY, |
| REPEAT_PENALTY, |
| TEMPERATURE, |
| TOP_K, |
| TOP_P, |
| header_info, |
| model_zoo, |
| system_prompt, |
| ) |
|
|
| |
|
|
| _KV_TYPE: dict[str, int] = { |
| "f32": 0, |
| "f16": 1, |
| "q4_0": 2, |
| "q4_1": 3, |
| "q5_0": 6, |
| "q5_1": 7, |
| "q8_0": 8, |
| } |
|
|
| _THINK_OPEN = "<think>" |
| _THINK_CLOSE = "</think>" |
| _METRICS_SEP = "\n" |
|
|
| N_CPU = os.cpu_count() or 4 |
| N_PHYS = max(1, N_CPU // 2) |
|
|
| _DEFAULT_MODEL = next(iter(model_zoo)) |
| _loaded: dict[str, Llama] = {} |
|
|
|
|
| |
|
|
|
|
| class ThinkStripper: |
| """Streaming filter that removes <think>β¦</think> blocks.""" |
|
|
| def __init__(self) -> None: |
| self.in_think = False |
| self.buf = "" |
|
|
| def feed(self, text: str) -> str: |
| self.buf += text |
| out: list[str] = [] |
|
|
| while self.buf: |
| if self.in_think: |
| end = self.buf.find(_THINK_CLOSE) |
| if end == -1: |
| self.buf = "" |
| break |
| self.buf = self.buf[end + len(_THINK_CLOSE) :] |
| self.in_think = False |
| continue |
|
|
| start = self.buf.find(_THINK_OPEN) |
| end = self.buf.find(_THINK_CLOSE) |
|
|
| if start == -1 and end == -1: |
| out.append(self.buf) |
| self.buf = "" |
| elif start == -1: |
| out.append(self.buf[:end]) |
| self.buf = self.buf[end + len(_THINK_CLOSE) :] |
| else: |
| out.append(self.buf[:start]) |
| self.buf = self.buf[start + len(_THINK_OPEN) :] |
| self.in_think = True |
|
|
| return "".join(out) |
|
|
|
|
| |
|
|
|
|
| def _load_model(name: str) -> Llama: |
| cfg = model_zoo[name] |
| path = hf_hub_download(repo_id=cfg["repo_id"], filename=cfg["model_file"]) |
|
|
| base = dict( |
| model_path=path, |
| n_ctx=N_CTX, |
| n_batch=1024, |
| n_ubatch=1024, |
| n_threads=N_PHYS, |
| n_threads_batch=N_CPU, |
| flash_attn=bool(FLASH_ATTN), |
| use_mmap=True, |
| use_mlock=False, |
| verbose=False, |
| ) |
|
|
| kv = _KV_TYPE.get(KV_CACHE_TYPE) |
| try: |
| model = Llama(**base, type_k=kv, type_v=kv) |
| print(f"KV cache type: {KV_CACHE_TYPE}") |
| except ValueError: |
| print(f"KV cache '{KV_CACHE_TYPE}' unsupported on this backend, using default.") |
| model = Llama(**base) |
| return model |
|
|
|
|
| print(f"Loading {_DEFAULT_MODEL} β¦") |
| _loaded[_DEFAULT_MODEL] = _load_model(_DEFAULT_MODEL) |
| think_stripper = ThinkStripper() |
| print("Model ready.") |
|
|
|
|
| |
|
|
|
|
| def _to_str(content) -> str: |
| if isinstance(content, str): |
| return content |
| if isinstance(content, list): |
| return " ".join(b.get("text", "") for b in content if isinstance(b, dict)) |
| return str(content) |
|
|
|
|
| def _strip_think(text: str) -> str: |
| return think_stripper.feed(text) |
|
|
|
|
| def _strip_metrics(text: str) -> str: |
| """Drop the trailing metrics line we appended to assistant messages.""" |
| return text.split(_METRICS_SEP)[0] if _METRICS_SEP in text else text |
|
|
|
|
| def _display_content(turn: dict) -> str: |
| """User-visible content (without metrics line) of a history turn.""" |
| return _strip_metrics(_to_str(turn.get("content", ""))) |
|
|
|
|
| def _pick_feed_content(disp_turn: dict, raw_turn: dict | None) -> str: |
| """ |
| Choose the content to feed back into the model for a given turn. |
| |
| Prefer the raw version (which keeps <think>β¦</think>) so the KV-cache |
| prefix can be reused; if the user clearly edited the message via |
| `editable=True`, fall back to the displayed version instead. |
| """ |
| disp = _display_content(disp_turn) |
|
|
| if not ( |
| isinstance(raw_turn, dict) and raw_turn.get("role") == disp_turn.get("role") |
| ): |
| return disp |
|
|
| raw = _to_str(raw_turn.get("content", "")) |
|
|
| if disp_turn.get("role") == "assistant": |
| |
| if _strip_think(raw).strip() == disp.strip(): |
| return raw |
| return disp |
|
|
| |
| return raw if raw.strip() == disp.strip() else disp |
|
|
|
|
| |
|
|
|
|
| def respond( |
| message: str, history: list[dict], model_name: str, raw_history: list[dict] |
| ): |
| |
| if model_name not in _loaded: |
| print(f"Switching to {model_name} β¦") |
| _loaded[model_name] = _load_model(model_name) |
| print(f"{model_name} ready.") |
| llm = _loaded[model_name] |
|
|
| if not isinstance(history, list): |
| history = [] |
| if not isinstance(raw_history, list): |
| raw_history = [] |
|
|
| |
| messages: list[dict] = [{"role": "system", "content": system_prompt}] |
| aligned_raw: list[dict] = [] |
| for i, turn in enumerate(history): |
| if not isinstance(turn, dict) or "role" not in turn or "content" not in turn: |
| continue |
| raw_turn = raw_history[i] if i < len(raw_history) else None |
| feed = _pick_feed_content(turn, raw_turn) |
| messages.append({"role": turn["role"], "content": feed}) |
| aligned_raw.append({"role": turn["role"], "content": feed}) |
| messages.append({"role": "user", "content": message}) |
|
|
| |
| t_start = time.perf_counter() |
| n_gen = 0 |
| raw = "" |
| prev_visible = "" |
|
|
| for chunk in llm.create_chat_completion( |
| messages=messages, |
| max_tokens=MAX_TOKENS, |
| temperature=TEMPERATURE, |
| top_p=TOP_P, |
| top_k=TOP_K, |
| repeat_penalty=REPEAT_PENALTY, |
| presence_penalty=PRESENCE_PENALTY, |
| min_p=MIN_P, |
| stream=True, |
| ): |
| delta = chunk["choices"][0]["delta"].get("content") or "" |
| if not delta: |
| continue |
|
|
| raw += delta |
| n_gen += 1 |
| visible = _strip_think(raw) |
| if visible != prev_visible: |
| |
| yield visible, raw_history |
| prev_visible = visible |
|
|
| total_time = time.perf_counter() - t_start |
| overall_tps = n_gen / total_time if total_time > 0 else 0.0 |
| metrics_line = f"βοΈ {n_gen}t | β±οΈ {total_time:.1f}s | π {overall_tps:.1f}t/s" |
|
|
| |
| new_raw_history = [ |
| *aligned_raw, |
| {"role": "user", "content": message}, |
| {"role": "assistant", "content": raw}, |
| ] |
|
|
| response = _strip_think(raw) |
| yield f"{response}{_METRICS_SEP}`{metrics_line}`", new_raw_history |
|
|
|
|
| |
|
|
| with open("./style.css") as f: |
| CSS = f.read() |
|
|
| with gr.Blocks(title="EdgeRazor Playground") as demo: |
| gr.Image( |
| value="https://raw.githubusercontent.com/zhangsq-nju/EdgeRazor/main/asset/Logo-full.png", |
| show_label=False, |
| container=False, |
| interactive=False, |
| elem_classes=["logo-wrap"], |
| ) |
| gr.Markdown(header_info, elem_classes=["header-md"]) |
|
|
| current_model = gr.State(_DEFAULT_MODEL) |
| raw_history_state = gr.State([]) |
|
|
| with gr.Row(): |
| model_dd = gr.Dropdown( |
| choices=list(model_zoo.keys()), |
| value=_DEFAULT_MODEL, |
| label="Model", |
| interactive=True, |
| elem_id="model-selector", |
| ) |
|
|
| chat_iface = gr.ChatInterface( |
| fn=respond, |
| additional_inputs=[current_model, raw_history_state], |
| additional_outputs=[raw_history_state], |
| additional_inputs_accordion=gr.Accordion(label="", open=False, visible=False), |
| editable=True, |
| chatbot=gr.Chatbot(label="", height=480), |
| ) |
|
|
| def _on_model_change(new_model, cur_model, history): |
| |
| |
| if new_model == cur_model: |
| safe_history = history if isinstance(history, list) else [] |
| return ( |
| cur_model, |
| gr.update(value=cur_model), |
| safe_history, |
| safe_history, |
| [], |
| ) |
| return ( |
| new_model, |
| gr.update(value=new_model), |
| [], |
| [], |
| [], |
| ) |
|
|
| model_dd.change( |
| fn=_on_model_change, |
| inputs=[model_dd, current_model, chat_iface.chatbot_state], |
| outputs=[ |
| current_model, |
| model_dd, |
| chat_iface.chatbot, |
| chat_iface.chatbot_state, |
| raw_history_state, |
| ], |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch( |
| css=CSS, |
| server_name="0.0.0.0", |
| server_port=7860, |
| ssr_mode=False, |
| ) |
|
|