| """
|
| Comprehensive unit tests for TorchForge.
|
|
|
| Tests core functionality, governance, monitoring, and deployment.
|
| """
|
|
|
| import pytest
|
| import torch
|
| import torch.nn as nn
|
| from pathlib import Path
|
| import tempfile
|
|
|
| from torchforge import ForgeModel, ForgeConfig
|
| from torchforge.governance import ComplianceChecker, NISTFramework
|
| from torchforge.monitoring import ModelMonitor
|
| from torchforge.deployment import DeploymentManager
|
|
|
|
|
| class SimpleModel(nn.Module):
|
| """Simple model for testing."""
|
|
|
| def __init__(self, input_dim: int = 10, output_dim: int = 2):
|
| super().__init__()
|
| self.fc = nn.Linear(input_dim, output_dim)
|
|
|
| def forward(self, x):
|
| return self.fc(x)
|
|
|
|
|
| class TestForgeModel:
|
| """Test ForgeModel functionality."""
|
|
|
| def test_model_creation(self):
|
| """Test basic model creation."""
|
| base_model = SimpleModel()
|
| config = ForgeConfig(model_name="test_model", version="1.0.0")
|
| model = ForgeModel(base_model, config=config)
|
|
|
| assert model.config.model_name == "test_model"
|
| assert model.config.version == "1.0.0"
|
| assert model.model_id is not None
|
|
|
| def test_forward_pass(self):
|
| """Test forward pass."""
|
| base_model = SimpleModel()
|
| config = ForgeConfig(model_name="test_model", version="1.0.0")
|
| model = ForgeModel(base_model, config=config)
|
|
|
| x = torch.randn(32, 10)
|
| output = model(x)
|
|
|
| assert output.shape == (32, 2)
|
|
|
| def test_track_prediction(self):
|
| """Test prediction tracking."""
|
| base_model = SimpleModel()
|
| config = ForgeConfig(
|
| model_name="test_model",
|
| version="1.0.0",
|
| enable_governance=True
|
| )
|
| model = ForgeModel(base_model, config=config)
|
|
|
| x = torch.randn(32, 10)
|
| y = torch.randint(0, 2, (32,))
|
| output = model(x)
|
|
|
| model.track_prediction(output, y)
|
| assert len(model.prediction_history) == 1
|
|
|
| def test_checkpoint_save_load(self):
|
| """Test checkpoint save and load."""
|
| base_model = SimpleModel()
|
| config = ForgeConfig(model_name="test_model", version="1.0.0")
|
| model = ForgeModel(base_model, config=config)
|
|
|
| with tempfile.TemporaryDirectory() as tmpdir:
|
| checkpoint_path = Path(tmpdir) / "checkpoint.pt"
|
| model.save_checkpoint(checkpoint_path)
|
|
|
|
|
| loaded_base = SimpleModel()
|
| loaded_model = ForgeModel.load_checkpoint(
|
| checkpoint_path,
|
| loaded_base
|
| )
|
|
|
| assert loaded_model.config.model_name == "test_model"
|
| assert loaded_model.config.version == "1.0.0"
|
|
|
| def test_metrics_collection(self):
|
| """Test metrics collection."""
|
| base_model = SimpleModel()
|
| config = ForgeConfig(
|
| model_name="test_model",
|
| version="1.0.0",
|
| enable_monitoring=True
|
| )
|
| model = ForgeModel(base_model, config=config)
|
|
|
|
|
| for _ in range(10):
|
| x = torch.randn(32, 10)
|
| _ = model(x)
|
|
|
| metrics = model.get_metrics_summary()
|
| assert metrics["inference_count"] == 10
|
| assert "latency_mean_ms" in metrics
|
|
|
|
|
| class TestConfiguration:
|
| """Test configuration management."""
|
|
|
| def test_config_creation(self):
|
| """Test configuration creation."""
|
| config = ForgeConfig(
|
| model_name="test_model",
|
| version="1.0.0",
|
| enable_monitoring=True,
|
| enable_governance=True
|
| )
|
|
|
| assert config.model_name == "test_model"
|
| assert config.version == "1.0.0"
|
| assert config.enable_monitoring is True
|
| assert config.enable_governance is True
|
|
|
| def test_config_validation(self):
|
| """Test configuration validation."""
|
|
|
| with pytest.raises(Exception):
|
| ForgeConfig(model_name="test", version="invalid")
|
|
|
| def test_config_serialization(self):
|
| """Test configuration serialization."""
|
| config = ForgeConfig(model_name="test_model", version="1.0.0")
|
|
|
|
|
| config_dict = config.to_dict()
|
| assert config_dict["model_name"] == "test_model"
|
|
|
|
|
| json_str = config.to_json()
|
| assert "test_model" in json_str
|
|
|
|
|
| yaml_str = config.to_yaml()
|
| assert "test_model" in yaml_str
|
|
|
|
|
| class TestGovernance:
|
| """Test governance and compliance."""
|
|
|
| def test_compliance_checker(self):
|
| """Test compliance checking."""
|
| base_model = SimpleModel()
|
| config = ForgeConfig(
|
| model_name="test_model",
|
| version="1.0.0",
|
| enable_governance=True,
|
| enable_monitoring=True
|
| )
|
| model = ForgeModel(base_model, config=config)
|
|
|
| checker = ComplianceChecker(framework=NISTFramework.RMF_1_0)
|
| report = checker.assess_model(model)
|
|
|
| assert report.model_name == "test_model"
|
| assert report.overall_score >= 0
|
| assert report.overall_score <= 100
|
| assert len(report.checks) > 0
|
|
|
| def test_compliance_report_export(self):
|
| """Test compliance report export."""
|
| base_model = SimpleModel()
|
| config = ForgeConfig(
|
| model_name="test_model",
|
| version="1.0.0",
|
| enable_governance=True
|
| )
|
| model = ForgeModel(base_model, config=config)
|
|
|
| checker = ComplianceChecker()
|
| report = checker.assess_model(model)
|
|
|
| with tempfile.TemporaryDirectory() as tmpdir:
|
| json_path = Path(tmpdir) / "report.json"
|
| report.export_json(str(json_path))
|
| assert json_path.exists()
|
|
|
|
|
| class TestMonitoring:
|
| """Test monitoring functionality."""
|
|
|
| def test_model_monitor(self):
|
| """Test model monitor."""
|
| base_model = SimpleModel()
|
| config = ForgeConfig(
|
| model_name="test_model",
|
| version="1.0.0",
|
| enable_monitoring=True
|
| )
|
| model = ForgeModel(base_model, config=config)
|
|
|
| monitor = ModelMonitor(model)
|
| monitor.enable_drift_detection()
|
| monitor.enable_fairness_tracking()
|
|
|
| health = monitor.get_health_status()
|
| assert "status" in health
|
| assert health["drift_detection"] is True
|
| assert health["fairness_tracking"] is True
|
|
|
|
|
| class TestDeployment:
|
| """Test deployment functionality."""
|
|
|
| def test_deployment_manager(self):
|
| """Test deployment manager."""
|
| base_model = SimpleModel()
|
| config = ForgeConfig(model_name="test_model", version="1.0.0")
|
| model = ForgeModel(base_model, config=config)
|
|
|
| deployment = DeploymentManager(
|
| model=model,
|
| cloud_provider="aws",
|
| instance_type="ml.m5.large"
|
| )
|
|
|
| info = deployment.deploy(
|
| enable_autoscaling=True,
|
| min_instances=2,
|
| max_instances=10
|
| )
|
|
|
| assert info["status"] == "deployed"
|
| assert info["cloud_provider"] == "aws"
|
| assert info["autoscaling_enabled"] is True
|
|
|
| def test_deployment_metrics(self):
|
| """Test deployment metrics."""
|
| base_model = SimpleModel()
|
| config = ForgeConfig(model_name="test_model", version="1.0.0")
|
| model = ForgeModel(base_model, config=config)
|
|
|
| deployment = DeploymentManager(model=model)
|
| deployment.deploy()
|
|
|
| metrics = deployment.get_metrics(window="1h")
|
| assert hasattr(metrics, "latency_p95")
|
| assert hasattr(metrics, "requests_per_second")
|
|
|
|
|
| class TestIntegration:
|
| """Integration tests for complete workflows."""
|
|
|
| def test_end_to_end_workflow(self):
|
| """Test complete workflow from training to deployment."""
|
|
|
| base_model = SimpleModel()
|
| config = ForgeConfig(
|
| model_name="e2e_model",
|
| version="1.0.0",
|
| enable_governance=True,
|
| enable_monitoring=True,
|
| enable_optimization=True
|
| )
|
| model = ForgeModel(base_model, config=config)
|
|
|
|
|
| x = torch.randn(100, 10)
|
| y = torch.randint(0, 2, (100,))
|
|
|
| for i in range(5):
|
| output = model(x)
|
| model.track_prediction(output, y)
|
|
|
|
|
| checker = ComplianceChecker()
|
| report = checker.assess_model(model)
|
| assert report.overall_score > 0
|
|
|
|
|
| with tempfile.TemporaryDirectory() as tmpdir:
|
| checkpoint_path = Path(tmpdir) / "checkpoint.pt"
|
| model.save_checkpoint(checkpoint_path)
|
| assert checkpoint_path.exists()
|
|
|
|
|
| deployment = DeploymentManager(model=model)
|
| info = deployment.deploy()
|
| assert info["status"] == "deployed"
|
|
|
|
|
| if __name__ == "__main__":
|
| pytest.main([__file__, "-v"])
|
|
|