gary-boon Claude Opus 4.5 commited on
Commit
2bdf299
·
1 Parent(s): 2c6343b

Fix MistralTokenizer not loaded during model switch

Browse files

The switch_model endpoint was not creating the MistralTokenizer,
causing special tokens to be decoded incorrectly when switching
from CodeGen to Devstral.

Also adds attention overlay feature changes.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

backend/mistral_tokenizer.py CHANGED
@@ -115,7 +115,8 @@ class MistralTokenizerWrapper:
115
  if not self._available:
116
  raise RuntimeError("MistralTokenizer not available")
117
 
118
- return self.tokenizer.decode([token_id])
 
119
 
120
 
121
  def create_mistral_tokenizer(model_name: str) -> Optional[MistralTokenizerWrapper]:
 
115
  if not self._available:
116
  raise RuntimeError("MistralTokenizer not available")
117
 
118
+ result = self.tokenizer.decode([token_id])
119
+ return result
120
 
121
 
122
  def create_mistral_tokenizer(model_name: str) -> Optional[MistralTokenizerWrapper]:
backend/model_service.py CHANGED
@@ -150,6 +150,55 @@ class MatrixCache:
150
  "ttl_seconds": self._ttl
151
  }
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  # Global matrix cache instance
155
  matrix_cache = MatrixCache(ttl_seconds=3600) # 60 min TTL
