| |
| |
| |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import math |
| import hashlib |
| from collections import defaultdict, deque |
| from typing import List, Dict, Tuple, Optional, Union |
|
|
| SAFE_MIN = -1e6 |
| SAFE_MAX = 1e6 |
| EPS = 1e-8 |
|
|
| |
|
|
| def make_safe(tensor, min_val=SAFE_MIN, max_val=SAFE_MAX): |
| tensor = torch.where(torch.isnan(tensor), torch.tensor(0.0, device=tensor.device, dtype=tensor.dtype), tensor) |
| tensor = torch.where(torch.isinf(tensor), torch.tensor(max_val, device=tensor.device, dtype=tensor.dtype), tensor) |
| return torch.clamp(tensor, min_val, max_val) |
|
|
| def safe_cosine_similarity(a, b, dim=-1, eps=EPS): |
| if a.dtype != torch.float32: |
| a = a.float() |
| if b.dtype != torch.float32: |
| b = b.float() |
| a_norm = torch.norm(a, dim=dim, keepdim=True).clamp(min=eps) |
| b_norm = torch.norm(b, dim=dim, keepdim=True).clamp(min=eps) |
| return torch.sum(a * b, dim=dim, keepdim=True) / (a_norm * b_norm) |
|
|
| def item_to_vector(item, vector_dim=64): |
| if isinstance(item, str): |
| hash_obj = hashlib.md5(item.encode()) |
| hash_bytes = hash_obj.digest() |
| vector = torch.tensor([b / 255.0 for b in hash_bytes], dtype=torch.float32) |
| if len(vector) < vector_dim: |
| padding = torch.zeros(vector_dim - len(vector), dtype=torch.float32) |
| vector = torch.cat([vector, padding]) |
| else: |
| vector = vector[:vector_dim] |
| elif isinstance(item, (int, float)): |
| vector = torch.zeros(vector_dim, dtype=torch.float32) |
| for i in range(vector_dim // 2): |
| freq = 10000 ** (-2 * i / vector_dim) |
| vector[2*i] = math.sin(item * freq) |
| vector[2*i + 1] = math.cos(item * freq) |
| elif torch.is_tensor(item): |
| vector = item.flatten().float() |
| if len(vector) < vector_dim: |
| padding = torch.zeros(vector_dim - len(vector), dtype=torch.float32, device=vector.device) |
| vector = torch.cat([vector, padding]) |
| else: |
| vector = vector[:vector_dim] |
| else: |
| hash_val = hash(str(item)) % (2**31) |
| gen = torch.Generator(device='cpu') |
| gen.manual_seed(hash_val) |
| vector = torch.randn(vector_dim, generator=gen, dtype=torch.float32) |
| |
| return make_safe(vector) |
|
|
| |
| |
|
|
| class LearnableHashFunction(nn.Module): |
| def __init__(self, input_dim, hash_output_bits=32, learning_rate=0.01): |
| super().__init__() |
| self.input_dim = input_dim |
| self.hash_output_bits = hash_output_bits |
| self.learning_rate = learning_rate |
| |
| self.hash_network = nn.Sequential( |
| nn.Linear(input_dim, input_dim * 2), |
| nn.LayerNorm(input_dim * 2), |
| nn.Tanh(), |
| nn.Linear(input_dim * 2, hash_output_bits), |
| nn.Tanh() |
| ) |
| |
| self.hebbian_weights = nn.Parameter(torch.ones(hash_output_bits) * 0.1) |
| self.plasticity_rate = nn.Parameter(torch.tensor(learning_rate)) |
| |
| self.register_buffer('activity_history', torch.zeros(100, hash_output_bits)) |
| self.register_buffer('history_pointer', torch.tensor(0, dtype=torch.long)) |
| |
| self.coactivation_matrix = nn.Parameter(torch.eye(hash_output_bits) * 0.1) |
| |
| self.activation_threshold = nn.Parameter(torch.zeros(hash_output_bits)) |
| |
| def compute_hash_activation(self, item_vector): |
| if item_vector.dim() == 1: |
| item_vector = item_vector.unsqueeze(0) |
| item_vector = item_vector.to(next(self.hash_network.parameters()).device, dtype=torch.float32) |
| |
| base_hash = self.hash_network(item_vector).squeeze(0) |
| |
| hebbian_modulation = torch.tanh(self.hebbian_weights) |
| modulated_hash = base_hash * hebbian_modulation |
| |
| thresholded = modulated_hash - self.activation_threshold |
| |
| hash_probs = torch.sigmoid(thresholded * 10.0) |
| |
| return hash_probs, modulated_hash |
| |
| def get_hash_bits(self, item_vector, deterministic=False): |
| hash_probs, _ = self.compute_hash_activation(item_vector) |
| |
| if deterministic: |
| hash_bits = (hash_probs > 0.5).float() |
| else: |
| hash_bits = torch.bernoulli(hash_probs) |
| |
| return hash_bits |
| |
| def hebbian_update(self, item_vector, co_occurring_items=None): |
| hash_probs, modulated_hash = self.compute_hash_activation(item_vector) |
| |
| with torch.no_grad(): |
| ptr = int(self.history_pointer.item()) |
| self.activity_history[ptr % self.activity_history.size(0)].copy_(hash_probs.detach()) |
| self.history_pointer.add_(1) |
| self.history_pointer.remainder_(self.activity_history.size(0)) |
| |
| plasticity_rate = torch.clamp(self.plasticity_rate, 0.001, 0.1) |
| |
| activity_strength = torch.abs(modulated_hash) |
| hebbian_delta = plasticity_rate * activity_strength * hash_probs |
| |
| with torch.no_grad(): |
| self.hebbian_weights.data.add_(hebbian_delta * 0.05) |
| self.hebbian_weights.data.clamp_(-2.0, 2.0) |
| |
| if co_occurring_items is not None: |
| self.update_coactivation_matrix(hash_probs, co_occurring_items) |
| |
| return hash_probs |
| |
| def update_coactivation_matrix(self, current_activation, co_occurring_items): |
| with torch.no_grad(): |
| for co_item in co_occurring_items: |
| co_item_vector = item_to_vector(co_item, self.input_dim).to(current_activation.device) |
| co_activation, _ = self.compute_hash_activation(co_item_vector) |
| |
| coactivation_update = torch.outer(current_activation, co_activation) |
| |
| learning_rate = 0.01 |
| self.coactivation_matrix.data.add_(learning_rate * coactivation_update) |
| self.coactivation_matrix.data.clamp_(-1.0, 1.0) |
| |
| def get_similar_patterns(self, item_vector, top_k=5): |
| current_probs, _ = self.compute_hash_activation(item_vector) |
| |
| similarities = [] |
| for i in range(self.activity_history.shape[0]): |
| hist_pattern = self.activity_history[i] |
| if torch.sum(hist_pattern) > 0: |
| similarity = safe_cosine_similarity( |
| current_probs.unsqueeze(0), |
| hist_pattern.unsqueeze(0) |
| ).squeeze() |
| similarities.append((i, float(similarity.item()))) |
| |
| similarities.sort(key=lambda x: x[1], reverse=True) |
| |
| return similarities[:top_k] |
| |
| def apply_forgetting(self, forget_rate=0.99): |
| with torch.no_grad(): |
| self.hebbian_weights.data.mul_(forget_rate) |
| self.coactivation_matrix.data.mul_(forget_rate) |
|
|
| |
| |
|
|
| class HebbianBloomFilter(nn.Module): |
| def __init__(self, capacity=10000, error_rate=0.01, vector_dim=64, num_hash_functions=8): |
| super().__init__() |
| self.capacity = capacity |
| self.error_rate = error_rate |
| self.vector_dim = vector_dim |
| self.num_hash_functions = num_hash_functions |
| |
| self.bit_array_size = self._calculate_bit_array_size(capacity, error_rate) |
| |
| self.hash_functions = nn.ModuleList([ |
| LearnableHashFunction(vector_dim, hash_output_bits=32) |
| for _ in range(num_hash_functions) |
| ]) |
| |
| self.register_buffer('bit_array', torch.zeros(self.bit_array_size)) |
| self.register_buffer('confidence_array', torch.zeros(self.bit_array_size)) |
| |
| self.stored_items = {} |
| self.item_vectors = {} |
| |
| self.register_buffer('access_counts', torch.zeros(self.bit_array_size)) |
| self.register_buffer('total_items_added', torch.tensor(0, dtype=torch.long)) |
| |
| self.association_strength = nn.Parameter(torch.tensor(0.1)) |
| self.confidence_threshold = nn.Parameter(torch.tensor(0.5)) |
| |
| self.decay_rate = nn.Parameter(torch.tensor(0.999)) |
| |
| def _calculate_bit_array_size(self, capacity, error_rate): |
| return int(-capacity * math.log(error_rate) / (math.log(2) ** 2)) |
| |
| def _get_bit_indices(self, item_vector): |
| indices = [] |
| confidences = [] |
| |
| for hash_func in self.hash_functions: |
| hash_bits = hash_func.get_hash_bits(item_vector, deterministic=True) |
| |
| weights = (1 << torch.arange(len(hash_bits), device=hash_bits.device, dtype=torch.int64)) |
| bit_index = int((hash_bits.to(dtype=torch.int64) * weights).sum().item()) |
| bit_index = bit_index % self.bit_array_size |
| |
| hash_probs, _ = hash_func.compute_hash_activation(item_vector) |
| confidence = torch.mean(torch.abs(hash_probs - 0.5)) * 2 |
| |
| indices.append(bit_index) |
| confidences.append(confidence.item()) |
| |
| return indices, confidences |
| |
| def add(self, item, associated_items=None): |
| item_vector = item_to_vector(item, self.vector_dim) |
| |
| item_key = str(item) |
| self.stored_items[item_key] = item |
| self.item_vectors[item_key] = item_vector |
| |
| indices, confidences = self._get_bit_indices(item_vector) |
| |
| with torch.no_grad(): |
| for idx, conf in zip(indices, confidences): |
| self.bit_array[idx] = 1.0 |
| self.confidence_array[idx] = max(float(self.confidence_array[idx].item()), conf) |
| self.access_counts[idx] += 1 |
| |
| for hash_func in self.hash_functions: |
| hash_func.hebbian_update(item_vector, associated_items) |
| |
| with torch.no_grad(): |
| self.total_items_added.add_(1) |
| |
| if associated_items: |
| self._learn_associations(item, associated_items) |
| |
| return indices |
| |
| def _learn_associations(self, primary_item, associated_items): |
| primary_vector = item_to_vector(primary_item, self.vector_dim) |
| |
| for assoc_item in associated_items: |
| assoc_vector = item_to_vector(assoc_item, self.vector_dim) |
| |
| similarity = safe_cosine_similarity( |
| primary_vector.unsqueeze(0), |
| assoc_vector.unsqueeze(0) |
| ).squeeze() |
| |
| association_strength = torch.clamp(self.association_strength, 0.01, 1.0) |
| _ = association_strength |
| |
| for hash_func in self.hash_functions: |
| if float(similarity.item()) > 0.5: |
| hash_func.hebbian_update(primary_vector, [assoc_item]) |
| |
| def query(self, item, return_confidence=False): |
| item_vector = item_to_vector(item, self.vector_dim) |
| indices, confidences = self._get_bit_indices(item_vector) |
| |
| bit_checks = [self.bit_array[idx].item() > 0 for idx in indices] |
| is_member = all(bit_checks) |
| |
| if return_confidence: |
| bit_confidences = [self.confidence_array[idx].item() for idx in indices] |
| hash_confidences = confidences |
| |
| bit_conf = np.mean(bit_confidences) if bit_confidences else 0.0 |
| hash_conf = np.mean(hash_confidences) if hash_confidences else 0.0 |
| |
| access_conf = np.mean([self.access_counts[idx].item() for idx in indices]) |
| access_conf = min(access_conf / 10.0, 1.0) |
| |
| overall_confidence = (bit_conf + hash_conf + access_conf) / 3.0 |
| |
| return is_member, overall_confidence |
| |
| return is_member |
| |
| def find_similar_items(self, query_item, top_k=5): |
| query_vector = item_to_vector(query_item, self.vector_dim) |
| |
| coact_weights = [] |
| for hash_func in self.hash_functions: |
| q_act, _ = hash_func.compute_hash_activation(query_vector) |
| q_weight = torch.matmul(hash_func.coactivation_matrix.t(), q_act) |
| coact_weights.append((q_act, q_weight)) |
| |
| similarities = [] |
| for item_key, item_vector in self.item_vectors.items(): |
| base_sim = safe_cosine_similarity( |
| query_vector.unsqueeze(0), |
| item_vector.unsqueeze(0) |
| ).squeeze().item() |
| |
| co_sim_sum = 0.0 |
| for (hash_func, (q_act, q_weight)) in zip(self.hash_functions, coact_weights): |
| i_act, _ = hash_func.compute_hash_activation(item_vector) |
| co_sim_sum += torch.dot(q_weight, i_act).item() / max(1, len(i_act)) |
| co_sim = co_sim_sum / max(1, len(self.hash_functions)) |
| |
| alpha, beta = 0.6, 0.4 |
| score = alpha * base_sim + beta * co_sim |
| similarities.append((self.stored_items[item_key], score)) |
| |
| similarities.sort(key=lambda x: x[1], reverse=True) |
| return similarities[:top_k] |
| |
| def get_hash_statistics(self): |
| stats = { |
| 'total_items': int(self.total_items_added.item()), |
| 'bit_array_utilization': (self.bit_array > 0).float().mean().item(), |
| 'average_confidence': self.confidence_array.mean().item(), |
| 'hash_function_stats': [] |
| } |
| |
| for i, hash_func in enumerate(self.hash_functions): |
| hash_stats = { |
| 'function_id': i, |
| 'hebbian_weights_mean': hash_func.hebbian_weights.mean().item(), |
| 'plasticity_rate': hash_func.plasticity_rate.item(), |
| 'activation_threshold_mean': hash_func.activation_threshold.mean().item() |
| } |
| stats['hash_function_stats'].append(hash_stats) |
| |
| return stats |
| |
| def apply_temporal_decay(self): |
| decay_rate = torch.clamp(self.decay_rate, 0.9, 0.999) |
| |
| with torch.no_grad(): |
| self.confidence_array.mul_(decay_rate) |
| self.access_counts.mul_(decay_rate) |
| |
| low_confidence_mask = self.confidence_array < 0.1 |
| self.bit_array[low_confidence_mask] = 0.0 |
| self.confidence_array[low_confidence_mask] = 0.0 |
| |
| for hash_func in self.hash_functions: |
| hash_func.apply_forgetting(float(decay_rate.item())) |
| |
| def optimize_structure(self): |
| with torch.no_grad(): |
| high_access_ratio = (self.access_counts > self.access_counts.mean()).float().mean().item() |
| adjustment = -0.01 * high_access_ratio |
| for hash_func in self.hash_functions: |
| hash_func.activation_threshold.data.add_(adjustment) |
| hash_func.activation_threshold.data.clamp_(-1.0, 1.0) |
|
|
| |
| |
|
|
| class AssociativeHebbianBloomSystem(nn.Module): |
| def __init__(self, capacity=10000, vector_dim=64, num_filters=3): |
| super().__init__() |
| self.capacity = capacity |
| self.vector_dim = vector_dim |
| self.num_filters = num_filters |
| |
| self.filters = nn.ModuleList([ |
| HebbianBloomFilter( |
| capacity=capacity // num_filters, |
| error_rate=0.01, |
| vector_dim=vector_dim, |
| num_hash_functions=6 |
| ) for _ in range(num_filters) |
| ]) |
| |
| self.filter_selector = nn.Sequential( |
| nn.Linear(vector_dim, vector_dim // 2), |
| nn.ReLU(), |
| nn.Linear(vector_dim // 2, num_filters), |
| nn.Softmax(dim=-1) |
| ) |
| |
| self.global_association_net = nn.Sequential( |
| nn.Linear(vector_dim * 2, vector_dim), |
| nn.Tanh(), |
| nn.Linear(vector_dim, 1), |
| nn.Sigmoid() |
| ) |
| |
| self.register_buffer('global_access_count', torch.tensor(0, dtype=torch.long)) |
| |
| def add_item(self, item, category=None, associated_items=None): |
| item_vector = item_to_vector(item, self.vector_dim) |
| |
| filter_weights = self.filter_selector(item_vector.unsqueeze(0)).squeeze(0) |
| |
| with torch.no_grad(): |
| loads = torch.tensor([f.total_items_added.item() / max(1, f.capacity) for f in self.filters], dtype=filter_weights.dtype, device=filter_weights.device) |
| filter_weights = filter_weights - 0.1 * loads |
| |
| top_k_filters = min(2, self.num_filters) |
| _, top_indices = torch.topk(filter_weights, top_k_filters) |
| |
| added_to_filters = [] |
| for filter_idx in top_indices: |
| filter_obj = self.filters[filter_idx.item()] |
| indices = filter_obj.add(item, associated_items) |
| added_to_filters.append((filter_idx.item(), indices)) |
| |
| with torch.no_grad(): |
| self.global_access_count.add_(1) |
| |
| return added_to_filters |
| |
| def query_item(self, item, return_detailed=False): |
| item_vector = item_to_vector(item, self.vector_dim) |
| |
| results = [] |
| confidences = [] |
| |
| for i, filter_obj in enumerate(self.filters): |
| is_member, confidence = filter_obj.query(item, return_confidence=True) |
| results.append(is_member) |
| confidences.append(confidence) |
| |
| positive_votes = sum(results) |
| avg_confidence = np.mean(confidences) |
| |
| ensemble_decision = positive_votes > len(self.filters) // 2 |
| |
| if return_detailed: |
| return { |
| 'is_member': ensemble_decision, |
| 'confidence': avg_confidence, |
| 'individual_results': list(zip(results, confidences)), |
| 'positive_votes': positive_votes, |
| 'total_filters': len(self.filters) |
| } |
| |
| return ensemble_decision |
| |
| def find_associations(self, query_item, top_k=10): |
| all_similarities = [] |
| |
| for filter_obj in self.filters: |
| similarities = filter_obj.find_similar_items(query_item, top_k) |
| all_similarities.extend(similarities) |
| |
| unique_items = {} |
| for item, similarity in all_similarities: |
| item_key = str(item) |
| if item_key in unique_items: |
| unique_items[item_key] = max(unique_items[item_key], similarity) |
| else: |
| unique_items[item_key] = similarity |
| |
| ranked_items = sorted(unique_items.items(), key=lambda x: x[1], reverse=True) |
| |
| return ranked_items[:top_k] |
| |
| def system_maintenance(self): |
| for filter_obj in self.filters: |
| filter_obj.apply_temporal_decay() |
| filter_obj.optimize_structure() |
| |
| if self.global_access_count % 1000 == 0: |
| self._global_optimization() |
| |
| def _global_optimization(self): |
| print("Performing global Hebbian Bloom system optimization...") |
| |
| filter_utilizations = [] |
| for filter_obj in self.filters: |
| stats = filter_obj.get_hash_statistics() |
| utilization = stats['bit_array_utilization'] |
| filter_utilizations.append(utilization) |
| |
| def get_system_statistics(self): |
| """Get comprehensive system statistics.""" |
| stats = { |
| 'global_access_count': int(self.global_access_count.item()), |
| 'num_filters': self.num_filters, |
| 'filter_statistics': [] |
| } |
| |
| for i, filter_obj in enumerate(self.filters): |
| filter_stats = filter_obj.get_hash_statistics() |
| filter_stats['filter_id'] = i |
| stats['filter_statistics'].append(filter_stats) |
| |
| return stats |
|
|
|
|
| |
|
|