| from typing import List, TYPE_CHECKING |
| import torch |
| import sys |
|
|
| try: |
| if torch.__version__ and tuple(map(int, torch.__version__.split(".")[:2])) >= (2, 0) and sys.version_info < (3, 11): |
| compile_fn = torch.compile |
| else: |
| raise RuntimeError |
| except Exception: |
|
|
| def compile_fn(fn=None, **kwargs): |
| if fn is None: |
| return lambda f: f |
| return fn |
|
|
|
|
| if TYPE_CHECKING: |
| from .model import BitTransformerLM |
|
|
|
|
| @compile_fn |
| def bytes_to_bits(data: bytes) -> List[int]: |
| """Convert bytes to bits with per-byte parity bit.""" |
| result: List[int] = [] |
| for b in data: |
| bits = [(b >> i) & 1 for i in reversed(range(8))] |
| parity = sum(bits) % 2 |
| result.extend(bits + [parity]) |
| return result |
|
|
|
|
| @compile_fn |
| def bits_to_bytes(bits: List[int]) -> bytes: |
| """Convert parity-protected bits back to bytes.""" |
| if len(bits) % 9 != 0: |
| raise ValueError("Bit stream length must be multiple of 9") |
| out = bytearray() |
| for i in range(0, len(bits), 9): |
| chunk = bits[i : i + 9] |
| payload = chunk[:8] |
| parity = chunk[8] |
| if parity != sum(payload) % 2: |
| raise ValueError("Parity check failed") |
| value = 0 |
| for bit in payload: |
| value = (value << 1) | bit |
| out.append(value) |
| return bytes(out) |
|
|
|
|
| def text_to_bits(text: str) -> List[int]: |
| return bytes_to_bits(text.encode("utf-8")) |
|
|
|
|
| def bits_to_text(bits: List[int]) -> str: |
| return bits_to_bytes(bits).decode("utf-8", errors="replace") |
|
|
|
|
| def infer_text( |
| model: "BitTransformerLM", |
| text: str, |
| c_floor: float = 0.3, |
| s_floor: float = 0.5, |
| ) -> str: |
| """Run text through the model using the safety gate.""" |
| from .safety import hil_safe_inference |
| bits = text_to_bits(text) |
| tensor = torch.tensor(bits, dtype=torch.long).unsqueeze(0) |
| out_bits, _ = hil_safe_inference(model, tensor, c_floor=c_floor, s_floor=s_floor) |
| return bits_to_text(out_bits.squeeze(0).tolist()) |
|
|
|
|
| def sample_text( |
| model: "BitTransformerLM", |
| prompt: str, |
| max_new_tokens: int = 16, |
| temperature: float = 1.0, |
| top_p: float = 1.0, |
| ) -> str: |
| """Generate text from the model using simple top-p sampling.""" |
| model.eval() |
| bits = text_to_bits(prompt) |
| tensor = torch.tensor(bits, dtype=torch.long).unsqueeze(0) |
| for _ in range(max_new_tokens * 9): |
| if tensor.size(1) >= model.pos_enc.pe.size(0): |
| break |
| logits, _ = model(tensor, causal=True) |
| prob = logits[0, -1].softmax(-1) / temperature |
| sorted_prob, sorted_idx = prob.sort(descending=True) |
| cumulative = sorted_prob.cumsum(0) |
| mask = cumulative > top_p |
| sorted_prob[mask] = 0 |
| sorted_prob = sorted_prob / sorted_prob.sum() |
| next_bit = sorted_idx[torch.multinomial(sorted_prob, 1)] |
| tensor = torch.cat([tensor, next_bit.view(1, 1)], dim=1) |
| return bits_to_text(tensor.squeeze(0).tolist()) |
|
|