Add core DKM layer implementation
Browse files- dkm/dkm_layer.py +349 -0
dkm/dkm_layer.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DKM Layer: Differentiable K-Means Clustering Layer
|
| 3 |
+
|
| 4 |
+
Implements the core algorithm from Section 3.2 and 3.3 of the paper:
|
| 5 |
+
1. Compute distance matrix D between weights W and centroids C
|
| 6 |
+
2. Apply softmax with temperature τ to get attention matrix A
|
| 7 |
+
3. Update centroids: c_j = Σ_i(a_ij * w_i) / Σ_i(a_ij)
|
| 8 |
+
4. Iterate until convergence or max iterations
|
| 9 |
+
5. Compute compressed weights: W_tilde = A @ C
|
| 10 |
+
|
| 11 |
+
Supports both 1D and multi-dimensional clustering (Section 3.3).
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
import math
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class DKMLayer(nn.Module):
|
| 21 |
+
"""
|
| 22 |
+
Differentiable K-Means Clustering Layer.
|
| 23 |
+
|
| 24 |
+
This layer performs differentiable weight clustering by casting k-means
|
| 25 |
+
as an attention problem. During training, soft assignment via attention
|
| 26 |
+
allows gradients to flow through the clustering process. During inference,
|
| 27 |
+
weights are snapped to nearest centroids (hard assignment).
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
weight_tensor: The weight parameter to cluster (nn.Parameter)
|
| 31 |
+
n_clusters: Number of cluster centroids (k = 2^bits)
|
| 32 |
+
tau: Temperature for softmax attention (controls hardness of assignment)
|
| 33 |
+
dim: Dimension for multi-dimensional clustering (default=1 for scalar)
|
| 34 |
+
max_iter: Maximum number of DKM iterations per forward pass
|
| 35 |
+
epsilon: Convergence threshold for centroid updates
|
| 36 |
+
init_method: Centroid initialization method ('kmeans++' or 'random')
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
weight_tensor: nn.Parameter,
|
| 42 |
+
n_clusters: int = 16,
|
| 43 |
+
tau: float = 2e-5,
|
| 44 |
+
dim: int = 1,
|
| 45 |
+
max_iter: int = 5,
|
| 46 |
+
epsilon: float = 1e-4,
|
| 47 |
+
init_method: str = "kmeans++",
|
| 48 |
+
):
|
| 49 |
+
super().__init__()
|
| 50 |
+
|
| 51 |
+
self.n_clusters = n_clusters
|
| 52 |
+
self.tau = tau
|
| 53 |
+
self.dim = dim
|
| 54 |
+
self.max_iter = max_iter
|
| 55 |
+
self.epsilon = epsilon
|
| 56 |
+
self.init_method = init_method
|
| 57 |
+
|
| 58 |
+
# Store reference to the weight parameter (not a copy)
|
| 59 |
+
self.weight = weight_tensor
|
| 60 |
+
self.original_shape = weight_tensor.shape
|
| 61 |
+
|
| 62 |
+
# Validate dimensions
|
| 63 |
+
n_elements = weight_tensor.numel()
|
| 64 |
+
if n_elements % dim != 0:
|
| 65 |
+
raise ValueError(
|
| 66 |
+
f"Weight tensor has {n_elements} elements, which is not "
|
| 67 |
+
f"divisible by dim={dim}. Choose a different dim."
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
self.n_vectors = n_elements // dim
|
| 71 |
+
|
| 72 |
+
# Initialize centroids as a buffer (not a parameter - updated by DKM iterations)
|
| 73 |
+
# Shape: (n_clusters, dim) for multi-dim, (n_clusters, 1) for scalar
|
| 74 |
+
centroids = self._initialize_centroids(weight_tensor)
|
| 75 |
+
self.register_buffer("centroids", centroids)
|
| 76 |
+
|
| 77 |
+
# Track whether this is the first forward pass
|
| 78 |
+
self.register_buffer("initialized", torch.tensor(False))
|
| 79 |
+
|
| 80 |
+
def _initialize_centroids(self, weight_tensor: torch.Tensor) -> torch.Tensor:
|
| 81 |
+
"""
|
| 82 |
+
Initialize centroids using k-means++ or random selection.
|
| 83 |
+
|
| 84 |
+
For multi-dimensional clustering, weights are reshaped into
|
| 85 |
+
(N/d, d) sub-vectors before selecting initial centroids.
|
| 86 |
+
"""
|
| 87 |
+
with torch.no_grad():
|
| 88 |
+
# Flatten and reshape for multi-dim clustering
|
| 89 |
+
flat_weights = weight_tensor.detach().reshape(-1)
|
| 90 |
+
|
| 91 |
+
if self.dim == 1:
|
| 92 |
+
vectors = flat_weights.unsqueeze(1) # (N, 1)
|
| 93 |
+
else:
|
| 94 |
+
vectors = flat_weights.reshape(self.n_vectors, self.dim) # (N/d, d)
|
| 95 |
+
|
| 96 |
+
if self.init_method == "kmeans++":
|
| 97 |
+
centroids = self._kmeans_plus_plus_init(vectors)
|
| 98 |
+
else:
|
| 99 |
+
# Random initialization: select k random weight vectors
|
| 100 |
+
indices = torch.randperm(vectors.shape[0])[:self.n_clusters]
|
| 101 |
+
centroids = vectors[indices].clone()
|
| 102 |
+
|
| 103 |
+
return centroids # (n_clusters, dim)
|
| 104 |
+
|
| 105 |
+
def _kmeans_plus_plus_init(self, vectors: torch.Tensor) -> torch.Tensor:
|
| 106 |
+
"""
|
| 107 |
+
K-means++ initialization for better centroid starting positions.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
vectors: (N, d) weight vectors
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
centroids: (k, d) initial centroids
|
| 114 |
+
"""
|
| 115 |
+
n = vectors.shape[0]
|
| 116 |
+
k = self.n_clusters
|
| 117 |
+
d = vectors.shape[1]
|
| 118 |
+
|
| 119 |
+
# If fewer vectors than clusters, duplicate with small noise
|
| 120 |
+
if n < k:
|
| 121 |
+
repeats = (k // n) + 1
|
| 122 |
+
vectors_expanded = vectors.repeat(repeats, 1)[:k]
|
| 123 |
+
noise = torch.randn_like(vectors_expanded) * 1e-6
|
| 124 |
+
return (vectors_expanded + noise).clone()
|
| 125 |
+
|
| 126 |
+
# Choose first centroid randomly
|
| 127 |
+
idx = torch.randint(0, n, (1,)).item()
|
| 128 |
+
centroids = [vectors[idx].clone()]
|
| 129 |
+
|
| 130 |
+
for _ in range(1, k):
|
| 131 |
+
# Compute distances to nearest existing centroid
|
| 132 |
+
stacked = torch.stack(centroids, dim=0) # (current_k, d)
|
| 133 |
+
# distances: (N, current_k)
|
| 134 |
+
dists = torch.cdist(vectors.unsqueeze(0), stacked.unsqueeze(0)).squeeze(0)
|
| 135 |
+
min_dists = dists.min(dim=1).values # (N,)
|
| 136 |
+
|
| 137 |
+
# Choose next centroid with probability proportional to distance squared
|
| 138 |
+
probs = min_dists ** 2
|
| 139 |
+
prob_sum = probs.sum()
|
| 140 |
+
|
| 141 |
+
if prob_sum < 1e-30 or torch.isnan(prob_sum) or torch.isinf(prob_sum):
|
| 142 |
+
# Fallback: all distances are zero (e.g., uniform weights)
|
| 143 |
+
# Select a random unused index and add small noise to break ties
|
| 144 |
+
idx = torch.randint(0, n, (1,)).item()
|
| 145 |
+
else:
|
| 146 |
+
probs = probs / prob_sum
|
| 147 |
+
# Clamp to avoid negative values from float errors
|
| 148 |
+
probs = probs.clamp(min=0.0)
|
| 149 |
+
idx = torch.multinomial(probs, 1).item()
|
| 150 |
+
|
| 151 |
+
centroids.append(vectors[idx].clone())
|
| 152 |
+
|
| 153 |
+
result = torch.stack(centroids, dim=0) # (k, d)
|
| 154 |
+
|
| 155 |
+
# Add tiny noise to break ties if centroids are identical
|
| 156 |
+
if result.unique(dim=0).shape[0] < k:
|
| 157 |
+
noise = torch.randn_like(result) * (result.abs().mean() * 1e-4 + 1e-8)
|
| 158 |
+
result = result + noise
|
| 159 |
+
|
| 160 |
+
return result
|
| 161 |
+
|
| 162 |
+
def _compute_distance_matrix(
|
| 163 |
+
self, weights: torch.Tensor, centroids: torch.Tensor
|
| 164 |
+
) -> torch.Tensor:
|
| 165 |
+
"""
|
| 166 |
+
Compute negative squared Euclidean distance matrix D.
|
| 167 |
+
|
| 168 |
+
D[i,j] = -||w_i - c_j||^2
|
| 169 |
+
|
| 170 |
+
Per the paper: d_ij = -f(w_i, c_j) where f is Euclidean distance.
|
| 171 |
+
We use squared Euclidean for efficiency (equivalent for softmax).
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
weights: (N/d, d) weight sub-vectors
|
| 175 |
+
centroids: (k, d) centroid vectors
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
D: (N/d, k) negative distance matrix
|
| 179 |
+
"""
|
| 180 |
+
# Efficient computation: ||w - c||^2 = ||w||^2 - 2*w.c + ||c||^2
|
| 181 |
+
w_sq = (weights ** 2).sum(dim=1, keepdim=True) # (N/d, 1)
|
| 182 |
+
c_sq = (centroids ** 2).sum(dim=1, keepdim=True).t() # (1, k)
|
| 183 |
+
wc = weights @ centroids.t() # (N/d, k)
|
| 184 |
+
|
| 185 |
+
D = -(w_sq - 2 * wc + c_sq) # (N/d, k) — negative squared Euclidean
|
| 186 |
+
return D
|
| 187 |
+
|
| 188 |
+
def _compute_attention(self, D: torch.Tensor) -> torch.Tensor:
|
| 189 |
+
"""
|
| 190 |
+
Compute attention matrix A from distance matrix D using softmax with temperature τ.
|
| 191 |
+
|
| 192 |
+
a_ij = exp(d_ij / τ) / Σ_k exp(d_ik / τ)
|
| 193 |
+
|
| 194 |
+
This is the key differentiable component that enables gradient flow
|
| 195 |
+
through the clustering assignment.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
D: (N/d, k) negative distance matrix
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
A: (N/d, k) attention matrix (rows sum to 1)
|
| 202 |
+
"""
|
| 203 |
+
# Scale by temperature and apply softmax along cluster dimension
|
| 204 |
+
A = F.softmax(D / self.tau, dim=1) # (N/d, k)
|
| 205 |
+
return A
|
| 206 |
+
|
| 207 |
+
def _update_centroids(
|
| 208 |
+
self, A: torch.Tensor, weights: torch.Tensor
|
| 209 |
+
) -> torch.Tensor:
|
| 210 |
+
"""
|
| 211 |
+
Update centroids using attention-weighted average of weights.
|
| 212 |
+
|
| 213 |
+
c_j_new = Σ_i(a_ij * w_i) / Σ_i(a_ij)
|
| 214 |
+
|
| 215 |
+
This is the M-step equivalent in the EM interpretation (Appendix G).
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
A: (N/d, k) attention matrix
|
| 219 |
+
weights: (N/d, d) weight sub-vectors
|
| 220 |
+
|
| 221 |
+
Returns:
|
| 222 |
+
new_centroids: (k, d) updated centroids
|
| 223 |
+
"""
|
| 224 |
+
# Numerator: Σ_i(a_ij * w_i) for each centroid j
|
| 225 |
+
# A.t() @ weights: (k, N/d) @ (N/d, d) = (k, d)
|
| 226 |
+
numerator = A.t() @ weights # (k, d)
|
| 227 |
+
|
| 228 |
+
# Denominator: Σ_i(a_ij) for each centroid j
|
| 229 |
+
denominator = A.sum(dim=0, keepdim=True).t() # (k, 1)
|
| 230 |
+
|
| 231 |
+
# Avoid division by zero
|
| 232 |
+
denominator = denominator.clamp(min=1e-10)
|
| 233 |
+
|
| 234 |
+
new_centroids = numerator / denominator # (k, d)
|
| 235 |
+
return new_centroids
|
| 236 |
+
|
| 237 |
+
def forward(self, weight_override: torch.Tensor = None) -> torch.Tensor:
|
| 238 |
+
"""
|
| 239 |
+
Forward pass: perform differentiable k-means clustering.
|
| 240 |
+
|
| 241 |
+
The iterative process (Fig. 2 of the paper):
|
| 242 |
+
1. Compute distance matrix D between weights and centroids
|
| 243 |
+
2. Compute attention matrix A = softmax(D/τ)
|
| 244 |
+
3. Update centroids: C_new = (A^T @ W) / sum(A)
|
| 245 |
+
4. Repeat until convergence or max_iter reached
|
| 246 |
+
5. Return compressed weights: W_tilde = A @ C
|
| 247 |
+
|
| 248 |
+
During training: returns soft-assigned weights (differentiable)
|
| 249 |
+
During eval: returns hard-assigned weights (nearest centroid)
|
| 250 |
+
|
| 251 |
+
Args:
|
| 252 |
+
weight_override: Optional weight tensor to use instead of self.weight
|
| 253 |
+
|
| 254 |
+
Returns:
|
| 255 |
+
compressed_weight: Tensor with same shape as original weight
|
| 256 |
+
"""
|
| 257 |
+
weight_tensor = weight_override if weight_override is not None else self.weight
|
| 258 |
+
|
| 259 |
+
# Reshape weights into sub-vectors for multi-dim clustering
|
| 260 |
+
flat_weights = weight_tensor.reshape(-1)
|
| 261 |
+
if self.dim == 1:
|
| 262 |
+
W = flat_weights.unsqueeze(1) # (N, 1)
|
| 263 |
+
else:
|
| 264 |
+
W = flat_weights.reshape(self.n_vectors, self.dim) # (N/d, d)
|
| 265 |
+
|
| 266 |
+
# Re-initialize centroids on first call (ensures correct device/dtype)
|
| 267 |
+
if not self.initialized:
|
| 268 |
+
self.centroids = self._initialize_centroids(weight_tensor).to(
|
| 269 |
+
weight_tensor.device
|
| 270 |
+
)
|
| 271 |
+
self.initialized = torch.tensor(True, device=weight_tensor.device)
|
| 272 |
+
|
| 273 |
+
# Current centroids (detached from previous iteration's graph)
|
| 274 |
+
C = self.centroids.clone()
|
| 275 |
+
|
| 276 |
+
# Iterative DKM clustering (Section 3.2, Fig. 2)
|
| 277 |
+
for iteration in range(self.max_iter):
|
| 278 |
+
# Step 1: Compute distance matrix
|
| 279 |
+
D = self._compute_distance_matrix(W, C)
|
| 280 |
+
|
| 281 |
+
# Step 2: Compute attention matrix with temperature τ
|
| 282 |
+
A = self._compute_attention(D)
|
| 283 |
+
|
| 284 |
+
# Step 3: Update centroids
|
| 285 |
+
C_new = self._update_centroids(A, W)
|
| 286 |
+
|
| 287 |
+
# Step 4: Check convergence |C - C_new| ≤ ε
|
| 288 |
+
delta = (C - C_new).abs().max().item()
|
| 289 |
+
C = C_new
|
| 290 |
+
|
| 291 |
+
if delta <= self.epsilon:
|
| 292 |
+
break
|
| 293 |
+
|
| 294 |
+
# Store converged centroids for next batch (warm start)
|
| 295 |
+
self.centroids = C.detach().clone()
|
| 296 |
+
|
| 297 |
+
if self.training:
|
| 298 |
+
# Training: use soft assignment (differentiable)
|
| 299 |
+
# W_tilde = A @ C (attention-weighted centroids)
|
| 300 |
+
W_tilde = A @ C # (N/d, d)
|
| 301 |
+
else:
|
| 302 |
+
# Inference: snap to nearest centroid (hard assignment)
|
| 303 |
+
# Find nearest centroid for each weight vector
|
| 304 |
+
D_final = self._compute_distance_matrix(W, C)
|
| 305 |
+
assignments = D_final.argmax(dim=1) # (N/d,) — argmax because D is negative distance
|
| 306 |
+
W_tilde = C[assignments] # (N/d, d)
|
| 307 |
+
|
| 308 |
+
# Reshape back to original weight shape
|
| 309 |
+
compressed_weight = W_tilde.reshape(self.original_shape)
|
| 310 |
+
|
| 311 |
+
return compressed_weight
|
| 312 |
+
|
| 313 |
+
def get_assignments(self) -> torch.Tensor:
|
| 314 |
+
"""
|
| 315 |
+
Get hard cluster assignments for each weight (for inference/analysis).
|
| 316 |
+
|
| 317 |
+
Returns:
|
| 318 |
+
assignments: (N/d,) tensor of cluster indices
|
| 319 |
+
"""
|
| 320 |
+
with torch.no_grad():
|
| 321 |
+
flat_weights = self.weight.detach().reshape(-1)
|
| 322 |
+
if self.dim == 1:
|
| 323 |
+
W = flat_weights.unsqueeze(1)
|
| 324 |
+
else:
|
| 325 |
+
W = flat_weights.reshape(self.n_vectors, self.dim)
|
| 326 |
+
|
| 327 |
+
D = self._compute_distance_matrix(W, self.centroids)
|
| 328 |
+
assignments = D.argmax(dim=1)
|
| 329 |
+
|
| 330 |
+
return assignments
|
| 331 |
+
|
| 332 |
+
def get_codebook(self) -> torch.Tensor:
|
| 333 |
+
"""
|
| 334 |
+
Get the current centroid codebook.
|
| 335 |
+
|
| 336 |
+
Returns:
|
| 337 |
+
centroids: (k, d) centroid values
|
| 338 |
+
"""
|
| 339 |
+
return self.centroids.clone()
|
| 340 |
+
|
| 341 |
+
def extra_repr(self) -> str:
|
| 342 |
+
bits = math.log2(self.n_clusters)
|
| 343 |
+
bpw = bits / self.dim
|
| 344 |
+
return (
|
| 345 |
+
f"n_clusters={self.n_clusters}, tau={self.tau}, dim={self.dim}, "
|
| 346 |
+
f"max_iter={self.max_iter}, eps={self.epsilon}, "
|
| 347 |
+
f"bits={bits:.1f}, bits_per_weight={bpw:.2f}, "
|
| 348 |
+
f"weight_shape={list(self.original_shape)}"
|
| 349 |
+
)
|