sh4shv4t commited on
Commit
8ec5193
·
1 Parent(s): 48756ef

fix: move global declarations before first use (grpo_train, call_gemini)

Browse files
Files changed (2) hide show
  1. agent/gemini_client.py +1 -3
  2. training/grpo_train.py +1 -1
agent/gemini_client.py CHANGED
@@ -292,10 +292,10 @@ async def call_gemini(
292
  Parsed dict with keys: utterance (str), offer_amount (float|None),
293
  tactical_move (str|None). Returns SYNTHETIC_RESPONSE on any error.
294
  """
 
295
  if _is_mock_mode():
296
  return _get_mock_response(persona, len(messages), scenario_id)
297
 
298
- global _gemini_model_logged
299
  if not _gemini_model_logged:
300
  logger.info(f"[Gemini] Using model: {GEMINI_MODEL}")
301
  _gemini_model_logged = True
@@ -334,7 +334,6 @@ async def call_gemini(
334
  try:
335
  response = await loop.run_in_executor(None, _call)
336
 
337
- global _live_calls, _turn_count
338
  _turn_count += 1
339
  _live_calls += 1
340
  print(
@@ -372,7 +371,6 @@ async def call_gemini(
372
  file=sys.stderr,
373
  )
374
  logger.warning("Gemini API / parse failed after retries — using text fallback")
375
- global _fallback_calls
376
  _fallback_calls += 1
377
  if text:
378
  return {**SYNTHETIC_RESPONSE, "utterance": text[:300]}
 
292
  Parsed dict with keys: utterance (str), offer_amount (float|None),
293
  tactical_move (str|None). Returns SYNTHETIC_RESPONSE on any error.
294
  """
295
+ global _gemini_model_logged, _live_calls, _turn_count, _fallback_calls
296
  if _is_mock_mode():
297
  return _get_mock_response(persona, len(messages), scenario_id)
298
 
 
299
  if not _gemini_model_logged:
300
  logger.info(f"[Gemini] Using model: {GEMINI_MODEL}")
301
  _gemini_model_logged = True
 
334
  try:
335
  response = await loop.run_in_executor(None, _call)
336
 
 
337
  _turn_count += 1
338
  _live_calls += 1
339
  print(
 
371
  file=sys.stderr,
372
  )
373
  logger.warning("Gemini API / parse failed after retries — using text fallback")
 
374
  _fallback_calls += 1
375
  if text:
376
  return {**SYNTHETIC_RESPONSE, "utterance": text[:300]}
training/grpo_train.py CHANGED
@@ -173,6 +173,7 @@ def train_grpo(
173
 
174
 
175
  def main() -> None:
 
176
  parser = argparse.ArgumentParser(description="Parlay GRPO fine-tuning")
177
  parser.add_argument("--model", default="models/parlay-sft")
178
  parser.add_argument("--base_model", default="")
@@ -185,7 +186,6 @@ def main() -> None:
185
  args = parser.parse_args()
186
 
187
  logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
188
- global GRPO_GENERATIONS
189
  GRPO_GENERATIONS = args.g
190
  model_path = args.base_model or args.model
191
  train_grpo(model_path, args.data, args.output, args.steps)
 
173
 
174
 
175
  def main() -> None:
176
+ global GRPO_GENERATIONS
177
  parser = argparse.ArgumentParser(description="Parlay GRPO fine-tuning")
178
  parser.add_argument("--model", default="models/parlay-sft")
179
  parser.add_argument("--base_model", default="")
 
186
  args = parser.parse_args()
187
 
188
  logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
 
189
  GRPO_GENERATIONS = args.g
190
  model_path = args.base_model or args.model
191
  train_grpo(model_path, args.data, args.output, args.steps)