3v324v23 commited on
Commit
9ea8001
·
1 Parent(s): 9f531ed

test: Add unit and integration tests with Pytest

Browse files
Files changed (4) hide show
  1. pytest.ini +4 -0
  2. tests/__init__.py +0 -0
  3. tests/test_api.py +108 -0
  4. tests/test_unit.py +51 -0
pytest.ini ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ [pytest]
2
+ addopts = -v --cov=app --cov-report=term-missing
3
+ testpaths = tests
4
+ python_files = test_*.py
tests/__init__.py ADDED
File without changes
tests/test_api.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from fastapi.testclient import TestClient
3
+ from sqlalchemy import create_engine
4
+ from sqlalchemy.orm import sessionmaker
5
+ from sqlalchemy.pool import StaticPool
6
+
7
+ from app.main import app
8
+ from app.core.database import Base, get_db
9
+ from app.core.config import settings
10
+
11
+ # Setup in-memory SQLite database for testing
12
+ SQLALCHEMY_DATABASE_URL = "sqlite:///:memory:"
13
+
14
+ engine = create_engine(
15
+ SQLALCHEMY_DATABASE_URL,
16
+ connect_args={"check_same_thread": False},
17
+ poolclass=StaticPool,
18
+ )
19
+ TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
20
+
21
+ def override_get_db():
22
+ try:
23
+ db = TestingSessionLocal()
24
+ yield db
25
+ finally:
26
+ db.close()
27
+
28
+ app.dependency_overrides[get_db] = override_get_db
29
+
30
+ client = TestClient(app)
31
+
32
+ @pytest.fixture(scope="module")
33
+ def test_db():
34
+ # Create tables
35
+ Base.metadata.create_all(bind=engine)
36
+ yield
37
+ # Drop tables
38
+ Base.metadata.drop_all(bind=engine)
39
+
40
+ def test_health_check():
41
+ response = client.get("/health")
42
+ assert response.status_code == 200
43
+ assert response.json() == {"status": "healthy"}
44
+
45
+ def test_predict_unauthorized():
46
+ response = client.post("/predict", json={})
47
+ assert response.status_code == 422
48
+
49
+ def test_predict_invalid_key():
50
+ response = client.post("/predict", headers={"X-API-KEY": "wrong_key"}, json={})
51
+ assert response.status_code == 401
52
+
53
+ def test_predict_success(test_db):
54
+ # Valid input data based on updated schema
55
+ payload = {
56
+ "age": 30,
57
+ "genre": "M",
58
+ "revenu_mensuel": 5000,
59
+ "statut_marital": "Célibataire",
60
+ "departement": "R&D",
61
+ "poste": "Ingénieur",
62
+ "nombre_experiences_precedentes": 2,
63
+ "nombre_heures_travailless": 40,
64
+ "annee_experience_totale": 5,
65
+ "annees_dans_l_entreprise": 2,
66
+ "annees_dans_le_poste_actuel": 1,
67
+ "satisfaction_employee_environnement": 3,
68
+ "note_evaluation_precedente": 3,
69
+ "niveau_hierarchique_poste": 2,
70
+ "satisfaction_employee_nature_travail": 3,
71
+ "satisfaction_employee_equipe": 4,
72
+ "satisfaction_employee_equilibre_pro_perso": 3,
73
+ "note_evaluation_actuelle": 3,
74
+ "heure_supplementaires": "Non",
75
+ "augementation_salaire_precedente": "10-15%",
76
+ "nombre_participation_pee": 0,
77
+ "nb_formations_suivies": 1,
78
+ "nombre_employee_sous_responsabilite": 0,
79
+ "distance_domicile_travail": 10,
80
+ "niveau_education": 3,
81
+ "domaine_etude": "Sciences",
82
+ "ayant_enfants": "Non",
83
+ "frequence_deplacement": "Rare",
84
+ "annees_depuis_la_derniere_promotion": 1,
85
+ "annes_sous_responsable_actuel": 1
86
+ }
87
+
88
+ response = client.post(
89
+ "/predict",
90
+ headers={"X-API-KEY": settings.API_KEY},
91
+ json=payload
92
+ )
93
+
94
+ if response.status_code != 200:
95
+ print(f"Response error: {response.json()}")
96
+
97
+ assert response.status_code == 200
98
+ data = response.json()
99
+ assert "prediction" in data
100
+ assert "probability" in data
101
+
102
+ # Verify DB insertion
103
+ db = TestingSessionLocal()
104
+ log = db.query(Base.metadata.tables["prediction_logs"]).first()
105
+ assert log is not None
106
+ assert log.age == 30
107
+ assert log.genre == "M"
108
+ db.close()
tests/test_unit.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from app.models.schemas import InputSchema
3
+ from app.services.ml_service import ml_service
4
+
5
+ def test_input_schema_validation():
6
+ # Test valid data
7
+ valid_data = {
8
+ "age": 30,
9
+ "genre": "M",
10
+ "revenu_mensuel": 5000,
11
+ "statut_marital": "Célibataire",
12
+ "departement": "R&D",
13
+ "poste": "Ingénieur",
14
+ "nombre_experiences_precedentes": 2,
15
+ "nombre_heures_travailless": 40,
16
+ "annee_experience_totale": 5,
17
+ "annees_dans_l_entreprise": 2,
18
+ "annees_dans_le_poste_actuel": 1,
19
+ "satisfaction_employee_environnement": 3,
20
+ "note_evaluation_precedente": 3,
21
+ "niveau_hierarchique_poste": 2,
22
+ "satisfaction_employee_nature_travail": 3,
23
+ "satisfaction_employee_equipe": 4,
24
+ "satisfaction_employee_equilibre_pro_perso": 3,
25
+ "note_evaluation_actuelle": 3,
26
+ "heure_supplementaires": "Non",
27
+ "augementation_salaire_precedente": "10-15%",
28
+ "nombre_participation_pee": 0,
29
+ "nb_formations_suivies": 1,
30
+ "nombre_employee_sous_responsabilite": 0,
31
+ "distance_domicile_travail": 10,
32
+ "niveau_education": 3,
33
+ "domaine_etude": "Sciences",
34
+ "ayant_enfants": "Non",
35
+ "frequence_deplacement": "Rare",
36
+ "annees_depuis_la_derniere_promotion": 1,
37
+ "annes_sous_responsable_actuel": 1
38
+ }
39
+ schema = InputSchema(**valid_data)
40
+ assert schema.age == 30
41
+ assert schema.genre == "M"
42
+
43
+ # Test invalid data (missing field)
44
+ invalid_data = valid_data.copy()
45
+ del invalid_data["age"]
46
+ with pytest.raises(ValueError):
47
+ InputSchema(**invalid_data)
48
+
49
+ def test_model_loading():
50
+ assert ml_service.model is not None
51
+ assert ml_service.expected_features is not None