techfreakworm commited on
Commit
451dece
·
unverified ·
1 Parent(s): a75ec91

feat(seed): apply_seed helper that returns the seed actually used

Browse files
Files changed (2) hide show
  1. server/seed.py +41 -0
  2. tests/test_seed.py +42 -0
server/seed.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Seed helper for reproducible Chatterbox generations.
2
+
3
+ `apply_seed(seed)`:
4
+ - if `seed` is `None` or `< 0`, draw a fresh non-negative 31-bit int
5
+ - call torch / cuda / mps / pyrandom seeding APIs with the chosen seed
6
+ - return the seed that was actually used (so the endpoint can echo it back)
7
+
8
+ Failures inside platform-specific seeding (e.g. mps not present) are
9
+ swallowed — the helper is best-effort, not a contract for determinism
10
+ across hardware.
11
+ """
12
+ from __future__ import annotations
13
+
14
+ import random
15
+ from typing import Optional
16
+
17
+ import torch
18
+
19
+
20
+ def _maybe_seed_mps(seed: int) -> None:
21
+ mps = getattr(torch, "mps", None)
22
+ if mps is None:
23
+ return
24
+ fn = getattr(mps, "manual_seed", None)
25
+ if fn is None:
26
+ return
27
+ fn(seed)
28
+
29
+
30
+ def apply_seed(seed: Optional[int]) -> int:
31
+ if seed is None or seed < 0:
32
+ seed = random.randint(0, 2**31 - 1)
33
+ torch.manual_seed(seed)
34
+ if torch.cuda.is_available():
35
+ torch.cuda.manual_seed_all(seed)
36
+ try:
37
+ _maybe_seed_mps(seed)
38
+ except Exception:
39
+ pass
40
+ random.seed(seed)
41
+ return seed
tests/test_seed.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random as pyrandom
2
+ from unittest.mock import patch
3
+
4
+ from server.seed import apply_seed
5
+
6
+
7
+ def test_apply_seed_returns_provided_value():
8
+ assert apply_seed(42) == 42
9
+ assert apply_seed(0) == 0
10
+
11
+
12
+ def test_apply_seed_negative_draws_random():
13
+ s = apply_seed(-1)
14
+ assert isinstance(s, int)
15
+ assert 0 <= s < 2**31
16
+
17
+
18
+ def test_apply_seed_none_draws_random():
19
+ s = apply_seed(None)
20
+ assert isinstance(s, int)
21
+ assert 0 <= s < 2**31
22
+
23
+
24
+ def test_apply_seed_seeds_pyrandom_so_repeats_match():
25
+ s = apply_seed(123)
26
+ a = pyrandom.random()
27
+ apply_seed(s)
28
+ b = pyrandom.random()
29
+ assert a == b
30
+
31
+
32
+ def test_apply_seed_calls_torch_manual_seed():
33
+ with patch("server.seed.torch.manual_seed") as m:
34
+ apply_seed(99)
35
+ m.assert_called_once_with(99)
36
+
37
+
38
+ def test_apply_seed_swallows_mps_failure():
39
+ with patch("server.seed._maybe_seed_mps", side_effect=RuntimeError("nope")):
40
+ # Should not raise
41
+ s = apply_seed(7)
42
+ assert s == 7