fix(space): cast model weights to fp16 to keep cpu-basic worker alive
Browse filesThe 81M v0.9 chat checkpoint at fp32 is 324 MB of weights plus a similar
peak of forward-pass activations. On HF's cpu-basic Space (2 vCPU, shared
RAM pool with a per-worker memory budget), two consecutive generations
push the worker past its budget and the OS kills it. The Space restarts
cleanly but the user sees a client-side error after the second prompt and
has to reload.
Casting model.half() at load time on the Space halves both the weight
footprint (324 -> 162 MB) and the activation memory of each forward pass.
The cast is gated on os.environ.get('SPACE_ID') so local CPU/GPU users
still get fp32 by default. CPU-fp16 inference is slower than fp32 (no
tensor cores, no AVX-fp16) but at 81M we still expect ~15-25 s per reply,
which is acceptable for a CPU demo and far better than 'the demo errors
after two prompts'.
|
@@ -111,6 +111,20 @@ def load_model(path: str):
|
|
| 111 |
state = ckpt.get("model_state_dict", ckpt.get("model"))
|
| 112 |
model.load_state_dict(state, strict=False)
|
| 113 |
model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
return model, config, path
|
| 115 |
|
| 116 |
|
|
|
|
| 111 |
state = ckpt.get("model_state_dict", ckpt.get("model"))
|
| 112 |
model.load_state_dict(state, strict=False)
|
| 113 |
model.eval()
|
| 114 |
+
|
| 115 |
+
# On the HF Space (cpu-basic, ~2 vCPU, ~16 GB shared) the 81M v0.9
|
| 116 |
+
# checkpoint at fp32 is 324 MB of weights plus a similar peak of
|
| 117 |
+
# forward-pass activations. Two consecutive generations push the
|
| 118 |
+
# worker over its memory budget and it OOM-crashes between turns
|
| 119 |
+
# (the Space restarts cleanly but the user sees an error in the UI
|
| 120 |
+
# and has to reload the page). Casting weights to fp16 halves the
|
| 121 |
+
# weight footprint to ~162 MB and roughly halves activation memory
|
| 122 |
+
# too. CPU-fp16 inference is slower than fp32 but the model is
|
| 123 |
+
# small enough that we still come in at ~15-25 s per reply, which
|
| 124 |
+
# is fine for a CPU demo.
|
| 125 |
+
if os.environ.get("SPACE_ID"):
|
| 126 |
+
model = model.half()
|
| 127 |
+
|
| 128 |
return model, config, path
|
| 129 |
|
| 130 |
|