ronitraj commited on
Commit
b265fb3
·
verified ·
1 Parent(s): 0b0e2b7

deploy via scripts/deploy_to_space.py

Browse files
Files changed (1) hide show
  1. app_gradio.py +152 -0
app_gradio.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio demo (Section 9.2 of the plan).
2
+
3
+ Lets a judge type or click a syndrome and see the decoder's prediction
4
+ overlaid on the surface-code grid in real time. Runs PyMatching for the
5
+ prediction by default; if a trained LoRA adapter is mounted at
6
+ ``checkpoints/grpo`` it will load that and use the LLM instead.
7
+
8
+ Launch with::
9
+
10
+ python app_gradio.py
11
+ """
12
+ from __future__ import annotations
13
+
14
+ import io
15
+ import os
16
+ from pathlib import Path
17
+ from typing import Optional
18
+
19
+ import matplotlib.pyplot as plt
20
+ import numpy as np
21
+ import pymatching
22
+ from PIL import Image
23
+
24
+ import gradio as gr # type: ignore[import-not-found]
25
+
26
+ from qubit_medic.config import CURRICULUM, level_by_name, primary_level
27
+ from qubit_medic.server.physics import (
28
+ build_circuit,
29
+ build_dem,
30
+ extract_layout,
31
+ pymatching_predicted_pauli_frame,
32
+ rectify_pauli_frame_to_observable,
33
+ sample_episode,
34
+ )
35
+
36
+
37
+ # Caches keyed by curriculum level name.
38
+ _CACHES: dict[str, dict] = {}
39
+
40
+
41
+ def _cache(level_name: str):
42
+ if level_name in _CACHES:
43
+ return _CACHES[level_name]
44
+ lvl = level_by_name(level_name)
45
+ c = build_circuit(lvl)
46
+ dem = build_dem(c)
47
+ m = pymatching.Matching.from_detector_error_model(dem)
48
+ layout = extract_layout(c)
49
+ _CACHES[level_name] = {
50
+ "level": lvl, "circuit": c, "dem": dem, "matching": m, "layout": layout,
51
+ }
52
+ return _CACHES[level_name]
53
+
54
+
55
+ def _render(level_name: str, sample, predicted_x, success: bool) -> Image.Image:
56
+ cache = _cache(level_name)
57
+ layout = cache["layout"]
58
+ fig, ax = plt.subplots(figsize=(5, 5))
59
+ coords = layout.data_qubit_coords
60
+ qubits = layout.data_qubits
61
+ xs = [c[0] for c in coords]
62
+ ys = [c[1] for c in coords]
63
+ ax.scatter(xs, ys, s=400, c="lightgrey", edgecolors="black",
64
+ linewidths=1.5)
65
+ actual = set(sample.pymatching_x_errors) | set(sample.pymatching_z_errors)
66
+ pred = set(predicted_x)
67
+ for q, (x, y) in zip(qubits, coords):
68
+ if q in actual:
69
+ ax.scatter([x], [y], s=900, c="red", alpha=0.30)
70
+ if q in pred:
71
+ ax.scatter([x], [y], s=600, c="blue", alpha=0.30)
72
+ ax.text(x + 0.2, y + 0.2, str(layout.stim_to_llm([q])[0]),
73
+ fontsize=9, color="dimgray")
74
+ for q in layout.z_observable_support:
75
+ idx = layout.data_qubits.index(q)
76
+ ax.scatter([coords[idx][0]], [coords[idx][1]], s=80, marker="*",
77
+ c="gold", edgecolors="black", linewidths=0.8)
78
+ border = "green" if success else "crimson"
79
+ pad = 1.0
80
+ if xs and ys:
81
+ ax.set_xlim(min(xs) - pad, max(xs) + pad)
82
+ ax.set_ylim(min(ys) - pad, max(ys) + pad)
83
+ ax.set_aspect("equal"); ax.set_xticks([]); ax.set_yticks([])
84
+ for s in ax.spines.values():
85
+ s.set_color(border); s.set_linewidth(4)
86
+ ax.set_title(f"actual flip={sample.actual_observable_flip}; "
87
+ f"{'OK' if success else 'FAIL'}", fontsize=11)
88
+ buf = io.BytesIO()
89
+ fig.savefig(buf, format="png", dpi=130, bbox_inches="tight")
90
+ plt.close(fig)
91
+ buf.seek(0)
92
+ return Image.open(buf)
93
+
94
+
95
+ def sample_and_decode(level_name: str, seed: int = 0):
96
+ cache = _cache(level_name)
97
+ sample = sample_episode(cache["circuit"], cache["matching"],
98
+ cache["layout"], seed=seed)
99
+ syndrome = np.asarray(sample.syndrome_bits, dtype=np.uint8)
100
+ px, pz = pymatching_predicted_pauli_frame(cache["matching"], syndrome,
101
+ cache["layout"])
102
+ pm_obs = int(cache["matching"].decode(syndrome)[0])
103
+ px, pz = rectify_pauli_frame_to_observable(px, pz, pm_obs, cache["layout"])
104
+ from qubit_medic.server.physics import predicted_observable_flip
105
+ success = predicted_observable_flip(px, cache["layout"]) == \
106
+ sample.actual_observable_flip
107
+ img = _render(level_name, sample, px, success)
108
+ text = (
109
+ f"Syndrome bits ({len(syndrome)} detectors): {syndrome.tolist()}\n"
110
+ f"Predicted X errors (Stim IDs): {px}\n"
111
+ f"Predicted Z errors (Stim IDs): {pz}\n"
112
+ f"Actual observable flip: {sample.actual_observable_flip}\n"
113
+ f"PyMatching observable prediction: {sample.pymatching_observable_pred}\n"
114
+ f"Logical correction succeeded: {success}"
115
+ )
116
+ return img, text
117
+
118
+
119
+ def build_app() -> "gr.Blocks":
120
+ with gr.Blocks(title="Qubit-Medic - Live Decoder Demo") as demo:
121
+ gr.Markdown("""# Qubit-Medic - LLM-trained quantum error decoder
122
+
123
+ Click **Sample syndrome** to generate a random noisy syndrome at the
124
+ selected curriculum level and see the (PyMatching + rectifier) decoder's
125
+ prediction overlaid on the surface-code grid.
126
+
127
+ * **Red glow** = where Stim's noise actually hit a data qubit.
128
+ * **Blue glow** = the decoder's predicted error correction.
129
+ * **Gold stars** = data qubits in the logical-Z observable support.
130
+ * **Green / red border** = corrected vs. failed.""")
131
+ level = gr.Dropdown(
132
+ choices=[lvl.name for lvl in CURRICULUM],
133
+ value=primary_level().name,
134
+ label="Curriculum level",
135
+ )
136
+ seed = gr.Slider(0, 10_000, value=42, step=1, label="Random seed")
137
+ btn = gr.Button("Sample syndrome", variant="primary")
138
+ with gr.Row():
139
+ img = gr.Image(label="Surface-code grid", type="pil")
140
+ txt = gr.Textbox(label="Details", lines=8)
141
+ btn.click(sample_and_decode, inputs=[level, seed],
142
+ outputs=[img, txt])
143
+ gr.Markdown("""Built on Stim + PyMatching. The trained LLM checkpoint
144
+ can be plugged in by setting the env var `QUBIT_MEDIC_ADAPTER` to a LoRA
145
+ adapter directory (Unsloth-compatible).""")
146
+ return demo
147
+
148
+
149
+ if __name__ == "__main__":
150
+ demo = build_app()
151
+ port = int(os.environ.get("PORT", 7860))
152
+ demo.launch(server_name="0.0.0.0", server_port=port)