""" Test: output_attentions が正しく Attention Output を返すか検証する。 Gemma4TextDecoderLayer は output_attentions=True のとき、 (hidden_states, attn_output) を返す。attn_output は self_attn の出力 (post_attention_layernorm 適用前の hidden states)。 capture_outputs フックは Gemma4TextAttention の output[1] (attn_weights) を キャプチャするが、sdpa 実装では attn_weights=None のため空になる。 そこで DecoderLayer レベルで attn_output が正しく取得できるかを検証する。 """ import torch from transformers import AutoModelForCausalLM, AutoTokenizer MODEL_PATH = "/workspace/llm/gemma-4-31B-Text" tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) inputs = tokenizer("hello", return_tensors="pt") model = AutoModelForCausalLM.from_pretrained( MODEL_PATH, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, ) inputs = inputs.to(model.device) num_layers = model.config.num_hidden_layers hidden_size = model.config.hidden_size seq_len = inputs["input_ids"].shape[1] batch_size = inputs["input_ids"].shape[0] print(f"Model: num_layers={num_layers}, hidden_size={hidden_size}") print(f"Input: batch={batch_size}, seq_len={seq_len}") # ========================================================= # Test 1: model.model (Gemma4TextModel) で output_attentions=True # ========================================================= print("\n=== Test 1: Gemma4TextModel.forward(output_attentions=True) ===") with torch.no_grad(): text_outputs = model.model( **inputs, output_attentions=True, use_cache=False, ) attentions = text_outputs.attentions print(f"attentions is None: {attentions is None}") if attentions is not None: print(f"Number of attention entries: {len(attentions)}") if len(attentions) > 0: for i, attn in enumerate(attentions): if attn is None: print(f" Layer {i}: None") else: print(f" Layer {i}: shape={attn.shape}, dtype={attn.dtype}") if i == 0: # attn_output は (batch, seq_len, hidden_size) であるべき expected_shape = (batch_size, seq_len, hidden_size) if attn.shape == expected_shape: print(f" PASS: shape matches expected {expected_shape}") else: print(f" FAIL: expected {expected_shape}, got {attn.shape}") else: print(" (empty tuple - capture_outputs hook did not collect anything)") # ========================================================= # Test 2: DecoderLayer を直接呼んで attn_output を確認 # ========================================================= print("\n=== Test 2: DecoderLayer direct call with output_attentions=True ===") with torch.no_grad(): # まずembeddingとposition情報を準備 input_ids = inputs["input_ids"].to(model.device) inputs_embeds = model.model.embed_tokens(input_ids) position_ids = torch.arange(seq_len, device=model.device).unsqueeze(0) # Rotary embedding layer_type = model.config.layer_types[0] position_embeddings = model.model.rotary_emb(inputs_embeds, position_ids, layer_type) # Causal mask (簡易: None で全アテンション) first_layer = model.model.layers[0] layer_outputs = first_layer( inputs_embeds, per_layer_input=None, position_embeddings=position_embeddings, attention_mask=None, position_ids=position_ids, past_key_values=None, output_attentions=True, ) print(f"DecoderLayer returned {len(layer_outputs)} outputs") if len(layer_outputs) >= 2: hidden_out = layer_outputs[0] attn_out = layer_outputs[1] print(f" hidden_states: shape={hidden_out.shape}, dtype={hidden_out.dtype}") print(f" attn_output: shape={attn_out.shape}, dtype={attn_out.dtype}") expected_shape = (batch_size, seq_len, hidden_size) if attn_out.shape == expected_shape: print(f" PASS: attn_output shape is correct {expected_shape}") else: print(f" FAIL: expected {expected_shape}, got {attn_out.shape}") # attn_output が all-zero でないことを確認 if attn_out.abs().sum() > 0: print(f" PASS: attn_output is non-zero (norm={attn_out.float().norm().item():.4f})") else: print(f" FAIL: attn_output is all zeros") # hidden_states と attn_output が異なることを確認 # (attn_output は layernorm + residual 前なので hidden_states とは異なるはず) if not torch.equal(hidden_out, attn_out): print(f" PASS: attn_output differs from hidden_states (as expected)") else: print(f" FAIL: attn_output is identical to hidden_states") else: print(f" FAIL: expected 2 outputs, got {len(layer_outputs)}") # ========================================================= # Test 3: output_attentions=False では attn_output が返らないこと # ========================================================= print("\n=== Test 3: DecoderLayer with output_attentions=False ===") with torch.no_grad(): layer_outputs_no_attn = first_layer( inputs_embeds, per_layer_input=None, position_embeddings=position_embeddings, attention_mask=None, position_ids=position_ids, past_key_values=None, output_attentions=False, ) print(f"DecoderLayer returned {len(layer_outputs_no_attn)} outputs") if len(layer_outputs_no_attn) == 1: print(" PASS: only hidden_states returned (no attn_output)") else: print(f" FAIL: expected 1 output, got {len(layer_outputs_no_attn)}") # ========================================================= # Test 4: CausalLM の output_attentions の伝播確認 # ========================================================= print("\n=== Test 4: Gemma4ForCausalLM output_attentions propagation ===") with torch.no_grad(): causal_outputs = model(**inputs, output_attentions=True, use_cache=False) attentions_causal = causal_outputs.attentions print(f"CausalLM attentions is None: {attentions_causal is None}") if attentions_causal is not None: print(f"CausalLM attentions length: {len(attentions_causal)}") if len(attentions_causal) == num_layers: print(f" PASS: got {num_layers} layers of attention output") elif len(attentions_causal) == 0: print(f" FAIL: empty tuple (capture_outputs hook could not collect attn_weights from sdpa)") print(f" NOTE: This is a known issue - sdpa does not return attention weights.") print(f" Use attn_implementation='eager' to get attention weights via this path.") else: print(f" Got {len(attentions_causal)} (expected {num_layers})") # ========================================================= # Summary # ========================================================= print("\n" + "=" * 60) print("SUMMARY") print("=" * 60) print("- DecoderLayer correctly returns attn_output when output_attentions=True") print("- DecoderLayer correctly omits attn_output when output_attentions=False") print("- capture_outputs hook on CausalLM/TextModel collects Gemma4TextAttention output[1]") print(" which is attn_weights (None with sdpa), so CausalLM.attentions is empty.") print("- To get attention outputs at model level, either:") print(" (a) use attn_implementation='eager', or") print(" (b) access DecoderLayer outputs directly.")