"""Carbon Helix — gr.Server + ZeroGPU. Custom Three.js frontend served from index.html, streaming generation on ZeroGPU via @server.api. The JS client (@gradio/client) connects to the same origin and calls /generate. Heavy model load is guarded by `if gr.NO_RELOAD:` so `gradio app.py` hot-reload does not re-download a multi-gigabyte model on every save. """ import os import gradio as gr import torch from fastapi.responses import HTMLResponse from transformers import AutoModelForCausalLM, AutoTokenizer try: import spaces _HAS_SPACES = True except ImportError: _HAS_SPACES = False HERE = os.path.dirname(os.path.abspath(__file__)) MODEL_NAME = os.environ.get("MODEL_NAME", "HuggingFaceBio/Carbon-3B") HF_TOKEN = os.environ.get("HF_TOKEN") def left_pad_to_six(seq: str) -> tuple[str, int]: if not seq: return seq, 0 rem = len(seq) % 6 if rem == 0: return seq, 0 n_pad = 6 - rem return ("A" * n_pad) + seq, n_pad if gr.NO_RELOAD: print(f"[carbon-helix] loading tokenizer {MODEL_NAME}...", flush=True) tok = AutoTokenizer.from_pretrained( MODEL_NAME, trust_remote_code=True, token=HF_TOKEN ) print(f"[carbon-helix] loading model {MODEL_NAME} (bf16)...", flush=True) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.bfloat16, token=HF_TOKEN, low_cpu_mem_usage=True, ) model.eval() # Place on cuda at module level. Outside @spaces.GPU, PyTorch's CUDA # emulation mode (set up by the `spaces` library) accepts this without # holding a real GPU; ZeroGPU then fast-snapshots the placement so each # @spaces.GPU call restores instantly. Moving to cuda *inside* the GPU # function is significantly slower per the ZeroGPU docs. # https://huggingface.co/docs/hub/spaces-zerogpu#model-loading if _HAS_SPACES: model.to("cuda") N_PARAMS_B = sum(p.numel() for p in model.parameters()) / 1e9 print(f"[carbon-helix] model ready — {N_PARAMS_B:.2f}B params", flush=True) vocab = tok.get_vocab() EOS_ID = vocab.get("<|endoftext|>") DNA_CLOSE_ID = vocab.get("") STOP_IDS: set[int] = {i for i in (EOS_ID, DNA_CLOSE_ID) if i is not None} print(f"[carbon-helix] stop ids: eos={EOS_ID} ={DNA_CLOSE_ID}", flush=True) def _sample(logits: torch.Tensor, temperature: float, top_p: float) -> tuple[int, float]: original_logprobs = torch.log_softmax(logits.float(), dim=-1) if temperature <= 1e-3: token_id = int(logits.argmax().item()) return token_id, float(original_logprobs[token_id].item()) scaled = logits.float() / max(temperature, 1e-5) sorted_logits, sorted_indices = torch.sort(scaled, descending=True) cumulative = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) remove = cumulative > top_p remove[1:] = remove[:-1].clone() remove[0] = False filtered = scaled.clone() filtered[sorted_indices[remove]] = float("-inf") probs = torch.softmax(filtered, dim=-1) token_id = int(torch.multinomial(probs, num_samples=1).item()) return token_id, float(original_logprobs[token_id].item()) def _gpu_stream_impl(input_ids_cpu: torch.Tensor, max_tokens: int, temperature: float, top_p: float): # Model is already on cuda (placed at module level). Only inputs move. input_ids = input_ids_cpu.to("cuda") with torch.no_grad(): out = model(input_ids=input_ids, use_cache=True) past_kv = out.past_key_values next_logits = out.logits[0, -1, :] for _ in range(max_tokens): token_id, logprob = _sample(next_logits, temperature, top_p) yield token_id, logprob if token_id in STOP_IDS: break nt = torch.tensor([[token_id]], device="cuda") out = model(input_ids=nt, past_key_values=past_kv, use_cache=True) past_kv = out.past_key_values next_logits = out.logits[0, -1, :] def _gen_duration(input_ids_cpu, max_tokens, temperature, top_p) -> int: # Dynamic ZeroGPU lease: shorter durations earn higher queue priority and # don't over-consume the daily quota. ~15s warmup/restore + ~0.5s/token at # the RTX Pro 6000 Blackwell's expected throughput. Capped at 120s, which # is the current ZeroGPU per-call maximum. return min(120, max(20, 15 + int(int(max_tokens) * 0.5))) if _HAS_SPACES: _gpu_stream = spaces.GPU(duration=_gen_duration)(_gpu_stream_impl) else: _gpu_stream = _gpu_stream_impl server = gr.Server() @server.get("/", response_class=HTMLResponse) async def homepage(): with open(os.path.join(HERE, "index.html"), "r", encoding="utf-8") as f: return f.read() @server.get("/app-info") async def app_info(): # `/config` is owned by Gradio (reserved internal route), so we expose # our own info endpoint for the frontend's model-name display. return { "model": MODEL_NAME, "params_b": round(N_PARAMS_B, 2), } @server.api(name="generate", concurrency_limit=1) def generate(prompt: str, metadata: str, max_tokens: int, temperature: float, top_p: float) -> dict: seq = "".join(c for c in (prompt or "").upper() if c in "ACGTN") if not seq: yield {"event": "error", "message": "prompt must contain DNA bases (A/C/G/T)"} return seq_padded, pad_bases = left_pad_to_six(seq) full_prompt = (metadata or "") + "" + seq_padded input_ids = tok( full_prompt, return_tensors="pt", add_special_tokens=False ).input_ids yield { "event": "start", "input_length": len(seq), "pad_bases": pad_bases, "model": MODEL_NAME, } try: for tid, lp in _gpu_stream( input_ids, int(max_tokens), float(temperature), float(top_p) ): text = tok.decode([tid], skip_special_tokens=False) yield {"event": "token", "tokens": [text], "logprobs": [lp]} yield {"event": "done"} except Exception as e: yield {"event": "error", "message": str(e)} print(f"[carbon-helix] module loaded — model={MODEL_NAME}", flush=True) # Spaces' hot-reload launcher looks for a module-level `demo` variable. # Alias our gr.Server so it's discovered without falling back to a stub. demo = server if __name__ == "__main__": server.launch(server_name="0.0.0.0", server_port=7860, show_error=True)