Spaces:
Build error
Build error
| # Copyright 2025 Google LLC | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| from typing import Literal, Optional, Union, Any | |
| from google import genai | |
| from google.genai import types | |
| import termcolor | |
| from google.genai.types import ( | |
| Part, | |
| GenerateContentConfig, | |
| Content, | |
| Candidate, | |
| FunctionResponse, | |
| FinishReason, | |
| ) | |
| import time | |
| from rich.console import Console | |
| from rich.table import Table | |
| from computers import EnvState, Computer | |
| MAX_RECENT_TURN_WITH_SCREENSHOTS = 5 | |
| PREDEFINED_COMPUTER_USE_FUNCTIONS = [ | |
| "open_web_browser", | |
| "click_at", | |
| "hover_at", | |
| "type_text_at", | |
| "scroll_document", | |
| "scroll_at", | |
| "wait_5_seconds", | |
| "go_back", | |
| "go_forward", | |
| "search", | |
| "navigate", | |
| "key_combination", | |
| "drag_and_drop", | |
| ] | |
| console = Console() | |
| # Built-in Computer Use tools will return "EnvState". | |
| # Custom provided functions will return "dict". | |
| FunctionResponseT = Union[EnvState, dict] | |
| def multiply_numbers(x: float, y: float) -> dict: | |
| """Multiplies two numbers. | |
| Args: | |
| x (float): The first number. | |
| y (float): The second number. | |
| Returns: | |
| dict: The result of multiplication. | |
| """ | |
| return {"result": x * y} | |
| # def multiply_numbers(x: float, y: float) -> dict: | |
| # """Multiplies two numbers. | |
| # Args: | |
| # x (float): The first number. | |
| # y (float): The second number. | |
| # Returns: | |
| # dict: The result of multiplication. | |
| # """ | |
| # return {"result": x * y} | |
| # 在从types.FunctionDeclaration.from_callable中创建函数声明时,这里的__doc__中所有信息都会被转化为description, 即使其风格符合google style, 但不会进行进一步的解析成参数和返回值的说明 | |
| # 如果参数含义复杂需要在其中再补充一下说明 | |
| def evaluate_answer(query: str, answer:str) -> None: | |
| """ | |
| Evaluates the answer based on the given query. | |
| """ | |
| return None | |
| class BrowserAgent: | |
| def __init__( | |
| self, | |
| browser_computer: Computer, | |
| query: str, | |
| system_prompt: str, | |
| model_name: str, | |
| verbose: bool = True, | |
| ): | |
| self._browser_computer = browser_computer | |
| self._query = query | |
| self._model_name = model_name | |
| self._verbose = verbose | |
| self.final_reasoning = None | |
| self._client = genai.Client( | |
| api_key=os.environ.get("GEMINI_API_KEY"), | |
| vertexai=os.environ.get("USE_VERTEXAI", "0").lower() in ["true", "1"], | |
| project=os.environ.get("VERTEXAI_PROJECT"), | |
| location=os.environ.get("VERTEXAI_LOCATION"), | |
| ) | |
| self._contents: list[Content] = [ | |
| Content( | |
| role="user", | |
| parts=[ | |
| Part(text=self._query), | |
| ], | |
| ) | |
| ] | |
| # Exclude any predefined functions here. | |
| excluded_predefined_functions = ["scroll_document"] # 先屏蔽使用这个滚动函数 | |
| # Add your own custom functions here. | |
| custom_functions = [ | |
| # For example: | |
| types.FunctionDeclaration.from_callable( | |
| client=self._client, callable=multiply_numbers | |
| ) | |
| ] | |
| self._generate_content_config = GenerateContentConfig( | |
| temperature=1, | |
| top_p=0.95, | |
| top_k=40, | |
| max_output_tokens=8192, | |
| tools=[ | |
| types.Tool( | |
| computer_use=types.ComputerUse( | |
| environment=types.Environment.ENVIRONMENT_BROWSER, | |
| excluded_predefined_functions=excluded_predefined_functions, | |
| ), | |
| ), | |
| types.Tool(function_declarations=custom_functions), | |
| ], | |
| system_instruction=[system_prompt] # 定义system prompt | |
| ) | |
| def handle_action(self, action: types.FunctionCall) -> FunctionResponseT: | |
| """Handles the action and returns the environment state.""" | |
| if action.name == "open_web_browser": | |
| return self._browser_computer.open_web_browser() | |
| elif action.name == "click_at": | |
| x = self.denormalize_x(action.args["x"]) | |
| y = self.denormalize_y(action.args["y"]) | |
| return self._browser_computer.click_at( | |
| x=x, | |
| y=y, | |
| ) | |
| elif action.name == "hover_at": | |
| x = self.denormalize_x(action.args["x"]) | |
| y = self.denormalize_y(action.args["y"]) | |
| return self._browser_computer.hover_at( | |
| x=x, | |
| y=y, | |
| ) | |
| elif action.name == "type_text_at": | |
| x = self.denormalize_x(action.args["x"]) | |
| y = self.denormalize_y(action.args["y"]) | |
| press_enter = action.args.get("press_enter", False) | |
| clear_before_typing = action.args.get("clear_before_typing", True) | |
| return self._browser_computer.type_text_at( | |
| x=x, | |
| y=y, | |
| text=action.args["text"], | |
| press_enter=press_enter, | |
| clear_before_typing=clear_before_typing, | |
| ) | |
| elif action.name == "scroll_document": | |
| return self._browser_computer.scroll_document(action.args["direction"]) | |
| elif action.name == "scroll_at": | |
| x = self.denormalize_x(action.args["x"]) | |
| y = self.denormalize_y(action.args["y"]) | |
| magnitude = action.args.get("magnitude", 800) | |
| direction = action.args["direction"] | |
| if direction in ("up", "down"): | |
| magnitude = self.denormalize_y(magnitude) | |
| elif direction in ("left", "right"): | |
| magnitude = self.denormalize_x(magnitude) | |
| else: | |
| raise ValueError("Unknown direction: ", direction) | |
| return self._browser_computer.scroll_at( | |
| x=x, y=y, direction=direction, magnitude=magnitude | |
| ) | |
| elif action.name == "wait_5_seconds": | |
| return self._browser_computer.wait_5_seconds() | |
| elif action.name == "go_back": | |
| return self._browser_computer.go_back() | |
| elif action.name == "go_forward": | |
| return self._browser_computer.go_forward() | |
| elif action.name == "search": | |
| return self._browser_computer.search() | |
| elif action.name == "navigate": | |
| return self._browser_computer.navigate(action.args["url"]) | |
| elif action.name == "key_combination": | |
| return self._browser_computer.key_combination( | |
| action.args["keys"].split("+") | |
| ) | |
| elif action.name == "drag_and_drop": | |
| x = self.denormalize_x(action.args["x"]) | |
| y = self.denormalize_y(action.args["y"]) | |
| destination_x = self.denormalize_x(action.args["destination_x"]) | |
| destination_y = self.denormalize_y(action.args["destination_y"]) | |
| return self._browser_computer.drag_and_drop( | |
| x=x, | |
| y=y, | |
| destination_x=destination_x, | |
| destination_y=destination_y, | |
| ) | |
| # Handle the custom function declarations here. | |
| elif action.name == multiply_numbers.__name__: | |
| return multiply_numbers(x=action.args["x"], y=action.args["y"]) | |
| else: | |
| raise ValueError(f"Unsupported function: {action}") | |
| def get_model_response( | |
| self, max_retries=5, base_delay_s=1 | |
| ) -> types.GenerateContentResponse: | |
| for attempt in range(max_retries): | |
| try: | |
| response = self._client.models.generate_content( | |
| model=self._model_name, | |
| contents=self._contents, | |
| config=self._generate_content_config, | |
| ) | |
| return response # Return response on success | |
| except Exception as e: | |
| print(e) | |
| if attempt < max_retries - 1: | |
| delay = base_delay_s * (2**attempt) | |
| message = ( | |
| f"Generating content failed on attempt {attempt + 1}. " | |
| f"Retrying in {delay} seconds...\n" | |
| ) | |
| termcolor.cprint( | |
| message, | |
| color="yellow", | |
| ) | |
| time.sleep(delay) | |
| else: | |
| termcolor.cprint( | |
| f"Generating content failed after {max_retries} attempts.\n", | |
| color="red", | |
| ) | |
| raise | |
| def get_text(self, candidate: Candidate) -> Optional[str]: | |
| """Extracts the text from the candidate.""" | |
| if not candidate.content or not candidate.content.parts: | |
| return None | |
| text = [] | |
| # Gemini 在多模态的场景下,会分段返回内容,例如,除了text和function_call之外,也可能包含其他的部分. | |
| # { | |
| # "content": { | |
| # "parts": [ | |
| # {"text": "I’ll start by searching Google for relevant documentation."}, | |
| # {"function_call": {...}}, | |
| # {"text": "Now that I found it, here’s the explanation:"}, | |
| # {"text": "In Python, decorators are higher-order functions..."} | |
| # ] | |
| # } | |
| # } | |
| for part in candidate.content.parts: | |
| if part.text: | |
| text.append(part.text) | |
| return " ".join(text) or None | |
| def extract_function_calls(self, candidate: Candidate) -> list[types.FunctionCall]: | |
| """Extracts the function call from the candidate.""" | |
| if not candidate.content or not candidate.content.parts: | |
| return [] | |
| ret = [] | |
| for part in candidate.content.parts: | |
| if part.function_call: | |
| ret.append(part.function_call) | |
| return ret | |
| # reasoning, screenshot_base64, action, status, url | |
| def run_one_iteration_modify(self): | |
| # Generate a response from the model. | |
| try: | |
| response = self.get_model_response() | |
| except Exception as e: | |
| return "Error occurred", "COMPLETE", [] | |
| if not response.candidates: | |
| return "No candidates in response", "COMPLETE", [] | |
| candidate = response.candidates[0] | |
| if candidate.content: | |
| self._contents.append(candidate.content) | |
| reasoning = self.get_text(candidate) | |
| function_calls = self.extract_function_calls(candidate) | |
| if not function_calls: | |
| self.final_reasoning = reasoning | |
| return reasoning, "COMPLETE", [] | |
| # Process each function call and collect results | |
| function_responses_list = [] | |
| status = "CONTINUE" | |
| for function_call in function_calls: | |
| try: | |
| fc_result = self.handle_action(function_call) | |
| screenshot_base64 = "" | |
| url = "" | |
| response = {} | |
| if isinstance(fc_result, EnvState): # 浏览器操作事件 | |
| screenshot_base64 = fc_result.screenshot | |
| url = fc_result.url | |
| response = {"url": url} | |
| function_response = FunctionResponse( | |
| name=function_call.name, | |
| response={"url": url}, | |
| parts=[types.FunctionResponsePart( | |
| inline_data=types.FunctionResponseBlob( | |
| mime_type="image/png", data=screenshot_base64 | |
| ) | |
| )] | |
| ) | |
| self._contents.append(Content( | |
| role="user", | |
| parts=[Part(function_response=function_response)], | |
| )) | |
| elif isinstance(fc_result, dict): # 自定义函数 | |
| response = fc_result | |
| self._contents.append(Content( | |
| role="user", | |
| parts=[Part(function_response=FunctionResponse( | |
| name=function_call.name, | |
| response=fc_result | |
| ))], | |
| )) | |
| # Add to function_responses_list | |
| function_responses_list.append({ | |
| "screenshot": screenshot_base64, | |
| "action": function_call.name, | |
| "response": response, | |
| }) | |
| except Exception as e: | |
| return f"Error handling action {function_call.name}: {str(e)}", "COMPLETE", function_responses_list | |
| # only keep screenshots in the few most recent turns, remove the screenshot images from the old turns. | |
| turn_with_screenshots_found = 0 | |
| for content in reversed(self._contents): | |
| if content.role == "user" and content.parts: | |
| # check if content has screenshot of the predefined computer use functions. | |
| has_screenshot = False | |
| for part in content.parts: | |
| if ( | |
| part.function_response | |
| and part.function_response.parts | |
| and part.function_response.name | |
| in PREDEFINED_COMPUTER_USE_FUNCTIONS | |
| ): | |
| has_screenshot = True | |
| break | |
| if has_screenshot: | |
| turn_with_screenshots_found += 1 | |
| # remove the screenshot image if the number of screenshots exceed the limit. | |
| if turn_with_screenshots_found > MAX_RECENT_TURN_WITH_SCREENSHOTS: | |
| for part in content.parts: | |
| if ( | |
| part.function_response | |
| and part.function_response.parts | |
| and part.function_response.name | |
| in PREDEFINED_COMPUTER_USE_FUNCTIONS | |
| ): | |
| part.function_response.parts = None | |
| return reasoning, status, function_responses_list | |
| def run_one_iteration(self) -> Literal["COMPLETE", "CONTINUE"]: | |
| # Generate a response from the model. | |
| if self._verbose: | |
| with console.status( | |
| "Generating response from Gemini Computer Use...", spinner_style=None | |
| ): | |
| try: | |
| response = self.get_model_response() | |
| except Exception as e: | |
| return "COMPLETE" | |
| else: | |
| try: | |
| response = self.get_model_response() | |
| except Exception as e: | |
| return "COMPLETE" | |
| if not response.candidates: | |
| print("Response has no candidates!") | |
| print(response) | |
| raise ValueError("Empty response") | |
| # Extract the text and function call from the response. | |
| candidate = response.candidates[0] | |
| # Append the model turn to conversation history. | |
| if candidate.content: | |
| self._contents.append(candidate.content) | |
| reasoning = self.get_text(candidate) | |
| function_calls = self.extract_function_calls(candidate) | |
| # Retry the request in case of malformed FCs. | |
| if ( | |
| not function_calls | |
| and not reasoning | |
| and candidate.finish_reason == FinishReason.MALFORMED_FUNCTION_CALL | |
| ): | |
| return "CONTINUE" | |
| if not function_calls: | |
| print(f"Agent Loop Complete: {reasoning}") | |
| self.final_reasoning = reasoning | |
| return "COMPLETE" | |
| function_call_strs = [] | |
| for function_call in function_calls: | |
| # Print the function call and any reasoning. | |
| function_call_str = f"Name: {function_call.name}" | |
| if function_call.args: | |
| function_call_str += f"\nArgs:" | |
| for key, value in function_call.args.items(): | |
| function_call_str += f"\n {key}: {value}" | |
| function_call_strs.append(function_call_str) | |
| table = Table(expand=True) | |
| table.add_column( | |
| "Gemini Computer Use Reasoning", header_style="magenta", ratio=1, | |
| no_wrap=False, # 允许换行 | |
| overflow="fold", # 超出部分折行显示 | |
| ) | |
| table.add_column("Function Call(s)", header_style="cyan", ratio=1) | |
| table.add_row(reasoning, "\n".join(function_call_strs)) | |
| if self._verbose: | |
| console.print(table) | |
| print() | |
| function_responses = [] | |
| for function_call in function_calls: | |
| extra_fr_fields = {} | |
| if function_call.args and ( | |
| safety := function_call.args.get("safety_decision") | |
| ): | |
| decision = self._get_safety_confirmation(safety) | |
| if decision == "TERMINATE": | |
| print("Terminating agent loop") | |
| return "COMPLETE" | |
| # Explicitly mark the safety check as acknowledged. | |
| extra_fr_fields["safety_acknowledgement"] = "true" | |
| if self._verbose: | |
| with console.status( | |
| "Sending command to Computer...", spinner_style=None | |
| ): | |
| fc_result = self.handle_action(function_call) | |
| else: | |
| fc_result = self.handle_action(function_call) | |
| if isinstance(fc_result, EnvState): | |
| function_responses.append( | |
| FunctionResponse( | |
| name=function_call.name, | |
| response={ | |
| "url": fc_result.url, | |
| **extra_fr_fields, | |
| }, | |
| parts=[ | |
| types.FunctionResponsePart( | |
| inline_data=types.FunctionResponseBlob( | |
| mime_type="image/png", data=fc_result.screenshot | |
| ) | |
| ) | |
| ], | |
| ) | |
| ) | |
| elif isinstance(fc_result, dict): | |
| function_responses.append( | |
| FunctionResponse(name=function_call.name, response=fc_result) | |
| ) | |
| self._contents.append( | |
| Content( | |
| role="user", | |
| parts=[Part(function_response=fr) for fr in function_responses], | |
| ) | |
| ) | |
| # only keep screenshots in the few most recent turns, remove the screenshot images from the old turns. | |
| turn_with_screenshots_found = 0 | |
| for content in reversed(self._contents): | |
| if content.role == "user" and content.parts: | |
| # check if content has screenshot of the predefined computer use functions. | |
| has_screenshot = False | |
| for part in content.parts: | |
| if ( | |
| part.function_response | |
| and part.function_response.parts | |
| and part.function_response.name | |
| in PREDEFINED_COMPUTER_USE_FUNCTIONS | |
| ): | |
| has_screenshot = True | |
| break | |
| if has_screenshot: | |
| turn_with_screenshots_found += 1 | |
| # remove the screenshot image if the number of screenshots exceed the limit. | |
| if turn_with_screenshots_found > MAX_RECENT_TURN_WITH_SCREENSHOTS: | |
| for part in content.parts: | |
| if ( | |
| part.function_response | |
| and part.function_response.parts | |
| and part.function_response.name | |
| in PREDEFINED_COMPUTER_USE_FUNCTIONS | |
| ): | |
| part.function_response.parts = None | |
| return "CONTINUE" | |
| def _get_safety_confirmation( | |
| self, safety: dict[str, Any] | |
| ) -> Literal["CONTINUE", "TERMINATE"]: | |
| if safety["decision"] != "require_confirmation": | |
| raise ValueError(f"Unknown safety decision: safety['decision']") | |
| termcolor.cprint( | |
| "Safety service requires explicit confirmation!", | |
| color="yellow", | |
| attrs=["bold"], | |
| ) | |
| print(safety["explanation"]) | |
| decision = "" | |
| while decision.lower() not in ("y", "n", "ye", "yes", "no"): | |
| decision = input("Do you wish to proceed? [Yes]/[No]\n") | |
| if decision.lower() in ("n", "no"): | |
| return "TERMINATE" | |
| return "CONTINUE" | |
| def agent_loop(self): | |
| status = "CONTINUE" | |
| while status == "CONTINUE": | |
| status = self.run_one_iteration() | |
| def agent_loop_yield(self): | |
| status = "CONTINUE" | |
| while status == "CONTINUE": | |
| reasoning, status, function_responses_list = self.run_one_iteration_modify() | |
| yield reasoning, status, function_responses_list | |
| def denormalize_x(self, x: int) -> int: | |
| return int(x / 1000 * self._browser_computer.screen_size()[0]) | |
| def denormalize_y(self, y: int) -> int: | |
| return int(y / 1000 * self._browser_computer.screen_size()[1]) | |