| import io |
| import json |
| import os |
| import traceback |
| import inspect |
| from typing import Any, Dict, List, Optional, Union |
|
|
| from flask import Flask, jsonify, request, render_template, send_file |
| import subprocess |
| import sys |
| import warnings |
| import matplotlib.pyplot as plt |
| import torch |
| import torch.nn.functional as F |
| import requests |
| import gzip |
|
|
| from .model import BitTransformerLM, infer_long_sequence |
| from .optimization import configure_optimizer |
| from .collapse import collapse_submodel |
| from .dashboard import plot_telemetry |
| from .scale import expand_model |
| from .bit_io import text_to_bits, bits_to_text |
| from .safety import hil_safe_inference |
| from .compression import model_output_decompress, compress_bits |
| from .distributed import wrap_fsdp |
| from .training import train_loop |
| from .telemetry import detect_metric_drift |
| from .quantization import prepare_qat_fx, convert_qat_fx |
| from torch.distributed.fsdp import FullyShardedDataParallel |
| from .hf_checkpoint import hf_login, save_checkpoint, download_checkpoint |
|
|
|
|
| app = Flask(__name__) |
| app.config["MAX_CONTENT_LENGTH"] = 1 * 1024 * 1024 |
|
|
| MCP_SERVER_ADDR = os.getenv("MCP_SERVER_ADDR") |
|
|
|
|
| @app.errorhandler(Exception) |
| def handle_exception(err): |
| """Return JSON error responses with stack traces.""" |
| return ( |
| jsonify({"error": str(err), "trace": traceback.format_exc()}), |
| getattr(err, "code", 500), |
| ) |
|
|
| class MetricDriftWarning(UserWarning): |
| """Raised when telemetry metrics drift beyond the configured threshold.""" |
|
|
| def _switch_torch(use_gpu: bool) -> None: |
| """Install the appropriate PyTorch wheel and restart the process.""" |
| have_cuda = torch.version.cuda is not None |
| if use_gpu == have_cuda: |
| return |
| wheel = "torch==2.7.1+cu118" if use_gpu else "torch==2.7.1+cpu" |
| url = "https://download.pytorch.org/whl/cu118" if use_gpu else "https://download.pytorch.org/whl/cpu" |
| subprocess.run([ |
| sys.executable, |
| "-m", |
| "pip", |
| "install", |
| "--index-url", |
| url, |
| wheel, |
| ], check=True) |
| os.execv(sys.executable, [sys.executable] + sys.argv) |
|
|
| def mcp_post(path: str, data=None): |
| if not MCP_SERVER_ADDR: |
| return None |
| url = MCP_SERVER_ADDR.rstrip("/") + path |
| resp = requests.post(url, json=data) |
| resp.raise_for_status() |
| if resp.headers.get("Content-Type", "").startswith("image/"): |
| return resp.content |
| return resp.json() |
|
|
| def mcp_get(path: str): |
| if not MCP_SERVER_ADDR: |
| return None |
| url = MCP_SERVER_ADDR.rstrip("/") + path |
| resp = requests.get(url) |
| resp.raise_for_status() |
| if resp.headers.get("Content-Type", "").startswith("image/"): |
| return resp.content |
| return resp.json() |
|
|
| class ModelManager: |
| """Manage model state and training utilities for the dashboard.""" |
|
|
| def __init__( |
| self, |
| snapshot_dir: Optional[str] = None, |
| telemetry_log: Optional[str] = None, |
| *, |
| drift_window: int = 10, |
| drift_threshold: float = 0.2, |
| ) -> None: |
| self.snapshot_dir = snapshot_dir or os.getenv("SNAPSHOT_DIR", "snapshots") |
| self.telemetry_log = telemetry_log or os.getenv("TELEMETRY_LOG") |
| if self.telemetry_log is None: |
| self.telemetry_log = os.path.join(self.snapshot_dir, "metrics.json") |
| os.makedirs(self.snapshot_dir, exist_ok=True) |
| self.weights_path = os.path.join(self.snapshot_dir, "model.pt") |
|
|
| self.model: Optional[BitTransformerLM] = None |
| self.optimizer: Optional[torch.optim.Optimizer] = None |
| self.scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None |
| self.total_steps = 100 |
| self.metrics: Dict[str, List[float]] = { |
| "negentropy_logits": [], |
| "lz_complexity_logits": [], |
| "symbiosis_score": [], |
| } |
| self.drift_window = drift_window |
| self.drift_threshold = drift_threshold |
| self.lambda_K = 1.0 |
| self.lambda_C = 1.0 |
| self.lambda_S = 1.0 |
| self.c_floor = 0.3 |
| self.s_floor = 0.5 |
| self.causal = True |
| self.diffusion = False |
| self.decompress_output = False |
| self.use_compression = False |
| self.use_gpu = False |
| self.qat = False |
|
|
| |
| if os.path.exists(self.telemetry_log): |
| try: |
| with open(self.telemetry_log) as f: |
| saved = json.load(f) |
| for key in self.metrics: |
| self.metrics[key] = saved.get(key, []) |
| except Exception: |
| pass |
| if os.path.exists(self.weights_path): |
| try: |
| self.model = torch.load(self.weights_path, map_location="cpu") |
| self.optimizer, self.scheduler = configure_optimizer( |
| self.model, lr=1e-3, total_steps=self.total_steps |
| ) |
| self._apply_device() |
| except Exception: |
| self.model = None |
|
|
| config_path = os.getenv("MODEL_CONFIG", "/config/model_params.json") |
| if self.model is None and os.path.exists(config_path): |
| try: |
| with open(config_path) as f: |
| params = json.load(f) |
| self.init_model(params) |
| except Exception: |
| pass |
|
|
| def init_model(self, params: Dict) -> None: |
| int_fields = { |
| "d_model", |
| "nhead", |
| "num_layers", |
| "dim_feedforward", |
| "max_seq_len", |
| "chunk_size", |
| "overlap", |
| } |
| float_fields = {"act_threshold"} |
| bool_fields = {"reversible", "use_checkpoint"} |
| clean: Dict[str, Any] = {} |
| for k, v in params.items(): |
| if v is None: |
| clean[k] = None |
| elif k in int_fields: |
| clean[k] = int(v) |
| elif k in float_fields: |
| clean[k] = float(v) |
| elif k in bool_fields: |
| clean[k] = bool(v) |
| else: |
| clean[k] = v |
| self.model = BitTransformerLM( |
| **clean, |
| lambda_K=self.lambda_K, |
| lambda_C=self.lambda_C, |
| lambda_S=self.lambda_S, |
| ) |
| self.optimizer, self.scheduler = configure_optimizer( |
| self.model, lr=1e-3, total_steps=self.total_steps |
| ) |
| self._apply_device() |
| for key in self.metrics: |
| self.metrics[key].clear() |
|
|
| def set_lambdas(self, k: float, c: float, s: float) -> None: |
| """Update λ weights and propagate to the model.""" |
| self.lambda_K = k |
| self.lambda_C = c |
| self.lambda_S = s |
| if self.model is not None: |
| self.model.set_lambdas(k, c, s) |
|
|
| def set_floors(self, c_floor: float, s_floor: float) -> None: |
| """Update safety floors for complexity (C) and symbiosis (S).""" |
| self.c_floor = c_floor |
| self.s_floor = s_floor |
|
|
| def set_diffusion(self, flag: bool) -> None: |
| """Toggle Diffusion LM mode which disables causal masking and chunking.""" |
| self.diffusion = flag |
| self.causal = not flag |
| if self.model is not None and flag: |
| self.model.chunk_size = None |
|
|
| def set_decompress_output(self, flag: bool) -> None: |
| """Enable or disable decompression of model outputs.""" |
| self.decompress_output = flag |
|
|
| def set_compression(self, flag: bool) -> None: |
| """Toggle automatic compression of inputs.""" |
| self.use_compression = flag |
|
|
| def set_qat(self, flag: bool) -> None: |
| """Enable or disable 4-bit quantization-aware training.""" |
| self.qat = flag |
| if self.model is None: |
| return |
| if flag: |
| self.model = prepare_qat_fx(self.model) |
| else: |
| self.model = convert_qat_fx(self.model) |
|
|
| def set_gpu(self, flag: bool) -> None: |
| """Toggle GPU acceleration and FSDP, reinstalling PyTorch if needed.""" |
| _switch_torch(flag) |
| self.use_gpu = flag and torch.cuda.is_available() |
| self._apply_device() |
|
|
| def _apply_device(self) -> None: |
| """Move the model to the selected device and wrap with FSDP if needed.""" |
| if self.model is None: |
| return |
| if self.use_gpu: |
| device = torch.device("cuda") |
| if isinstance(self.model, FullyShardedDataParallel): |
| base = self.model.module |
| else: |
| base = self.model |
| base = base.to(device) |
| self.model = wrap_fsdp(base, device_id=device) |
| else: |
| device = torch.device("cpu") |
| if isinstance(self.model, FullyShardedDataParallel): |
| self.model = self.model.module |
| self.model = self.model.to(device) |
|
|
| def train_step(self, bits: torch.Tensor) -> float: |
| assert ( |
| self.model is not None |
| and self.optimizer is not None |
| and self.scheduler is not None |
| ) |
| self.model.train() |
| device = next(self.model.parameters()).device |
| bits = bits.to(device) |
| ratio = 1.0 |
| if self.use_compression: |
| comps = [compress_bits(row.to(torch.uint8)) for row in bits] |
| comp_len = sum(c.numel() for c in comps) |
| ratio = min(comp_len / bits.numel(), 1.0) |
| logits, telemetry = self.model.forward_compressed(comps, causal=self.causal) |
| else: |
| logits, telemetry = self.model(bits, causal=self.causal) |
| pred = logits[:, :-1, :].reshape(-1, 2) |
| target = bits[:, 1:].reshape(-1) |
| loss = F.cross_entropy(pred, target) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) |
| self.optimizer.step() |
| self.scheduler.step() |
| self.optimizer.zero_grad() |
| self._log_metrics(telemetry) |
| self._save_state() |
| return loss.item(), ratio |
|
|
| def train_epochs( |
| self, |
| bits: torch.Tensor, |
| *, |
| epochs: int = 1, |
| compress_prob: float = 0.5, |
| direct_prob: float = 0.0, |
| batch_size: int = 8, |
| num_workers: int = 0, |
| accum_steps: int = 1, |
| amp: bool = False, |
| compile_model: bool = False, |
| ) -> List[Dict[str, float]]: |
| """Run ``train_loop`` on a batch tensor and persist the state.""" |
| assert self.model is not None |
| device = next(self.model.parameters()).device |
| bits = bits.to(device) |
| import math |
| steps_per_epoch = max(1, math.ceil(len(bits) / batch_size)) |
| self.total_steps = math.ceil(epochs * steps_per_epoch / accum_steps) |
| self.optimizer, self.scheduler = configure_optimizer( |
| self.model, lr=1e-3, total_steps=self.total_steps |
| ) |
| metrics = train_loop( |
| self.model, |
| bits, |
| epochs=epochs, |
| compress_prob=compress_prob if self.use_compression else 0.0, |
| direct_prob=direct_prob, |
| batch_size=batch_size, |
| num_workers=num_workers, |
| accum_steps=accum_steps, |
| amp=amp, |
| compile_model=compile_model, |
| forward_kwargs={"causal": self.causal}, |
| optimizer=self.optimizer, |
| scheduler=self.scheduler, |
| ) |
| self._save_state() |
| return metrics |
|
|
| def scale_up(self, width_mult: float = 1.0) -> None: |
| assert self.model is not None |
| params = dict( |
| d_model=int(self.model.d_model * width_mult), |
| nhead=self.model.layers[0].self_attn.num_heads, |
| num_layers=self.model.num_layers * 2, |
| dim_feedforward=int(self.model.layers[0].linear1.out_features * width_mult), |
| max_seq_len=self.model.pos_enc.pe.size(0), |
| ) |
| self.model = expand_model(self.model, { |
| **params, |
| "lambda_K": self.lambda_K, |
| "lambda_C": self.lambda_C, |
| "lambda_S": self.lambda_S, |
| }) |
| self.optimizer, self.scheduler = configure_optimizer( |
| self.model, lr=1e-3, total_steps=self.total_steps |
| ) |
| self._save_state() |
|
|
| def collapse(self, cluster_bits: List[List[int]], target_params: Dict, width_scale: float = 1.0) -> None: |
| self.model, _ = collapse_submodel( |
| cluster_bits, |
| target_params, |
| width_scale=width_scale, |
| forward_kwargs={"causal": self.causal}, |
| ) |
| self.model.set_lambdas(self.lambda_K, self.lambda_C, self.lambda_S) |
| self.optimizer, self.scheduler = configure_optimizer( |
| self.model, lr=1e-3, total_steps=self.total_steps |
| ) |
| self._apply_device() |
| for key in self.metrics: |
| self.metrics[key].clear() |
|
|
| def infer(self, bits: torch.Tensor) -> Dict: |
| assert self.model is not None |
| self.model.eval() |
| device = next(self.model.parameters()).device |
| bits = bits.to(device) |
| ratio = 1.0 |
| with torch.no_grad(): |
| if self.use_compression: |
| comps = [compress_bits(row.to(torch.uint8)) for row in bits] |
| comp_len = sum(c.numel() for c in comps) |
| ratio = min(comp_len / bits.numel(), 1.0) |
| logits, telemetry = self.model.forward_compressed(comps, causal=self.causal) |
| else: |
| logits, telemetry = self.model(bits, causal=self.causal) |
| self._log_metrics(telemetry) |
| pred_bits = logits.argmax(-1) |
| if self.decompress_output: |
| try: |
| pred_bits = model_output_decompress(pred_bits) |
| except Exception as e: |
| return {"error": f"Decompression failed: {e}", "suggestion": "Disable compression toggle."} |
| def _to_python(obj): |
| if isinstance(obj, torch.Tensor): |
| return obj.tolist() |
| if isinstance(obj, list): |
| return [_to_python(o) for o in obj] |
| if isinstance(obj, dict): |
| return {kk: _to_python(vv) for kk, vv in obj.items()} |
| return obj |
| tele = {k: _to_python(v) for k, v in telemetry.items()} |
| return {"predicted": pred_bits.squeeze(0).tolist(), "telemetry": tele, "ratio": ratio} |
|
|
| def infer_long(self, bits: torch.Tensor, ctx_bits: int = 4096, overlap: int = 256) -> Dict: |
| """Run sliding-window inference on a long sequence.""" |
| assert self.model is not None |
| device = next(self.model.parameters()).device |
| bits = bits.to(device) |
| preds, logs = infer_long_sequence(self.model, bits.squeeze(0), ctx_bits=ctx_bits, overlap=overlap) |
| for tele in logs: |
| self._log_metrics(tele) |
| return {"predicted": preds.tolist(), "windows": len(logs)} |
|
|
| def _log_metrics(self, telemetry: Dict) -> None: |
| for key in self.metrics: |
| val = telemetry[key].mean().item() |
| self.metrics[key].append(val) |
| drift = detect_metric_drift( |
| self.metrics, window=self.drift_window, threshold=self.drift_threshold |
| ) |
| bad = [k for k, v in drift.items() if v] |
| if bad: |
| warnings.warn( |
| f"Metric drift detected: {', '.join(bad)}", |
| MetricDriftWarning, |
| ) |
|
|
| def infer_text(self, text: str) -> Dict[str, Any]: |
| """Run text through the model using the safety gate.""" |
| assert self.model is not None |
| device = next(self.model.parameters()).device |
| bits = torch.tensor(text_to_bits(text), dtype=torch.long).unsqueeze(0).to(device) |
| out_bits, telemetry = hil_safe_inference( |
| self.model, bits, c_floor=self.c_floor, s_floor=self.s_floor |
| ) |
| self._log_metrics(telemetry) |
| return { |
| "output": bits_to_text(out_bits.squeeze(0).tolist()), |
| "telemetry": telemetry, |
| } |
|
|
| def get_status(self) -> Dict[str, Any]: |
| info: Dict[str, Any] = { |
| "use_gpu": self.use_gpu, |
| "diffusion": self.diffusion, |
| "compression": self.use_compression, |
| "lambda_K": self.lambda_K, |
| "lambda_C": self.lambda_C, |
| "lambda_S": self.lambda_S, |
| "c_floor": self.c_floor, |
| "s_floor": self.s_floor, |
| "qat": self.qat, |
| } |
| if self.model is not None: |
| info.update( |
| { |
| "d_model": self.model.d_model, |
| "num_layers": self.model.num_layers, |
| "d_ff": self.model.layers[0].linear1.out_features, |
| "nhead": self.model.layers[0].self_attn.num_heads, |
| "max_seq_len": self.model.pos_enc.pe.size(0), |
| } |
| ) |
| else: |
| info.update( |
| { |
| "d_model": None, |
| "num_layers": 0, |
| "d_ff": None, |
| "nhead": None, |
| "max_seq_len": None, |
| } |
| ) |
| return info |
|
|
| def get_model_config(self) -> Dict[str, Any]: |
| """Return current model hyperparameters and safety settings.""" |
| cfg: Dict[str, Any] = { |
| "lambda_K": self.lambda_K, |
| "lambda_C": self.lambda_C, |
| "lambda_S": self.lambda_S, |
| "c_floor": self.c_floor, |
| "s_floor": self.s_floor, |
| } |
| if self.model is not None: |
| cfg.update( |
| { |
| "d_model": self.model.d_model, |
| "nhead": self.model.layers[0].self_attn.num_heads, |
| "num_layers": self.model.num_layers, |
| "dim_feedforward": self.model.layers[0].linear1.out_features, |
| "max_seq_len": self.model.pos_enc.pe.size(0), |
| "chunk_size": self.model.chunk_size, |
| "reversible": self.model.reversible, |
| "use_checkpoint": self.model.use_checkpoint, |
| } |
| ) |
| else: |
| cfg.update( |
| { |
| "d_model": None, |
| "nhead": None, |
| "num_layers": 0, |
| "dim_feedforward": None, |
| "max_seq_len": None, |
| "chunk_size": None, |
| "reversible": None, |
| "use_checkpoint": None, |
| } |
| ) |
| return cfg |
|
|
| def get_metrics(self) -> Dict[str, Any]: |
| """Return logged telemetry metrics with summary statistics.""" |
| from statistics import mean, stdev |
|
|
| data = { |
| "negentropy": self.metrics["negentropy_logits"], |
| "lz_complexity": self.metrics["lz_complexity_logits"], |
| "symbiosis": self.metrics["symbiosis_score"], |
| } |
| summary: Dict[str, Dict[str, Optional[float]]] = {} |
| for key, values in data.items(): |
| if values: |
| m = mean(values) |
| s = stdev(values) if len(values) > 1 else 0.0 |
| summary[key] = {"mean": m, "std": s} |
| else: |
| summary[key] = {"mean": None, "std": None} |
| data["summary"] = summary |
| return data |
|
|
|
|
| def _save_state(self) -> None: |
| if self.model is None: |
| return |
| torch.save(self.model, self.weights_path) |
| with open(self.telemetry_log, "w") as f: |
| json.dump(self.metrics, f) |
|
|
|
|
| manager: Optional[ModelManager] = None |
|
|
|
|
| @app.route("/") |
| def index(): |
| return render_template( |
| "dashboard.html", |
| metrics=manager.metrics, |
| lambdas={ |
| "lambda_K": manager.lambda_K, |
| "lambda_C": manager.lambda_C, |
| "lambda_S": manager.lambda_S, |
| }, |
| diffusion=manager.diffusion, |
| compression=manager.use_compression, |
| defaults={k: v.default for k, v in inspect.signature(BitTransformerLM.__init__).parameters.items() if v.default is not inspect._empty}, |
| c_floor=manager.c_floor, |
| s_floor=manager.s_floor, |
| qat=manager.qat, |
| ) |
|
|
|
|
| @app.route("/status", methods=["GET"]) |
| def status(): |
| if MCP_SERVER_ADDR: |
| return jsonify(mcp_get("/status")) |
| return jsonify(manager.get_status()) |
|
|
|
|
| @app.route("/model_config", methods=["GET"]) |
| def model_config(): |
| if MCP_SERVER_ADDR: |
| return jsonify(mcp_get("/model_config")) |
| return jsonify(manager.get_model_config()) |
|
|
|
|
| @app.route("/metrics", methods=["GET"]) |
| def metrics(): |
| if MCP_SERVER_ADDR: |
| return jsonify(mcp_get("/metrics")) |
| return jsonify(manager.get_metrics()) |
|
|
|
|
| @app.route("/save_checkpoint", methods=["POST"]) |
| def save_checkpoint_route(): |
| repo_id = request.json.get("repo_id") |
| token = request.json.get("token") or os.getenv("HF_TOKEN") |
| if MCP_SERVER_ADDR: |
| return jsonify(mcp_post("/save_checkpoint", {"repo_id": repo_id, "token": token})) |
| if manager.model is None: |
| return jsonify({"error": "model not initialized"}), 400 |
| if token: |
| hf_login(token=token) |
| save_checkpoint(manager.model, repo_id=repo_id) |
| return jsonify({"status": "saved"}) |
|
|
|
|
| @app.route("/download_checkpoint", methods=["POST"]) |
| def download_checkpoint_route(): |
| repo_id = request.json.get("repo_id") |
| token = request.json.get("token") or os.getenv("HF_TOKEN") |
| if MCP_SERVER_ADDR: |
| return jsonify(mcp_post("/download_checkpoint", {"repo_id": repo_id, "token": token})) |
| if token: |
| hf_login(token=token) |
| dest = manager.weights_path + ".gz" |
| ok = download_checkpoint(dest, repo_id=repo_id) |
| if not ok: |
| return jsonify({"status": "failed"}), 500 |
| if manager.model is None: |
| return jsonify({"status": "downloaded", "loaded": False}) |
| with gzip.open(dest, "rb") as f: |
| state = torch.load(f, map_location="cpu") |
| manager.model.load_state_dict(state) |
| manager.optimizer, manager.scheduler = configure_optimizer( |
| manager.model, lr=1e-3, total_steps=manager.total_steps |
| ) |
| manager._apply_device() |
| manager._save_state() |
| return jsonify({"status": "downloaded", "loaded": True}) |
|
|
|
|
| @app.route("/text_to_bits", methods=["POST"]) |
| def text_to_bits_route(): |
| text = request.json.get("text", "") |
| if len(text) > 100_000: |
| return jsonify({"error": "text too large"}), 413 |
| return jsonify({"bits": text_to_bits(text)}) |
|
|
|
|
| @app.route("/dataset", methods=["GET"]) |
| def dataset_route(): |
| name = request.args.get("name", "") |
| split = request.args.get("split", "train") |
| size = int(request.args.get("size", 1)) |
| seq_len = int(request.args.get("seq_len", 64)) |
| if size * seq_len > 1_000_000: |
| return jsonify({"error": "dataset too large"}), 413 |
| if name == "wikitext2": |
| try: |
| from datasets import load_dataset |
|
|
| ds = load_dataset("wikitext", "wikitext-2-raw-v1", split=split) |
| lines = [t for t in ds["text"] if t.strip()][:size] |
| except Exception: |
| bits = torch.randint(0, 2, (size, seq_len), dtype=torch.long) |
| return jsonify({"bits": bits.tolist()}) |
| bits_list = [] |
| for text in lines: |
| b = text_to_bits(text)[:seq_len] |
| if len(b) < seq_len: |
| b.extend([0] * (seq_len - len(b))) |
| bits_list.append(b) |
| if len(bits_list) < size: |
| pad = size - len(bits_list) |
| bits_list.extend(torch.randint(0, 2, (pad, seq_len), dtype=torch.long).tolist()) |
| return jsonify({"bits": bits_list}) |
| return jsonify({"error": "unknown dataset"}), 400 |
|
|
|
|
| @app.route("/init", methods=["POST"]) |
| def init_model(): |
| data = request.json or {} |
| int_fields = { |
| "d_model", |
| "nhead", |
| "num_layers", |
| "dim_feedforward", |
| "max_seq_len", |
| "chunk_size", |
| "overlap", |
| } |
| float_fields = {"act_threshold"} |
| bool_fields = {"reversible", "use_checkpoint"} |
| params = {} |
| for k, v in data.items(): |
| if v is None: |
| params[k] = None |
| elif k in int_fields: |
| params[k] = int(v) |
| elif k in float_fields: |
| params[k] = float(v) |
| elif k in bool_fields: |
| params[k] = bool(v) |
| else: |
| params[k] = v |
| if MCP_SERVER_ADDR: |
| data = mcp_post("/init", params) |
| return jsonify(data) |
| manager.init_model(params) |
| return jsonify({"status": "initialized", "params": params}) |
|
|
|
|
| @app.route("/train", methods=["POST"]) |
| def train_model(): |
| bits = torch.tensor(request.json["bits"], dtype=torch.long) |
| if MCP_SERVER_ADDR: |
| data = mcp_post("/train", {"bits": request.json["bits"]}) |
| return jsonify(data) |
| loss, ratio = manager.train_step(bits) |
| return jsonify({"loss": loss, "ratio": ratio}) |
|
|
|
|
| @app.route("/train_epochs", methods=["POST"]) |
| def train_epochs_route(): |
| bits = torch.tensor(request.json["bits"], dtype=torch.long) |
| epochs = int(request.json.get("epochs", 1)) |
| compress_prob = float(request.json.get("compress_prob", 0.5)) |
| direct_prob = float(request.json.get("direct_prob", 0.0)) |
| if MCP_SERVER_ADDR: |
| data = mcp_post( |
| "/train_epochs", |
| { |
| "bits": request.json["bits"], |
| "epochs": epochs, |
| "compress_prob": compress_prob, |
| "direct_prob": direct_prob, |
| }, |
| ) |
| return jsonify(data) |
| metrics = manager.train_epochs( |
| bits, |
| epochs=epochs, |
| compress_prob=compress_prob, |
| direct_prob=direct_prob, |
| ) |
| return jsonify({"metrics": metrics}) |
|
|
|
|
| @app.route("/scale_up", methods=["POST"]) |
| def scale_up(): |
| width_mult = float(request.json.get("width_mult", 1.0)) |
| if MCP_SERVER_ADDR: |
| data = mcp_post("/scale_up", {"width_mult": width_mult}) |
| return jsonify(data) |
| manager.scale_up(width_mult) |
| return jsonify({ |
| "status": "scaled", |
| "layers": manager.model.num_layers, |
| "d_model": manager.model.d_model, |
| }) |
|
|
|
|
| @app.route("/collapse", methods=["POST"]) |
| def collapse_model(): |
| cluster_bits = request.json["clusters"] |
| params = {k: int(v) for k, v in request.json["params"].items()} |
| width_scale = float(request.json.get("width_scale", 1.0)) |
| if MCP_SERVER_ADDR: |
| data = mcp_post( |
| "/collapse", |
| {"clusters": cluster_bits, "params": params, "width_scale": width_scale}, |
| ) |
| return jsonify(data) |
| manager.collapse(cluster_bits, params, width_scale) |
| return jsonify({"status": "collapsed"}) |
|
|
|
|
| @app.route("/lambdas", methods=["GET", "POST"]) |
| def update_lambdas(): |
| if request.method == "POST": |
| data = request.json |
| if MCP_SERVER_ADDR: |
| res = mcp_post("/lambdas", data) |
| return jsonify(res) |
| manager.set_lambdas( |
| float(data["lambda_K"]), float(data["lambda_C"]), float(data["lambda_S"]) |
| ) |
| return jsonify({"status": "updated"}) |
| else: |
| if MCP_SERVER_ADDR: |
| return jsonify(mcp_get("/lambdas")) |
| return jsonify( |
| { |
| "lambda_K": manager.lambda_K, |
| "lambda_C": manager.lambda_C, |
| "lambda_S": manager.lambda_S, |
| } |
| ) |
|
|
|
|
| @app.route("/config/telemetry", methods=["GET", "POST"]) |
| def telemetry_config(): |
| """Get or update telemetry λ weights and safety floors.""" |
| if request.method == "POST": |
| data = request.json |
| if MCP_SERVER_ADDR: |
| res = mcp_post("/config/telemetry", data) |
| return jsonify(res) |
| manager.set_lambdas( |
| float(data.get("lambda_K", manager.lambda_K)), |
| float(data.get("lambda_C", manager.lambda_C)), |
| float(data.get("lambda_S", manager.lambda_S)), |
| ) |
| manager.set_floors( |
| float(data.get("c_floor", manager.c_floor)), |
| float(data.get("s_floor", manager.s_floor)), |
| ) |
| return jsonify({"status": "updated"}) |
| else: |
| if MCP_SERVER_ADDR: |
| return jsonify(mcp_get("/config/telemetry")) |
| return jsonify( |
| { |
| "lambda_K": manager.lambda_K, |
| "lambda_C": manager.lambda_C, |
| "lambda_S": manager.lambda_S, |
| "c_floor": manager.c_floor, |
| "s_floor": manager.s_floor, |
| } |
| ) |
|
|
|
|
| @app.route("/diffusion", methods=["GET", "POST"]) |
| def update_diffusion(): |
| if request.method == "POST": |
| if MCP_SERVER_ADDR: |
| return jsonify(mcp_post("/diffusion", request.json)) |
| manager.set_diffusion(bool(request.json.get("diffusion", False))) |
| return jsonify({"status": "updated"}) |
| else: |
| if MCP_SERVER_ADDR: |
| return jsonify(mcp_get("/diffusion")) |
| return jsonify({"diffusion": manager.diffusion}) |
|
|
|
|
| @app.route("/gpu", methods=["GET", "POST"]) |
| def update_gpu(): |
| if request.method == "POST": |
| if MCP_SERVER_ADDR: |
| return jsonify(mcp_post("/gpu", request.json)) |
| manager.set_gpu(bool(request.json.get("use_gpu", False))) |
| return jsonify({"status": "updated"}) |
| else: |
| if MCP_SERVER_ADDR: |
| return jsonify(mcp_get("/gpu")) |
| return jsonify({"use_gpu": manager.use_gpu}) |
|
|
|
|
| @app.route("/compression", methods=["GET", "POST"]) |
| def update_compression(): |
| if request.method == "POST": |
| if MCP_SERVER_ADDR: |
| return jsonify(mcp_post("/compression", request.json)) |
| manager.set_compression(bool(request.json.get("compression", False))) |
| return jsonify({"status": "updated"}) |
| else: |
| if MCP_SERVER_ADDR: |
| return jsonify(mcp_get("/compression")) |
| return jsonify({"compression": manager.use_compression}) |
|
|
|
|
| @app.route("/qat", methods=["GET", "POST"]) |
| def update_qat(): |
| if request.method == "POST": |
| if MCP_SERVER_ADDR: |
| return jsonify(mcp_post("/qat", request.json)) |
| manager.set_qat(bool(request.json.get("qat", False))) |
| return jsonify({"status": "updated"}) |
| else: |
| if MCP_SERVER_ADDR: |
| return jsonify(mcp_get("/qat")) |
| return jsonify({"qat": manager.qat}) |
|
|
|
|
| @app.route("/infer", methods=["POST"]) |
| def inference(): |
| bits = torch.tensor(request.json["bits"], dtype=torch.long) |
| if MCP_SERVER_ADDR: |
| data = mcp_post("/infer", {"bits": request.json["bits"]}) |
| return jsonify(data) |
| result = manager.infer(bits) |
| return jsonify(result) |
|
|
|
|
| @app.route("/infer_long", methods=["POST"]) |
| def inference_long(): |
| bits = torch.tensor(request.json["bits"], dtype=torch.long) |
| ctx = int(request.json.get("ctx_bits", 4096)) |
| overlap = int(request.json.get("overlap", 256)) |
| if MCP_SERVER_ADDR: |
| data = mcp_post( |
| "/infer_long", |
| {"bits": request.json["bits"], "ctx_bits": ctx, "overlap": overlap}, |
| ) |
| return jsonify(data) |
| result = manager.infer_long(bits, ctx_bits=ctx, overlap=overlap) |
| return jsonify(result) |
|
|
|
|
| @app.route("/infer_text", methods=["POST"]) |
| def inference_text(): |
| text = request.json.get("text", "") |
| if MCP_SERVER_ADDR: |
| data = mcp_post("/infer_text", {"text": text}) |
| return jsonify(data) |
| result = manager.infer_text(text) |
| return jsonify(result) |
|
|
| @app.route("/plot.png") |
| def plot_png(): |
| if MCP_SERVER_ADDR: |
| resp = requests.get(MCP_SERVER_ADDR.rstrip("/") + "/plot.png") |
| resp.raise_for_status() |
| return send_file(io.BytesIO(resp.content), mimetype="image/png") |
| fig, _ = plot_telemetry(manager.metrics) |
| buf = io.BytesIO() |
| fig.savefig(buf, format="png") |
| plt.close(fig) |
| buf.seek(0) |
| return send_file(buf, mimetype="image/png") |
|
|
|
|
| def run_dashboard(host: Optional[str] = None, port: Optional[int] = None, |
| snapshot_dir: Optional[str] = None, telemetry_log: Optional[str] = None) -> None: |
| """Launch the Flask dashboard server.""" |
| env_host = os.getenv("HOST", "0.0.0.0") |
| env_port = int(os.getenv("PORT", "5000")) |
| host = host or env_host |
| port = port or env_port |
| global manager |
| if manager is None: |
| manager = ModelManager(snapshot_dir, telemetry_log) |
| app.run(host=host, port=port, debug=True) |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
|
|
| parser = argparse.ArgumentParser(description="Run dashboard server") |
| parser.add_argument("--host", default=os.getenv("HOST", "0.0.0.0")) |
| parser.add_argument("--port", type=int, default=int(os.getenv("PORT", "5000"))) |
| parser.add_argument("--snapshot-dir", default=os.getenv("SNAPSHOT_DIR", "snapshots")) |
| parser.add_argument("--telemetry-log", default=os.getenv("TELEMETRY_LOG")) |
| args = parser.parse_args() |
| run_dashboard(args.host, args.port, args.snapshot_dir, args.telemetry_log) |
|
|