feat(seed): apply_seed helper that returns the seed actually used
Browse files- server/seed.py +41 -0
- tests/test_seed.py +42 -0
server/seed.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Seed helper for reproducible Chatterbox generations.
|
| 2 |
+
|
| 3 |
+
`apply_seed(seed)`:
|
| 4 |
+
- if `seed` is `None` or `< 0`, draw a fresh non-negative 31-bit int
|
| 5 |
+
- call torch / cuda / mps / pyrandom seeding APIs with the chosen seed
|
| 6 |
+
- return the seed that was actually used (so the endpoint can echo it back)
|
| 7 |
+
|
| 8 |
+
Failures inside platform-specific seeding (e.g. mps not present) are
|
| 9 |
+
swallowed — the helper is best-effort, not a contract for determinism
|
| 10 |
+
across hardware.
|
| 11 |
+
"""
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import random
|
| 15 |
+
from typing import Optional
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _maybe_seed_mps(seed: int) -> None:
|
| 21 |
+
mps = getattr(torch, "mps", None)
|
| 22 |
+
if mps is None:
|
| 23 |
+
return
|
| 24 |
+
fn = getattr(mps, "manual_seed", None)
|
| 25 |
+
if fn is None:
|
| 26 |
+
return
|
| 27 |
+
fn(seed)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def apply_seed(seed: Optional[int]) -> int:
|
| 31 |
+
if seed is None or seed < 0:
|
| 32 |
+
seed = random.randint(0, 2**31 - 1)
|
| 33 |
+
torch.manual_seed(seed)
|
| 34 |
+
if torch.cuda.is_available():
|
| 35 |
+
torch.cuda.manual_seed_all(seed)
|
| 36 |
+
try:
|
| 37 |
+
_maybe_seed_mps(seed)
|
| 38 |
+
except Exception:
|
| 39 |
+
pass
|
| 40 |
+
random.seed(seed)
|
| 41 |
+
return seed
|
tests/test_seed.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random as pyrandom
|
| 2 |
+
from unittest.mock import patch
|
| 3 |
+
|
| 4 |
+
from server.seed import apply_seed
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def test_apply_seed_returns_provided_value():
|
| 8 |
+
assert apply_seed(42) == 42
|
| 9 |
+
assert apply_seed(0) == 0
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def test_apply_seed_negative_draws_random():
|
| 13 |
+
s = apply_seed(-1)
|
| 14 |
+
assert isinstance(s, int)
|
| 15 |
+
assert 0 <= s < 2**31
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def test_apply_seed_none_draws_random():
|
| 19 |
+
s = apply_seed(None)
|
| 20 |
+
assert isinstance(s, int)
|
| 21 |
+
assert 0 <= s < 2**31
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def test_apply_seed_seeds_pyrandom_so_repeats_match():
|
| 25 |
+
s = apply_seed(123)
|
| 26 |
+
a = pyrandom.random()
|
| 27 |
+
apply_seed(s)
|
| 28 |
+
b = pyrandom.random()
|
| 29 |
+
assert a == b
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def test_apply_seed_calls_torch_manual_seed():
|
| 33 |
+
with patch("server.seed.torch.manual_seed") as m:
|
| 34 |
+
apply_seed(99)
|
| 35 |
+
m.assert_called_once_with(99)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def test_apply_seed_swallows_mps_failure():
|
| 39 |
+
with patch("server.seed._maybe_seed_mps", side_effect=RuntimeError("nope")):
|
| 40 |
+
# Should not raise
|
| 41 |
+
s = apply_seed(7)
|
| 42 |
+
assert s == 7
|