Spaces:
Paused
Paused
gary-boon Claude Opus 4.5 commited on
Commit ·
2bdf299
1
Parent(s): 2c6343b
Fix MistralTokenizer not loaded during model switch
Browse filesThe switch_model endpoint was not creating the MistralTokenizer,
causing special tokens to be decoded incorrectly when switching
from CodeGen to Devstral.
Also adds attention overlay feature changes.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- backend/mistral_tokenizer.py +2 -1
- backend/model_service.py +134 -0
backend/mistral_tokenizer.py
CHANGED
|
@@ -115,7 +115,8 @@ class MistralTokenizerWrapper:
|
|
| 115 |
if not self._available:
|
| 116 |
raise RuntimeError("MistralTokenizer not available")
|
| 117 |
|
| 118 |
-
|
|
|
|
| 119 |
|
| 120 |
|
| 121 |
def create_mistral_tokenizer(model_name: str) -> Optional[MistralTokenizerWrapper]:
|
|
|
|
| 115 |
if not self._available:
|
| 116 |
raise RuntimeError("MistralTokenizer not available")
|
| 117 |
|
| 118 |
+
result = self.tokenizer.decode([token_id])
|
| 119 |
+
return result
|
| 120 |
|
| 121 |
|
| 122 |
def create_mistral_tokenizer(model_name: str) -> Optional[MistralTokenizerWrapper]:
|
backend/model_service.py
CHANGED
|
@@ -150,6 +150,55 @@ class MatrixCache:
|
|
| 150 |
"ttl_seconds": self._ttl
|
| 151 |
}
|
| 152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
# Global matrix cache instance
|
| 155 |
matrix_cache = MatrixCache(ttl_seconds=3600) # 60 min TTL
|
|
@@ -1363,6 +1412,16 @@ async def switch_model(request: Dict[str, Any], authenticated: bool = Depends(ve
|
|
| 1363 |
# Create adapter
|
| 1364 |
manager.adapter = create_adapter(manager.model, manager.tokenizer, model_id)
|
| 1365 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1366 |
logger.info(f"✅ {config['display_name']} loaded successfully")
|
| 1367 |
logger.info(f" Layers: {manager.adapter.get_num_layers()}, Heads: {manager.adapter.get_num_heads()}")
|
| 1368 |
|
|
@@ -2976,6 +3035,81 @@ async def get_matrix_cache_stats(authenticated: bool = Depends(verify_api_key)):
|
|
| 2976 |
return matrix_cache.get_stats()
|
| 2977 |
|
| 2978 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2979 |
@app.post("/analyze/study")
|
| 2980 |
async def analyze_study(request: StudyRequest, authenticated: bool = Depends(verify_api_key)):
|
| 2981 |
"""
|
|
|
|
| 150 |
"ttl_seconds": self._ttl
|
| 151 |
}
|
| 152 |
|
| 153 |
+
def get_attention_row(self, request_id: str, step: int, layer: int, head: int) -> Optional[list]:
|
| 154 |
+
"""
|
| 155 |
+
Extract single attention row (last token's attention to all preceding positions).
|
| 156 |
+
Used for attention overlay visualization.
|
| 157 |
+
"""
|
| 158 |
+
data = self.get(request_id, step, layer, head)
|
| 159 |
+
if not data or 'attention_weights' not in data:
|
| 160 |
+
return None
|
| 161 |
+
attention = data['attention_weights']
|
| 162 |
+
if attention is None or len(attention) == 0:
|
| 163 |
+
return None
|
| 164 |
+
# Return last row (query token attending to all keys)
|
| 165 |
+
# Handle both numpy arrays and lists
|
| 166 |
+
last_row = attention[-1]
|
| 167 |
+
if hasattr(last_row, 'tolist'):
|
| 168 |
+
return last_row.tolist()
|
| 169 |
+
return list(last_row)
|
| 170 |
+
|
| 171 |
+
def get_aggregate_row(self, request_id: str, step: int, layer: int,
|
| 172 |
+
num_heads: int, mode: str = "mean") -> Optional[list]:
|
| 173 |
+
"""
|
| 174 |
+
Compute aggregated attention row across all heads for a layer.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
request_id: UUID from analysis
|
| 178 |
+
step: Generation step
|
| 179 |
+
layer: Layer index
|
| 180 |
+
num_heads: Number of attention heads in model
|
| 181 |
+
mode: Aggregation mode - "mean" or "max"
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
List of aggregated attention weights, or None if data unavailable
|
| 185 |
+
"""
|
| 186 |
+
rows = []
|
| 187 |
+
for h in range(num_heads):
|
| 188 |
+
row = self.get_attention_row(request_id, step, layer, h)
|
| 189 |
+
if row:
|
| 190 |
+
rows.append(row)
|
| 191 |
+
if not rows:
|
| 192 |
+
return None
|
| 193 |
+
arr = np.array(rows)
|
| 194 |
+
if mode == "mean":
|
| 195 |
+
return np.mean(arr, axis=0).tolist()
|
| 196 |
+
elif mode == "max":
|
| 197 |
+
return np.max(arr, axis=0).tolist()
|
| 198 |
+
else:
|
| 199 |
+
# Default to mean for unknown modes
|
| 200 |
+
return np.mean(arr, axis=0).tolist()
|
| 201 |
+
|
| 202 |
|
| 203 |
# Global matrix cache instance
|
| 204 |
matrix_cache = MatrixCache(ttl_seconds=3600) # 60 min TTL
|
|
|
|
| 1412 |
# Create adapter
|
| 1413 |
manager.adapter = create_adapter(manager.model, manager.tokenizer, model_id)
|
| 1414 |
|
| 1415 |
+
# For Devstral, also load MistralTokenizer for correct Tekken encoding
|
| 1416 |
+
manager.mistral_tokenizer = None
|
| 1417 |
+
if model_id == "devstral-small":
|
| 1418 |
+
from .mistral_tokenizer import create_mistral_tokenizer
|
| 1419 |
+
manager.mistral_tokenizer = create_mistral_tokenizer(manager.model_name)
|
| 1420 |
+
if manager.mistral_tokenizer:
|
| 1421 |
+
logger.info("Loaded MistralTokenizer for Devstral (correct Tekken encoding)")
|
| 1422 |
+
else:
|
| 1423 |
+
logger.warning("MistralTokenizer not available - Devstral may produce garbage output")
|
| 1424 |
+
|
| 1425 |
logger.info(f"✅ {config['display_name']} loaded successfully")
|
| 1426 |
logger.info(f" Layers: {manager.adapter.get_num_layers()}, Heads: {manager.adapter.get_num_heads()}")
|
| 1427 |
|
|
|
|
| 3035 |
return matrix_cache.get_stats()
|
| 3036 |
|
| 3037 |
|
| 3038 |
+
@app.get("/analyze/research/attention/row")
|
| 3039 |
+
async def get_attention_row(
|
| 3040 |
+
request_id: str,
|
| 3041 |
+
step: int,
|
| 3042 |
+
layer: int,
|
| 3043 |
+
head: Optional[int] = None,
|
| 3044 |
+
aggregate_mode: str = "mean",
|
| 3045 |
+
authenticated: bool = Depends(verify_api_key)
|
| 3046 |
+
):
|
| 3047 |
+
"""
|
| 3048 |
+
Retrieve single attention row for overlay visualization.
|
| 3049 |
+
|
| 3050 |
+
Returns the attention weights from the query token (at position `step`)
|
| 3051 |
+
to all preceding positions. This is a minimal payload for efficient
|
| 3052 |
+
lazy-loading in the attention overlay feature.
|
| 3053 |
+
|
| 3054 |
+
Parameters:
|
| 3055 |
+
- request_id: UUID from the original analysis response
|
| 3056 |
+
- step: Generation step (0 = first generated token)
|
| 3057 |
+
- layer: Layer index (0-based)
|
| 3058 |
+
- head: Head index (0-based), or None for aggregated view
|
| 3059 |
+
- aggregate_mode: "mean" or "max" when head is None
|
| 3060 |
+
|
| 3061 |
+
Returns:
|
| 3062 |
+
- attention_weights: List of attention weights [0..seq_len]
|
| 3063 |
+
- seq_len: Number of positions in the sequence
|
| 3064 |
+
- layer: Layer index
|
| 3065 |
+
- head: Head index (null if aggregated)
|
| 3066 |
+
- aggregate_mode: Mode used if aggregated (null otherwise)
|
| 3067 |
+
"""
|
| 3068 |
+
# Get number of heads from model config
|
| 3069 |
+
if not manager.model:
|
| 3070 |
+
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 3071 |
+
|
| 3072 |
+
config = manager.model.config
|
| 3073 |
+
num_heads = getattr(config, 'num_attention_heads', getattr(config, 'n_head', 16))
|
| 3074 |
+
|
| 3075 |
+
if head is not None:
|
| 3076 |
+
# Fetch specific head
|
| 3077 |
+
attention_row = matrix_cache.get_attention_row(request_id, step, layer, head)
|
| 3078 |
+
if attention_row is None:
|
| 3079 |
+
logger.warning(f"Attention row cache miss: request_id={request_id}, step={step}, layer={layer}, head={head}")
|
| 3080 |
+
raise HTTPException(
|
| 3081 |
+
status_code=404,
|
| 3082 |
+
detail="Attention data not found. Cache may have expired (60 min TTL). Please re-analyze."
|
| 3083 |
+
)
|
| 3084 |
+
logger.info(f"Attention row cache hit: request_id={request_id}, step={step}, layer={layer}, head={head}")
|
| 3085 |
+
return {
|
| 3086 |
+
"attention_weights": attention_row,
|
| 3087 |
+
"seq_len": len(attention_row),
|
| 3088 |
+
"layer": layer,
|
| 3089 |
+
"head": head,
|
| 3090 |
+
"aggregate_mode": None
|
| 3091 |
+
}
|
| 3092 |
+
else:
|
| 3093 |
+
# Aggregate across all heads
|
| 3094 |
+
attention_row = matrix_cache.get_aggregate_row(
|
| 3095 |
+
request_id, step, layer, num_heads, aggregate_mode
|
| 3096 |
+
)
|
| 3097 |
+
if attention_row is None:
|
| 3098 |
+
logger.warning(f"Attention row aggregate cache miss: request_id={request_id}, step={step}, layer={layer}")
|
| 3099 |
+
raise HTTPException(
|
| 3100 |
+
status_code=404,
|
| 3101 |
+
detail="Attention data not found. Cache may have expired (60 min TTL). Please re-analyze."
|
| 3102 |
+
)
|
| 3103 |
+
logger.info(f"Attention row aggregate cache hit: request_id={request_id}, step={step}, layer={layer}, mode={aggregate_mode}")
|
| 3104 |
+
return {
|
| 3105 |
+
"attention_weights": attention_row,
|
| 3106 |
+
"seq_len": len(attention_row),
|
| 3107 |
+
"layer": layer,
|
| 3108 |
+
"head": None,
|
| 3109 |
+
"aggregate_mode": aggregate_mode
|
| 3110 |
+
}
|
| 3111 |
+
|
| 3112 |
+
|
| 3113 |
@app.post("/analyze/study")
|
| 3114 |
async def analyze_study(request: StudyRequest, authenticated: bool = Depends(verify_api_key)):
|
| 3115 |
"""
|