""" Load testing for ml-intern production API using locust. Usage: locust -f tests/load_test.py --host http://localhost:8000 """ import json import random import uuid from locust import HttpUser, task, between class ChatUser(HttpUser): wait_time = between(0.5, 2.0) def on_start(self): self.session_id = str(uuid.uuid4()) self.models = [ "groq/llama-3.3-70b-versatile", "groq/llama-3.1-8b-instant", "nim/llama-3-8b", "ollama/llama3.1", ] self.correlation_id = str(uuid.uuid4()) @task(10) def chat_completion(self): model = random.choice(self.models) payload = { "model": model, "messages": [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": f"Hello, this is test request {random.randint(1, 1000)}"} ], "temperature": 0.7, "max_tokens": 500, "stream": False, "session_id": self.session_id, } headers = { "Content-Type": "application/json", "X-Correlation-ID": self.correlation_id, } with self.client.post( "/v1/chat/completions", json=payload, headers=headers, catch_response=True, ) as response: if response.status_code == 200: data = response.json() if "content" in data or "id" in data: response.success() else: response.failure("Invalid response structure") elif response.status_code == 429: response.success() else: response.failure(f"Unexpected status: {response.status_code}") @task(1) def streaming_chat(self): model = random.choice(self.models) payload = { "model": model, "messages": [{"role": "user", "content": "Count to 10 slowly."}], "temperature": 0.7, "max_tokens": 500, "stream": True, "session_id": self.session_id, } headers = {"Content-Type": "application/json", "X-Correlation-ID": self.correlation_id} with self.client.post( "/v1/chat/completions", json=payload, headers=headers, catch_response=True, stream=True, ) as response: if response.status_code == 200: response.success() elif response.status_code == 429: response.success() else: response.failure(f"Unexpected status: {response.status_code}") @task(5) def health_check(self): with self.client.get("/health", catch_response=True) as response: if response.status_code == 200: data = response.json() if data.get("status") in ["healthy", "degraded"]: response.success() else: response.failure(f"Unhealthy status: {data.get('status')}") else: response.failure(f"Status: {response.status_code}") @task(2) def list_models(self): self.client.get("/v1/models") class BurstUser(HttpUser): wait_time = between(0, 0.1) def on_start(self): self.session_id = str(uuid.uuid4()) @task def rapid_requests(self): model = random.choice(["groq/llama-3.3-70b-versatile", "nim/llama-3-8b"]) payload = { "model": model, "messages": [{"role": "user", "content": "Quick test"}], "temperature": 0.7, "max_tokens": 100, "stream": False, "session_id": self.session_id, } with self.client.post( "/v1/chat/completions", json=payload, catch_response=True, ) as response: if response.status_code in [200, 429]: response.success() else: response.failure(f"Status: {response.status_code}") class CacheUser(HttpUser): wait_time = between(1, 3) def on_start(self): self.session_id = str(uuid.uuid4()) self.fixed_message = "What is the capital of France?" @task def repeated_query(self): payload = { "model": "groq/llama-3.3-70b-versatile", "messages": [{"role": "user", "content": self.fixed_message}], "temperature": 0.7, "max_tokens": 100, "stream": False, "session_id": self.session_id, } with self.client.post( "/v1/chat/completions", json=payload, catch_response=True, ) as response: if response.status_code == 200: data = response.json() if data.get("cached"): response.success() else: response.success() else: response.failure(f"Status: {response.status_code}")