Spaces:
Running
Running
File size: 2,457 Bytes
8577352 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | import torch
import torch.nn as nn
from typing import Dict, List, Any
import numpy as np
class UniversalityProbe:
"""
Probes for universal feature representations across different architectures (e.g., DT vs DQN).
"""
def __init__(self, dt_model: nn.Module, dqn_model: nn.Module):
self.dt_model = dt_model
self.dqn_model = dqn_model
def collect_paired_activations(
self,
env_states: torch.Tensor,
dt_hook_point: str,
dqn_layer_idx: int
) -> Dict[str, torch.Tensor]:
"""
Collects activations from both models on the same set of environmental states.
"""
# DT activations (assuming cache is handled or provided)
# This is a simplified placeholder
dt_acts = torch.randn(env_states.shape[0], 128) # Mock
# DQN activations
# dqn_acts = self.dqn_model.get_layer_activations(env_states, dqn_layer_idx)
dqn_acts = torch.randn(env_states.shape[0], 64) # Mock
return {
"dt": dt_acts,
"dqn": dqn_acts
}
def compute_cross_correlation(
self,
dt_sae_features: torch.Tensor,
dqn_activations: torch.Tensor
) -> torch.Tensor:
"""
Computes the correlation matrix between DT SAE features and DQN activations.
High correlation suggests a 'Universal Concept'.
"""
# Normalize
dt_feat_norm = (dt_sae_features - dt_sae_features.mean(dim=0)) / (dt_sae_features.std(dim=0) + 1e-8)
dqn_act_norm = (dqn_activations - dqn_activations.mean(dim=0)) / (dqn_activations.std(dim=0) + 1e-8)
# Correlation matrix
correlation = torch.matmul(dt_feat_norm.t(), dqn_act_norm) / dt_feat_norm.shape[0]
return correlation
def identify_universal_features(
self,
correlation_matrix: torch.Tensor,
threshold: float = 0.7
) -> List[Dict[str, Any]]:
"""
Identifies pairs of (DT Feature, DQN Neuron) that represent the same concept.
"""
universal_pairs = []
matches = (correlation_matrix.abs() > threshold).nonzero()
for i, j in matches:
universal_pairs.append({
"dt_feature_idx": i.item(),
"dqn_neuron_idx": j.item(),
"correlation": correlation_matrix[i, j].item()
})
return universal_pairs
|