| import torch |
| import torch.nn as nn |
| import math |
| import datetime |
| from typing import Dict, List, Any, Optional |
|
|
| class ContextualTimeEncoder(nn.Module): |
| def __init__(self, output_dim: int = 128, dtype: torch.dtype = torch.float32): |
| """ |
| Encodes a Unix timestamp with support for mixed precision. |
| |
| Args: |
| output_dim (int): The final dimension of the output embedding. |
| dtype (torch.dtype): The data type for the model's parameters (e.g., torch.float16). |
| """ |
| super().__init__() |
| self.dtype = dtype |
| if output_dim < 12: |
| raise ValueError(f"output_dim must be at least 12, but got {output_dim}") |
|
|
| ts_dim = output_dim // 2 |
| hour_dim = output_dim // 4 |
| day_dim = output_dim - ts_dim - hour_dim |
| |
| self.ts_dim = ts_dim + (ts_dim % 2) |
| self.hour_dim = hour_dim + (hour_dim % 2) |
| self.day_dim = day_dim + (day_dim % 2) |
| |
| total_internal_dim = self.ts_dim + self.hour_dim + self.day_dim |
| |
| self.projection = nn.Linear(total_internal_dim, output_dim) |
| |
| |
| self.to(dtype) |
|
|
| |
| total_params = sum(p.numel() for p in self.parameters()) |
| trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) |
| print(f"[ContextualTimeEncoder] Params: {total_params:,} (Trainable: {trainable_params:,})") |
|
|
| def _sinusoidal_encode(self, values: torch.Tensor, d_model: int) -> torch.Tensor: |
| device = values.device |
| half_dim = d_model // 2 |
| |
| |
| div_term = torch.exp(torch.arange(0, half_dim, device=device, dtype=torch.float64) * -(math.log(10000.0) / half_dim)) |
| args = values.double().unsqueeze(-1) * div_term |
|
|
| return torch.cat([torch.sin(args), torch.cos(args)], dim=-1) |
|
|
| def _cyclical_encode(self, values: torch.Tensor, d_model: int, max_val: float) -> torch.Tensor: |
| device = values.device |
| norm_values = (values.float() / max_val) * 2 * math.pi |
| |
| half_dim = d_model // 2 |
| sin_args = norm_values.unsqueeze(-1).repeat(1, half_dim) |
| cos_args = norm_values.unsqueeze(-1).repeat(1, half_dim) |
| |
| return torch.cat([torch.sin(sin_args), torch.cos(cos_args)], dim=-1) |
|
|
| def forward(self, timestamps: torch.Tensor) -> torch.Tensor: |
| device = self.projection.weight.device |
| |
| |
| original_shape = timestamps.shape |
| |
| timestamps_flat = timestamps.flatten().double() |
| |
| |
| ts_encoding = self._sinusoidal_encode(timestamps_flat, self.ts_dim) |
| |
| |
| |
| hours = torch.tensor([datetime.datetime.fromtimestamp(float(ts), tz=datetime.timezone.utc).hour for ts in timestamps_flat], device=device, dtype=torch.float32) |
| days = torch.tensor([datetime.datetime.fromtimestamp(float(ts), tz=datetime.timezone.utc).weekday() for ts in timestamps_flat], device=device, dtype=torch.float32) |
| |
| |
| hour_encoding = self._cyclical_encode(hours, self.hour_dim, max_val=24.0) |
| day_encoding = self._cyclical_encode(days, self.day_dim, max_val=7.0) |
|
|
| |
| combined_encoding = torch.cat([ts_encoding, hour_encoding, day_encoding], dim=1) |
| projected = self.projection(combined_encoding.to(self.dtype)) |
| |
| |
| output_shape = original_shape + (self.projection.out_features,) |
| return projected.view(output_shape) |
|
|
| def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
| mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float() |
| summed = torch.sum(last_hidden_state * mask, 1) |
| denom = torch.clamp(mask.sum(1), min=1e-9) |
| return summed / denom |
|
|