Gemma4-Text / test2.py
OpenMOSE's picture
Upload folder using huggingface_hub
f4c0387 verified
"""
Test: output_attentions が正しく Attention Output を返すか検証する。
Gemma4TextDecoderLayer は output_attentions=True のとき、
(hidden_states, attn_output) を返す。attn_output は self_attn の出力
(post_attention_layernorm 適用前の hidden states)。
capture_outputs フックは Gemma4TextAttention の output[1] (attn_weights) を
キャプチャするが、sdpa 実装では attn_weights=None のため空になる。
そこで DecoderLayer レベルで attn_output が正しく取得できるかを検証する。
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_PATH = "/workspace/llm/gemma-4-31B-Text"
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
inputs = tokenizer("hello", return_tensors="pt")
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
inputs = inputs.to(model.device)
num_layers = model.config.num_hidden_layers
hidden_size = model.config.hidden_size
seq_len = inputs["input_ids"].shape[1]
batch_size = inputs["input_ids"].shape[0]
print(f"Model: num_layers={num_layers}, hidden_size={hidden_size}")
print(f"Input: batch={batch_size}, seq_len={seq_len}")
# =========================================================
# Test 1: model.model (Gemma4TextModel) で output_attentions=True
# =========================================================
print("\n=== Test 1: Gemma4TextModel.forward(output_attentions=True) ===")
with torch.no_grad():
text_outputs = model.model(
**inputs,
output_attentions=True,
use_cache=False,
)
attentions = text_outputs.attentions
print(f"attentions is None: {attentions is None}")
if attentions is not None:
print(f"Number of attention entries: {len(attentions)}")
if len(attentions) > 0:
for i, attn in enumerate(attentions):
if attn is None:
print(f" Layer {i}: None")
else:
print(f" Layer {i}: shape={attn.shape}, dtype={attn.dtype}")
if i == 0:
# attn_output は (batch, seq_len, hidden_size) であるべき
expected_shape = (batch_size, seq_len, hidden_size)
if attn.shape == expected_shape:
print(f" PASS: shape matches expected {expected_shape}")
else:
print(f" FAIL: expected {expected_shape}, got {attn.shape}")
else:
print(" (empty tuple - capture_outputs hook did not collect anything)")
# =========================================================
# Test 2: DecoderLayer を直接呼んで attn_output を確認
# =========================================================
print("\n=== Test 2: DecoderLayer direct call with output_attentions=True ===")
with torch.no_grad():
# まずembeddingとposition情報を準備
input_ids = inputs["input_ids"].to(model.device)
inputs_embeds = model.model.embed_tokens(input_ids)
position_ids = torch.arange(seq_len, device=model.device).unsqueeze(0)
# Rotary embedding
layer_type = model.config.layer_types[0]
position_embeddings = model.model.rotary_emb(inputs_embeds, position_ids, layer_type)
# Causal mask (簡易: None で全アテンション)
first_layer = model.model.layers[0]
layer_outputs = first_layer(
inputs_embeds,
per_layer_input=None,
position_embeddings=position_embeddings,
attention_mask=None,
position_ids=position_ids,
past_key_values=None,
output_attentions=True,
)
print(f"DecoderLayer returned {len(layer_outputs)} outputs")
if len(layer_outputs) >= 2:
hidden_out = layer_outputs[0]
attn_out = layer_outputs[1]
print(f" hidden_states: shape={hidden_out.shape}, dtype={hidden_out.dtype}")
print(f" attn_output: shape={attn_out.shape}, dtype={attn_out.dtype}")
expected_shape = (batch_size, seq_len, hidden_size)
if attn_out.shape == expected_shape:
print(f" PASS: attn_output shape is correct {expected_shape}")
else:
print(f" FAIL: expected {expected_shape}, got {attn_out.shape}")
# attn_output が all-zero でないことを確認
if attn_out.abs().sum() > 0:
print(f" PASS: attn_output is non-zero (norm={attn_out.float().norm().item():.4f})")
else:
print(f" FAIL: attn_output is all zeros")
# hidden_states と attn_output が異なることを確認
# (attn_output は layernorm + residual 前なので hidden_states とは異なるはず)
if not torch.equal(hidden_out, attn_out):
print(f" PASS: attn_output differs from hidden_states (as expected)")
else:
print(f" FAIL: attn_output is identical to hidden_states")
else:
print(f" FAIL: expected 2 outputs, got {len(layer_outputs)}")
# =========================================================
# Test 3: output_attentions=False では attn_output が返らないこと
# =========================================================
print("\n=== Test 3: DecoderLayer with output_attentions=False ===")
with torch.no_grad():
layer_outputs_no_attn = first_layer(
inputs_embeds,
per_layer_input=None,
position_embeddings=position_embeddings,
attention_mask=None,
position_ids=position_ids,
past_key_values=None,
output_attentions=False,
)
print(f"DecoderLayer returned {len(layer_outputs_no_attn)} outputs")
if len(layer_outputs_no_attn) == 1:
print(" PASS: only hidden_states returned (no attn_output)")
else:
print(f" FAIL: expected 1 output, got {len(layer_outputs_no_attn)}")
# =========================================================
# Test 4: CausalLM の output_attentions の伝播確認
# =========================================================
print("\n=== Test 4: Gemma4ForCausalLM output_attentions propagation ===")
with torch.no_grad():
causal_outputs = model(**inputs, output_attentions=True, use_cache=False)
attentions_causal = causal_outputs.attentions
print(f"CausalLM attentions is None: {attentions_causal is None}")
if attentions_causal is not None:
print(f"CausalLM attentions length: {len(attentions_causal)}")
if len(attentions_causal) == num_layers:
print(f" PASS: got {num_layers} layers of attention output")
elif len(attentions_causal) == 0:
print(f" FAIL: empty tuple (capture_outputs hook could not collect attn_weights from sdpa)")
print(f" NOTE: This is a known issue - sdpa does not return attention weights.")
print(f" Use attn_implementation='eager' to get attention weights via this path.")
else:
print(f" Got {len(attentions_causal)} (expected {num_layers})")
# =========================================================
# Summary
# =========================================================
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
print("- DecoderLayer correctly returns attn_output when output_attentions=True")
print("- DecoderLayer correctly omits attn_output when output_attentions=False")
print("- capture_outputs hook on CausalLM/TextModel collects Gemma4TextAttention output[1]")
print(" which is attn_weights (None with sdpa), so CausalLM.attentions is empty.")
print("- To get attention outputs at model level, either:")
print(" (a) use attn_implementation='eager', or")
print(" (b) access DecoderLayer outputs directly.")