raazkumar's picture
Upload production/tests/load_test.py
a32fac2 verified
"""
Load testing for ml-intern production API using locust.
Usage:
locust -f tests/load_test.py --host http://localhost:8000
"""
import json
import random
import uuid
from locust import HttpUser, task, between
class ChatUser(HttpUser):
wait_time = between(0.5, 2.0)
def on_start(self):
self.session_id = str(uuid.uuid4())
self.models = [
"groq/llama-3.3-70b-versatile",
"groq/llama-3.1-8b-instant",
"nim/llama-3-8b",
"ollama/llama3.1",
]
self.correlation_id = str(uuid.uuid4())
@task(10)
def chat_completion(self):
model = random.choice(self.models)
payload = {
"model": model,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": f"Hello, this is test request {random.randint(1, 1000)}"}
],
"temperature": 0.7,
"max_tokens": 500,
"stream": False,
"session_id": self.session_id,
}
headers = {
"Content-Type": "application/json",
"X-Correlation-ID": self.correlation_id,
}
with self.client.post(
"/v1/chat/completions",
json=payload,
headers=headers,
catch_response=True,
) as response:
if response.status_code == 200:
data = response.json()
if "content" in data or "id" in data:
response.success()
else:
response.failure("Invalid response structure")
elif response.status_code == 429:
response.success()
else:
response.failure(f"Unexpected status: {response.status_code}")
@task(1)
def streaming_chat(self):
model = random.choice(self.models)
payload = {
"model": model,
"messages": [{"role": "user", "content": "Count to 10 slowly."}],
"temperature": 0.7,
"max_tokens": 500,
"stream": True,
"session_id": self.session_id,
}
headers = {"Content-Type": "application/json", "X-Correlation-ID": self.correlation_id}
with self.client.post(
"/v1/chat/completions",
json=payload,
headers=headers,
catch_response=True,
stream=True,
) as response:
if response.status_code == 200:
response.success()
elif response.status_code == 429:
response.success()
else:
response.failure(f"Unexpected status: {response.status_code}")
@task(5)
def health_check(self):
with self.client.get("/health", catch_response=True) as response:
if response.status_code == 200:
data = response.json()
if data.get("status") in ["healthy", "degraded"]:
response.success()
else:
response.failure(f"Unhealthy status: {data.get('status')}")
else:
response.failure(f"Status: {response.status_code}")
@task(2)
def list_models(self):
self.client.get("/v1/models")
class BurstUser(HttpUser):
wait_time = between(0, 0.1)
def on_start(self):
self.session_id = str(uuid.uuid4())
@task
def rapid_requests(self):
model = random.choice(["groq/llama-3.3-70b-versatile", "nim/llama-3-8b"])
payload = {
"model": model,
"messages": [{"role": "user", "content": "Quick test"}],
"temperature": 0.7,
"max_tokens": 100,
"stream": False,
"session_id": self.session_id,
}
with self.client.post(
"/v1/chat/completions",
json=payload,
catch_response=True,
) as response:
if response.status_code in [200, 429]:
response.success()
else:
response.failure(f"Status: {response.status_code}")
class CacheUser(HttpUser):
wait_time = between(1, 3)
def on_start(self):
self.session_id = str(uuid.uuid4())
self.fixed_message = "What is the capital of France?"
@task
def repeated_query(self):
payload = {
"model": "groq/llama-3.3-70b-versatile",
"messages": [{"role": "user", "content": self.fixed_message}],
"temperature": 0.7,
"max_tokens": 100,
"stream": False,
"session_id": self.session_id,
}
with self.client.post(
"/v1/chat/completions",
json=payload,
catch_response=True,
) as response:
if response.status_code == 200:
data = response.json()
if data.get("cached"):
response.success()
else:
response.success()
else:
response.failure(f"Status: {response.status_code}")