| from typing import Annotated, Literal |
| from typing_extensions import TypedDict |
|
|
| from langgraph.graph import StateGraph, MessagesState, START, END |
| from langchain_core.messages import AIMessage, HumanMessage, SystemMessage |
| from langchain_core.output_parsers import JsonOutputParser |
| from langchain_community.document_transformers import BeautifulSoupTransformer, beautiful_soup_transformer |
|
|
| from langgraph.types import Command |
|
|
| from langchain_groq import ChatGroq |
|
|
| import operator |
| import pprint |
| import os |
| import requests |
| import html2text |
|
|
| API_KEY = os.getenv("GROQ_API_KEY") |
| OUT_RES = "<|FINISHED|>" |
|
|
| HTML_TRANSFORMER = html2text.HTML2Text() |
| HTML_TRANSFORMER.ignore_links = True |
| HTML_TRANSFORMER.ignore_images = True |
|
|
| BS_TRANSFORMER = BeautifulSoupTransformer() |
|
|
|
|
| def local_message_add(dict1, dict2): |
| key2 = list(dict2.keys())[0] |
| if key2 not in dict1: |
| dict1[key2] = dict2[key2] |
| else: |
| dict1[key2] = dict1[key2] + dict2[key2] |
| return dict1 |
|
|
| def variable_state_update(dict1, dict2): |
| dict1.update(dict2) |
| return dict1 |
|
|
| class GeneralStates(TypedDict): |
| messages: Annotated[list[dict[str, str]], lambda x,y:x+y] |
| checkpoints: dict[str,list] |
| local_messages: Annotated[dict, local_message_add] |
| variables: Annotated[dict, variable_state_update] |
|
|
|
|
| def format_sequence(seq, nested=False): |
| if isinstance(seq, (list, tuple, set, frozenset, dict)): |
| |
| if isinstance(seq, dict): |
| return format_dict(seq, nested=nested) |
| |
| else: |
| return format_list_like(seq, nested=nested) |
| else: |
| return seq |
| |
| |
|
|
| def format_dict(d, nested=False): |
| |
| items = [] |
| for i, (key, value) in enumerate(d.items()): |
| if isinstance(value, (list, tuple, set, frozenset, dict)): |
| value = format_sequence(value, nested=True) |
| if not nested: |
| items.append(f"{i+1}. {key}: {value}") |
| else: |
| items.append(f"{key}: {value}") |
| return ",\n".join(items) |
|
|
| def format_list_like(seq, nested=False): |
| |
| items = [] |
| for i,item in enumerate(seq): |
| if isinstance(item, (list, tuple, set, frozenset, dict)): |
| item = format_sequence(item, nested=True) |
| if not nested: |
| items.append(f"{i+1}. {item}") |
| else: |
| items.append(str(item)) |
| return ",\n".join(items) |
|
|
|
|
| def format_dict_api(input_dict, combined): |
| formatted_dict = {} |
| for key, value in input_dict.items(): |
| if isinstance(value, dict): |
| formatted_dict[key] = format_dict_api(value, combined) |
| elif isinstance(value, str): |
| |
| formatted_dict[key] = value.format(**combined) |
| |
| |
| |
|
|
| else: |
| formatted_dict[key] = value |
|
|
| return formatted_dict |
|
|
|
|
| def run_api(api_endpoints, variables, response, input_message, chain_id): |
| if not api_endpoints: |
| return {} |
| combined = variables.copy() |
| if response: |
| api_endpoint_type = "output" |
| if isinstance(response, dict): |
| combined = combined | response |
| |
| else: |
| combined["output_message"] = response |
| else: |
| api_endpoint_type = "input" |
|
|
| combined["input_message"] = input_message |
| resp = [] |
| errors = [] |
| for x in api_endpoints: |
| try: |
| input_var = {inp: combined[inp] for inp in x["input_variables"]} |
| res = requests.request( |
| x['method'], |
| x['url'], |
| headers = format_dict_api(x['headers'], input_var) if x["headers"] else None, |
| params = format_dict_api(x["params"], input_var) if x["params"] else None, |
| json = format_dict_api(x["request_body"], input_var) if x["request_body"] else None, |
| ) |
|
|
| if x['response_type'] == 'json': |
| res = res.json() |
| else: |
| res = res.text |
| |
| if res[:15] == "<!DOCTYPE html>": |
| if x["html_to_markdown"]: |
| res = HTML_TRANSFORMER.handle(res) |
| elif x["html_tags_to_extract"]: |
| res = BS_TRANSFORMER.extract_tags(res, tags=x["html_tags_to_extract"]) |
| resp.append([res, x["name"]]) |
| except Exception as e: |
| errors.append([e, x["name"]]) |
|
|
| api_dict = {} |
| |
| for x in resp: |
| |
| api_dict[f"{api_endpoint_type}_{x[1]}_{chain_id}_success"] = x[0] |
| for x in errors: |
| |
| api_dict[f"{api_endpoint_type}_{x[1]}_{chain_id}_error"] = x[0] |
| variables.update(api_dict) |
| return api_dict |
|
|
|
|
| def agent_builder(states: GeneralStates, chain: dict, row:int, depth: int): |
| |
| |
| model_config = chain.get("agent") |
| print("[MODEL CONFIG]", model_config) |
| child = chain.get("child") |
| checkpoints = states.get("checkpoints", {}) |
|
|
| print("[STATES]", states) |
|
|
| for k,v in checkpoints.items(): |
| |
| if k == chain["id"]: |
| return Command(goto=v) |
|
|
| api_dict = {"variables":{}} |
|
|
| variables = states.get("variables", {}) |
| variables["input_message"] = states["messages"][-1].content |
|
|
| |
| |
|
|
| api_res = run_api(model_config["input_api_endpoints"], variables, None, states["messages"][-1].content, chain["id"]) |
| api_dict["variables"].update(api_res) |
|
|
| for c in child: |
| if c["condition_from"] == "input" and states['messages'][-1].content.strip() == c["condition"]: |
| redirect_agent_message = AIMessage(f"Switch to Agent {c['id']}") |
|
|
| |
| local_message = states["local_messages"].get(chain["id"]) |
| if local_message: |
| update_dict = { |
| "local_messages": { |
| |
| chain["id"]:[redirect_agent_message], |
| c["id"]:[states['messages'][-1]] |
| }, |
| } |
| else: |
| update_dict = {} |
|
|
| if c.get("checkpoint"): |
| |
| update_dict["checkpoints"] = {chain["id"]:c["id"]} |
|
|
| return Command(goto=c["id"], update=update_dict | api_dict) |
|
|
| |
| messages = states["local_messages"].get(chain["id"]) |
|
|
| if messages: |
| messages.append(states["messages"][-1]) |
| else: |
| messages = states["messages"] |
|
|
| input_var = model_config.get("input_variables") |
| output_variables = model_config.get("output_variables") |
|
|
|
|
| if model_config.get("is_template"): |
| response = model_config.get("prompt") |
| if input_var: |
| |
| response = response.format(**{var: variables[var] for var in input_var}) |
| response = AIMessage(response) |
|
|
| api_res = run_api(model_config["output_api_endpoints"], variables, response.content, messages[-1].content, chain["id"]) |
| api_dict["variables"].update(api_res) |
| |
|
|
| if output_variables: |
| out = {out_var: response.content for out_var in output_variables} |
| if "messages" not in output_variables: |
| api_dict["variables"].update(out) |
| return api_dict |
| else: |
| out.pop("messages") |
| api_dict["variables"].update(out) |
|
|
| return {"messages":[response], "local_messages":{chain["id"]:[response]}} | api_dict |
|
|
| def run_agent(i, loop_input_variables, variables): |
| if input_var: |
| print("[AGENT ID]", chain['id']) |
| print("[INPUT VARIABLES]", input_var) |
| print("[VARIABLES]", variables) |
| user_input = "\n".join([str(variables[var]) for var in input_var]) |
| |
| if i == -1: |
| prompt = model_config.get("prompt").format(**{var: variables[var] for var in input_var}) |
| else: |
| prompt = model_config.get("prompt").format( |
| **{var: variables[var][i] if var in loop_input_variables else variables[var] for var in input_var} |
| ) |
| else: |
| user_input = messages[-1].content |
| prompt = model_config.get("prompt") + "\n\n" + messages[-1].content |
|
|
|
|
| model = ChatGroq( |
| |
| |
| |
| |
| model="llama-3.3-70b-versatile", |
| |
| temperature=model_config.get("creativity"), |
| max_tokens=None, |
| timeout=None, |
| max_retries=2, |
| api_key=API_KEY |
| ) |
|
|
| routes = model_config.get("routes") |
| output_collector = model_config.get("output_collector") |
|
|
| if routes: |
| add_prompt = f"YOU MUST GENERATE OUTPUT STRICTLY one of the following list : [{', '.join(routes)}]\n\n" |
| if model_config.get("routes_description"): |
| add_prompt += "HERE IS THE CONDITIONS FOR EACH OUTPUT:\n" |
| add_prompt += "\n".join([f"{x}: {y}" for x,y in zip(routes, model_config.get("routes_description"))]) |
| add_prompt += "\n\n" |
|
|
| prompt = add_prompt + prompt |
| elif output_collector: |
| |
| add_prompt = f"YOU MUST GENERATE OUTPUT STRICTLY IN THE FOLLOWING JSON FORMAT, REMEMBER TO ADD {{}} BEFORE AND AFTER JSON CODE:\n" |
| |
| add_prompt += "\n".join(output_collector) |
| add_prompt += "\n\n" |
|
|
| prompt = prompt +"\n\n"+ add_prompt |
|
|
| response = (model | JsonOutputParser()).invoke(messages[:-1] + [HumanMessage(content=prompt)]) |
|
|
| if output_variables: |
| for k in response.keys(): |
| if k not in output_variables: |
| del response[k] |
|
|
| api_res = run_api(model_config["output_api_endpoints"], variables, response, messages[-1].content, chain["id"]) |
|
|
| api_dict["variables"].update(api_res) |
|
|
| return {"variables":response | api_dict["variables"]} |
|
|
| response = model.invoke(messages[:-1] + [HumanMessage(content=prompt)]) |
|
|
| for c in child: |
| if c["condition_from"] == "output" and response.content.strip() == c["condition"]: |
| redirect_agent_message = AIMessage(f"Switch to Agent {c['id']}") |
|
|
| |
| local_message = states["local_messages"].get(chain["id"]) |
| if local_message: |
| update_dict = { |
| "local_messages": { |
| |
| chain["id"]:[redirect_agent_message], |
| c["id"]:[HumanMessage(user_input)] |
| }, |
| } |
| else: |
| update_dict = {} |
|
|
| if c.get("checkpoint"): |
| |
| update_dict["checkpoints"] = {chain["id"]:c["id"]} |
|
|
| api_res = run_api(model_config["output_api_endpoints"], variables, response.content, messages[-1].content, chain["id"]) |
| api_dict["variables"].update(api_res) |
|
|
| |
| |
| |
| if output_variables: |
| out = {out_var: response.content for out_var in output_variables} |
| if "messages" not in output_variables: |
| api_dict["variables"].update(out) |
| return api_dict |
| else: |
| api_dict["messages"] = out.pop("messages") |
| api_dict["variables"].update(out) |
|
|
| return Command(goto=c["id"], update=update_dict | api_dict) |
| elif response.content.strip() == OUT_RES: |
|
|
| api_res = run_api(model_config["output_api_endpoints"], variables, None, messages[-1].content, chain["id"]) |
| api_dict["variables"].update(api_res) |
|
|
| return {} | api_dict |
|
|
| api_res = run_api(model_config["output_api_endpoints"], variables, response.content, messages[-1].content, chain["id"]) |
| api_dict["variables"].update(api_res) |
|
|
| if output_variables: |
| out = {out_var: response.content for out_var in output_variables} |
| if "messages" not in output_variables: |
| api_dict["variables"].update(out) |
| return api_dict |
| else: |
| out.pop("messages") |
| api_dict["variables"].update(out) |
|
|
|
|
| |
| return {"messages":[response], "local_messages":{chain["id"]:[response]}} | api_dict |
|
|
| if not chain["loop_input_variables"]: |
| return run_agent(-1, [], variables) |
| else: |
| max_loop = min([len(states["variables"].get(x)) for x in chain["loop_input_variables"]]) |
|
|
| updates = {"variables":{}} |
|
|
| for i in range(max_loop): |
| out_variables = run_agent(i, chain["loop_input_variables"], variables) |
|
|
| if type(out_variables) == dict: |
| if not out_variables.get("variables"): |
| continue |
| for k,v in out_variables["variables"].items(): |
| if k not in updates["variables"].keys(): |
| updates["variables"][k] = [] |
| if type(v) == list: |
| updates["variables"][k] += v |
| else: |
| updates["variables"][k].append(v) |
| else: |
| updates = out_variables |
| return updates |
|
|
|
|
| def route(states, routes): |
| if states["messages"][-1].content.strip() in routes: |
| return states["messages"][-1].content.strip() |
| return END |
|
|
| def build_chain(chains, checkpointer, parent_name=None, depth=0): |
| print("[BUILD CHAIN] START....") |
|
|
| stack = [(chains, parent_name, depth, 0)] |
|
|
| builder = StateGraph(GeneralStates) |
|
|
| while stack: |
| current_chains, current_parent, current_depth, i = stack.pop() |
| print("STACK", i) |
|
|
| if i >= len(current_chains): |
| continue |
|
|
| c = current_chains[i] |
| c_id = c["id"] |
|
|
| |
|
|
| try: |
| print("ADDED NODE!") |
| builder.add_node( |
| c_id, |
| lambda states, c=c, i=i, depth=current_depth: agent_builder(states, c, i, depth) |
| ) |
| |
| except ValueError as e: |
| print("[ERROR]",e) |
| pass |
|
|
|
|
| |
| if i + 1 < len(current_chains): |
| stack.append((current_chains, current_parent, current_depth, i + 1)) |
|
|
| |
| if c.get("child"): |
| stack.append(( |
| c["child"], |
| c_id, |
| current_depth + 1, |
| 0 |
| )) |
|
|
| condition_ids = [] |
|
|
| for x in c["child"]: |
| if x["condition"]: |
| condition_ids.append(x["id"]) |
| else: |
| builder.add_edge(c_id, x["id"]) |
|
|
| if condition_ids: |
| builder.add_conditional_edges( |
| c_id, |
| lambda states: route(states, condition_ids), path_map=condition_ids + [END] |
| ) |
| else: |
| builder.add_edge( |
| c_id, |
| END |
| ) |
|
|
| print("SET STARTING POINTS") |
|
|
| for start_point in chains: |
| builder.add_edge(START, start_point["id"]) |
| print("[NODES]", builder.nodes) |
| print("[EDGES]", builder.edges) |
| graph = builder.compile(checkpointer=checkpointer) |
| return graph |
|
|