| """
|
| Inference code for mod3-verified threshold network.
|
|
|
| This network computes MOD-3 (Hamming weight mod 3) on 8-bit binary inputs.
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| from safetensors.torch import load_file
|
|
|
|
|
| def heaviside(x):
|
| """Heaviside step function: 1 if x >= 0, else 0."""
|
| return (x >= 0).float()
|
|
|
|
|
| class Mod3Network(nn.Module):
|
| """
|
| Verified threshold network for MOD-3 computation.
|
|
|
| Architecture: 8 -> 9 -> 2 -> 3
|
| - Layer 1: Thermometer encoding (9 neurons detect HW >= k)
|
| - Layer 2: MOD-3 detection using (1,1,-2) weight pattern
|
| - Output: 3-class classification
|
| """
|
|
|
| def __init__(self):
|
| super().__init__()
|
| self.layer1 = nn.Linear(8, 9)
|
| self.layer2 = nn.Linear(9, 2)
|
| self.output = nn.Linear(2, 3)
|
|
|
| def forward(self, x):
|
| """Forward pass with Heaviside activation."""
|
| x = x.float()
|
| x = heaviside(self.layer1(x))
|
| x = heaviside(self.layer2(x))
|
| x = self.output(x)
|
| return x
|
|
|
| def predict(self, x):
|
| """Get predicted class (0, 1, or 2)."""
|
| return self.forward(x).argmax(dim=-1)
|
|
|
| @classmethod
|
| def from_safetensors(cls, path):
|
| """Load model from safetensors file."""
|
| model = cls()
|
| weights = load_file(path)
|
|
|
| model.layer1.weight.data = weights['layer1.weight']
|
| model.layer1.bias.data = weights['layer1.bias']
|
| model.layer2.weight.data = weights['layer2.weight']
|
| model.layer2.bias.data = weights['layer2.bias']
|
| model.output.weight.data = weights['output.weight']
|
| model.output.bias.data = weights['output.bias']
|
|
|
| return model
|
|
|
|
|
| def mod3_reference(x):
|
| """Reference implementation: Hamming weight mod 3."""
|
| return (x.sum(dim=-1) % 3).long()
|
|
|
|
|
| def verify(model, verbose=True):
|
| """Verify model on all 256 inputs."""
|
| inputs = torch.zeros(256, 8)
|
| for i in range(256):
|
| for j in range(8):
|
| inputs[i, j] = (i >> j) & 1
|
|
|
| targets = mod3_reference(inputs)
|
| predictions = model.predict(inputs)
|
|
|
| correct = (predictions == targets).sum().item()
|
|
|
| if verbose:
|
| print(f"Verification: {correct}/256 ({100*correct/256:.1f}%)")
|
|
|
| if correct < 256:
|
| errors = (predictions != targets).nonzero(as_tuple=True)[0]
|
| print(f"Errors at indices: {errors[:10].tolist()}")
|
|
|
| return correct == 256
|
|
|
|
|
| def demo():
|
| """Demonstration of MOD-3 computation."""
|
| print("Loading mod3-verified model...")
|
| model = Mod3Network.from_safetensors('model.safetensors')
|
|
|
| print("\nVerifying on all 256 inputs...")
|
| verify(model)
|
|
|
| print("\nExample predictions:")
|
| test_cases = [
|
| [0, 0, 0, 0, 0, 0, 0, 0],
|
| [1, 0, 0, 0, 0, 0, 0, 0],
|
| [1, 1, 0, 0, 0, 0, 0, 0],
|
| [1, 1, 1, 0, 0, 0, 0, 0],
|
| [1, 1, 1, 1, 0, 0, 0, 0],
|
| [1, 1, 1, 1, 1, 0, 0, 0],
|
| [1, 1, 1, 1, 1, 1, 0, 0],
|
| [1, 1, 1, 1, 1, 1, 1, 0],
|
| [1, 1, 1, 1, 1, 1, 1, 1],
|
| ]
|
|
|
| for bits in test_cases:
|
| x = torch.tensor([bits], dtype=torch.float32)
|
| hw = sum(bits)
|
| pred = model.predict(x).item()
|
| expected = hw % 3
|
| status = "OK" if pred == expected else "ERROR"
|
| print(f" {bits} -> HW={hw}, pred={pred}, expected={expected} [{status}]")
|
|
|
|
|
| if __name__ == '__main__':
|
| demo()
|
|
|