| """ |
| Load testing for ml-intern production API using locust. |
| |
| Usage: |
| locust -f tests/load_test.py --host http://localhost:8000 |
| """ |
|
|
| import json |
| import random |
| import uuid |
|
|
| from locust import HttpUser, task, between |
|
|
|
|
| class ChatUser(HttpUser): |
| wait_time = between(0.5, 2.0) |
| |
| def on_start(self): |
| self.session_id = str(uuid.uuid4()) |
| self.models = [ |
| "groq/llama-3.3-70b-versatile", |
| "groq/llama-3.1-8b-instant", |
| "nim/llama-3-8b", |
| "ollama/llama3.1", |
| ] |
| self.correlation_id = str(uuid.uuid4()) |
| |
| @task(10) |
| def chat_completion(self): |
| model = random.choice(self.models) |
| payload = { |
| "model": model, |
| "messages": [ |
| {"role": "system", "content": "You are a helpful assistant."}, |
| {"role": "user", "content": f"Hello, this is test request {random.randint(1, 1000)}"} |
| ], |
| "temperature": 0.7, |
| "max_tokens": 500, |
| "stream": False, |
| "session_id": self.session_id, |
| } |
| headers = { |
| "Content-Type": "application/json", |
| "X-Correlation-ID": self.correlation_id, |
| } |
| with self.client.post( |
| "/v1/chat/completions", |
| json=payload, |
| headers=headers, |
| catch_response=True, |
| ) as response: |
| if response.status_code == 200: |
| data = response.json() |
| if "content" in data or "id" in data: |
| response.success() |
| else: |
| response.failure("Invalid response structure") |
| elif response.status_code == 429: |
| response.success() |
| else: |
| response.failure(f"Unexpected status: {response.status_code}") |
| |
| @task(1) |
| def streaming_chat(self): |
| model = random.choice(self.models) |
| payload = { |
| "model": model, |
| "messages": [{"role": "user", "content": "Count to 10 slowly."}], |
| "temperature": 0.7, |
| "max_tokens": 500, |
| "stream": True, |
| "session_id": self.session_id, |
| } |
| headers = {"Content-Type": "application/json", "X-Correlation-ID": self.correlation_id} |
| with self.client.post( |
| "/v1/chat/completions", |
| json=payload, |
| headers=headers, |
| catch_response=True, |
| stream=True, |
| ) as response: |
| if response.status_code == 200: |
| response.success() |
| elif response.status_code == 429: |
| response.success() |
| else: |
| response.failure(f"Unexpected status: {response.status_code}") |
| |
| @task(5) |
| def health_check(self): |
| with self.client.get("/health", catch_response=True) as response: |
| if response.status_code == 200: |
| data = response.json() |
| if data.get("status") in ["healthy", "degraded"]: |
| response.success() |
| else: |
| response.failure(f"Unhealthy status: {data.get('status')}") |
| else: |
| response.failure(f"Status: {response.status_code}") |
| |
| @task(2) |
| def list_models(self): |
| self.client.get("/v1/models") |
|
|
|
|
| class BurstUser(HttpUser): |
| wait_time = between(0, 0.1) |
| |
| def on_start(self): |
| self.session_id = str(uuid.uuid4()) |
| |
| @task |
| def rapid_requests(self): |
| model = random.choice(["groq/llama-3.3-70b-versatile", "nim/llama-3-8b"]) |
| payload = { |
| "model": model, |
| "messages": [{"role": "user", "content": "Quick test"}], |
| "temperature": 0.7, |
| "max_tokens": 100, |
| "stream": False, |
| "session_id": self.session_id, |
| } |
| with self.client.post( |
| "/v1/chat/completions", |
| json=payload, |
| catch_response=True, |
| ) as response: |
| if response.status_code in [200, 429]: |
| response.success() |
| else: |
| response.failure(f"Status: {response.status_code}") |
|
|
|
|
| class CacheUser(HttpUser): |
| wait_time = between(1, 3) |
| |
| def on_start(self): |
| self.session_id = str(uuid.uuid4()) |
| self.fixed_message = "What is the capital of France?" |
| |
| @task |
| def repeated_query(self): |
| payload = { |
| "model": "groq/llama-3.3-70b-versatile", |
| "messages": [{"role": "user", "content": self.fixed_message}], |
| "temperature": 0.7, |
| "max_tokens": 100, |
| "stream": False, |
| "session_id": self.session_id, |
| } |
| with self.client.post( |
| "/v1/chat/completions", |
| json=payload, |
| catch_response=True, |
| ) as response: |
| if response.status_code == 200: |
| data = response.json() |
| if data.get("cached"): |
| response.success() |
| else: |
| response.success() |
| else: |
| response.failure(f"Status: {response.status_code}") |
|
|