fix: move global declarations before first use (grpo_train, call_gemini)
Browse files- agent/gemini_client.py +1 -3
- training/grpo_train.py +1 -1
agent/gemini_client.py
CHANGED
|
@@ -292,10 +292,10 @@ async def call_gemini(
|
|
| 292 |
Parsed dict with keys: utterance (str), offer_amount (float|None),
|
| 293 |
tactical_move (str|None). Returns SYNTHETIC_RESPONSE on any error.
|
| 294 |
"""
|
|
|
|
| 295 |
if _is_mock_mode():
|
| 296 |
return _get_mock_response(persona, len(messages), scenario_id)
|
| 297 |
|
| 298 |
-
global _gemini_model_logged
|
| 299 |
if not _gemini_model_logged:
|
| 300 |
logger.info(f"[Gemini] Using model: {GEMINI_MODEL}")
|
| 301 |
_gemini_model_logged = True
|
|
@@ -334,7 +334,6 @@ async def call_gemini(
|
|
| 334 |
try:
|
| 335 |
response = await loop.run_in_executor(None, _call)
|
| 336 |
|
| 337 |
-
global _live_calls, _turn_count
|
| 338 |
_turn_count += 1
|
| 339 |
_live_calls += 1
|
| 340 |
print(
|
|
@@ -372,7 +371,6 @@ async def call_gemini(
|
|
| 372 |
file=sys.stderr,
|
| 373 |
)
|
| 374 |
logger.warning("Gemini API / parse failed after retries — using text fallback")
|
| 375 |
-
global _fallback_calls
|
| 376 |
_fallback_calls += 1
|
| 377 |
if text:
|
| 378 |
return {**SYNTHETIC_RESPONSE, "utterance": text[:300]}
|
|
|
|
| 292 |
Parsed dict with keys: utterance (str), offer_amount (float|None),
|
| 293 |
tactical_move (str|None). Returns SYNTHETIC_RESPONSE on any error.
|
| 294 |
"""
|
| 295 |
+
global _gemini_model_logged, _live_calls, _turn_count, _fallback_calls
|
| 296 |
if _is_mock_mode():
|
| 297 |
return _get_mock_response(persona, len(messages), scenario_id)
|
| 298 |
|
|
|
|
| 299 |
if not _gemini_model_logged:
|
| 300 |
logger.info(f"[Gemini] Using model: {GEMINI_MODEL}")
|
| 301 |
_gemini_model_logged = True
|
|
|
|
| 334 |
try:
|
| 335 |
response = await loop.run_in_executor(None, _call)
|
| 336 |
|
|
|
|
| 337 |
_turn_count += 1
|
| 338 |
_live_calls += 1
|
| 339 |
print(
|
|
|
|
| 371 |
file=sys.stderr,
|
| 372 |
)
|
| 373 |
logger.warning("Gemini API / parse failed after retries — using text fallback")
|
|
|
|
| 374 |
_fallback_calls += 1
|
| 375 |
if text:
|
| 376 |
return {**SYNTHETIC_RESPONSE, "utterance": text[:300]}
|
training/grpo_train.py
CHANGED
|
@@ -173,6 +173,7 @@ def train_grpo(
|
|
| 173 |
|
| 174 |
|
| 175 |
def main() -> None:
|
|
|
|
| 176 |
parser = argparse.ArgumentParser(description="Parlay GRPO fine-tuning")
|
| 177 |
parser.add_argument("--model", default="models/parlay-sft")
|
| 178 |
parser.add_argument("--base_model", default="")
|
|
@@ -185,7 +186,6 @@ def main() -> None:
|
|
| 185 |
args = parser.parse_args()
|
| 186 |
|
| 187 |
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
|
| 188 |
-
global GRPO_GENERATIONS
|
| 189 |
GRPO_GENERATIONS = args.g
|
| 190 |
model_path = args.base_model or args.model
|
| 191 |
train_grpo(model_path, args.data, args.output, args.steps)
|
|
|
|
| 173 |
|
| 174 |
|
| 175 |
def main() -> None:
|
| 176 |
+
global GRPO_GENERATIONS
|
| 177 |
parser = argparse.ArgumentParser(description="Parlay GRPO fine-tuning")
|
| 178 |
parser.add_argument("--model", default="models/parlay-sft")
|
| 179 |
parser.add_argument("--base_model", default="")
|
|
|
|
| 186 |
args = parser.parse_args()
|
| 187 |
|
| 188 |
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
|
|
|
|
| 189 |
GRPO_GENERATIONS = args.g
|
| 190 |
model_path = args.base_model or args.model
|
| 191 |
train_grpo(model_path, args.data, args.output, args.steps)
|