GirishaBuilds01 commited on
Commit
716e4bc
·
verified ·
1 Parent(s): c89819d

Create visuals.py

Browse files
Files changed (1) hide show
  1. visuals.py +136 -0
visuals.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # visuals.py
2
+ import torch
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import gradio as gr
6
+ from PIL import Image
7
+
8
+ from uei_core.models import ModelPortfolio
9
+ from uei_core.uncertainty import UncertaintyEstimator
10
+ from uei_core.energy import EnergyProfiler
11
+
12
+ device = "cpu"
13
+
14
+ models = ModelPortfolio(device=device)
15
+ unc = UncertaintyEstimator()
16
+ energy = EnergyProfiler()
17
+
18
+
19
+ # ------------------------------
20
+ # 1️⃣ Plot: Uncertainty vs Energy Curve
21
+ # ------------------------------
22
+ def plot_unc_energy(img):
23
+ x = models.preprocess(img)
24
+
25
+ logits_s, E_s = energy.measure(models.infer_small, x)
26
+ logits_l, E_l = energy.measure(models.infer_large, x)
27
+
28
+ U_s = float(unc.estimate(logits_s))
29
+ U_l = float(unc.estimate(logits_l))
30
+
31
+ # Create plot
32
+ fig, ax = plt.subplots(figsize=(5,4), dpi=120)
33
+
34
+ xs = [E_s, E_l]
35
+ ys = [U_s, U_l]
36
+ labels = ["Small Model", "Large Model"]
37
+ colors = ["#1f77b4", "#ff7f0e"]
38
+
39
+ ax.scatter(xs, ys, s=150, color=colors)
40
+ ax.plot(xs, ys, linestyle="--", color="#888")
41
+
42
+ for i, label in enumerate(labels):
43
+ ax.annotate(label, (xs[i], ys[i]), textcoords="offset points",
44
+ xytext=(8,5), ha='left', fontsize=10)
45
+
46
+ ax.set_xlabel("Energy (proxy units)")
47
+ ax.set_ylabel("Estimated Uncertainty")
48
+ ax.set_title("Uncertainty vs Energy")
49
+ ax.grid(True, alpha=0.3)
50
+
51
+ return fig
52
+
53
+
54
+ # ------------------------------
55
+ # 2️⃣ Plot: Layer Activation Heatmap
56
+ # ------------------------------
57
+ def activation_heatmap(img):
58
+ x = models.preprocess(img)
59
+
60
+ # Register forward hook on the first conv
61
+ activations = {}
62
+
63
+ def hook(module, input, output):
64
+ activations["feat"] = output.detach().cpu()
65
+
66
+ h = models.small.features[0].register_forward_hook(hook)
67
+ models.small(x)
68
+ h.remove()
69
+
70
+ feat = activations["feat"][0] # first batch
71
+
72
+ # Average channels → 2D heatmap
73
+ heat = feat.mean(dim=0).numpy()
74
+
75
+ fig, ax = plt.subplots(figsize=(4,4), dpi=120)
76
+ ax.imshow(heat, cmap="viridis")
77
+ ax.set_title("Early Layer Activation Heatmap")
78
+ ax.axis("off")
79
+
80
+ return fig
81
+
82
+
83
+ # ------------------------------
84
+ # 3️⃣ Plot: Model Comparison Bars
85
+ # ------------------------------
86
+ def model_comparison(img):
87
+ x = models.preprocess(img)
88
+
89
+ logits_s, E_s = energy.measure(models.infer_small, x)
90
+ logits_l, E_l = energy.measure(models.infer_large, x)
91
+
92
+ U_s = float(unc.estimate(logits_s))
93
+ U_l = float(unc.estimate(logits_l))
94
+
95
+ fig, ax = plt.subplots(figsize=(6,4))
96
+
97
+ labels = ["Small Model", "Large Model"]
98
+ energy_vals = [E_s, E_l]
99
+ unc_vals = [U_s, U_l]
100
+
101
+ x_axis = np.arange(len(labels))
102
+ w = 0.35
103
+
104
+ ax.bar(x_axis - w/2, energy_vals, w, label="Energy", color="#2ca02c")
105
+ ax.bar(x_axis + w/2, unc_vals, w, label="Uncertainty", color="#d62728")
106
+
107
+ ax.set_xticks(x_axis)
108
+ ax.set_xticklabels(labels)
109
+ ax.set_title("Model Energy & Uncertainty Comparison")
110
+ ax.legend()
111
+ ax.grid(alpha=0.2)
112
+
113
+ return fig
114
+
115
+
116
+ # ------------------------------
117
+ # 🔥 Gradio Interface
118
+ # ------------------------------
119
+ def get_visual_ui():
120
+ with gr.Blocks() as demo:
121
+ gr.Markdown("## 🔍 UEI Visualization Dashboard")
122
+ gr.Markdown("Explore how UEI behaves internally with colorful charts")
123
+
124
+ img = gr.Image(type="pil", label="Upload Image")
125
+
126
+ with gr.Tabs():
127
+ with gr.Tab("Uncertainty vs Energy"):
128
+ gr.Plot(label="Chart").render(fn=plot_unc_energy, inputs=img)
129
+
130
+ with gr.Tab("Layer Activations"):
131
+ gr.Plot(label="Activation Heatmap").render(fn=activation_heatmap, inputs=img)
132
+
133
+ with gr.Tab("Model Comparison"):
134
+ gr.Plot(label="Energy & Uncertainty Bars").render(fn=model_comparison, inputs=img)
135
+
136
+ return demo