| import torch
|
| from safetensors.torch import save_file
|
|
|
| weights = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| weights['a0.weight'] = torch.tensor([[1.0, 0.0, 1.0, 0.0]], dtype=torch.float32)
|
| weights['a0.bias'] = torch.tensor([-1.0], dtype=torch.float32)
|
|
|
|
|
| weights['a1.weight'] = torch.tensor([[1.0, 1.0, 0.0, 0.0]], dtype=torch.float32)
|
| weights['a1.bias'] = torch.tensor([-1.0], dtype=torch.float32)
|
|
|
| save_file(weights, 'model.safetensors')
|
|
|
| def onehot_decode(y3, y2, y1, y0):
|
| inp = torch.tensor([float(y3), float(y2), float(y1), float(y0)])
|
|
|
| a0 = int((inp @ weights['a0.weight'].T + weights['a0.bias'] >= 0).item())
|
| a1 = int((inp @ weights['a1.weight'].T + weights['a1.bias'] >= 0).item())
|
|
|
| return a1, a0
|
|
|
| def reference_decode(y3, y2, y1, y0):
|
| if y0 == 1:
|
| return 0, 0
|
| if y1 == 1:
|
| return 0, 1
|
| if y2 == 1:
|
| return 1, 0
|
| if y3 == 1:
|
| return 1, 1
|
| return 0, 0
|
|
|
| print("Verifying One-Hot Decoder (4-to-2)...")
|
| errors = 0
|
| test_cases = [(0,0,0,1), (0,0,1,0), (0,1,0,0), (1,0,0,0)]
|
| for y3, y2, y1, y0 in test_cases:
|
| result = onehot_decode(y3, y2, y1, y0)
|
| expected = reference_decode(y3, y2, y1, y0)
|
| if result != expected:
|
| errors += 1
|
| print(f"ERROR: ({y3},{y2},{y1},{y0}) -> {result}, expected {expected}")
|
|
|
| if errors == 0:
|
| print("All 4 test cases passed!")
|
| else:
|
| print(f"FAILED: {errors} errors")
|
|
|
| print("\nTruth Table:")
|
| print("y3 y2 y1 y0 | a1 a0 | value")
|
| print("-" * 30)
|
| for y3, y2, y1, y0 in test_cases:
|
| a1, a0 = onehot_decode(y3, y2, y1, y0)
|
| val = a1 * 2 + a0
|
| print(f" {y3} {y2} {y1} {y0} | {a1} {a0} | {val}")
|
|
|
| mag = sum(t.abs().sum().item() for t in weights.values())
|
| print(f"\nMagnitude: {mag:.0f}")
|
| print(f"Parameters: {sum(t.numel() for t in weights.values())}")
|
| print(f"Neurons: {len([k for k in weights.keys() if 'weight' in k])}")
|
|
|