Ghostgim commited on
Commit
088475f
·
verified ·
1 Parent(s): c6fa371

fix(space): cast model weights to fp16 to keep cpu-basic worker alive

Browse files

The 81M v0.9 chat checkpoint at fp32 is 324 MB of weights plus a similar
peak of forward-pass activations. On HF's cpu-basic Space (2 vCPU, shared
RAM pool with a per-worker memory budget), two consecutive generations
push the worker past its budget and the OS kills it. The Space restarts
cleanly but the user sees a client-side error after the second prompt and
has to reload.

Casting model.half() at load time on the Space halves both the weight
footprint (324 -> 162 MB) and the activation memory of each forward pass.
The cast is gated on os.environ.get('SPACE_ID') so local CPU/GPU users
still get fp32 by default. CPU-fp16 inference is slower than fp32 (no
tensor cores, no AVX-fp16) but at 81M we still expect ~15-25 s per reply,
which is acceptable for a CPU demo and far better than 'the demo errors
after two prompts'.

Files changed (1) hide show
  1. app.py +14 -0
app.py CHANGED
@@ -111,6 +111,20 @@ def load_model(path: str):
111
  state = ckpt.get("model_state_dict", ckpt.get("model"))
112
  model.load_state_dict(state, strict=False)
113
  model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  return model, config, path
115
 
116
 
 
111
  state = ckpt.get("model_state_dict", ckpt.get("model"))
112
  model.load_state_dict(state, strict=False)
113
  model.eval()
114
+
115
+ # On the HF Space (cpu-basic, ~2 vCPU, ~16 GB shared) the 81M v0.9
116
+ # checkpoint at fp32 is 324 MB of weights plus a similar peak of
117
+ # forward-pass activations. Two consecutive generations push the
118
+ # worker over its memory budget and it OOM-crashes between turns
119
+ # (the Space restarts cleanly but the user sees an error in the UI
120
+ # and has to reload the page). Casting weights to fp16 halves the
121
+ # weight footprint to ~162 MB and roughly halves activation memory
122
+ # too. CPU-fp16 inference is slower than fp32 but the model is
123
+ # small enough that we still come in at ~15-25 s per reply, which
124
+ # is fine for a CPU demo.
125
+ if os.environ.get("SPACE_ID"):
126
+ model = model.half()
127
+
128
  return model, config, path
129
 
130