import torch from safetensors.torch import save_file weights = {} # One-Hot Decoder (4-to-2) # Inputs: y3, y2, y1, y0 (one-hot encoding, exactly one bit set) # Outputs: a1, a0 (binary value 0-3) # # y3 y2 y1 y0 | a1 a0 | value # -----------+-------+------ # 0 0 0 1 | 0 0 | 0 # 0 0 1 0 | 0 1 | 1 # 0 1 0 0 | 1 0 | 2 # 1 0 0 0 | 1 1 | 3 # # Single layer implementation: # a0 = OR(y1, y3) = 1 when position 1 or 3 # a1 = OR(y2, y3) = 1 when position 2 or 3 # a0 = OR(y1, y3) weights['a0.weight'] = torch.tensor([[1.0, 0.0, 1.0, 0.0]], dtype=torch.float32) weights['a0.bias'] = torch.tensor([-1.0], dtype=torch.float32) # a1 = OR(y2, y3) weights['a1.weight'] = torch.tensor([[1.0, 1.0, 0.0, 0.0]], dtype=torch.float32) weights['a1.bias'] = torch.tensor([-1.0], dtype=torch.float32) save_file(weights, 'model.safetensors') def onehot_decode(y3, y2, y1, y0): inp = torch.tensor([float(y3), float(y2), float(y1), float(y0)]) a0 = int((inp @ weights['a0.weight'].T + weights['a0.bias'] >= 0).item()) a1 = int((inp @ weights['a1.weight'].T + weights['a1.bias'] >= 0).item()) return a1, a0 def reference_decode(y3, y2, y1, y0): if y0 == 1: return 0, 0 if y1 == 1: return 0, 1 if y2 == 1: return 1, 0 if y3 == 1: return 1, 1 return 0, 0 print("Verifying One-Hot Decoder (4-to-2)...") errors = 0 test_cases = [(0,0,0,1), (0,0,1,0), (0,1,0,0), (1,0,0,0)] for y3, y2, y1, y0 in test_cases: result = onehot_decode(y3, y2, y1, y0) expected = reference_decode(y3, y2, y1, y0) if result != expected: errors += 1 print(f"ERROR: ({y3},{y2},{y1},{y0}) -> {result}, expected {expected}") if errors == 0: print("All 4 test cases passed!") else: print(f"FAILED: {errors} errors") print("\nTruth Table:") print("y3 y2 y1 y0 | a1 a0 | value") print("-" * 30) for y3, y2, y1, y0 in test_cases: a1, a0 = onehot_decode(y3, y2, y1, y0) val = a1 * 2 + a0 print(f" {y3} {y2} {y1} {y0} | {a1} {a0} | {val}") mag = sum(t.abs().sum().item() for t in weights.values()) print(f"\nMagnitude: {mag:.0f}") print(f"Parameters: {sum(t.numel() for t in weights.values())}") print(f"Neurons: {len([k for k in weights.keys() if 'weight' in k])}")