Nav772 commited on
Commit
6b1678e
·
verified ·
1 Parent(s): b14bd3f

Update retreiver.py

Browse files
Files changed (1) hide show
  1. retreiver.py +91 -0
retreiver.py CHANGED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+ from langchain.docstore.document import Document
3
+
4
+ # Load Dataset
5
+ guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
6
+
7
+ # Convert dataset entries to document
8
+ docs = [
9
+ Document(
10
+ page_content = "\n".join([
11
+ f"Name: {guest['name']}",
12
+ f"Relation: {guest['relation']}",
13
+ f"Description: {guest['description']}",
14
+ f"Email: {guest['email']}"
15
+ ]),
16
+ metadata={"name": guest["name"]}
17
+ )
18
+ for guest in guest_dataset
19
+ ]
20
+ # ---------------------------------------------------------------------------------------------
21
+ from langchain_community.retreivers import BM25Retreiver
22
+ from langchain.tools import Tool
23
+
24
+ bm25_retriever = BM25Retreiver.from_documents(docs)
25
+
26
+ def extract_text(query: str) -> str:
27
+ """ Retrieves detailed information on the guests attending the Gala based on the name and relation."""
28
+ results = bm25_retriever.invoke(query)
29
+ if results:
30
+ return "\n\n".join([doc.page_content for doc in results[:3]])
31
+ else:
32
+ return "No matching information of tthe guests found"
33
+
34
+
35
+ guest_info_tool = Tool(
36
+ name = "guest_info_retriever",
37
+ func = extract_text,
38
+ description = "Retrieves detailed information on thr guests attending the Gala based on the name and the relation"
39
+ )
40
+ # ---------------------------------------------------------------------------------------------
41
+
42
+ from typing import TypeDict, Annotated
43
+ from langgraph.graph.message import add_messages
44
+ from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
45
+ from langgraph.prebuilt import ToolNode
46
+ from langgraph.graph import START, StateGraph
47
+ from langgraph.prebuilt import tools_condition
48
+ from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
49
+
50
+ llm = HuggingFaceEndpoint(
51
+ repo_id = "Qwen/Qwen2.5-Coder-32B-Instruct",
52
+ huggingfacehub_api_token =
53
+ )
54
+
55
+ chat = ChatHuggingFace(llm=llm, verbose=True)
56
+ tools = [guest_info_tool]
57
+ chat_with_tools = chat.bind_tools(tools)
58
+
59
+ # Generate Agentstate & AgentGraph
60
+
61
+ class AgentState(TypeDict):
62
+ messages: Annotated[list[AnyMessage], add_messages]
63
+
64
+ def assistant(state : AgentState):
65
+ retutn {
66
+ "messages" : [chat chat_with_tools.invoke(state["messages"])]
67
+ }
68
+
69
+ builder = StateGraph(AgentState)
70
+
71
+ # Define the nodes
72
+
73
+ builder.add_node("assistant", assistant)
74
+ builder.add_node("tools", ToolNode(tools))
75
+
76
+ # Define Edges
77
+
78
+ builder.add_edge(START, "assistant")
79
+ builder.add_conditional_edges("assistant",
80
+ # If the latest message requires a tool, route to tools
81
+ # Otherwise, provide a direct response
82
+ tools_condition,
83
+ )
84
+ builder.add_edge("tools", "assistant")
85
+ alfred = builder.compile()
86
+
87
+ messages = [HumanMessage(content="Tell me about our guest named 'Lady Ada Lovelace'.")]
88
+ response = alfred.invoke({"messages": messages})
89
+
90
+ print("🎩 Alfred's Response:")
91
+ print(response['messages'][-1].content))