muskan singh commited on
Commit
a35bcd0
·
1 Parent(s): ef4ebed

gemma fix

Browse files
Files changed (1) hide show
  1. inference.py +22 -8
inference.py CHANGED
@@ -201,20 +201,24 @@ def run_workflow(workflow_id: str) -> float:
201
  obs_text = obs_to_text(obs)
202
  history.append({"role": "user", "content": obs_text})
203
 
204
- # Trim history to avoid context overflow
205
- # if len(history) > 20:
206
- # history = history[-20:]
207
- # Trim history — always keep an even number so roles alternate correctly
208
  if len(history) > 20:
209
  history = history[-20:]
210
- # Ensure history starts with a user message (Gemma requires strict alternation)
211
  if history and history[0]["role"] != "user":
212
- history = history[1:]
 
 
 
 
 
 
 
 
213
 
214
  try:
215
  response = llm_client.chat.completions.create(
216
  model = MODEL_NAME,
217
- messages = [{"role": "system", "content": SYSTEM_PROMPT}] + history,
218
  temperature = 0.0,
219
  max_tokens = 300,
220
  )
@@ -331,11 +335,21 @@ async def run_workflow_generator(
331
  history.append({"role": "user", "content": obs_text})
332
  if len(history) > 20:
333
  history = history[-20:]
 
 
 
 
 
 
 
 
 
 
334
 
335
  try:
336
  response = llm_client.chat.completions.create(
337
  model = MODEL_NAME,
338
- messages = [{"role": "system", "content": SYSTEM_PROMPT}] + history,
339
  temperature = 0.0,
340
  max_tokens = 300,
341
  )
 
201
  obs_text = obs_to_text(obs)
202
  history.append({"role": "user", "content": obs_text})
203
 
204
+ # Trim history keep last 20, ensure it starts with a user message
 
 
 
205
  if len(history) > 20:
206
  history = history[-20:]
 
207
  if history and history[0]["role"] != "user":
208
+ history = history[1:]
209
+
210
+ # Inject system prompt into first user message (Gemma/models without system role)
211
+ messages_for_llm = list(history)
212
+ if messages_for_llm:
213
+ messages_for_llm[0] = {
214
+ "role": "user",
215
+ "content": SYSTEM_PROMPT + "\n\n---\n\n" + messages_for_llm[0]["content"],
216
+ }
217
 
218
  try:
219
  response = llm_client.chat.completions.create(
220
  model = MODEL_NAME,
221
+ messages = messages_for_llm,
222
  temperature = 0.0,
223
  max_tokens = 300,
224
  )
 
335
  history.append({"role": "user", "content": obs_text})
336
  if len(history) > 20:
337
  history = history[-20:]
338
+ if history and history[0]["role"] != "user":
339
+ history = history[1:]
340
+
341
+ # Inject system prompt into first user message (Gemma/models without system role)
342
+ messages_for_llm = list(history)
343
+ if messages_for_llm:
344
+ messages_for_llm[0] = {
345
+ "role": "user",
346
+ "content": SYSTEM_PROMPT + "\n\n---\n\n" + messages_for_llm[0]["content"],
347
+ }
348
 
349
  try:
350
  response = llm_client.chat.completions.create(
351
  model = MODEL_NAME,
352
+ messages = messages_for_llm,
353
  temperature = 0.0,
354
  max_tokens = 300,
355
  )