SHAP values on this model

#1
by valimikayilov - opened

I’m planning a comprehensive SHAP analysis and explainability on this xLSTM(based on only mLSTM) model: https://huggingface.co/stefan-it/xlstm-german-wikipedia
Main goals:
• Understand how the model makes predictions through feature attributions
• Explore how the mLSTM memory mechanism works under the hood
• Visualize what the model “pays attention to” when processing text
Any advice on the best approach to tackle this? Would appreciate suggestions on tools, methods, or workflows that work well for this kind of analysis.
Thanks!

Hi — nice plan, sounds very interesting!
Here’s a concise, practical workflow and tools that work well for explainability on xLSTM / mLSTM models, based on my experience:

Suggested approach

Token-level attributions — use Integrated Gradients (Captum) or DeepSHAP where possible to get token importance; aggregate subword attributions to word-level.

Perturbation / occlusion tests — systematically mask/replace tokens to measure logit deltas (good sanity check and causal signal).

Hidden & cell state logging — record hidden and cell states at each time step for representative examples; visualize trajectories and look for stable dimensions (spikes, plateaus).

Clustering & projection — run PCA / UMAP / t-SNE on recorded states to find clusters; inspect tokens/contexts for cluster semantics.

Ablation experiments — zero-out or perturb specific cell dims mid-sequence and re-run forward pass to measure causal impact on outputs.

Synthetic tests — create controlled sentences to probe long-distance dependencies, negation, named entities, numeric patterns, etc.

Combine methods — compare IG/SHAP rankings with occlusion; consistent signals are trustworthy.

Tools & libs

PyTorch + Hugging Face tokenizers

Captum (IntegratedGradients, LayerIntegratedGradients, DeepLift)

shap (DeepExplainer / KernelExplainer for checks)

numpy / pandas / matplotlib / plotly for viz

scikit-learn / umap-learn for projections

spaCy for POS / NER features to correlate with states

Practical tips

Use embedding-level IG (or LayerIntegratedGradients on embedding layer) to get gradient attributions for tokens.

Baseline choice matters (zeros vs average embeddings); test several baselines.

Aggregate subword attributions to the word level for readable visuals.

Keep sanity checks: correlation between occlusion deltas and IG/SHAP attributions.

Log states for small sets (50–200 examples) first — state arrays grow fast.

Document seeds, preprocessing, sequence length, and baseline for reproducibility.

If you want, I can:

provide a ready-to-run notebook (Captum + occlusion + state logging) tailored to stefan-it/xlstm-german-wikipedia, or

share small code snippets to log hidden/cell states, do IG, and plot attributions.

Would you like a notebook or a few starter scripts?

— Soltani / AbramarSolution

Sign up or log in to comment