# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from typing import Literal, Optional, Union, Any from google import genai from google.genai import types import termcolor from google.genai.types import ( Part, GenerateContentConfig, Content, Candidate, FunctionResponse, FinishReason, ) import time from rich.console import Console from rich.table import Table from computers import EnvState, Computer MAX_RECENT_TURN_WITH_SCREENSHOTS = 5 PREDEFINED_COMPUTER_USE_FUNCTIONS = [ "open_web_browser", "click_at", "hover_at", "type_text_at", "scroll_document", "scroll_at", "wait_5_seconds", "go_back", "go_forward", "search", "navigate", "key_combination", "drag_and_drop", ] console = Console() # Built-in Computer Use tools will return "EnvState". # Custom provided functions will return "dict". FunctionResponseT = Union[EnvState, dict] def multiply_numbers(x: float, y: float) -> dict: """Multiplies two numbers. Args: x (float): The first number. y (float): The second number. Returns: dict: The result of multiplication. """ return {"result": x * y} # def multiply_numbers(x: float, y: float) -> dict: # """Multiplies two numbers. # Args: # x (float): The first number. # y (float): The second number. # Returns: # dict: The result of multiplication. # """ # return {"result": x * y} # 在从types.FunctionDeclaration.from_callable中创建函数声明时,这里的__doc__中所有信息都会被转化为description, 即使其风格符合google style, 但不会进行进一步的解析成参数和返回值的说明 # 如果参数含义复杂需要在其中再补充一下说明 def evaluate_answer(query: str, answer:str) -> None: """ Evaluates the answer based on the given query. """ return None class BrowserAgent: def __init__( self, browser_computer: Computer, query: str, system_prompt: str, model_name: str, verbose: bool = True, ): self._browser_computer = browser_computer self._query = query self._model_name = model_name self._verbose = verbose self.final_reasoning = None self._client = genai.Client( api_key=os.environ.get("GEMINI_API_KEY"), vertexai=os.environ.get("USE_VERTEXAI", "0").lower() in ["true", "1"], project=os.environ.get("VERTEXAI_PROJECT"), location=os.environ.get("VERTEXAI_LOCATION"), ) self._contents: list[Content] = [ Content( role="user", parts=[ Part(text=self._query), ], ) ] # Exclude any predefined functions here. excluded_predefined_functions = ["scroll_document"] # 先屏蔽使用这个滚动函数 # Add your own custom functions here. custom_functions = [ # For example: types.FunctionDeclaration.from_callable( client=self._client, callable=multiply_numbers ) ] self._generate_content_config = GenerateContentConfig( temperature=1, top_p=0.95, top_k=40, max_output_tokens=8192, tools=[ types.Tool( computer_use=types.ComputerUse( environment=types.Environment.ENVIRONMENT_BROWSER, excluded_predefined_functions=excluded_predefined_functions, ), ), types.Tool(function_declarations=custom_functions), ], system_instruction=[system_prompt] # 定义system prompt ) def handle_action(self, action: types.FunctionCall) -> FunctionResponseT: """Handles the action and returns the environment state.""" if action.name == "open_web_browser": return self._browser_computer.open_web_browser() elif action.name == "click_at": x = self.denormalize_x(action.args["x"]) y = self.denormalize_y(action.args["y"]) return self._browser_computer.click_at( x=x, y=y, ) elif action.name == "hover_at": x = self.denormalize_x(action.args["x"]) y = self.denormalize_y(action.args["y"]) return self._browser_computer.hover_at( x=x, y=y, ) elif action.name == "type_text_at": x = self.denormalize_x(action.args["x"]) y = self.denormalize_y(action.args["y"]) press_enter = action.args.get("press_enter", False) clear_before_typing = action.args.get("clear_before_typing", True) return self._browser_computer.type_text_at( x=x, y=y, text=action.args["text"], press_enter=press_enter, clear_before_typing=clear_before_typing, ) elif action.name == "scroll_document": return self._browser_computer.scroll_document(action.args["direction"]) elif action.name == "scroll_at": x = self.denormalize_x(action.args["x"]) y = self.denormalize_y(action.args["y"]) magnitude = action.args.get("magnitude", 800) direction = action.args["direction"] if direction in ("up", "down"): magnitude = self.denormalize_y(magnitude) elif direction in ("left", "right"): magnitude = self.denormalize_x(magnitude) else: raise ValueError("Unknown direction: ", direction) return self._browser_computer.scroll_at( x=x, y=y, direction=direction, magnitude=magnitude ) elif action.name == "wait_5_seconds": return self._browser_computer.wait_5_seconds() elif action.name == "go_back": return self._browser_computer.go_back() elif action.name == "go_forward": return self._browser_computer.go_forward() elif action.name == "search": return self._browser_computer.search() elif action.name == "navigate": return self._browser_computer.navigate(action.args["url"]) elif action.name == "key_combination": return self._browser_computer.key_combination( action.args["keys"].split("+") ) elif action.name == "drag_and_drop": x = self.denormalize_x(action.args["x"]) y = self.denormalize_y(action.args["y"]) destination_x = self.denormalize_x(action.args["destination_x"]) destination_y = self.denormalize_y(action.args["destination_y"]) return self._browser_computer.drag_and_drop( x=x, y=y, destination_x=destination_x, destination_y=destination_y, ) # Handle the custom function declarations here. elif action.name == multiply_numbers.__name__: return multiply_numbers(x=action.args["x"], y=action.args["y"]) else: raise ValueError(f"Unsupported function: {action}") def get_model_response( self, max_retries=5, base_delay_s=1 ) -> types.GenerateContentResponse: for attempt in range(max_retries): try: response = self._client.models.generate_content( model=self._model_name, contents=self._contents, config=self._generate_content_config, ) return response # Return response on success except Exception as e: print(e) if attempt < max_retries - 1: delay = base_delay_s * (2**attempt) message = ( f"Generating content failed on attempt {attempt + 1}. " f"Retrying in {delay} seconds...\n" ) termcolor.cprint( message, color="yellow", ) time.sleep(delay) else: termcolor.cprint( f"Generating content failed after {max_retries} attempts.\n", color="red", ) raise def get_text(self, candidate: Candidate) -> Optional[str]: """Extracts the text from the candidate.""" if not candidate.content or not candidate.content.parts: return None text = [] # Gemini 在多模态的场景下,会分段返回内容,例如,除了text和function_call之外,也可能包含其他的部分. # { # "content": { # "parts": [ # {"text": "I’ll start by searching Google for relevant documentation."}, # {"function_call": {...}}, # {"text": "Now that I found it, here’s the explanation:"}, # {"text": "In Python, decorators are higher-order functions..."} # ] # } # } for part in candidate.content.parts: if part.text: text.append(part.text) return " ".join(text) or None def extract_function_calls(self, candidate: Candidate) -> list[types.FunctionCall]: """Extracts the function call from the candidate.""" if not candidate.content or not candidate.content.parts: return [] ret = [] for part in candidate.content.parts: if part.function_call: ret.append(part.function_call) return ret # reasoning, screenshot_base64, action, status, url def run_one_iteration_modify(self): # Generate a response from the model. try: response = self.get_model_response() except Exception as e: return "Error occurred", "COMPLETE", [] if not response.candidates: return "No candidates in response", "COMPLETE", [] candidate = response.candidates[0] if candidate.content: self._contents.append(candidate.content) reasoning = self.get_text(candidate) function_calls = self.extract_function_calls(candidate) if not function_calls: self.final_reasoning = reasoning return reasoning, "COMPLETE", [] # Process each function call and collect results function_responses_list = [] status = "CONTINUE" for function_call in function_calls: try: fc_result = self.handle_action(function_call) screenshot_base64 = "" url = "" response = {} if isinstance(fc_result, EnvState): # 浏览器操作事件 screenshot_base64 = fc_result.screenshot url = fc_result.url response = {"url": url} function_response = FunctionResponse( name=function_call.name, response={"url": url}, parts=[types.FunctionResponsePart( inline_data=types.FunctionResponseBlob( mime_type="image/png", data=screenshot_base64 ) )] ) self._contents.append(Content( role="user", parts=[Part(function_response=function_response)], )) elif isinstance(fc_result, dict): # 自定义函数 response = fc_result self._contents.append(Content( role="user", parts=[Part(function_response=FunctionResponse( name=function_call.name, response=fc_result ))], )) # Add to function_responses_list function_responses_list.append({ "screenshot": screenshot_base64, "action": function_call.name, "response": response, }) except Exception as e: return f"Error handling action {function_call.name}: {str(e)}", "COMPLETE", function_responses_list # only keep screenshots in the few most recent turns, remove the screenshot images from the old turns. turn_with_screenshots_found = 0 for content in reversed(self._contents): if content.role == "user" and content.parts: # check if content has screenshot of the predefined computer use functions. has_screenshot = False for part in content.parts: if ( part.function_response and part.function_response.parts and part.function_response.name in PREDEFINED_COMPUTER_USE_FUNCTIONS ): has_screenshot = True break if has_screenshot: turn_with_screenshots_found += 1 # remove the screenshot image if the number of screenshots exceed the limit. if turn_with_screenshots_found > MAX_RECENT_TURN_WITH_SCREENSHOTS: for part in content.parts: if ( part.function_response and part.function_response.parts and part.function_response.name in PREDEFINED_COMPUTER_USE_FUNCTIONS ): part.function_response.parts = None return reasoning, status, function_responses_list def run_one_iteration(self) -> Literal["COMPLETE", "CONTINUE"]: # Generate a response from the model. if self._verbose: with console.status( "Generating response from Gemini Computer Use...", spinner_style=None ): try: response = self.get_model_response() except Exception as e: return "COMPLETE" else: try: response = self.get_model_response() except Exception as e: return "COMPLETE" if not response.candidates: print("Response has no candidates!") print(response) raise ValueError("Empty response") # Extract the text and function call from the response. candidate = response.candidates[0] # Append the model turn to conversation history. if candidate.content: self._contents.append(candidate.content) reasoning = self.get_text(candidate) function_calls = self.extract_function_calls(candidate) # Retry the request in case of malformed FCs. if ( not function_calls and not reasoning and candidate.finish_reason == FinishReason.MALFORMED_FUNCTION_CALL ): return "CONTINUE" if not function_calls: print(f"Agent Loop Complete: {reasoning}") self.final_reasoning = reasoning return "COMPLETE" function_call_strs = [] for function_call in function_calls: # Print the function call and any reasoning. function_call_str = f"Name: {function_call.name}" if function_call.args: function_call_str += f"\nArgs:" for key, value in function_call.args.items(): function_call_str += f"\n {key}: {value}" function_call_strs.append(function_call_str) table = Table(expand=True) table.add_column( "Gemini Computer Use Reasoning", header_style="magenta", ratio=1, no_wrap=False, # 允许换行 overflow="fold", # 超出部分折行显示 ) table.add_column("Function Call(s)", header_style="cyan", ratio=1) table.add_row(reasoning, "\n".join(function_call_strs)) if self._verbose: console.print(table) print() function_responses = [] for function_call in function_calls: extra_fr_fields = {} if function_call.args and ( safety := function_call.args.get("safety_decision") ): decision = self._get_safety_confirmation(safety) if decision == "TERMINATE": print("Terminating agent loop") return "COMPLETE" # Explicitly mark the safety check as acknowledged. extra_fr_fields["safety_acknowledgement"] = "true" if self._verbose: with console.status( "Sending command to Computer...", spinner_style=None ): fc_result = self.handle_action(function_call) else: fc_result = self.handle_action(function_call) if isinstance(fc_result, EnvState): function_responses.append( FunctionResponse( name=function_call.name, response={ "url": fc_result.url, **extra_fr_fields, }, parts=[ types.FunctionResponsePart( inline_data=types.FunctionResponseBlob( mime_type="image/png", data=fc_result.screenshot ) ) ], ) ) elif isinstance(fc_result, dict): function_responses.append( FunctionResponse(name=function_call.name, response=fc_result) ) self._contents.append( Content( role="user", parts=[Part(function_response=fr) for fr in function_responses], ) ) # only keep screenshots in the few most recent turns, remove the screenshot images from the old turns. turn_with_screenshots_found = 0 for content in reversed(self._contents): if content.role == "user" and content.parts: # check if content has screenshot of the predefined computer use functions. has_screenshot = False for part in content.parts: if ( part.function_response and part.function_response.parts and part.function_response.name in PREDEFINED_COMPUTER_USE_FUNCTIONS ): has_screenshot = True break if has_screenshot: turn_with_screenshots_found += 1 # remove the screenshot image if the number of screenshots exceed the limit. if turn_with_screenshots_found > MAX_RECENT_TURN_WITH_SCREENSHOTS: for part in content.parts: if ( part.function_response and part.function_response.parts and part.function_response.name in PREDEFINED_COMPUTER_USE_FUNCTIONS ): part.function_response.parts = None return "CONTINUE" def _get_safety_confirmation( self, safety: dict[str, Any] ) -> Literal["CONTINUE", "TERMINATE"]: if safety["decision"] != "require_confirmation": raise ValueError(f"Unknown safety decision: safety['decision']") termcolor.cprint( "Safety service requires explicit confirmation!", color="yellow", attrs=["bold"], ) print(safety["explanation"]) decision = "" while decision.lower() not in ("y", "n", "ye", "yes", "no"): decision = input("Do you wish to proceed? [Yes]/[No]\n") if decision.lower() in ("n", "no"): return "TERMINATE" return "CONTINUE" def agent_loop(self): status = "CONTINUE" while status == "CONTINUE": status = self.run_one_iteration() def agent_loop_yield(self): status = "CONTINUE" while status == "CONTINUE": reasoning, status, function_responses_list = self.run_one_iteration_modify() yield reasoning, status, function_responses_list def denormalize_x(self, x: int) -> int: return int(x / 1000 * self._browser_computer.screen_size()[0]) def denormalize_y(self, y: int) -> int: return int(y / 1000 * self._browser_computer.screen_size()[1])