threshold-onehot-decoder / create_safetensors.py
phanerozoic's picture
Upload folder using huggingface_hub
4ee35bd verified
import torch
from safetensors.torch import save_file
weights = {}
# One-Hot Decoder (4-to-2)
# Inputs: y3, y2, y1, y0 (one-hot encoding, exactly one bit set)
# Outputs: a1, a0 (binary value 0-3)
#
# y3 y2 y1 y0 | a1 a0 | value
# -----------+-------+------
# 0 0 0 1 | 0 0 | 0
# 0 0 1 0 | 0 1 | 1
# 0 1 0 0 | 1 0 | 2
# 1 0 0 0 | 1 1 | 3
#
# Single layer implementation:
# a0 = OR(y1, y3) = 1 when position 1 or 3
# a1 = OR(y2, y3) = 1 when position 2 or 3
# a0 = OR(y1, y3)
weights['a0.weight'] = torch.tensor([[1.0, 0.0, 1.0, 0.0]], dtype=torch.float32)
weights['a0.bias'] = torch.tensor([-1.0], dtype=torch.float32)
# a1 = OR(y2, y3)
weights['a1.weight'] = torch.tensor([[1.0, 1.0, 0.0, 0.0]], dtype=torch.float32)
weights['a1.bias'] = torch.tensor([-1.0], dtype=torch.float32)
save_file(weights, 'model.safetensors')
def onehot_decode(y3, y2, y1, y0):
inp = torch.tensor([float(y3), float(y2), float(y1), float(y0)])
a0 = int((inp @ weights['a0.weight'].T + weights['a0.bias'] >= 0).item())
a1 = int((inp @ weights['a1.weight'].T + weights['a1.bias'] >= 0).item())
return a1, a0
def reference_decode(y3, y2, y1, y0):
if y0 == 1:
return 0, 0
if y1 == 1:
return 0, 1
if y2 == 1:
return 1, 0
if y3 == 1:
return 1, 1
return 0, 0
print("Verifying One-Hot Decoder (4-to-2)...")
errors = 0
test_cases = [(0,0,0,1), (0,0,1,0), (0,1,0,0), (1,0,0,0)]
for y3, y2, y1, y0 in test_cases:
result = onehot_decode(y3, y2, y1, y0)
expected = reference_decode(y3, y2, y1, y0)
if result != expected:
errors += 1
print(f"ERROR: ({y3},{y2},{y1},{y0}) -> {result}, expected {expected}")
if errors == 0:
print("All 4 test cases passed!")
else:
print(f"FAILED: {errors} errors")
print("\nTruth Table:")
print("y3 y2 y1 y0 | a1 a0 | value")
print("-" * 30)
for y3, y2, y1, y0 in test_cases:
a1, a0 = onehot_decode(y3, y2, y1, y0)
val = a1 * 2 + a0
print(f" {y3} {y2} {y1} {y0} | {a1} {a0} | {val}")
mag = sum(t.abs().sum().item() for t in weights.values())
print(f"\nMagnitude: {mag:.0f}")
print(f"Parameters: {sum(t.numel() for t in weights.values())}")
print(f"Neurons: {len([k for k in weights.keys() if 'weight' in k])}")