| |
| |
| |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import math |
| from collections import defaultdict |
| from typing import List, Dict, Tuple, Optional |
|
|
| SAFE_MIN = -1e6 |
| SAFE_MAX = 1e6 |
| EPS = 1e-8 |
|
|
| |
|
|
| def make_safe(tensor, min_val=SAFE_MIN, max_val=SAFE_MAX): |
| tensor = torch.where(torch.isnan(tensor), torch.tensor(0.0, device=tensor.device, dtype=tensor.dtype), tensor) |
| tensor = torch.where(torch.isinf(tensor), torch.tensor(max_val, device=tensor.device, dtype=tensor.dtype), tensor) |
| return torch.clamp(tensor, min_val, max_val) |
|
|
| def safe_softmax(x, dim=-1, temperature=1.0): |
| x = x.to(dtype=torch.float32) |
| x = make_safe(x, min_val=-50, max_val=50) |
| if isinstance(temperature, torch.Tensor): |
| temperature = float(temperature.detach().cpu().item()) |
| temperature = max(float(temperature), EPS) |
| x = x / temperature |
| x = x - x.amax(dim=dim, keepdim=True) |
| return F.softmax(x, dim=dim) |
|
|
| |
| |
|
|
| class LiquidDynamicsCore(nn.Module): |
| def __init__(self, state_dim, input_dim, liquid_time_constant=1.0): |
| super().__init__() |
| self.state_dim = state_dim |
| self.input_dim = input_dim |
| self.liquid_time_constant = nn.Parameter(torch.tensor(liquid_time_constant)) |
| |
| self.W_rec = nn.Parameter(torch.randn(state_dim, state_dim) * 0.1) |
| self.W_in = nn.Parameter(torch.randn(state_dim, input_dim) * 0.1) |
| self.bias = nn.Parameter(torch.zeros(state_dim)) |
| |
| self.activation = nn.Tanh() |
| |
| self.register_buffer('liquid_state', torch.zeros(1, state_dim)) |
| |
| self.noise_scale = nn.Parameter(torch.tensor(0.1)) |
| self.exploration_rate = nn.Parameter(torch.tensor(0.05)) |
| |
| def reset_state(self, batch_size=1): |
| with torch.no_grad(): |
| if self.liquid_state.shape[0] != batch_size: |
| self.liquid_state = torch.zeros( |
| batch_size, self.state_dim, |
| device=self.liquid_state.device, |
| dtype=self.liquid_state.dtype, |
| ) |
| else: |
| self.liquid_state.zero_() |
| |
| def evolve_liquid(self, input_signal, confidence_weight=1.0, dt=0.1): |
| batch_size = input_signal.shape[0] |
| |
| if self.liquid_state.shape[0] != batch_size: |
| self.reset_state(batch_size) |
| |
| tau = torch.clamp(self.liquid_time_constant, 0.1, 10.0) |
| |
| recurrent_input = torch.matmul(self.activation(self.liquid_state), self.W_rec.T) |
| |
| external_input = torch.matmul(input_signal, self.W_in.T) |
| |
| dynamics = (-self.liquid_state / tau + recurrent_input + external_input + self.bias) |
| |
| if isinstance(confidence_weight, torch.Tensor): |
| if confidence_weight.dim() == 1: |
| confidence_weight = confidence_weight.unsqueeze(-1) |
| confidence_weight = confidence_weight.to(self.liquid_state.dtype) |
| else: |
| confidence_weight = torch.tensor(confidence_weight, device=self.liquid_state.device, dtype=self.liquid_state.dtype) |
| |
| exploration_noise = torch.randn_like(self.liquid_state) * self.noise_scale |
| exploration_strength = (1.0 - confidence_weight) * self.exploration_rate |
| |
| modulated_dynamics = confidence_weight * dynamics + exploration_strength * exploration_noise |
| |
| self.liquid_state.add_(dt * make_safe(modulated_dynamics)) |
| |
| return self.liquid_state.clone() |
| |
| def get_liquid_features(self): |
| return { |
| 'raw_state': self.liquid_state.clone(), |
| 'activated_state': self.activation(self.liquid_state), |
| 'state_energy': torch.sum(self.liquid_state ** 2, dim=-1, keepdim=True), |
| 'state_entropy': self._compute_state_entropy() |
| } |
| |
| def _compute_state_entropy(self): |
| state_probs = safe_softmax(self.liquid_state, dim=-1, temperature=1.0) |
| entropy = -torch.sum(state_probs * torch.log(state_probs + EPS), dim=-1, keepdim=True) |
| return entropy |
|
|
| |
| |
|
|
| class BayesianConfidenceNetwork(nn.Module): |
| def __init__(self, state_dim, num_variables=5, num_states_per_var=3): |
| super().__init__() |
| self.state_dim = state_dim |
| self.num_variables = num_variables |
| self.num_states_per_var = num_states_per_var |
| |
| self.feature_extractor = nn.Sequential( |
| nn.Linear(state_dim, state_dim * 2), |
| nn.LayerNorm(state_dim * 2), |
| nn.ReLU(), |
| nn.Linear(state_dim * 2, num_variables * num_states_per_var) |
| ) |
| |
| self.conditional_prob_tables = nn.ParameterList([ |
| nn.Parameter(torch.randn(num_states_per_var, num_states_per_var * (num_variables - 1)) * 0.1) |
| for _ in range(num_variables) |
| ]) |
| |
| self.priors = nn.Parameter(torch.ones(num_variables, num_states_per_var)) |
| |
| self.confidence_net = nn.Sequential( |
| nn.Linear(num_variables, num_variables * 2), |
| nn.ReLU(), |
| nn.Linear(num_variables * 2, 1), |
| nn.Sigmoid() |
| ) |
| |
| self.uncertainty_estimator = nn.Sequential( |
| nn.Linear(state_dim, state_dim), |
| nn.ReLU(), |
| nn.Linear(state_dim, 1), |
| nn.Sigmoid() |
| ) |
| |
| def extract_variable_beliefs(self, liquid_features): |
| liquid_state = liquid_features['activated_state'] |
| |
| evidence = self.feature_extractor(liquid_state) |
| evidence = evidence.view(-1, self.num_variables, self.num_states_per_var) |
| |
| variable_beliefs = safe_softmax(evidence, dim=-1) |
| |
| return variable_beliefs |
| |
| def bayesian_inference(self, variable_beliefs): |
| batch_size = variable_beliefs.shape[0] |
| device = variable_beliefs.device |
| |
| current_beliefs = safe_softmax(self.priors.unsqueeze(0).expand(batch_size, -1, -1), dim=-1) |
| |
| for iteration in range(3): |
| new_beliefs = current_beliefs.clone() |
| |
| for var_idx in range(self.num_variables): |
| evidence = variable_beliefs[:, var_idx, :] |
| |
| if self.num_variables > 1: |
| other_var_beliefs = torch.cat([ |
| current_beliefs[:, :var_idx].flatten(1), |
| current_beliefs[:, var_idx+1:].flatten(1) |
| ], dim=1) |
| else: |
| other_var_beliefs = torch.zeros(batch_size, 0, device=device) |
| |
| if other_var_beliefs.shape[1] > 0: |
| cond_probs = torch.matmul(other_var_beliefs, self.conditional_prob_tables[var_idx].T) |
| cond_probs = safe_softmax(cond_probs, dim=-1) |
| else: |
| cond_probs = torch.ones_like(evidence) / self.num_states_per_var |
| |
| combined = evidence * cond_probs |
| new_beliefs[:, var_idx, :] = safe_softmax(combined, dim=-1) |
| |
| current_beliefs = new_beliefs |
| |
| return current_beliefs |
| |
| def compute_confidence(self, beliefs, liquid_features): |
| belief_entropy = -torch.sum(beliefs * torch.log(beliefs + EPS), dim=-1) |
| avg_entropy = belief_entropy.mean(dim=-1, keepdim=True) |
| |
| max_entropy = math.log(self.num_states_per_var) |
| entropy_confidence = 1.0 - (avg_entropy / max_entropy) |
| |
| nn_confidence = self.confidence_net(belief_entropy) |
| |
| liquid_uncertainty = self.uncertainty_estimator(liquid_features['raw_state']) |
| state_confidence = 1.0 - liquid_uncertainty |
| |
| total_confidence = 0.4 * entropy_confidence + 0.3 * nn_confidence + 0.3 * state_confidence |
| |
| return torch.clamp(total_confidence, 0.0, 1.0) |
| |
| def forward(self, liquid_features): |
| variable_beliefs = self.extract_variable_beliefs(liquid_features) |
| |
| posterior_beliefs = self.bayesian_inference(variable_beliefs) |
| |
| confidence = self.compute_confidence(posterior_beliefs, liquid_features) |
| |
| return { |
| 'beliefs': posterior_beliefs, |
| 'confidence': confidence, |
| 'variable_beliefs': variable_beliefs |
| } |
|
|
| |
| |
|
|
| class LiquidBayesChain(nn.Module): |
| def __init__(self, input_dim, state_dim, output_dim, num_chain_steps=3): |
| super().__init__() |
| self.input_dim = input_dim |
| self.state_dim = state_dim |
| self.output_dim = output_dim |
| self.num_chain_steps = num_chain_steps |
| |
| self.liquid_core = LiquidDynamicsCore(state_dim, input_dim) |
| self.bayesian_confidence = BayesianConfidenceNetwork(state_dim) |
| |
| self.final_predictor = nn.Sequential( |
| nn.Linear(state_dim, state_dim * 2), |
| nn.LayerNorm(state_dim * 2), |
| nn.ReLU(), |
| nn.Dropout(0.1), |
| nn.Linear(state_dim * 2, output_dim) |
| ) |
| |
| self.final_bayesian = BayesianConfidenceNetwork(output_dim, num_variables=3, num_states_per_var=4) |
| |
| self.step_weights = nn.Parameter(torch.ones(num_chain_steps)) |
| |
| def single_chain_step(self, input_signal, step_idx=0): |
| if step_idx == 0: |
| liquid_state = self.liquid_core.evolve_liquid(input_signal, confidence_weight=1.0) |
| else: |
| liquid_features = self.liquid_core.get_liquid_features() |
| bayes_output = self.bayesian_confidence(liquid_features) |
| confidence = bayes_output['confidence'] |
| |
| liquid_state = self.liquid_core.evolve_liquid(input_signal, confidence_weight=confidence) |
| |
| liquid_features = self.liquid_core.get_liquid_features() |
| |
| bayes_output = self.bayesian_confidence(liquid_features) |
| |
| return { |
| 'liquid_state': liquid_state, |
| 'liquid_features': liquid_features, |
| 'bayes_output': bayes_output, |
| 'confidence': bayes_output['confidence'] |
| } |
| |
| def forward(self, input_signal, return_chain_states=False): |
| batch_size = input_signal.shape[0] |
| |
| self.liquid_core.reset_state(batch_size) |
| |
| chain_states = [] |
| |
| for step in range(self.num_chain_steps): |
| step_output = self.single_chain_step(input_signal, step_idx=step) |
| step_output['step_idx'] = step |
| chain_states.append(step_output) |
| |
| final_liquid_state = chain_states[-1]['liquid_features']['activated_state'] |
| prediction_logits = self.final_predictor(final_liquid_state) |
| |
| prediction_features = { |
| 'raw_state': prediction_logits, |
| 'activated_state': torch.tanh(prediction_logits) |
| } |
| final_bayes = self.final_bayesian(prediction_features) |
| |
| step_weights = safe_softmax(self.step_weights, dim=0) |
| weighted_confidence = sum( |
| step_weights[i] * chain_states[i]['confidence'] |
| for i in range(self.num_chain_steps) |
| ) |
| |
| output = { |
| 'prediction': prediction_logits, |
| 'final_confidence': weighted_confidence, |
| 'final_beliefs': final_bayes['beliefs'], |
| 'prediction_uncertainty': 1.0 - final_bayes['confidence'] |
| } |
| |
| if return_chain_states: |
| output['chain_states'] = chain_states |
| |
| return output |
| |
| def predict_with_uncertainty(self, input_signal): |
| output = self.forward(input_signal, return_chain_states=True) |
| |
| uncertainty_info = { |
| 'prediction': output['prediction'], |
| 'confidence': output['final_confidence'], |
| 'prediction_uncertainty': output['prediction_uncertainty'], |
| 'chain_confidences': [state['confidence'] for state in output['chain_states']], |
| 'liquid_entropies': [state['liquid_features']['state_entropy'] for state in output['chain_states']] |
| } |
| |
| return uncertainty_info |
|
|
| |
|
|