| |
| |
| |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import math |
| import matplotlib.pyplot as plt |
| from typing import List, Dict, Tuple, Optional |
|
|
| SAFE_MIN = -1e6 |
| SAFE_MAX = 1e6 |
| EPS = 1e-8 |
|
|
| |
|
|
| def make_safe(tensor, min_val=SAFE_MIN, max_val=SAFE_MAX): |
| zero = torch.tensor(0.0, device=tensor.device, dtype=tensor.dtype) |
| maxv = torch.tensor(max_val, device=tensor.device, dtype=tensor.dtype) |
| tensor = torch.where(torch.isnan(tensor), zero, tensor) |
| tensor = torch.where(torch.isinf(tensor), maxv, tensor) |
| return torch.clamp(tensor, min_val, max_val) |
|
|
|
|
| def safe_complex_division(numerator, denominator, eps=EPS): |
| denominator_conj = torch.conj(denominator) |
| norm_sq = torch.real(denominator * denominator_conj) |
| norm_sq = torch.clamp(norm_sq, min=eps) |
| return (numerator * denominator_conj) / norm_sq |
|
|
| |
| |
|
|
| class MobiusTransform(nn.Module): |
| def __init__(self, learnable=True, init_identity=True): |
| super().__init__() |
| self.learnable = learnable |
| |
| if init_identity: |
| a_init, b_init, c_init, d_init = 1.0, 0.0, 0.0, 1.0 |
| else: |
| a_init, d_init = 1.0, 1.0 |
| b_init, c_init = 0.1, 0.1 |
| |
| if learnable: |
| self.a = nn.Parameter(torch.tensor([a_init, 0.0])) |
| self.b = nn.Parameter(torch.tensor([b_init, 0.0])) |
| self.c = nn.Parameter(torch.tensor([c_init, 0.0])) |
| self.d = nn.Parameter(torch.tensor([d_init, 0.0])) |
| else: |
| self.register_buffer('a', torch.tensor([a_init, 0.0])) |
| self.register_buffer('b', torch.tensor([b_init, 0.0])) |
| self.register_buffer('c', torch.tensor([c_init, 0.0])) |
| self.register_buffer('d', torch.tensor([d_init, 0.0])) |
| |
| def to_complex(self, param): |
| return torch.complex(param[0], param[1]) |
| |
| def get_determinant(self): |
| a_complex = self.to_complex(self.a) |
| b_complex = self.to_complex(self.b) |
| c_complex = self.to_complex(self.c) |
| d_complex = self.to_complex(self.d) |
| |
| det = a_complex * d_complex - b_complex * c_complex |
| return det |
| |
| def normalize_parameters(self): |
| if self.learnable: |
| with torch.no_grad(): |
| det = torch.abs(self.get_determinant()) |
| if det < EPS: |
| one = torch.tensor([1.0, 0.0], device=self.a.device, dtype=self.a.dtype) |
| self.a.copy_(one) |
| self.d.copy_(one) |
| self.b.mul_(0.1) |
| self.c.mul_(0.1) |
| for p in (self.a, self.b, self.c, self.d): |
| p.clamp_(-10.0, 10.0) |
|
|
| |
| def transform(self, z): |
| self.normalize_parameters() |
| |
| a_complex = self.to_complex(self.a) |
| b_complex = self.to_complex(self.b) |
| c_complex = self.to_complex(self.c) |
| d_complex = self.to_complex(self.d) |
| |
| numerator = a_complex * z + b_complex |
| denominator = c_complex * z + d_complex |
| transformed = safe_complex_division(numerator, denominator) |
| |
| return transformed |
| |
| def inverse_transform(self, w): |
| self.normalize_parameters() |
| |
| a_complex = self.to_complex(self.a) |
| b_complex = self.to_complex(self.b) |
| c_complex = self.to_complex(self.c) |
| d_complex = self.to_complex(self.d) |
| |
| numerator = d_complex * w - b_complex |
| denominator = -c_complex * w + a_complex |
| |
| return safe_complex_division(numerator, denominator) |
| |
| def get_transform_info(self): |
| det = self.get_determinant() |
| one = torch.tensor(1.0, device=det.device, dtype=det.real.dtype) |
| return { |
| 'determinant': det, |
| 'is_identity': torch.allclose(torch.abs(det), one, atol=1e-6), |
| 'parameters': {'a': self.to_complex(self.a), 'b': self.to_complex(self.b), |
| 'c': self.to_complex(self.c), 'd': self.to_complex(self.d)} |
| } |
|
|
| |
| |
|
|
| class ComplexStateMarkovChain(nn.Module): |
| def __init__(self, num_states, state_embedding_dim=64, distance_kernel='gaussian'): |
| super().__init__() |
| self.num_states = num_states |
| self.state_embedding_dim = state_embedding_dim |
| self.distance_kernel = distance_kernel |
| |
| self.state_positions = nn.Parameter( |
| torch.complex( |
| torch.randn(num_states) * 2.0, |
| torch.randn(num_states) * 2.0 |
| ) |
| ) |
| |
| self.state_embeddings = nn.Parameter(torch.randn(num_states, state_embedding_dim) * 0.1) |
| |
| self.base_transition_logits = nn.Parameter(torch.randn(num_states, num_states) * 0.1) |
| self.distance_scale = nn.Parameter(torch.tensor(1.0)) |
| self.distance_bias = nn.Parameter(torch.tensor(0.0)) |
| |
| if distance_kernel == 'gaussian': |
| self.kernel_width = nn.Parameter(torch.tensor(1.0)) |
| elif distance_kernel == 'inverse': |
| self.kernel_power = nn.Parameter(torch.tensor(1.0)) |
| |
| def compute_transformed_distances(self, mobius_transform): |
| transformed_positions = mobius_transform.transform(self.state_positions) |
| |
| pos_i = transformed_positions.unsqueeze(0) |
| pos_j = transformed_positions.unsqueeze(1) |
| |
| complex_diff = pos_i - pos_j |
| distances = torch.abs(complex_diff) |
| |
| return distances, transformed_positions |
| |
| def distance_to_probability(self, distances): |
| distances = torch.clamp(distances, min=EPS) |
| |
| if self.distance_kernel == 'gaussian': |
| width = torch.clamp(self.kernel_width, min=0.1, max=10.0) |
| prob_contrib = torch.exp(-distances**2 / (2 * width**2)) |
| elif self.distance_kernel == 'inverse': |
| power = torch.clamp(self.kernel_power, min=0.5, max=3.0) |
| prob_contrib = 1.0 / (distances**power + EPS) |
| else: |
| prob_contrib = torch.clamp(1.0 - distances, min=0.0) |
| |
| return prob_contrib |
| |
| def compute_transition_matrix(self, mobius_transform): |
| distances, transformed_positions = self.compute_transformed_distances(mobius_transform) |
| |
| distance_contrib = self.distance_to_probability(distances) |
| |
| scale = torch.clamp(self.distance_scale, min=0.1, max=10.0) |
| bias = torch.clamp(self.distance_bias, min=-5.0, max=5.0) |
| scaled_distance = scale * distance_contrib + bias |
| |
| transition_logits = self.base_transition_logits + scaled_distance |
| transition_logits = transition_logits + torch.eye(self.num_states, device=transition_logits.device)*0.05 |
| |
| transition_matrix = F.softmax(transition_logits, dim=1) |
| |
| return transition_matrix, transformed_positions |
| |
| def forward(self, initial_state, num_steps, mobius_transform): |
| batch_size = initial_state.shape[0] if initial_state.dim() > 1 else 1 |
| |
| if initial_state.dim() == 1: |
| current_state = initial_state.unsqueeze(0) |
| else: |
| current_state = initial_state |
| |
| transition_matrix, transformed_positions = self.compute_transition_matrix(mobius_transform) |
| |
| trajectory = [current_state.clone()] |
| state_positions = [transformed_positions[current_state.argmax(dim=-1)]] |
| |
| for step in range(num_steps): |
| current_state = torch.matmul(current_state, transition_matrix) |
| trajectory.append(current_state.clone()) |
| |
| most_likely_states = current_state.argmax(dim=-1) |
| state_positions.append(transformed_positions[most_likely_states]) |
| |
| return { |
| 'trajectory': torch.stack(trajectory), |
| 'final_state': current_state, |
| 'state_positions': torch.stack(state_positions), |
| 'transition_matrix': transition_matrix, |
| 'transformed_positions': transformed_positions |
| } |
|
|
| |
| |
|
|
| class MobiusMarkovSystem(nn.Module): |
| def __init__(self, num_states, state_embedding_dim=64, evolution_steps=10): |
| super().__init__() |
| self.num_states = num_states |
| self.evolution_steps = evolution_steps |
| |
| self.mobius_transform = MobiusTransform(learnable=True, init_identity=True) |
| self.markov_chain = ComplexStateMarkovChain(num_states, state_embedding_dim) |
| |
| self.mobius_evolution = nn.Sequential( |
| nn.Linear(state_embedding_dim, state_embedding_dim), |
| nn.Tanh(), |
| nn.Linear(state_embedding_dim, 8), |
| ) |
| |
| self.state_encoder = nn.Sequential( |
| nn.Linear(num_states, state_embedding_dim), |
| nn.LayerNorm(state_embedding_dim), |
| nn.ReLU(), |
| nn.Linear(state_embedding_dim, state_embedding_dim) |
| ) |
| |
| self.state_decoder = nn.Sequential( |
| nn.Linear(state_embedding_dim, state_embedding_dim), |
| nn.ReLU(), |
| nn.Linear(state_embedding_dim, num_states), |
| nn.Softmax(dim=-1) |
| ) |
| |
| self.geometry_controller = nn.Parameter(torch.tensor(0.1)) |
| |
| def evolve_mobius_parameters(self, state_embedding): |
| evolution_signal = self.mobius_evolution(state_embedding) |
| evolution_rate = torch.clamp(self.geometry_controller, 0.01, 1.0) |
| if self.mobius_transform.learnable: |
| with torch.no_grad(): |
| updates = (evolution_signal.view(4, 2) * evolution_rate * 0.01)\ |
| .to(device=self.mobius_transform.a.device, dtype=self.mobius_transform.a.dtype) |
| self.mobius_transform.a.add_(updates[0]) |
| self.mobius_transform.b.add_(updates[1]) |
| self.mobius_transform.c.add_(updates[2]) |
| self.mobius_transform.d.add_(updates[3]) |
| self.mobius_transform.normalize_parameters() |
|
|
| |
| def forward(self, initial_state, return_full_trajectory=False): |
| state_embedding = self.state_encoder(initial_state) |
| |
| evolution_history = { |
| 'states': [], |
| 'geometries': [], |
| 'transition_matrices': [], |
| 'transformed_positions': [] |
| } |
| |
| current_state = initial_state |
| |
| for step in range(self.evolution_steps): |
| state_embedding = self.state_encoder(current_state) |
| |
| self.evolve_mobius_parameters(state_embedding.mean(dim=0)) |
| |
| markov_output = self.markov_chain.forward( |
| current_state, |
| num_steps=1, |
| mobius_transform=self.mobius_transform |
| ) |
| |
| current_state = markov_output['final_state'] |
| |
| if return_full_trajectory: |
| evolution_history['states'].append(current_state.clone()) |
| evolution_history['geometries'].append(self.mobius_transform.get_transform_info()) |
| evolution_history['transition_matrices'].append(markov_output['transition_matrix']) |
| evolution_history['transformed_positions'].append(markov_output['transformed_positions']) |
| |
| final_embedding = self.state_encoder(current_state) |
| final_prediction = self.state_decoder(final_embedding) |
| |
| output = { |
| 'final_state': current_state, |
| 'final_prediction': final_prediction, |
| 'final_embedding': final_embedding, |
| 'final_geometry': self.mobius_transform.get_transform_info() |
| } |
| |
| if return_full_trajectory: |
| output['evolution_history'] = evolution_history |
| |
| return output |
| |
| def predict_sequence(self, initial_state, sequence_length): |
| predictions = [] |
| current_state = initial_state |
| |
| for _ in range(sequence_length): |
| output = self.forward(current_state) |
| predictions.append(output['final_prediction']) |
| current_state = output['final_state'] |
| |
| return torch.stack(predictions) |
| |
| def get_system_info(self): |
| return { |
| 'num_states': self.num_states, |
| 'evolution_steps': self.evolution_steps, |
| 'current_geometry': self.mobius_transform.get_transform_info(), |
| 'state_positions': self.markov_chain.state_positions, |
| 'geometry_evolution_rate': self.geometry_controller.item() |
| } |
|
|
|
|