@@ -1363,6 +1412,16 @@ async def switch_model(request: Dict[str, Any], authenticated: bool = Depends(ve
1363
  # Create adapter
1364
  manager.adapter = create_adapter(manager.model, manager.tokenizer, model_id)
1365
 
 
 
 
 
 
 
 
 
 
 
1366
  logger.info(f"✅ {config['display_name']} loaded successfully")
1367
  logger.info(f" Layers: {manager.adapter.get_num_layers()}, Heads: {manager.adapter.get_num_heads()}")
1368
 
@@ -2976,6 +3035,81 @@ async def get_matrix_cache_stats(authenticated: bool = Depends(verify_api_key)):
2976
  return matrix_cache.get_stats()
2977
 
2978
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2979
  @app.post("/analyze/study")
2980
  async def analyze_study(request: StudyRequest, authenticated: bool = Depends(verify_api_key)):
2981
  """
 
150
  "ttl_seconds": self._ttl
151
  }
152
 
153
+ def get_attention_row(self, request_id: str, step: int, layer: int, head: int) -> Optional[list]:
154
+ """
155
+ Extract single attention row (last token's attention to all preceding positions).
156
+ Used for attention overlay visualization.
157
+ """
158
+ data = self.get(request_id, step, layer, head)
159
+ if not data or 'attention_weights' not in data:
160
+ return None
161
+ attention = data['attention_weights']
162
+ if attention is None or len(attention) == 0:
163
+ return None
164
+ # Return last row (query token attending to all keys)
165
+ # Handle both numpy arrays and lists
166
+ last_row = attention[-1]
167
+ if hasattr(last_row, 'tolist'):
168
+ return last_row.tolist()
169
+ return list(last_row)
170
+
171
+ def get_aggregate_row(self, request_id: str, step: int, layer: int,
172
+ num_heads: int, mode: str = "mean") -> Optional[list]:
173
+ """
174
+ Compute aggregated attention row across all heads for a layer.
175
+
176
+ Args:
177
+ request_id: UUID from analysis
178
+ step: Generation step
179
+ layer: Layer index
180
+ num_heads: Number of attention heads in model
181
+ mode: Aggregation mode - "mean" or "max"
182
+
183
+ Returns:
184
+ List of aggregated attention weights, or None if data unavailable
185
+ """
186
+ rows = []
187
+ for h in range(num_heads):
188
+ row = self.get_attention_row(request_id, step, layer, h)
189
+ if row:
190
+ rows.append(row)
191
+ if not rows:
192
+ return None
193
+ arr = np.array(rows)
194
+ if mode == "mean":
195
+ return np.mean(arr, axis=0).tolist()
196
+ elif mode == "max":
197
+ return np.max(arr, axis=0).tolist()
198
+ else:
199
+ # Default to mean for unknown modes
200
+ return np.mean(arr, axis=0).tolist()
201
+
202
 
203
  # Global matrix cache instance
204
  matrix_cache = MatrixCache(ttl_seconds=3600) # 60 min TTL
 
1412
  # Create adapter
1413
  manager.adapter = create_adapter(manager.model, manager.tokenizer, model_id)
1414
 
1415
+ # For Devstral, also load MistralTokenizer for correct Tekken encoding
1416
+ manager.mistral_tokenizer = None
1417
+ if model_id == "devstral-small":
1418
+ from .mistral_tokenizer import create_mistral_tokenizer
1419
+ manager.mistral_tokenizer = create_mistral_tokenizer(manager.model_name)
1420
+ if manager.mistral_tokenizer:
1421
+ logger.info("Loaded MistralTokenizer for Devstral (correct Tekken encoding)")
1422
+ else:
1423
+ logger.warning("MistralTokenizer not available - Devstral may produce garbage output")
1424
+
1425
  logger.info(f"✅ {config['display_name']} loaded successfully")
1426
  logger.info(f" Layers: {manager.adapter.get_num_layers()}, Heads: {manager.adapter.get_num_heads()}")
1427
 
 
3035
  return matrix_cache.get_stats()
3036
 
3037
 
3038
+ @app.get("/analyze/research/attention/row")
3039
+ async def get_attention_row(
3040
+ request_id: str,
3041
+ step: int,
3042
+ layer: int,
3043
+ head: Optional[int] = None,
3044
+ aggregate_mode: str = "mean",
3045
+ authenticated: bool = Depends(verify_api_key)
3046
+ ):
3047
+ """
3048
+ Retrieve single attention row for overlay visualization.
3049
+
3050
+ Returns the attention weights from the query token (at position `step`)
3051
+ to all preceding positions. This is a minimal payload for efficient
3052
+ lazy-loading in the attention overlay feature.
3053
+
3054
+ Parameters:
3055
+ - request_id: UUID from the original analysis response
3056
+ - step: Generation step (0 = first generated token)
3057
+ - layer: Layer index (0-based)
3058
+ - head: Head index (0-based), or None for aggregated view
3059
+ - aggregate_mode: "mean" or "max" when head is None
3060
+
3061
+ Returns:
3062
+ - attention_weights: List of attention weights [0..seq_len]
3063
+ - seq_len: Number of positions in the sequence
3064
+ - layer: Layer index
3065
+ - head: Head index (null if aggregated)
3066
+ - aggregate_mode: Mode used if aggregated (null otherwise)
3067
+ """
3068
+ # Get number of heads from model config
3069
+ if not manager.model:
3070
+ raise HTTPException(status_code=503, detail="Model not loaded")
3071
+
3072
+ config = manager.model.config
3073
+ num_heads = getattr(config, 'num_attention_heads', getattr(config, 'n_head', 16))
3074
+
3075
+ if head is not None:
3076
+ # Fetch specific head
3077
+ attention_row = matrix_cache.get_attention_row(request_id, step, layer, head)
3078
+ if attention_row is None:
3079
+ logger.warning(f"Attention row cache miss: request_id={request_id}, step={step}, layer={layer}, head={head}")
3080
+ raise HTTPException(
3081
+ status_code=404,
3082
+ detail="Attention data not found. Cache may have expired (60 min TTL). Please re-analyze."
3083
+ )
3084
+ logger.info(f"Attention row cache hit: request_id={request_id}, step={step}, layer={layer}, head={head}")
3085
+ return {
3086
+ "attention_weights": attention_row,
3087
+ "seq_len": len(attention_row),
3088
+ "layer": layer,
3089
+ "head": head,
3090
+ "aggregate_mode": None
3091
+ }
3092
+ else:
3093
+ # Aggregate across all heads
3094
+ attention_row = matrix_cache.get_aggregate_row(
3095
+ request_id, step, layer, num_heads, aggregate_mode
3096
+ )
3097
+ if attention_row is None:
3098
+ logger.warning(f"Attention row aggregate cache miss: request_id={request_id}, step={step}, layer={layer}")
3099
+ raise HTTPException(
3100
+ status_code=404,
3101
+ detail="Attention data not found. Cache may have expired (60 min TTL). Please re-analyze."
3102
+ )
3103
+ logger.info(f"Attention row aggregate cache hit: request_id={request_id}, step={step}, layer={layer}, mode={aggregate_mode}")
3104
+ return {
3105
+ "attention_weights": attention_row,
3106
+ "seq_len": len(attention_row),
3107
+ "layer": layer,
3108
+ "head": None,
3109
+ "aggregate_mode": aggregate_mode
3110
+ }
3111
+
3112
+
3113
  @app.post("/analyze/study")
3114
  async def analyze_study(request: StudyRequest, authenticated: bool = Depends(verify_api_key)):
3115
  """