Alex W. commited on
Commit
f02f9b7
·
1 Parent(s): 9181528

add debug info for gemma-4-31b-it

Browse files
Files changed (1) hide show
  1. core/metrics.py +23 -0
core/metrics.py CHANGED
@@ -105,6 +105,29 @@ def analyze_layer(
105
  f" {'α_QK':>7} {'α_QV':>7} {'α_KV':>7}\n"
106
  )
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  for kv_h in range(n_kv):
109
  k_t = W_k[kv_h * d_head:(kv_h + 1) * d_head, :]
110
  v_t = W_v[kv_h * d_head:(kv_h + 1) * d_head, :]
 
105
  f" {'α_QK':>7} {'α_QV':>7} {'α_KV':>7}\n"
106
  )
107
 
108
+ # 打印 W_k 每个 d_head 块的 L2 norm,看能量分布
109
+ lines.append(f"[DEBUG] W_k 各头能量(行块 L2 norm):\n")
110
+ for i in range(n_kv):
111
+ block = W_k[i * d_head:(i + 1) * d_head, :]
112
+ norm = float(block.norm())
113
+ # 同时打印该块的最大奇异值
114
+ s_tmp = torch.linalg.svd(block, full_matrices=False)[1]
115
+ lines.append(
116
+ f" KV头{i:2d}: block_norm={norm:.2f} "
117
+ f"sigma_max={float(s_tmp[0]):.4f}\n"
118
+ )
119
+
120
+ lines.append(f"[DEBUG] W_q 各头能量:\n")
121
+ for i in range(n_q):
122
+ block = W_q[i * d_head:(i + 1) * d_head, :]
123
+ norm = float(block.norm())
124
+ s_tmp = torch.linalg.svd(block, full_matrices=False)[1]
125
+ lines.append(
126
+ f" Q头{i:2d}: block_norm={norm:.2f} "
127
+ f"sigma_max={float(s_tmp[0]):.4f}\n"
128
+ )
129
+
130
+
131
  for kv_h in range(n_kv):
132
  k_t = W_k[kv_h * d_head:(kv_h + 1) * d_head, :]
133
  v_t = W_v[kv_h * d_head:(kv_h + 1) * d_head, :]