raazkumar commited on
Commit
a32fac2
·
verified ·
1 Parent(s): 7f8d8d7

Upload production/tests/load_test.py

Browse files
Files changed (1) hide show
  1. production/tests/load_test.py +163 -0
production/tests/load_test.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Load testing for ml-intern production API using locust.
3
+
4
+ Usage:
5
+ locust -f tests/load_test.py --host http://localhost:8000
6
+ """
7
+
8
+ import json
9
+ import random
10
+ import uuid
11
+
12
+ from locust import HttpUser, task, between
13
+
14
+
15
+ class ChatUser(HttpUser):
16
+ wait_time = between(0.5, 2.0)
17
+
18
+ def on_start(self):
19
+ self.session_id = str(uuid.uuid4())
20
+ self.models = [
21
+ "groq/llama-3.3-70b-versatile",
22
+ "groq/llama-3.1-8b-instant",
23
+ "nim/llama-3-8b",
24
+ "ollama/llama3.1",
25
+ ]
26
+ self.correlation_id = str(uuid.uuid4())
27
+
28
+ @task(10)
29
+ def chat_completion(self):
30
+ model = random.choice(self.models)
31
+ payload = {
32
+ "model": model,
33
+ "messages": [
34
+ {"role": "system", "content": "You are a helpful assistant."},
35
+ {"role": "user", "content": f"Hello, this is test request {random.randint(1, 1000)}"}
36
+ ],
37
+ "temperature": 0.7,
38
+ "max_tokens": 500,
39
+ "stream": False,
40
+ "session_id": self.session_id,
41
+ }
42
+ headers = {
43
+ "Content-Type": "application/json",
44
+ "X-Correlation-ID": self.correlation_id,
45
+ }
46
+ with self.client.post(
47
+ "/v1/chat/completions",
48
+ json=payload,
49
+ headers=headers,
50
+ catch_response=True,
51
+ ) as response:
52
+ if response.status_code == 200:
53
+ data = response.json()
54
+ if "content" in data or "id" in data:
55
+ response.success()
56
+ else:
57
+ response.failure("Invalid response structure")
58
+ elif response.status_code == 429:
59
+ response.success()
60
+ else:
61
+ response.failure(f"Unexpected status: {response.status_code}")
62
+
63
+ @task(1)
64
+ def streaming_chat(self):
65
+ model = random.choice(self.models)
66
+ payload = {
67
+ "model": model,
68
+ "messages": [{"role": "user", "content": "Count to 10 slowly."}],
69
+ "temperature": 0.7,
70
+ "max_tokens": 500,
71
+ "stream": True,
72
+ "session_id": self.session_id,
73
+ }
74
+ headers = {"Content-Type": "application/json", "X-Correlation-ID": self.correlation_id}
75
+ with self.client.post(
76
+ "/v1/chat/completions",
77
+ json=payload,
78
+ headers=headers,
79
+ catch_response=True,
80
+ stream=True,
81
+ ) as response:
82
+ if response.status_code == 200:
83
+ response.success()
84
+ elif response.status_code == 429:
85
+ response.success()
86
+ else:
87
+ response.failure(f"Unexpected status: {response.status_code}")
88
+
89
+ @task(5)
90
+ def health_check(self):
91
+ with self.client.get("/health", catch_response=True) as response:
92
+ if response.status_code == 200:
93
+ data = response.json()
94
+ if data.get("status") in ["healthy", "degraded"]:
95
+ response.success()
96
+ else:
97
+ response.failure(f"Unhealthy status: {data.get('status')}")
98
+ else:
99
+ response.failure(f"Status: {response.status_code}")
100
+
101
+ @task(2)
102
+ def list_models(self):
103
+ self.client.get("/v1/models")
104
+
105
+
106
+ class BurstUser(HttpUser):
107
+ wait_time = between(0, 0.1)
108
+
109
+ def on_start(self):
110
+ self.session_id = str(uuid.uuid4())
111
+
112
+ @task
113
+ def rapid_requests(self):
114
+ model = random.choice(["groq/llama-3.3-70b-versatile", "nim/llama-3-8b"])
115
+ payload = {
116
+ "model": model,
117
+ "messages": [{"role": "user", "content": "Quick test"}],
118
+ "temperature": 0.7,
119
+ "max_tokens": 100,
120
+ "stream": False,
121
+ "session_id": self.session_id,
122
+ }
123
+ with self.client.post(
124
+ "/v1/chat/completions",
125
+ json=payload,
126
+ catch_response=True,
127
+ ) as response:
128
+ if response.status_code in [200, 429]:
129
+ response.success()
130
+ else:
131
+ response.failure(f"Status: {response.status_code}")
132
+
133
+
134
+ class CacheUser(HttpUser):
135
+ wait_time = between(1, 3)
136
+
137
+ def on_start(self):
138
+ self.session_id = str(uuid.uuid4())
139
+ self.fixed_message = "What is the capital of France?"
140
+
141
+ @task
142
+ def repeated_query(self):
143
+ payload = {
144
+ "model": "groq/llama-3.3-70b-versatile",
145
+ "messages": [{"role": "user", "content": self.fixed_message}],
146
+ "temperature": 0.7,
147
+ "max_tokens": 100,
148
+ "stream": False,
149
+ "session_id": self.session_id,
150
+ }
151
+ with self.client.post(
152
+ "/v1/chat/completions",
153
+ json=payload,
154
+ catch_response=True,
155
+ ) as response:
156
+ if response.status_code == 200:
157
+ data = response.json()
158
+ if data.get("cached"):
159
+ response.success()
160
+ else:
161
+ response.success()
162
+ else:
163
+ response.failure(f"Status: {response.status_code}")