Spaces:
Running on Zero
Running on Zero
| """Carbon Helix — gr.Server + ZeroGPU. | |
| Custom Three.js frontend served from index.html, streaming generation | |
| on ZeroGPU via @server.api. The JS client (@gradio/client) connects to | |
| the same origin and calls /generate. | |
| Heavy model load is guarded by `if gr.NO_RELOAD:` so `gradio app.py` | |
| hot-reload does not re-download a multi-gigabyte model on every save. | |
| """ | |
| import os | |
| import gradio as gr | |
| import torch | |
| from fastapi.responses import HTMLResponse | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| try: | |
| import spaces | |
| _HAS_SPACES = True | |
| except ImportError: | |
| _HAS_SPACES = False | |
| HERE = os.path.dirname(os.path.abspath(__file__)) | |
| MODEL_NAME = os.environ.get("MODEL_NAME", "HuggingFaceBio/Carbon-3B") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| def left_pad_to_six(seq: str) -> tuple[str, int]: | |
| if not seq: | |
| return seq, 0 | |
| rem = len(seq) % 6 | |
| if rem == 0: | |
| return seq, 0 | |
| n_pad = 6 - rem | |
| return ("A" * n_pad) + seq, n_pad | |
| if gr.NO_RELOAD: | |
| print(f"[carbon-helix] loading tokenizer {MODEL_NAME}...", flush=True) | |
| tok = AutoTokenizer.from_pretrained( | |
| MODEL_NAME, trust_remote_code=True, token=HF_TOKEN | |
| ) | |
| print(f"[carbon-helix] loading model {MODEL_NAME} (bf16)...", flush=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch.bfloat16, | |
| token=HF_TOKEN, | |
| low_cpu_mem_usage=True, | |
| ) | |
| model.eval() | |
| # Place on cuda at module level. Outside @spaces.GPU, PyTorch's CUDA | |
| # emulation mode (set up by the `spaces` library) accepts this without | |
| # holding a real GPU; ZeroGPU then fast-snapshots the placement so each | |
| # @spaces.GPU call restores instantly. Moving to cuda *inside* the GPU | |
| # function is significantly slower per the ZeroGPU docs. | |
| # https://huggingface.co/docs/hub/spaces-zerogpu#model-loading | |
| if _HAS_SPACES: | |
| model.to("cuda") | |
| N_PARAMS_B = sum(p.numel() for p in model.parameters()) / 1e9 | |
| print(f"[carbon-helix] model ready — {N_PARAMS_B:.2f}B params", flush=True) | |
| vocab = tok.get_vocab() | |
| EOS_ID = vocab.get("<|endoftext|>") | |
| DNA_CLOSE_ID = vocab.get("</dna>") | |
| STOP_IDS: set[int] = {i for i in (EOS_ID, DNA_CLOSE_ID) if i is not None} | |
| print(f"[carbon-helix] stop ids: eos={EOS_ID} </dna>={DNA_CLOSE_ID}", flush=True) | |
| def _sample(logits: torch.Tensor, temperature: float, top_p: float) -> tuple[int, float]: | |
| original_logprobs = torch.log_softmax(logits.float(), dim=-1) | |
| if temperature <= 1e-3: | |
| token_id = int(logits.argmax().item()) | |
| return token_id, float(original_logprobs[token_id].item()) | |
| scaled = logits.float() / max(temperature, 1e-5) | |
| sorted_logits, sorted_indices = torch.sort(scaled, descending=True) | |
| cumulative = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) | |
| remove = cumulative > top_p | |
| remove[1:] = remove[:-1].clone() | |
| remove[0] = False | |
| filtered = scaled.clone() | |
| filtered[sorted_indices[remove]] = float("-inf") | |
| probs = torch.softmax(filtered, dim=-1) | |
| token_id = int(torch.multinomial(probs, num_samples=1).item()) | |
| return token_id, float(original_logprobs[token_id].item()) | |
| def _gpu_stream_impl(input_ids_cpu: torch.Tensor, max_tokens: int, | |
| temperature: float, top_p: float): | |
| # Model is already on cuda (placed at module level). Only inputs move. | |
| input_ids = input_ids_cpu.to("cuda") | |
| with torch.no_grad(): | |
| out = model(input_ids=input_ids, use_cache=True) | |
| past_kv = out.past_key_values | |
| next_logits = out.logits[0, -1, :] | |
| for _ in range(max_tokens): | |
| token_id, logprob = _sample(next_logits, temperature, top_p) | |
| yield token_id, logprob | |
| if token_id in STOP_IDS: | |
| break | |
| nt = torch.tensor([[token_id]], device="cuda") | |
| out = model(input_ids=nt, past_key_values=past_kv, use_cache=True) | |
| past_kv = out.past_key_values | |
| next_logits = out.logits[0, -1, :] | |
| def _gen_duration(input_ids_cpu, max_tokens, temperature, top_p) -> int: | |
| # Dynamic ZeroGPU lease: shorter durations earn higher queue priority and | |
| # don't over-consume the daily quota. ~15s warmup/restore + ~0.5s/token at | |
| # the RTX Pro 6000 Blackwell's expected throughput. Capped at 120s, which | |
| # is the current ZeroGPU per-call maximum. | |
| return min(120, max(20, 15 + int(int(max_tokens) * 0.5))) | |
| if _HAS_SPACES: | |
| _gpu_stream = spaces.GPU(duration=_gen_duration)(_gpu_stream_impl) | |
| else: | |
| _gpu_stream = _gpu_stream_impl | |
| server = gr.Server() | |
| async def homepage(): | |
| with open(os.path.join(HERE, "index.html"), "r", encoding="utf-8") as f: | |
| return f.read() | |
| async def app_info(): | |
| # `/config` is owned by Gradio (reserved internal route), so we expose | |
| # our own info endpoint for the frontend's model-name display. | |
| return { | |
| "model": MODEL_NAME, | |
| "params_b": round(N_PARAMS_B, 2), | |
| } | |
| def generate(prompt: str, metadata: str, max_tokens: int, | |
| temperature: float, top_p: float) -> dict: | |
| seq = "".join(c for c in (prompt or "").upper() if c in "ACGTN") | |
| if not seq: | |
| yield {"event": "error", "message": "prompt must contain DNA bases (A/C/G/T)"} | |
| return | |
| seq_padded, pad_bases = left_pad_to_six(seq) | |
| full_prompt = (metadata or "") + "<dna>" + seq_padded | |
| input_ids = tok( | |
| full_prompt, return_tensors="pt", add_special_tokens=False | |
| ).input_ids | |
| yield { | |
| "event": "start", | |
| "input_length": len(seq), | |
| "pad_bases": pad_bases, | |
| "model": MODEL_NAME, | |
| } | |
| try: | |
| for tid, lp in _gpu_stream( | |
| input_ids, int(max_tokens), float(temperature), float(top_p) | |
| ): | |
| text = tok.decode([tid], skip_special_tokens=False) | |
| yield {"event": "token", "tokens": [text], "logprobs": [lp]} | |
| yield {"event": "done"} | |
| except Exception as e: | |
| yield {"event": "error", "message": str(e)} | |
| print(f"[carbon-helix] module loaded — model={MODEL_NAME}", flush=True) | |
| # Spaces' hot-reload launcher looks for a module-level `demo` variable. | |
| # Alias our gr.Server so it's discovered without falling back to a stub. | |
| demo = server | |
| if __name__ == "__main__": | |
| server.launch(server_name="0.0.0.0", server_port=7860, show_error=True) | |