carbon-helix-2d / app.py
ysharma's picture
ysharma HF Staff
ZeroGPU: cap duration at 120s via dynamic callable; place model on cuda at module level
e6d345e verified
"""Carbon Helix — gr.Server + ZeroGPU.
Custom Three.js frontend served from index.html, streaming generation
on ZeroGPU via @server.api. The JS client (@gradio/client) connects to
the same origin and calls /generate.
Heavy model load is guarded by `if gr.NO_RELOAD:` so `gradio app.py`
hot-reload does not re-download a multi-gigabyte model on every save.
"""
import os
import gradio as gr
import torch
from fastapi.responses import HTMLResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
try:
import spaces
_HAS_SPACES = True
except ImportError:
_HAS_SPACES = False
HERE = os.path.dirname(os.path.abspath(__file__))
MODEL_NAME = os.environ.get("MODEL_NAME", "HuggingFaceBio/Carbon-3B")
HF_TOKEN = os.environ.get("HF_TOKEN")
def left_pad_to_six(seq: str) -> tuple[str, int]:
if not seq:
return seq, 0
rem = len(seq) % 6
if rem == 0:
return seq, 0
n_pad = 6 - rem
return ("A" * n_pad) + seq, n_pad
if gr.NO_RELOAD:
print(f"[carbon-helix] loading tokenizer {MODEL_NAME}...", flush=True)
tok = AutoTokenizer.from_pretrained(
MODEL_NAME, trust_remote_code=True, token=HF_TOKEN
)
print(f"[carbon-helix] loading model {MODEL_NAME} (bf16)...", flush=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16,
token=HF_TOKEN,
low_cpu_mem_usage=True,
)
model.eval()
# Place on cuda at module level. Outside @spaces.GPU, PyTorch's CUDA
# emulation mode (set up by the `spaces` library) accepts this without
# holding a real GPU; ZeroGPU then fast-snapshots the placement so each
# @spaces.GPU call restores instantly. Moving to cuda *inside* the GPU
# function is significantly slower per the ZeroGPU docs.
# https://huggingface.co/docs/hub/spaces-zerogpu#model-loading
if _HAS_SPACES:
model.to("cuda")
N_PARAMS_B = sum(p.numel() for p in model.parameters()) / 1e9
print(f"[carbon-helix] model ready — {N_PARAMS_B:.2f}B params", flush=True)
vocab = tok.get_vocab()
EOS_ID = vocab.get("<|endoftext|>")
DNA_CLOSE_ID = vocab.get("</dna>")
STOP_IDS: set[int] = {i for i in (EOS_ID, DNA_CLOSE_ID) if i is not None}
print(f"[carbon-helix] stop ids: eos={EOS_ID} </dna>={DNA_CLOSE_ID}", flush=True)
def _sample(logits: torch.Tensor, temperature: float, top_p: float) -> tuple[int, float]:
original_logprobs = torch.log_softmax(logits.float(), dim=-1)
if temperature <= 1e-3:
token_id = int(logits.argmax().item())
return token_id, float(original_logprobs[token_id].item())
scaled = logits.float() / max(temperature, 1e-5)
sorted_logits, sorted_indices = torch.sort(scaled, descending=True)
cumulative = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
remove = cumulative > top_p
remove[1:] = remove[:-1].clone()
remove[0] = False
filtered = scaled.clone()
filtered[sorted_indices[remove]] = float("-inf")
probs = torch.softmax(filtered, dim=-1)
token_id = int(torch.multinomial(probs, num_samples=1).item())
return token_id, float(original_logprobs[token_id].item())
def _gpu_stream_impl(input_ids_cpu: torch.Tensor, max_tokens: int,
temperature: float, top_p: float):
# Model is already on cuda (placed at module level). Only inputs move.
input_ids = input_ids_cpu.to("cuda")
with torch.no_grad():
out = model(input_ids=input_ids, use_cache=True)
past_kv = out.past_key_values
next_logits = out.logits[0, -1, :]
for _ in range(max_tokens):
token_id, logprob = _sample(next_logits, temperature, top_p)
yield token_id, logprob
if token_id in STOP_IDS:
break
nt = torch.tensor([[token_id]], device="cuda")
out = model(input_ids=nt, past_key_values=past_kv, use_cache=True)
past_kv = out.past_key_values
next_logits = out.logits[0, -1, :]
def _gen_duration(input_ids_cpu, max_tokens, temperature, top_p) -> int:
# Dynamic ZeroGPU lease: shorter durations earn higher queue priority and
# don't over-consume the daily quota. ~15s warmup/restore + ~0.5s/token at
# the RTX Pro 6000 Blackwell's expected throughput. Capped at 120s, which
# is the current ZeroGPU per-call maximum.
return min(120, max(20, 15 + int(int(max_tokens) * 0.5)))
if _HAS_SPACES:
_gpu_stream = spaces.GPU(duration=_gen_duration)(_gpu_stream_impl)
else:
_gpu_stream = _gpu_stream_impl
server = gr.Server()
@server.get("/", response_class=HTMLResponse)
async def homepage():
with open(os.path.join(HERE, "index.html"), "r", encoding="utf-8") as f:
return f.read()
@server.get("/app-info")
async def app_info():
# `/config` is owned by Gradio (reserved internal route), so we expose
# our own info endpoint for the frontend's model-name display.
return {
"model": MODEL_NAME,
"params_b": round(N_PARAMS_B, 2),
}
@server.api(name="generate", concurrency_limit=1)
def generate(prompt: str, metadata: str, max_tokens: int,
temperature: float, top_p: float) -> dict:
seq = "".join(c for c in (prompt or "").upper() if c in "ACGTN")
if not seq:
yield {"event": "error", "message": "prompt must contain DNA bases (A/C/G/T)"}
return
seq_padded, pad_bases = left_pad_to_six(seq)
full_prompt = (metadata or "") + "<dna>" + seq_padded
input_ids = tok(
full_prompt, return_tensors="pt", add_special_tokens=False
).input_ids
yield {
"event": "start",
"input_length": len(seq),
"pad_bases": pad_bases,
"model": MODEL_NAME,
}
try:
for tid, lp in _gpu_stream(
input_ids, int(max_tokens), float(temperature), float(top_p)
):
text = tok.decode([tid], skip_special_tokens=False)
yield {"event": "token", "tokens": [text], "logprobs": [lp]}
yield {"event": "done"}
except Exception as e:
yield {"event": "error", "message": str(e)}
print(f"[carbon-helix] module loaded — model={MODEL_NAME}", flush=True)
# Spaces' hot-reload launcher looks for a module-level `demo` variable.
# Alias our gr.Server so it's discovered without falling back to a stub.
demo = server
if __name__ == "__main__":
server.launch(server_name="0.0.0.0", server_port=7860, show_error=True)