Lzy01241010 commited on
Commit
c412028
·
1 Parent(s): 97b3442

agent: Azure OpenAI support for visit extractor + condenser

Browse files

Mirrors inference/tool_visit.py: when AZURE_OPENAI_ENDPOINT is set we use
AzureOpenAI() (with AZURE_OPENAI_API_VERSION) and the effective model name
becomes AZURE_OPENAI_DEPLOYMENT, otherwise we fall back to the plain
OpenAI client + SUMMARY_MODEL_NAME / MEMORY_MODEL_NAME. Condenser trigger
gate accepts either MEMORY_MODEL_NAME or AZURE_OPENAI_DEPLOYMENT.

Files changed (1) hide show
  1. app.py +64 -19
app.py CHANGED
@@ -1567,16 +1567,52 @@ MEMORY_TOKEN_THRESHOLD = int(
1567
  or "16000"
1568
  )
1569
 
1570
-
1571
- def _get_openai_client(api_key: str, base_url: Optional[str]):
1572
- """Lazy import so the Space still imports if `openai` isn't installed yet."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1573
  try:
1574
  from openai import OpenAI
1575
  except Exception:
1576
- return None
1577
  if not api_key:
1578
- return None
1579
- return OpenAI(api_key=api_key, base_url=base_url) if base_url else OpenAI(api_key=api_key)
 
 
 
 
 
 
 
 
 
 
 
1580
 
1581
 
1582
  def _approx_token_count(text: str) -> int:
@@ -1616,22 +1652,28 @@ _LAST_EXTRACT_ERR: Optional[str] = None
1616
 
1617
  def _llm_extract(webpage_content: str, goal: str) -> Optional[str]:
1618
  """Run the SUMMARY model as the visit extractor. Mirrors
1619
- inference/prompt.py:build_visit_extractor_messages + tool_visit's call."""
 
1620
  global _LAST_EXTRACT_ERR
1621
  _LAST_EXTRACT_ERR = None
1622
- if not SUMMARY_MODEL_NAME:
1623
- _LAST_EXTRACT_ERR = "SUMMARY_MODEL_NAME env var not set"
1624
- return None
1625
  if not SUMMARY_API_KEY:
1626
- _LAST_EXTRACT_ERR = "API_KEY / SUMMARY_OPENAI_API_KEY env var not set"
1627
  return None
1628
- client = _get_openai_client(SUMMARY_API_KEY, SUMMARY_API_BASE)
 
 
1629
  if client is None:
1630
  _LAST_EXTRACT_ERR = "openai client could not be constructed (package missing?)"
1631
  return None
 
 
 
 
 
 
1632
  try:
