| |
| |
| |
|
|
| """ |
| Mathematical Foundation & Conceptual Documentation |
| ------------------------------------------------- |
| |
| CORE PRINCIPLE: |
| Combines state space models with liquid computing principles to create adaptive |
| continuous-time dynamics for sequence processing. The system learns time constants |
| dynamically based on input characteristics, enabling efficient processing of |
| variable-speed temporal patterns. |
| |
| MATHEMATICAL FOUNDATION: |
| ======================= |
| |
| 1. STATE SPACE MODEL FUNDAMENTALS: |
| Continuous-time: dx/dt = Ax(t) + Bu(t) |
| y(t) = Cx(t) + Du(t) |
| |
| Discrete-time: x[k+1] = A_d·x[k] + B_d·u[k] |
| y[k] = C·x[k] + D·u[k] |
| |
| Where: |
| - x(t): state vector (hidden representation) |
| - u(t): input vector (external signals) |
| - y(t): output vector (observations) |
| - A: state transition matrix (dynamics) |
| - B: input matrix (how inputs affect states) |
| - C: output matrix (how states generate outputs) |
| - D: feedthrough matrix (direct input-output) |
| |
| 2. LIQUID DYNAMICS WITH ADAPTIVE TIME CONSTANTS: |
| dx/dt = -x/τ(x,u) + A·x + B·u |
| |
| Where τ(x,u) are adaptive time constants: |
| τ(x,u) = τ_base · (1 + α·φ(x,u)) |
| |
| - τ_base: learnable base time constants |
| - α: adaptation rate parameter |
| - φ(x,u): neural adaptation function |
| |
| Fast time constants → quick adaptation to rapid changes |
| Slow time constants → smooth integration of stable patterns |
| |
| 3. CONTINUOUS-TO-DISCRETE CONVERSION: |
| Using matrix exponential and zero-order hold: |
| |
| A_d = exp(A·Δt) |
| B_d = A^(-1)·(A_d - I)·B |
| |
| For numerical stability, we use: |
| [A_d B_d] = exp([A B] · Δt) |
| [0 I ] [0 0] |
| |
| 4. HIPPO MATRIX INITIALIZATION: |
| HiPPO (High-order Polynomial Projection Operators) for optimal memory: |
| |
| A_ij = {√(2i+1)·√(2j+1) if i > j |
| {-(2i+1) if i = j |
| {0 if i < j |
| |
| This creates a skew-symmetric structure that preserves information |
| over long sequences by projecting onto Legendre polynomials. |
| |
| 5. NUMERICAL INTEGRATION: |
| Multi-step Euler method for stability: |
| x(t+Δt) = x(t) + Δt·f(x(t),u(t)) |
| |
| With adaptive time stepping: |
| Δt_eff = min(Δt_target, 0.1·min(τ)) |
| |
| |
| CONCEPTUAL REASONING: |
| ==================== |
| |
| WHY LIQUID + STATE SPACE MODELS? |
| - Traditional SSMs have fixed dynamics |
| - Real-world sequences have variable temporal scales |
| - Liquid dynamics enable adaptive processing speeds |
| - Continuous-time formulation handles irregular sampling |
| |
| KEY INNOVATIONS: |
| 1. **Adaptive Time Constants**: Learn processing speed from data |
| 2. **HiPPO Initialization**: Optimal memory retention properties |
| 3. **Continuous-Discrete Bridge**: Seamless time-domain conversion |
| 4. **Multi-Scale Processing**: Handle fast and slow temporal patterns |
| 5. **Efficient Implementation**: Linear complexity in sequence length |
| |
| APPLICATIONS: |
| - Long-range sequence modeling (DNA, audio, text) |
| - Time-series with irregular sampling rates |
| - Speech recognition with variable speaking speeds |
| - Language modeling with adaptive processing |
| - Control systems with time-varying dynamics |
| |
| COMPLEXITY ANALYSIS: |
| - Time: O(N·d²) where N=sequence length, d=state dimension |
| - Space: O(d²) for state matrices + O(N·d) for sequence states |
| - Training: O(N·d²·L) where L=number of layers |
| - Inference: Linear in sequence length (vs quadratic for attention) |
| |
| ADVANTAGES OVER TRANSFORMERS: |
| - Linear complexity vs quadratic attention |
| - Continuous-time formulation handles variable rates |
| - Built-in inductive bias for temporal dynamics |
| - Natural handling of infinite-length sequences |
| - Memory-efficient processing of long sequences |
| |
| BIOLOGICAL INSPIRATION: |
| - Neural membrane time constants in biological circuits |
| - Adaptive integration windows in cortical processing |
| - Multiple timescale dynamics in neural networks |
| - Continuous-time neural differential equations |
| """ |
|
|
| from __future__ import annotations |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import math |
| from typing import List, Dict, Tuple, Optional, Union, Any |
| from scipy import linalg |
| from scipy.signal import cont2discrete |
|
|
| |
| SAFE_MIN: float = -1e6 |
| SAFE_MAX: float = 1e6 |
| EPS: float = 1e-8 |
|
|
| |
|
|
| def make_safe( |
| tensor: torch.Tensor, |
| min_val: float = SAFE_MIN, |
| max_val: float = SAFE_MAX |
| ) -> torch.Tensor: |
| """Clamp tensor values to safe numerical range, replacing NaN/Inf. |
| |
| Args: |
| tensor: Input tensor to make numerically safe |
| min_val: Minimum allowed value |
| max_val: Maximum allowed value |
| |
| Returns: |
| Numerically safe tensor with values in [min_val, max_val] |
| """ |
| tensor = torch.where(torch.isnan(tensor), torch.tensor(0.0, device=tensor.device), tensor) |
| tensor = torch.where(torch.isinf(tensor), torch.tensor(max_val, device=tensor.device), tensor) |
| return torch.clamp(tensor, min_val, max_val) |
|
|
| def discrete_to_continuous_time(A_discrete: torch.Tensor, dt: float = 1.0) -> torch.Tensor: |
| """Convert discrete-time matrix to continuous-time using matrix logarithm. |
| |
| Mathematical Details: |
| If A_d = exp(A_c · dt), then A_c = log(A_d) / dt |
| |
| Args: |
| A_discrete: Discrete-time state transition matrix |
| dt: Time step used in discretization |
| |
| Returns: |
| Continuous-time state matrix |
| """ |
| try: |
| A_continuous = linalg.logm(A_discrete.detach().cpu().numpy()) / dt |
| return torch.tensor(A_continuous, dtype=torch.float32, device=A_discrete.device) |
| except: |
| |
| return torch.eye(A_discrete.shape[0], device=A_discrete.device) * 0.01 |
|
|
| def continuous_to_discrete_time( |
| A_continuous: torch.Tensor, |
| B_continuous: torch.Tensor, |
| dt: float = 1.0 |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Convert continuous-time system to discrete-time using zero-order hold. |
| |
| Mathematical Details: |
| Uses matrix exponential method for exact discretization: |
| [A_d B_d] = exp([A B] · dt) |
| [0 I ] [0 0] |
| |
| Handles batched matrices by processing each batch element individually |
| to avoid SciPy's limitation with multi-dimensional arrays. |
| |
| Args: |
| A_continuous: Continuous-time state matrix [batch?, state, state] |
| B_continuous: Continuous-time input matrix [state, input] |
| dt: Time step for discretization |
| |
| Returns: |
| Tuple of (A_discrete, B_discrete) matrices |
| """ |
| try: |
| A_np = A_continuous.detach().cpu().numpy() |
| B_np = B_continuous.detach().cpu().numpy() |
| |
| if A_np.ndim == 3: |
| |
| A_list, B_list = [], [] |
| for i in range(A_np.shape[0]): |
| Ad, Bd, _, _, _ = cont2discrete( |
| (A_np[i], B_np, np.eye(A_np.shape[-1]), 0), dt |
| ) |
| A_list.append(Ad) |
| B_list.append(Bd) |
| A_discrete = torch.tensor(np.stack(A_list), dtype=torch.float32, device=A_continuous.device) |
| B_discrete = torch.tensor(np.stack(B_list), dtype=torch.float32, device=B_continuous.device) |
| else: |
| |
| A_discrete, B_discrete, _, _, _ = cont2discrete( |
| (A_np, B_np, np.eye(A_np.shape[0]), 0), dt |
| ) |
| A_discrete = torch.tensor(A_discrete, dtype=torch.float32, device=A_continuous.device) |
| B_discrete = torch.tensor(B_discrete, dtype=torch.float32, device=B_continuous.device) |
|
|
| return A_discrete, B_discrete |
| except Exception: |
| |
| n = A_continuous.shape[-1] |
| eye = torch.eye(n, device=A_continuous.device) |
| if A_continuous.dim() == 3: |
| eye = eye.unsqueeze(0).expand(A_continuous.size(0), -1, -1) |
| B_disc = B_continuous.unsqueeze(0).expand(A_continuous.size(0), -1, -1) |
| else: |
| B_disc = B_continuous |
| A_discrete = eye + A_continuous * dt |
| B_discrete = B_disc * dt |
| return A_discrete, B_discrete |
|
|
| |
| |
| |
| class LiquidTimeConstantController(nn.Module): |
| """Adaptive time constant controller for liquid dynamics. |
| |
| Controls the temporal dynamics of the liquid state by learning context-dependent |
| time constants. Fast time constants enable quick adaptation to rapid changes, |
| while slow time constants provide stable integration of persistent patterns. |
| |
| Mathematical Framework: |
| - Base time constants: τ_base = exp(log_τ) |
| - Adaptive modulation: τ(x,u) = τ_base · (1 + α·φ(x,u)) |
| - Neural adaptation: φ(x,u) = tanh(W·[x,u] + b) |
| - Stability constraint: τ ∈ [0.01, 10.0] |
| """ |
| |
| def __init__( |
| self, |
| state_dim: int, |
| input_dim: int, |
| init_tau: float = 1.0 |
| ) -> None: |
| """Initialize adaptive time constant controller. |
| |
| Args: |
| state_dim: Dimension of state vector |
| input_dim: Dimension of input vector |
| init_tau: Initial time constant value |
| """ |
| super().__init__() |
| self.state_dim = state_dim |
| self.input_dim = input_dim |
| |
| |
| self.log_tau = nn.Parameter(torch.ones(state_dim) * math.log(init_tau)) |
| |
| |
| |
| self.tau_adaptation = nn.Sequential( |
| nn.Linear(state_dim + input_dim, state_dim * 2), |
| nn.LayerNorm(state_dim * 2), |
| nn.Tanh(), |
| nn.Linear(state_dim * 2, state_dim), |
| nn.Tanh() |
| ) |
| |
| |
| self.adaptation_rate = nn.Parameter(torch.tensor(0.1)) |
| |
| def get_time_constants( |
| self, |
| state: torch.Tensor, |
| input_signal: torch.Tensor |
| ) -> torch.Tensor: |
| """Compute context-dependent time constants. |
| |
| Mathematical Details: |
| 1. Base time constants: τ_base = exp(log_τ) |
| 2. Context features: f = [state, input] |
| 3. Modulation: m = tanh(W·f + b) |
| 4. Final time constants: τ = τ_base · (1 + α·m) |
| |
| Args: |
| state: Current liquid state [batch_size, state_dim] |
| input_signal: Current input [batch_size, input_dim] |
| |
| Returns: |
| Adaptive time constants [batch_size, state_dim] |
| """ |
| |
| base_tau = torch.exp(self.log_tau) |
| base_tau = torch.clamp(base_tau, 0.01, 10.0) |
| |
| |
| combined_input = torch.cat([state, input_signal], dim=-1) |
| tau_modulation = self.tau_adaptation(combined_input) |
| |
| |
| adaptation_rate = torch.clamp(self.adaptation_rate, 0.001, 1.0) |
| modulated_tau = base_tau * (1.0 + adaptation_rate * tau_modulation) |
| |
| |
| return torch.clamp(modulated_tau, 0.01, 10.0) |
| |
| def get_effective_dt(self, tau: torch.Tensor, target_dt: float = 0.1) -> float: |
| """Compute effective time step for numerical stability. |
| |
| The effective time step is chosen to be much smaller than the fastest |
| time constant to ensure numerical stability of the integration. |
| |
| Mathematical Constraint: |
| Δt_eff ≤ 0.1 · min(τ) for stability |
| |
| Args: |
| tau: Time constants tensor [batch_size, state_dim] |
| target_dt: Desired time step |
| |
| Returns: |
| Effective time step (scalar) |
| """ |
| |
| min_tau_val = torch.min(tau).item() |
| effective_dt = max(0.001, min(float(target_dt), min_tau_val * 0.1)) |
| return effective_dt |
|
|
| |
| |
| |
| class LiquidSSMCore(nn.Module): |
| """Core Liquid State Space Model with adaptive continuous-time dynamics. |
| |
| Implements a state space model with liquid computing principles where |
| time constants adapt based on input characteristics. Combines the |
| representational power of SSMs with the adaptability of liquid dynamics. |
| |
| Mathematical Framework: |
| - Liquid dynamics: dx/dt = -x/τ(x,u) + A·x + B·u |
| - Output equation: y = C·x + D·u |
| - HiPPO initialization for optimal memory properties |
| - Adaptive discretization for numerical integration |
| """ |
| |
| def __init__( |
| self, |
| state_dim: int, |
| input_dim: int, |
| output_dim: int, |
| dt: float = 0.1, |
| init_method: str = 'hippo' |
| ) -> None: |
| """Initialize Liquid SSM core with adaptive dynamics. |
| |
| Args: |
| state_dim: Dimension of hidden state vector |
| input_dim: Dimension of input vector |
| output_dim: Dimension of output vector |
| dt: Target time step for integration |
| init_method: Initialization method ('hippo' or 'random') |
| """ |
| super().__init__() |
| self.state_dim = state_dim |
| self.input_dim = input_dim |
| self.output_dim = output_dim |
| self.dt = dt |
| |
| |
| if init_method == 'hippo': |
| self.A_continuous = nn.Parameter(self._init_hippo_matrix(state_dim)) |
| else: |
| self.A_continuous = nn.Parameter(torch.randn(state_dim, state_dim) * 0.1) |
| |
| |
| self.B_continuous = nn.Parameter(torch.randn(state_dim, input_dim) * 0.1) |
| self.C = nn.Parameter(torch.randn(output_dim, state_dim) * 0.1) |
| self.D = nn.Parameter(torch.zeros(output_dim, input_dim)) |
| |
| |
| self.time_controller = LiquidTimeConstantController(state_dim, input_dim, init_tau=1.0) |
| |
| |
| self.output_scale = nn.Parameter(torch.ones(output_dim)) |
| self.output_bias = nn.Parameter(torch.zeros(output_dim)) |
| |
| |
| self.state_normalizer = nn.LayerNorm(state_dim) |
| |
| |
| self.register_buffer('continuous_state', torch.zeros(1, state_dim)) |
| |
| def _init_hippo_matrix(self, N: int) -> torch.Tensor: |
| """Initialize state matrix with HiPPO structure for optimal memory. |
| |
| HiPPO (High-order Polynomial Projection Operators) creates a state |
| transition matrix that optimally preserves information by projecting |
| the input history onto a basis of Legendre polynomials. |
| |
| Mathematical Details: |
| A_ij = {√(2i+1)·√(2j+1) if i > j (coupling strength) |
| {-(2i+1) if i = j (decay rate) |
| {0 if i < j (causality) |
| |
| Args: |
| N: State dimension (number of basis functions) |
| |
| Returns: |
| HiPPO matrix [N, N] |
| """ |
| A = torch.zeros(N, N) |
| for i in range(N): |
| for j in range(N): |
| if i > j: |
| |
| A[i, j] = math.sqrt(2 * i + 1) * math.sqrt(2 * j + 1) |
| elif i == j: |
| |
| A[i, j] = -(2 * i + 1) |
| return A * 0.1 |
| |
| def reset_state(self, batch_size: int = 1) -> None: |
| """Reset continuous state for new sequence processing. |
| |
| Args: |
| batch_size: Number of parallel sequences to process |
| """ |
| device = self.A_continuous.device |
| self.continuous_state = torch.zeros(batch_size, self.state_dim, device=device) |
| |
| def liquid_state_evolution( |
| self, |
| input_signal: torch.Tensor, |
| num_steps: int = 10 |
| ) -> Tuple[torch.Tensor, torch.Tensor, float]: |
| """Evolve state using adaptive liquid dynamics with numerical integration. |
| |
| Implements the core liquid evolution equation: |
| dx/dt = -x/τ(x,u) + A·x + B·u |
| |
| Uses multi-step integration for numerical accuracy and adaptive |
| time stepping based on the fastest time constant. |
| |
| Mathematical Process: |
| 1. Compute adaptive time constants: τ(x,u) |
| 2. Form liquid dynamics matrix: A_liquid = A - diag(1/τ) |
| 3. Discretize system: (A_d, B_d) = discretize(A_liquid, B, Δt) |
| 4. Integrate: x(k+1) = A_d·x(k) + B_d·u(k) |
| |
| Args: |
| input_signal: External input [batch_size, input_dim] |
| num_steps: Number of integration steps for accuracy |
| |
| Returns: |
| Tuple of (evolved_state, time_constants, effective_dt) |
| """ |
| batch_size = input_signal.shape[0] |
| |
| |
| if self.continuous_state.shape[0] != batch_size: |
| self.reset_state(batch_size) |
| |
| |
| tau = self.time_controller.get_time_constants(self.continuous_state, input_signal) |
| effective_dt = self.time_controller.get_effective_dt(tau, self.dt) |
| |
| |
| |
| |
| tau_matrix = torch.diag_embed(1.0 / tau) |
| liquid_A = self.A_continuous - tau_matrix |
| |
| |
| liquid_A = make_safe(liquid_A, min_val=-10.0, max_val=10.0) |
| |
| |
| A_discrete, B_discrete = continuous_to_discrete_time( |
| liquid_A, self.B_continuous, effective_dt |
| ) |
| |
| |
| current_state = self.continuous_state |
| |
| |
| if A_discrete.dim() == 3: |
| |
| A_T = A_discrete.transpose(1, 2) |
| B_T = B_discrete.transpose(1, 2) |
| input_update = torch.bmm(input_signal.unsqueeze(1), B_T).squeeze(1) |
| for _ in range(num_steps): |
| state_update = torch.bmm(current_state.unsqueeze(1), A_T).squeeze(1) |
| current_state = state_update + input_update |
| current_state = make_safe(current_state) |
| else: |
| |
| A_T = A_discrete.T |
| B_T = B_discrete.T |
| input_update = input_signal @ B_T |
| for _ in range(num_steps): |
| current_state = current_state @ A_T + input_update |
| current_state = make_safe(current_state) |
| |
| |
| self.continuous_state = current_state |
| |
| return current_state, tau, effective_dt |
| |
| def compute_output( |
| self, |
| state: torch.Tensor, |
| input_signal: torch.Tensor |
| ) -> torch.Tensor: |
| """Compute output from state space model: y = C·x + D·u. |
| |
| Args: |
| state: Current state vector [batch_size, state_dim] |
| input_signal: Current input [batch_size, input_dim] |
| |
| Returns: |
| Output vector [batch_size, output_dim] |
| """ |
| |
| normalized_state = self.state_normalizer(state) |
| |
| |
| state_output = torch.matmul(normalized_state, self.C.T) |
| direct_output = torch.matmul(input_signal, self.D.T) |
| |
| raw_output = state_output + direct_output |
| |
| |
| output = self.output_scale * raw_output + self.output_bias |
| |
| return make_safe(output) |
| |
| def forward( |
| self, |
| input_signal: torch.Tensor, |
| return_diagnostics: bool = False |
| ) -> Dict[str, Union[torch.Tensor, float]]: |
| """Complete forward pass through Liquid SSM. |
| |
| Args: |
| input_signal: Input vector [batch_size, input_dim] |
| return_diagnostics: Whether to return diagnostic information |
| |
| Returns: |
| Dictionary containing output and optional diagnostics |
| """ |
| |
| evolved_state, tau, effective_dt = self.liquid_state_evolution(input_signal) |
| |
| |
| output = self.compute_output(evolved_state, input_signal) |
| |
| result = { |
| 'output': output, |
| 'state': evolved_state |
| } |
| |
| if return_diagnostics: |
| result.update({ |
| 'time_constants': tau, |
| 'effective_dt': effective_dt, |
| 'state_norm': torch.norm(evolved_state, dim=-1), |
| 'adaptation_rate': self.time_controller.adaptation_rate |
| }) |
| |
| return result |
|
|
| |
| |
| |
| class LiquidSSMSequenceLayer(nn.Module): |
| """Sequence processing layer using Liquid SSM with residual connections. |
| |
| Processes variable-length sequences through Liquid SSM while maintaining |
| adaptive dynamics across time steps. Includes input/output projections, |
| residual connections, and sequence-level adaptation mechanisms. |
| |
| Architecture: |
| Input → Projection → Liquid SSM → Sequence Adaptation → Output Projection → Residual |
| """ |
| |
| def __init__( |
| self, |
| input_dim: int, |
| state_dim: int, |
| output_dim: int, |
| seq_len: Optional[int] = None |
| ) -> None: |
| """Initialize Liquid SSM sequence processing layer. |
| |
| Args: |
| input_dim: Dimension of input features |
| state_dim: Dimension of internal state |
| output_dim: Dimension of output features |
| seq_len: Maximum sequence length (optional) |
| """ |
| super().__init__() |
| self.input_dim = input_dim |
| self.state_dim = state_dim |
| self.output_dim = output_dim |
| self.seq_len = seq_len |
| |
| |
| |
| |
| self.liquid_ssm = LiquidSSMCore(state_dim, state_dim, output_dim) |
| |
| |
| self.input_projection = nn.Sequential( |
| nn.Linear(input_dim, state_dim), |
| nn.LayerNorm(state_dim), |
| nn.GELU() |
| ) |
| |
| |
| self.output_projection = nn.Sequential( |
| nn.Linear(output_dim, output_dim * 2), |
| nn.LayerNorm(output_dim * 2), |
| nn.GELU(), |
| nn.Dropout(0.1), |
| nn.Linear(output_dim * 2, output_dim) |
| ) |
| |
| |
| self.residual_weight = nn.Parameter(torch.tensor(0.1)) |
| |
| |
| self.sequence_adapter = nn.Sequential( |
| nn.Linear(state_dim, state_dim), |
| nn.Tanh(), |
| nn.Linear(state_dim, 1), |
| nn.Sigmoid() |
| ) |
| |
| def forward( |
| self, |
| sequence: torch.Tensor, |
| return_diagnostics: bool = False |
| ) -> Dict[str, Union[torch.Tensor, List[Dict]]]: |
| """Process complete sequence through Liquid SSM. |
| |
| Processes each time step sequentially while maintaining liquid state |
| continuity across the sequence. Applies sequence-level adaptation |
| and residual connections for improved gradient flow. |
| |
| Args: |
| sequence: Input sequence [batch_size, seq_len, input_dim] |
| return_diagnostics: Whether to return per-timestep diagnostics |
| |
| Returns: |
| Dictionary containing output sequence and optional diagnostics |
| """ |
| batch_size, seq_len, input_dim = sequence.shape |
| |
| |
| self.liquid_ssm.reset_state(batch_size) |
| |
| |
| outputs = [] |
| diagnostics = [] if return_diagnostics else None |
| |
| for t in range(seq_len): |
| |
| current_input = sequence[:, t, :] |
| |
| |
| projected_input = self.input_projection(current_input) |
| |
| |
| ssm_result = self.liquid_ssm(projected_input, return_diagnostics=return_diagnostics) |
| |
| |
| adaptation_factor = self.sequence_adapter(ssm_result['state']) |
| adapted_output = ssm_result['output'] * adaptation_factor |
| |
| |
| final_output = self.output_projection(adapted_output) |
| |
| |
| if final_output.shape == current_input.shape: |
| residual_strength = torch.clamp(self.residual_weight, 0.0, 1.0) |
| final_output = final_output + residual_strength * current_input |
| |
| outputs.append(final_output) |
| |
| if return_diagnostics: |
| diagnostics.append({ |
| 'timestep': t, |
| 'adaptation_factor': adaptation_factor.mean().item(), |
| **ssm_result |
| }) |
| |
| |
| output_sequence = torch.stack(outputs, dim=1) |
| |
| result = {'output': output_sequence} |
| |
| if return_diagnostics: |
| result['diagnostics'] = diagnostics |
| |
| return result |
|
|
| |
| |
| |
| class LiquidSSMLanguageModel(nn.Module): |
| """Complete language model using Liquid State Space Models. |
| |
| Implements a transformer-alternative architecture using Liquid SSMs for |
| sequence processing. Provides linear complexity in sequence length while |
| maintaining strong representational capabilities through adaptive dynamics. |
| |
| Architecture: |
| Embeddings → Liquid SSM Layers → Output Head |
| |
| Each layer includes: |
| - Layer normalization |
| - Liquid SSM processing |
| - Global adaptation |
| - Residual connections |
| """ |
| |
| def __init__( |
| self, |
| vocab_size: int, |
| d_model: int = 512, |
| state_dim: int = 256, |
| num_layers: int = 6, |
| max_seq_len: int = 2048 |
| ) -> None: |
| """Initialize Liquid SSM Language Model. |
| |
| Args: |
| vocab_size: Size of vocabulary |
| d_model: Model dimension (embedding/hidden size) |
| state_dim: Liquid state dimension |
| num_layers: Number of Liquid SSM layers |
| max_seq_len: Maximum sequence length |
| """ |
| super().__init__() |
| self.vocab_size = vocab_size |
| self.d_model = d_model |
| self.state_dim = state_dim |
| self.num_layers = num_layers |
| self.max_seq_len = max_seq_len |
| |
| |
| self.token_embedding = nn.Embedding(vocab_size, d_model) |
| self.position_embedding = nn.Embedding(max_seq_len, d_model) |
| |
| |
| self.liquid_layers = nn.ModuleList([ |
| LiquidSSMSequenceLayer(d_model, state_dim, d_model) |
| for _ in range(num_layers) |
| ]) |
| |
| |
| self.layer_norms = nn.ModuleList([ |
| nn.LayerNorm(d_model) for _ in range(num_layers) |
| ]) |
| |
| |
| self.output_norm = nn.LayerNorm(d_model) |
| self.lm_head = nn.Linear(d_model, vocab_size) |
| |
| |
| self.global_adaptation = nn.Sequential( |
| nn.Linear(d_model, d_model // 4), |
| nn.GELU(), |
| nn.Linear(d_model // 4, 1), |
| nn.Sigmoid() |
| ) |
| |
| self._init_weights() |
| |
| def _init_weights(self) -> None: |
| for module in self.modules(): |
| if isinstance(module, nn.Linear): |
| nn.init.xavier_uniform_(module.weight) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| |
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| labels: Optional[torch.Tensor] = None, |
| return_diagnostics: bool = False |
| ) -> Dict[str, Union[torch.Tensor, List[Dict]]]: |
| """Forward pass through Liquid SSM Language Model. |
| |
| Args: |
| input_ids: Token IDs [batch_size, seq_len] |
| labels: Target labels for loss computation [batch_size, seq_len] |
| return_diagnostics: Whether to return layer diagnostics |
| |
| Returns: |
| Dictionary containing logits, loss, and optional diagnostics |
| """ |
| batch_size, seq_len = input_ids.shape |
| device = input_ids.device |
| |
| |
| if seq_len > self.max_seq_len: |
| input_ids = input_ids[:, :self.max_seq_len] |
| seq_len = self.max_seq_len |
| if labels is not None: |
| labels = labels[:, :self.max_seq_len] |
| |
| |
| input_ids = torch.clamp(input_ids, 0, self.vocab_size - 1) |
| |
| |
| token_emb = self.token_embedding(input_ids) |
| pos_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) |
| pos_emb = self.position_embedding(pos_ids) |
| |
| x = token_emb + pos_emb |
| x = make_safe(x) |
| |
| |
| layer_diagnostics = [] if return_diagnostics else None |
| |
| |
| for layer_idx, (liquid_layer, layer_norm) in enumerate(zip(self.liquid_layers, self.layer_norms)): |
| |
| residual = x |
| |
| |
| x = layer_norm(x) |
| |
| |
| layer_result = liquid_layer(x, return_diagnostics=return_diagnostics) |
| x = layer_result['output'] |
| |
| |
| adaptation = self.global_adaptation(x.mean(dim=1, keepdim=True)) |
| x = x * adaptation |
| |
| |
| x = residual + x |
| x = make_safe(x) |
| |
| if return_diagnostics: |
| layer_diagnostics.append({ |
| 'layer': layer_idx, |
| 'adaptation': adaptation.mean().item(), |
| **layer_result |
| }) |
| |
| |
| x = self.output_norm(x) |
| logits = self.lm_head(x) |
| logits = make_safe(logits, min_val=-50, max_val=50) |
| |
| |
| loss = None |
| if labels is not None: |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss = F.cross_entropy( |
| shift_logits.view(-1, self.vocab_size), |
| shift_labels.view(-1), |
| ignore_index=-100 |
| ) |
| |
| result = { |
| 'logits': logits, |
| 'loss': loss |
| } |
| |
| if return_diagnostics: |
| result['layer_diagnostics'] = layer_diagnostics |
| |
| return result |
| |
| @torch.no_grad() |
| def generate( |
| self, |
| input_ids: torch.Tensor, |
| max_length: int = 100, |
| temperature: float = 1.0, |
| top_p: float = 0.95, |
| return_diagnostics: bool = False |
| ) -> Dict[str, Union[torch.Tensor, List[Dict]]]: |
| """Generate text using Liquid SSM with nucleus sampling. |
| |
| Args: |
| input_ids: Prompt token IDs [batch_size, prompt_len] |
| max_length: Maximum total sequence length |
| temperature: Sampling temperature (higher = more random) |
| top_p: Nucleus sampling probability threshold |
| return_diagnostics: Whether to return generation diagnostics |
| |
| Returns: |
| Dictionary containing generated IDs and optional diagnostics |
| """ |
| self.eval() |
| generated = input_ids.clone() |
| all_diagnostics = [] if return_diagnostics else None |
| |
| for step in range(max_length - input_ids.shape[1]): |
| |
| if generated.shape[1] > self.max_seq_len: |
| break |
| |
| |
| outputs = self(generated, return_diagnostics=return_diagnostics) |
| logits = outputs['logits'] |
| |
| if return_diagnostics: |
| all_diagnostics.append(outputs.get('layer_diagnostics', [])) |
| |
| |
| next_token_logits = logits[:, -1, :] / max(temperature, EPS) |
| next_token_logits = make_safe(next_token_logits, min_val=-50, max_val=50) |
| |
| |
| sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| |
| |
| sorted_indices_to_remove = cumulative_probs > top_p |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| sorted_indices_to_remove[..., 0] = False |
| |
| |
| for b in range(next_token_logits.size(0)): |
| indices_to_remove = sorted_indices[b][sorted_indices_to_remove[b]] |
| next_token_logits[b, indices_to_remove] = -float('inf') |
| |
| |
| probs = F.softmax(next_token_logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| next_token = torch.clamp(next_token, 0, self.vocab_size - 1) |
| |
| |
| generated = torch.cat([generated, next_token], dim=1) |
| |
| |
| if next_token.item() == 2: |
| break |
| |
| result = {'generated_ids': generated} |
| if return_diagnostics: |
| result['diagnostics'] = all_diagnostics |
| |
| return result |
|
|
| |
| |
|
|
| def test_liquid_ssm() -> bool: |
| print("Testing Liquid State Space Model - Continuous-Time Adaptive Sequence Processing") |
| print("=" * 90) |
| |
| |
| vocab_size = 1000 |
| d_model = 256 |
| state_dim = 128 |
| num_layers = 4 |
| |
| model = LiquidSSMLanguageModel( |
| vocab_size=vocab_size, |
| d_model=d_model, |
| state_dim=state_dim, |
| num_layers=num_layers, |
| max_seq_len=512 |
| ) |
| |
| print(f"Created Liquid SSM Language Model:") |
| print(f" - Vocabulary size: {vocab_size}") |
| print(f" - Model dimension: {d_model}") |
| print(f" - State dimension: {state_dim}") |
| print(f" - Number of layers: {num_layers}") |
| |
| |
| total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f" - Total parameters: {total_params:,} ({total_params/1e6:.1f}M)") |
| |
| |
| batch_size = 4 |
| seq_len = 32 |
| test_input = torch.randint(0, vocab_size, (batch_size, seq_len)) |
| test_labels = torch.randint(0, vocab_size, (batch_size, seq_len)) |
| |
| print(f"\nTesting with batch_size={batch_size}, seq_len={seq_len}") |
| |
| |
| print("\nExecuting forward pass...") |
| outputs = model(test_input, labels=test_labels, return_diagnostics=True) |
| |
| print("Forward pass results:") |
| print(f" - Output logits shape: {outputs['logits'].shape}") |
| print(f" - Loss: {outputs['loss']:.4f}") |
| |
| |
| print("\nLiquid dynamics analysis:") |
| diagnostics = outputs['layer_diagnostics'] |
| |
| for layer_idx in range(min(3, len(diagnostics))): |
| layer_diag = diagnostics[layer_idx] |
| print(f" Layer {layer_idx + 1}:") |
| print(f" - Global adaptation: {layer_diag['adaptation']:.3f}") |
| |
| if 'diagnostics' in layer_diag: |
| time_constants = [d['time_constants'].mean().item() for d in layer_diag['diagnostics'][:3]] |
| print(f" - Avg time constants: {[f'{tc:.3f}' for tc in time_constants]}") |
| |
| |
| print("\nTesting text generation...") |
| prompt = torch.randint(0, vocab_size, (1, 8)) |
| generation_result = model.generate( |
| prompt, |
| max_length=20, |
| temperature=1.0, |
| return_diagnostics=True |
| ) |
| |
| generated_ids = generation_result['generated_ids'] |
| print(f" - Generated sequence length: {generated_ids.shape[1]}") |
| print(f" - Prompt length: {prompt.shape[1]}") |
| print(f" - New tokens generated: {generated_ids.shape[1] - prompt.shape[1]}") |
| |
| |
| print("\nEfficiency analysis:") |
| |
| |
| seq_lengths = [64, 128, 256] |
| for test_len in seq_lengths: |
| test_seq = torch.randint(0, vocab_size, (1, test_len)) |
| |
| import time |
| start_time = time.time() |
| with torch.no_grad(): |
| test_output = model(test_seq) |
| end_time = time.time() |
| |
| processing_time = end_time - start_time |
| tokens_per_second = test_len / processing_time |
| |
| print(f" - Length {test_len}: {processing_time:.3f}s ({tokens_per_second:.0f} tokens/s)") |
| |
| print("\nLiquid SSM test completed!") |
| print("✓ Continuous-time adaptive dynamics") |
| print("✓ Learnable time constants based on content") |
| print("✓ Efficient sequence processing") |
| print("✓ State space model foundation with liquid adaptation") |
| print("✓ Potential transformer alternative with continuous dynamics") |
| |
| return True |
|
|
| def adaptive_dynamics_demo() -> None: |
| print("\n" + "="*70) |
| print("ADAPTIVE DYNAMICS DEMONSTRATION") |
| print("="*70) |
| |
| |
| model = LiquidSSMCore(state_dim=16, input_dim=8, output_dim=8) |
| model.eval() |
| |
| |
| patterns = { |
| "Smooth": torch.sin(torch.linspace(0, 2*math.pi, 8)).unsqueeze(0), |
| "Spiky": torch.tensor([0, 1, 0, -1, 0, 1, 0, -1], dtype=torch.float).unsqueeze(0), |
| "Constant": torch.ones(1, 8) * 0.5, |
| "Random": torch.randn(1, 8) |
| } |
| |
| print("Testing adaptive time constants with different input patterns:") |
| |
| for pattern_name, pattern_input in patterns.items(): |
| model.reset_state(1) |
| |
| |
| with torch.no_grad(): |
| result = model(pattern_input, return_diagnostics=True) |
| |
| time_constants = result['time_constants'].squeeze().tolist() |
| adaptation_rate = result['adaptation_rate'].item() |
| |
| print(f"\n{pattern_name} pattern:") |
| print(f" Time constants: {[f'{tc:.3f}' for tc in time_constants[:4]]}...") |
| print(f" Adaptation rate: {adaptation_rate:.4f}") |
| print(f" Effective dt: {result['effective_dt']:.4f}") |
| |
| print("\n Adaptive dynamics show how liquid SSM adjusts to different input characteristics") |
| print(" Smooth inputs → larger time constants, Spiky inputs → smaller time constants") |
|
|
| if __name__ == "__main__": |
| test_liquid_ssm() |
| adaptive_dynamics_demo() |
|
|