| import io |
| import time |
| import contextlib |
| from pathlib import Path |
| import sys |
| import torch |
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| if str(ROOT) not in sys.path: |
| sys.path.insert(0, str(ROOT)) |
|
|
| from progressive_scaleup import progressive_scale_up_text |
| from unified_workflow import run_workflow |
| from bit_transformer.bit_io import text_to_bits |
| from bit_transformer.safety import hil_safe_inference |
|
|
|
|
| def capture_run(func, *args, **kwargs): |
| buf = io.StringIO() |
| start = time.time() |
| with contextlib.redirect_stdout(buf): |
| result = func(*args, **kwargs) |
| duration = time.time() - start |
| return result, buf.getvalue(), duration |
|
|
|
|
| def main() -> None: |
| summary: list[str] = [] |
|
|
| _, log, dur = capture_run( |
| progressive_scale_up_text, |
| improve_thresh=0.01, |
| steps=10, |
| width_mult=2.0, |
| max_len=64, |
| dataset_size=512, |
| forward_kwargs={"causal": True}, |
| ) |
| summary.append("### Progressive Scale-Up (causal=True)\n") |
| summary.append(log.strip()) |
| summary.append(f"Duration: {dur:.2f}s\n") |
|
|
| _, log, dur = capture_run( |
| progressive_scale_up_text, |
| improve_thresh=0.01, |
| steps=10, |
| width_mult=2.0, |
| max_len=64, |
| dataset_size=512, |
| forward_kwargs={"causal": False}, |
| ) |
| summary.append("### Progressive Scale-Up (causal=False)\n") |
| summary.append(log.strip()) |
| summary.append(f"Duration: {dur:.2f}s\n") |
|
|
| (model, _), log, dur = capture_run( |
| run_workflow, |
| steps=2, |
| max_len=32, |
| dataset_size=32, |
| plateau_steps=1, |
| epochs_per_step=1, |
| extra_steps=1, |
| diffusion=False, |
| ) |
| bits = text_to_bits("hi") |
| tensor = torch.tensor(bits, dtype=torch.long).unsqueeze(0) |
| out_bits, _ = hil_safe_inference(model, tensor, c_floor=0.0, s_floor=0.0) |
| summary.append("### Unified Workflow (causal=True)\n") |
| summary.append(log.strip()) |
| summary.append(f"Inference on 'hi': {out_bits.squeeze(0).tolist()}\n") |
| summary.append(f"Duration: {dur:.2f}s\n") |
|
|
| (_, _), log, dur = capture_run( |
| run_workflow, |
| steps=2, |
| max_len=32, |
| dataset_size=32, |
| plateau_steps=1, |
| epochs_per_step=1, |
| extra_steps=1, |
| diffusion=True, |
| ) |
| summary.append("### Unified Workflow (causal=False / Diffusion)\n") |
| summary.append(log.strip()) |
| summary.append(f"Duration: {dur:.2f}s\n") |
|
|
| report = "\n".join(summary) |
| print(report) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|