| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class ProjectionHead(nn.Module): |
| def __init__(self, in_dim: int = 512, out_dim: int = 3008, activation: str = "tanh"): |
| super().__init__() |
| self.linear = nn.Linear(in_dim, out_dim, bias=False) |
| act = activation.lower() |
| if act == "tanh": |
| self.act = nn.Tanh() |
| elif act == "relu": |
| self.act = nn.ReLU() |
| elif act == "gelu": |
| self.act = nn.GELU() |
| elif act == "sigmoid": |
| self.act = nn.Sigmoid() |
| else: |
| raise ValueError(f"Unsupported activation: {activation}") |
| nn.init.normal_(self.linear.weight, mean=0.0, std=0.02) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.act(self.linear(x)) |
|
|
|
|
| def binarize_sign(x: torch.Tensor) -> torch.Tensor: |
| """Return {0,1} bits: 1 if x>0 else 0, dtype int64. |
| Expects x on any device; returns tensor on same device. |
| """ |
| return (x > 0).to(torch.int64) |
|
|
|
|
| def pack_bits_64(bits: torch.Tensor, dim_features: int) -> torch.Tensor: |
| """Pack {0,1} bits to int64 words of length 64. |
| |
| Args: |
| bits: Tensor (N, D) with 0/1 int64 entries |
| dim_features: D, must be divisible by 64 |
| Returns: |
| Tensor (N, D//64) int64 on same device. |
| """ |
| assert dim_features % 64 == 0, "proj_dim must be divisible by 64 for bit-packing" |
| if bits.dtype != torch.int64: |
| bits = bits.to(torch.int64) |
| N = bits.size(0) |
| words = dim_features // 64 |
| bits = bits.view(N, words, 64) |
| shifts = torch.arange(64, device=bits.device, dtype=torch.int64) |
| packed = (bits << shifts).sum(-1) |
| return packed.contiguous() |
|
|
|
|
| def _popcount64(x: torch.Tensor) -> torch.Tensor: |
| """Compute population count for each int64 element using bit hacks. |
| Returns same shape tensor with counts in int64. |
| """ |
| |
| m1 = torch.tensor(0x5555555555555555, dtype=torch.int64, device=x.device) |
| m2 = torch.tensor(0x3333333333333333, dtype=torch.int64, device=x.device) |
| m4 = torch.tensor(0x0F0F0F0F0F0F0F0F, dtype=torch.int64, device=x.device) |
| x = x - ((x >> 1) & m1) |
| x = (x & m2) + ((x >> 2) & m2) |
| x = (x + (x >> 4)) & m4 |
| x = x + (x >> 8) |
| x = x + (x >> 16) |
| x = x + (x >> 32) |
| return x & torch.tensor(0x7F, dtype=torch.int64, device=x.device) |
|
|
|
|
| def hamming_distance_packed(a_words: torch.Tensor, b_words: torch.Tensor, block: int = 1024) -> torch.Tensor: |
| """Compute pairwise Hamming distances between two packed code sets. |
| |
| Args: |
| a_words: (Na, W) int64 packed codes |
| b_words: (Nb, W) int64 packed codes |
| block: block size for A to limit memory |
| Returns: |
| dist: (Na, Nb) int64 distances |
| """ |
| assert a_words.dtype == torch.int64 and b_words.dtype == torch.int64 |
| Na, W = a_words.shape |
| Nb, Wb = b_words.shape |
| assert W == Wb |
| device = a_words.device |
| out = torch.empty((Na, Nb), dtype=torch.int64, device=device) |
| for s in range(0, Na, block): |
| e = min(Na, s + block) |
| aw = a_words[s:e] |
| |
| xor = aw.unsqueeze(1) ^ b_words.unsqueeze(0) |
| pc = _popcount64(xor) |
| dist = pc.sum(-1) |
| out[s:e] = dist |
| return out |
|
|
|
|