1633
  resp = client.chat.completions.create(
1634
- model=SUMMARY_MODEL_NAME,
1635
  messages=[
1636
  {
1637
  "role": "user",
@@ -1650,9 +1692,12 @@ def _llm_extract(webpage_content: str, goal: str) -> Optional[str]:
1650
 
1651
  def _llm_condense(events_text: str, prev_state: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
1652
  """Run the MEMORY model as the State Summarizer. Returns a parsed JSON
1653
- state dict, or None if condensation failed."""
1654
- client = _get_openai_client(MEMORY_API_KEY, MEMORY_API_BASE)
1655
- if client is None or not MEMORY_MODEL_NAME:
 
 
 
1656
  return None
1657
  user_payload = json.dumps(
1658
  {
@@ -1663,7 +1708,7 @@ def _llm_condense(events_text: str, prev_state: Optional[Dict[str, Any]]) -> Opt
1663
  )
1664
  try:
1665
  resp = client.chat.completions.create(
1666
- model=MEMORY_MODEL_NAME,
1667
  messages=[
1668
  {"role": "system", "content": MEMORY_SYSTEM_PROMPT},
1669
  {"role": "user", "content": user_payload},
@@ -1980,7 +2025,7 @@ def build_research_agent(
1980
  # context as [system, original_question, RESEARCH_STATE_SUMMARY].
1981
  if (
1982
  strategy == "condenser"
1983
- and MEMORY_MODEL_NAME
1984
  and MEMORY_API_KEY
1985
  and turn > 1
1986
  and _messages_token_count(messages) > MEMORY_TOKEN_THRESHOLD
@@ -2016,7 +2061,7 @@ def build_research_agent(
2016
  yield _emit()
2017
  elif (
2018
  strategy == "condenser"
2019
- and (not MEMORY_MODEL_NAME or not MEMORY_API_KEY)
2020
  and state.trusted_notes
2021
  and turn > 1
2022
  and turn % 3 == 0
 
1567
  or "16000"
1568
  )
1569
 
1570
+ # Azure OpenAI support — mirrors inference/tool_visit.py logic. When
1571
+ # AZURE_OPENAI_ENDPOINT is set, we use AzureOpenAI() instead of OpenAI()
1572
+ # and AZURE_OPENAI_DEPLOYMENT overrides the per-purpose model name.
1573
+ AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT", "").strip()
1574
+ AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION", "").strip() or "2024-06-01"
1575
+ AZURE_OPENAI_DEPLOYMENT = os.getenv("AZURE_OPENAI_DEPLOYMENT", "").strip()
1576
+
1577
+
1578
+ def _get_chat_client_and_model(
1579
+ api_key: str, base_url: Optional[str], fallback_model_name: str
1580
+ ) -> Tuple[Optional[Any], str]:
1581
+ """Construct an OpenAI-compatible chat client. Auto-switches to
1582
+ AzureOpenAI when AZURE_OPENAI_ENDPOINT is configured; in that case the
1583
+ effective model name becomes AZURE_OPENAI_DEPLOYMENT (Azure uses
1584
+ deployment names, not raw model ids). Returns (client, model_name)."""
1585
+ if AZURE_OPENAI_ENDPOINT:
1586
+ try:
1587
+ from openai import AzureOpenAI
1588
+ except Exception:
1589
+ return None, fallback_model_name
1590
+ if not api_key:
1591
+ return None, fallback_model_name
1592
+ client = AzureOpenAI(
1593
+ api_key=api_key,
1594
+ azure_endpoint=AZURE_OPENAI_ENDPOINT,
1595
+ api_version=AZURE_OPENAI_API_VERSION,
1596
+ )
1597
+ return client, (AZURE_OPENAI_DEPLOYMENT or fallback_model_name)
1598
  try:
1599
  from openai import OpenAI
1600
  except Exception:
1601
+ return None, fallback_model_name
1602
  if not api_key:
1603
+ return None, fallback_model_name
1604
+ client = (
1605
+ OpenAI(api_key=api_key, base_url=base_url) if base_url else OpenAI(api_key=api_key)
1606
+ )
1607
+ return client, fallback_model_name
1608
+
1609
+
1610
+ # Backwards-compat shim: older callers asked for just a client without
1611
+ # Azure awareness. Keep returning the non-Azure client so we don't break
1612
+ # anything if a future patch imports it.
1613
+ def _get_openai_client(api_key: str, base_url: Optional[str]):
1614
+ client, _ = _get_chat_client_and_model(api_key, base_url, fallback_model_name="")
1615
+ return client
1616
 
1617
 
1618
  def _approx_token_count(text: str) -> int:
 
1652
 
1653
  def _llm_extract(webpage_content: str, goal: str) -> Optional[str]:
1654
  """Run the SUMMARY model as the visit extractor. Mirrors
1655
+ inference/prompt.py:build_visit_extractor_messages + tool_visit's call.
1656
+ Picks AzureOpenAI when AZURE_OPENAI_ENDPOINT is set."""
1657
  global _LAST_EXTRACT_ERR
1658
  _LAST_EXTRACT_ERR = None
 
 
 
1659
  if not SUMMARY_API_KEY:
1660
+ _LAST_EXTRACT_ERR = "API_KEY env var not set"
1661
  return None
1662
+ client, model_name = _get_chat_client_and_model(
1663
+ SUMMARY_API_KEY, SUMMARY_API_BASE, SUMMARY_MODEL_NAME
1664
+ )
1665
  if client is None:
1666
  _LAST_EXTRACT_ERR = "openai client could not be constructed (package missing?)"
1667
  return None
1668
+ if not model_name:
1669
+ _LAST_EXTRACT_ERR = (
1670
+ "no model name (set SUMMARY_MODEL_NAME or, on Azure, "
1671
+ "AZURE_OPENAI_DEPLOYMENT)"
1672
+ )
1673
+ return None
1674
  try:
1675
  resp = client.chat.completions.create(
1676
+ model=model_name,
1677
  messages=[
1678
  {
1679
  "role": "user",
 
1692
 
1693
  def _llm_condense(events_text: str, prev_state: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
1694
  """Run the MEMORY model as the State Summarizer. Returns a parsed JSON
1695
+ state dict, or None if condensation failed. Picks AzureOpenAI when
1696
+ AZURE_OPENAI_ENDPOINT is set."""
1697
+ client, model_name = _get_chat_client_and_model(
1698
+ MEMORY_API_KEY, MEMORY_API_BASE, MEMORY_MODEL_NAME
1699
+ )
1700
+ if client is None or not model_name:
1701
  return None
1702
  user_payload = json.dumps(
1703
  {
 
1708
  )
1709
  try:
1710
  resp = client.chat.completions.create(
1711
+ model=model_name,
1712
  messages=[
1713
  {"role": "system", "content": MEMORY_SYSTEM_PROMPT},
1714
  {"role": "user", "content": user_payload},
 
2025
  # context as [system, original_question, RESEARCH_STATE_SUMMARY].
2026
  if (
2027
  strategy == "condenser"
2028
+ and (MEMORY_MODEL_NAME or AZURE_OPENAI_DEPLOYMENT)
2029
  and MEMORY_API_KEY
2030
  and turn > 1
2031
  and _messages_token_count(messages) > MEMORY_TOKEN_THRESHOLD
 
2061
  yield _emit()
2062
  elif (
2063
  strategy == "condenser"
2064
+ and not ((MEMORY_MODEL_NAME or AZURE_OPENAI_DEPLOYMENT) and MEMORY_API_KEY)
2065
  and state.trusted_notes
2066
  and turn > 1
2067
  and turn % 3 == 0