oracle / models /helper_encoders.py
zirobtc's picture
Upload folder using huggingface_hub
86523f8
import torch
import torch.nn as nn
import math
import datetime
from typing import Dict, List, Any, Optional
class ContextualTimeEncoder(nn.Module):
def __init__(self, output_dim: int = 128, dtype: torch.dtype = torch.float32):
"""
Encodes a Unix timestamp with support for mixed precision.
Args:
output_dim (int): The final dimension of the output embedding.
dtype (torch.dtype): The data type for the model's parameters (e.g., torch.float16).
"""
super().__init__()
self.dtype = dtype
if output_dim < 12:
raise ValueError(f"output_dim must be at least 12, but got {output_dim}")
ts_dim = output_dim // 2
hour_dim = output_dim // 4
day_dim = output_dim - ts_dim - hour_dim
self.ts_dim = ts_dim + (ts_dim % 2)
self.hour_dim = hour_dim + (hour_dim % 2)
self.day_dim = day_dim + (day_dim % 2)
total_internal_dim = self.ts_dim + self.hour_dim + self.day_dim
self.projection = nn.Linear(total_internal_dim, output_dim)
# Cast the entire module to the specified dtype
self.to(dtype)
# Log params
total_params = sum(p.numel() for p in self.parameters())
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
print(f"[ContextualTimeEncoder] Params: {total_params:,} (Trainable: {trainable_params:,})")
def _sinusoidal_encode(self, values: torch.Tensor, d_model: int) -> torch.Tensor:
device = values.device
half_dim = d_model // 2
# Calculations for sinusoidal encoding are more stable in float32
div_term = torch.exp(torch.arange(0, half_dim, device=device, dtype=torch.float64) * -(math.log(10000.0) / half_dim))
args = values.double().unsqueeze(-1) * div_term
return torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
def _cyclical_encode(self, values: torch.Tensor, d_model: int, max_val: float) -> torch.Tensor:
device = values.device
norm_values = (values.float() / max_val) * 2 * math.pi
half_dim = d_model // 2
sin_args = norm_values.unsqueeze(-1).repeat(1, half_dim)
cos_args = norm_values.unsqueeze(-1).repeat(1, half_dim)
return torch.cat([torch.sin(sin_args), torch.cos(cos_args)], dim=-1)
def forward(self, timestamps: torch.Tensor) -> torch.Tensor:
device = self.projection.weight.device
# 1. Store original shape (e.g., [B, L]) and flatten
original_shape = timestamps.shape
# Preserve precision for large Unix timestamps.
timestamps_flat = timestamps.flatten().double() # Shape [N_total]
# 2. Sinusoidal encode (already vectorized)
ts_encoding = self._sinusoidal_encode(timestamps_flat, self.ts_dim)
# 3. List comprehension (this is the only non-vectorized part)
# This loop is now correct, as it iterates over the 1D flat tensor
hours = torch.tensor([datetime.datetime.fromtimestamp(float(ts), tz=datetime.timezone.utc).hour for ts in timestamps_flat], device=device, dtype=torch.float32)
days = torch.tensor([datetime.datetime.fromtimestamp(float(ts), tz=datetime.timezone.utc).weekday() for ts in timestamps_flat], device=device, dtype=torch.float32)
# 4. Cyclical encode (already vectorized)
hour_encoding = self._cyclical_encode(hours, self.hour_dim, max_val=24.0)
day_encoding = self._cyclical_encode(days, self.day_dim, max_val=7.0)
# 5. Combine and project
combined_encoding = torch.cat([ts_encoding, hour_encoding, day_encoding], dim=1)
projected = self.projection(combined_encoding.to(self.dtype)) # Shape [N_total, output_dim]
# 6. Reshape to match original (e.g., [B, L, output_dim])
output_shape = original_shape + (self.projection.out_features,)
return projected.view(output_shape)
def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
summed = torch.sum(last_hidden_state * mask, 1)
denom = torch.clamp(mask.sum(1), min=1e-9)
return summed / denom