| import torch |
| from vllm import LLM, SamplingParams |
|
|
| |
| model = LLM( |
| model="Qwen/Qwen2.5-0.5B", |
| tensor_parallel_size=1, |
| enforce_eager=True, |
| trust_remote_code=True |
| ) |
|
|
| |
| layer2_original = None |
| layer2_patched = None |
|
|
| def capture_layer2_hook(module, input, output): |
| """Hook to capture layer 2 hidden states""" |
| global layer2_original, layer2_patched |
| hidden_state = output[0] if isinstance(output, tuple) else output |
| |
| |
| if hasattr(module, '_is_patched'): |
| layer2_patched = hidden_state.detach().clone().cpu() |
| else: |
| layer2_original = hidden_state.detach().clone().cpu() |
|
|
| def make_qwen_hook(): |
| def qwen_forward(self, x): |
| |
| gate_up, _ = self.gate_up_proj(x) |
| intermediate_size = gate_up.size(-1) // 2 |
| gate = gate_up[..., :intermediate_size] |
| up = gate_up[..., intermediate_size:] |
| gate_activation = torch.nn.functional.silu(gate) |
| |
| |
| x, _ = self.down_proj(gate_activation * up) |
| return x |
| |
| return qwen_forward |
|
|
| def main(): |
| sentence = "hello world" |
| |
| |
| layer2 = model.llm_engine.model_executor.driver_worker.model_runner.model.model.layers[2] |
| |
| |
| hook = layer2.register_forward_hook(capture_layer2_hook) |
| |
| print("=== Getting original hidden states ===") |
| |
| sampling_params = SamplingParams(temperature=0, max_tokens=1) |
| model.generate([sentence], sampling_params) |
| |
| print("=== Applying patch and getting patched hidden states ===") |
| |
| original_forward = layer2.mlp.forward |
| layer2.mlp.forward = make_qwen_hook().__get__(layer2.mlp, layer2.mlp.__class__) |
| layer2._is_patched = True |
| |
| |
| model.generate([sentence], sampling_params) |
| |
| |
| layer2.mlp.forward = original_forward |
| delattr(layer2, '_is_patched') |
| hook.remove() |
| |
| print("=== Comparison ===") |
| print(f"Original shape: {layer2_original.shape}") |
| print(f"Patched shape: {layer2_patched.shape}") |
| |
| |
| if torch.allclose(layer2_original, layer2_patched, rtol=1e-4, atol=1e-6): |
| print("✅ PATCH IS CORRECT: Hidden states match!") |
| else: |
| max_diff = torch.max(torch.abs(layer2_original - layer2_patched)).item() |
| mean_diff = torch.mean(torch.abs(layer2_original - layer2_patched)).item() |
| print(f"❌ PATCH IS INCORRECT: Max diff = {max_diff:.6f}, Mean diff = {mean_diff:.6f}") |
| |
| print(f"\nOriginal hidden states (first 10 values):\n{layer2_original.flatten()[:10]}") |
| print(f"\nPatched hidden states (first 10 values):\n{layer2_patched.flatten()[:10]}") |
|
|
| if __name__ == "__main__": |
| main() |