Yihan-Wang's picture
Upload folder using huggingface_hub
4d2fcd2 verified
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Literal, Optional, Union, Any
from google import genai
from google.genai import types
import termcolor
from google.genai.types import (
Part,
GenerateContentConfig,
Content,
Candidate,
FunctionResponse,
FinishReason,
)
import time
from rich.console import Console
from rich.table import Table
from computers import EnvState, Computer
MAX_RECENT_TURN_WITH_SCREENSHOTS = 5
PREDEFINED_COMPUTER_USE_FUNCTIONS = [
"open_web_browser",
"click_at",
"hover_at",
"type_text_at",
"scroll_document",
"scroll_at",
"wait_5_seconds",
"go_back",
"go_forward",
"search",
"navigate",
"key_combination",
"drag_and_drop",
]
console = Console()
# Built-in Computer Use tools will return "EnvState".
# Custom provided functions will return "dict".
FunctionResponseT = Union[EnvState, dict]
def multiply_numbers(x: float, y: float) -> dict:
"""Multiplies two numbers.
Args:
x (float): The first number.
y (float): The second number.
Returns:
dict: The result of multiplication.
"""
return {"result": x * y}
# def multiply_numbers(x: float, y: float) -> dict:
# """Multiplies two numbers.
# Args:
# x (float): The first number.
# y (float): The second number.
# Returns:
# dict: The result of multiplication.
# """
# return {"result": x * y}
# 在从types.FunctionDeclaration.from_callable中创建函数声明时,这里的__doc__中所有信息都会被转化为description, 即使其风格符合google style, 但不会进行进一步的解析成参数和返回值的说明
# 如果参数含义复杂需要在其中再补充一下说明
def evaluate_answer(query: str, answer:str) -> None:
"""
Evaluates the answer based on the given query.
"""
return None
class BrowserAgent:
def __init__(
self,
browser_computer: Computer,
query: str,
system_prompt: str,
model_name: str,
verbose: bool = True,
):
self._browser_computer = browser_computer
self._query = query
self._model_name = model_name
self._verbose = verbose
self.final_reasoning = None
self._client = genai.Client(
api_key=os.environ.get("GEMINI_API_KEY"),
vertexai=os.environ.get("USE_VERTEXAI", "0").lower() in ["true", "1"],
project=os.environ.get("VERTEXAI_PROJECT"),
location=os.environ.get("VERTEXAI_LOCATION"),
)
self._contents: list[Content] = [
Content(
role="user",
parts=[
Part(text=self._query),
],
)
]
# Exclude any predefined functions here.
excluded_predefined_functions = ["scroll_document"] # 先屏蔽使用这个滚动函数
# Add your own custom functions here.
custom_functions = [
# For example:
types.FunctionDeclaration.from_callable(
client=self._client, callable=multiply_numbers
)
]
self._generate_content_config = GenerateContentConfig(
temperature=1,
top_p=0.95,
top_k=40,
max_output_tokens=8192,
tools=[
types.Tool(
computer_use=types.ComputerUse(
environment=types.Environment.ENVIRONMENT_BROWSER,
excluded_predefined_functions=excluded_predefined_functions,
),
),
types.Tool(function_declarations=custom_functions),
],
system_instruction=[system_prompt] # 定义system prompt
)
def handle_action(self, action: types.FunctionCall) -> FunctionResponseT:
"""Handles the action and returns the environment state."""
if action.name == "open_web_browser":
return self._browser_computer.open_web_browser()
elif action.name == "click_at":
x = self.denormalize_x(action.args["x"])
y = self.denormalize_y(action.args["y"])
return self._browser_computer.click_at(
x=x,
y=y,
)
elif action.name == "hover_at":
x = self.denormalize_x(action.args["x"])
y = self.denormalize_y(action.args["y"])
return self._browser_computer.hover_at(
x=x,
y=y,
)
elif action.name == "type_text_at":
x = self.denormalize_x(action.args["x"])
y = self.denormalize_y(action.args["y"])
press_enter = action.args.get("press_enter", False)
clear_before_typing = action.args.get("clear_before_typing", True)
return self._browser_computer.type_text_at(
x=x,
y=y,
text=action.args["text"],
press_enter=press_enter,
clear_before_typing=clear_before_typing,
)
elif action.name == "scroll_document":
return self._browser_computer.scroll_document(action.args["direction"])
elif action.name == "scroll_at":
x = self.denormalize_x(action.args["x"])
y = self.denormalize_y(action.args["y"])
magnitude = action.args.get("magnitude", 800)
direction = action.args["direction"]
if direction in ("up", "down"):
magnitude = self.denormalize_y(magnitude)
elif direction in ("left", "right"):
magnitude = self.denormalize_x(magnitude)
else:
raise ValueError("Unknown direction: ", direction)
return self._browser_computer.scroll_at(
x=x, y=y, direction=direction, magnitude=magnitude
)
elif action.name == "wait_5_seconds":
return self._browser_computer.wait_5_seconds()
elif action.name == "go_back":
return self._browser_computer.go_back()
elif action.name == "go_forward":
return self._browser_computer.go_forward()
elif action.name == "search":
return self._browser_computer.search()
elif action.name == "navigate":
return self._browser_computer.navigate(action.args["url"])
elif action.name == "key_combination":
return self._browser_computer.key_combination(
action.args["keys"].split("+")
)
elif action.name == "drag_and_drop":
x = self.denormalize_x(action.args["x"])
y = self.denormalize_y(action.args["y"])
destination_x = self.denormalize_x(action.args["destination_x"])
destination_y = self.denormalize_y(action.args["destination_y"])
return self._browser_computer.drag_and_drop(
x=x,
y=y,
destination_x=destination_x,
destination_y=destination_y,
)
# Handle the custom function declarations here.
elif action.name == multiply_numbers.__name__:
return multiply_numbers(x=action.args["x"], y=action.args["y"])
else:
raise ValueError(f"Unsupported function: {action}")
def get_model_response(
self, max_retries=5, base_delay_s=1
) -> types.GenerateContentResponse:
for attempt in range(max_retries):
try:
response = self._client.models.generate_content(
model=self._model_name,
contents=self._contents,
config=self._generate_content_config,
)
return response # Return response on success
except Exception as e:
print(e)
if attempt < max_retries - 1:
delay = base_delay_s * (2**attempt)
message = (
f"Generating content failed on attempt {attempt + 1}. "
f"Retrying in {delay} seconds...\n"
)
termcolor.cprint(
message,
color="yellow",
)
time.sleep(delay)
else:
termcolor.cprint(
f"Generating content failed after {max_retries} attempts.\n",
color="red",
)
raise
def get_text(self, candidate: Candidate) -> Optional[str]:
"""Extracts the text from the candidate."""
if not candidate.content or not candidate.content.parts:
return None
text = []
# Gemini 在多模态的场景下,会分段返回内容,例如,除了text和function_call之外,也可能包含其他的部分.
# {
# "content": {
# "parts": [
# {"text": "I’ll start by searching Google for relevant documentation."},
# {"function_call": {...}},
# {"text": "Now that I found it, here’s the explanation:"},
# {"text": "In Python, decorators are higher-order functions..."}
# ]
# }
# }
for part in candidate.content.parts:
if part.text:
text.append(part.text)
return " ".join(text) or None
def extract_function_calls(self, candidate: Candidate) -> list[types.FunctionCall]:
"""Extracts the function call from the candidate."""
if not candidate.content or not candidate.content.parts:
return []
ret = []
for part in candidate.content.parts:
if part.function_call:
ret.append(part.function_call)
return ret
# reasoning, screenshot_base64, action, status, url
def run_one_iteration_modify(self):
# Generate a response from the model.
try:
response = self.get_model_response()
except Exception as e:
return "Error occurred", "COMPLETE", []
if not response.candidates:
return "No candidates in response", "COMPLETE", []
candidate = response.candidates[0]
if candidate.content:
self._contents.append(candidate.content)
reasoning = self.get_text(candidate)
function_calls = self.extract_function_calls(candidate)
if not function_calls:
self.final_reasoning = reasoning
return reasoning, "COMPLETE", []
# Process each function call and collect results
function_responses_list = []
status = "CONTINUE"
for function_call in function_calls:
try:
fc_result = self.handle_action(function_call)
screenshot_base64 = ""
url = ""
response = {}
if isinstance(fc_result, EnvState): # 浏览器操作事件
screenshot_base64 = fc_result.screenshot
url = fc_result.url
response = {"url": url}
function_response = FunctionResponse(
name=function_call.name,
response={"url": url},
parts=[types.FunctionResponsePart(
inline_data=types.FunctionResponseBlob(
mime_type="image/png", data=screenshot_base64
)
)]
)
self._contents.append(Content(
role="user",
parts=[Part(function_response=function_response)],
))
elif isinstance(fc_result, dict): # 自定义函数
response = fc_result
self._contents.append(Content(
role="user",
parts=[Part(function_response=FunctionResponse(
name=function_call.name,
response=fc_result
))],
))
# Add to function_responses_list
function_responses_list.append({
"screenshot": screenshot_base64,
"action": function_call.name,
"response": response,
})
except Exception as e:
return f"Error handling action {function_call.name}: {str(e)}", "COMPLETE", function_responses_list
# only keep screenshots in the few most recent turns, remove the screenshot images from the old turns.
turn_with_screenshots_found = 0
for content in reversed(self._contents):
if content.role == "user" and content.parts:
# check if content has screenshot of the predefined computer use functions.
has_screenshot = False
for part in content.parts:
if (
part.function_response
and part.function_response.parts
and part.function_response.name
in PREDEFINED_COMPUTER_USE_FUNCTIONS
):
has_screenshot = True
break
if has_screenshot:
turn_with_screenshots_found += 1
# remove the screenshot image if the number of screenshots exceed the limit.
if turn_with_screenshots_found > MAX_RECENT_TURN_WITH_SCREENSHOTS:
for part in content.parts:
if (
part.function_response
and part.function_response.parts
and part.function_response.name
in PREDEFINED_COMPUTER_USE_FUNCTIONS
):
part.function_response.parts = None
return reasoning, status, function_responses_list
def run_one_iteration(self) -> Literal["COMPLETE", "CONTINUE"]:
# Generate a response from the model.
if self._verbose:
with console.status(
"Generating response from Gemini Computer Use...", spinner_style=None
):
try:
response = self.get_model_response()
except Exception as e:
return "COMPLETE"
else:
try:
response = self.get_model_response()
except Exception as e:
return "COMPLETE"
if not response.candidates:
print("Response has no candidates!")
print(response)
raise ValueError("Empty response")
# Extract the text and function call from the response.
candidate = response.candidates[0]
# Append the model turn to conversation history.
if candidate.content:
self._contents.append(candidate.content)
reasoning = self.get_text(candidate)
function_calls = self.extract_function_calls(candidate)
# Retry the request in case of malformed FCs.
if (
not function_calls
and not reasoning
and candidate.finish_reason == FinishReason.MALFORMED_FUNCTION_CALL
):
return "CONTINUE"
if not function_calls:
print(f"Agent Loop Complete: {reasoning}")
self.final_reasoning = reasoning
return "COMPLETE"
function_call_strs = []
for function_call in function_calls:
# Print the function call and any reasoning.
function_call_str = f"Name: {function_call.name}"
if function_call.args:
function_call_str += f"\nArgs:"
for key, value in function_call.args.items():
function_call_str += f"\n {key}: {value}"
function_call_strs.append(function_call_str)
table = Table(expand=True)
table.add_column(
"Gemini Computer Use Reasoning", header_style="magenta", ratio=1,
no_wrap=False, # 允许换行
overflow="fold", # 超出部分折行显示
)
table.add_column("Function Call(s)", header_style="cyan", ratio=1)
table.add_row(reasoning, "\n".join(function_call_strs))
if self._verbose:
console.print(table)
print()
function_responses = []
for function_call in function_calls:
extra_fr_fields = {}
if function_call.args and (
safety := function_call.args.get("safety_decision")
):
decision = self._get_safety_confirmation(safety)
if decision == "TERMINATE":
print("Terminating agent loop")
return "COMPLETE"
# Explicitly mark the safety check as acknowledged.
extra_fr_fields["safety_acknowledgement"] = "true"
if self._verbose:
with console.status(
"Sending command to Computer...", spinner_style=None
):
fc_result = self.handle_action(function_call)
else:
fc_result = self.handle_action(function_call)
if isinstance(fc_result, EnvState):
function_responses.append(
FunctionResponse(
name=function_call.name,
response={
"url": fc_result.url,
**extra_fr_fields,
},
parts=[
types.FunctionResponsePart(
inline_data=types.FunctionResponseBlob(
mime_type="image/png", data=fc_result.screenshot
)
)
],
)
)
elif isinstance(fc_result, dict):
function_responses.append(
FunctionResponse(name=function_call.name, response=fc_result)
)
self._contents.append(
Content(
role="user",
parts=[Part(function_response=fr) for fr in function_responses],
)
)
# only keep screenshots in the few most recent turns, remove the screenshot images from the old turns.
turn_with_screenshots_found = 0
for content in reversed(self._contents):
if content.role == "user" and content.parts:
# check if content has screenshot of the predefined computer use functions.
has_screenshot = False
for part in content.parts:
if (
part.function_response
and part.function_response.parts
and part.function_response.name
in PREDEFINED_COMPUTER_USE_FUNCTIONS
):
has_screenshot = True
break
if has_screenshot:
turn_with_screenshots_found += 1
# remove the screenshot image if the number of screenshots exceed the limit.
if turn_with_screenshots_found > MAX_RECENT_TURN_WITH_SCREENSHOTS:
for part in content.parts:
if (
part.function_response
and part.function_response.parts
and part.function_response.name
in PREDEFINED_COMPUTER_USE_FUNCTIONS
):
part.function_response.parts = None
return "CONTINUE"
def _get_safety_confirmation(
self, safety: dict[str, Any]
) -> Literal["CONTINUE", "TERMINATE"]:
if safety["decision"] != "require_confirmation":
raise ValueError(f"Unknown safety decision: safety['decision']")
termcolor.cprint(
"Safety service requires explicit confirmation!",
color="yellow",
attrs=["bold"],
)
print(safety["explanation"])
decision = ""
while decision.lower() not in ("y", "n", "ye", "yes", "no"):
decision = input("Do you wish to proceed? [Yes]/[No]\n")
if decision.lower() in ("n", "no"):
return "TERMINATE"
return "CONTINUE"
def agent_loop(self):
status = "CONTINUE"
while status == "CONTINUE":
status = self.run_one_iteration()
def agent_loop_yield(self):
status = "CONTINUE"
while status == "CONTINUE":
reasoning, status, function_responses_list = self.run_one_iteration_modify()
yield reasoning, status, function_responses_list
def denormalize_x(self, x: int) -> int:
return int(x / 1000 * self._browser_computer.screen_size()[0])
def denormalize_y(self, y: int) -> int:
return int(y / 1000 * self._browser_computer.screen_size()[1])