Demosthene-OR commited on
Commit
a2110a1
·
1 Parent(s): 5f00917
app.py CHANGED
@@ -1,66 +1,67 @@
1
- # Import necessary modules
2
  import streamlit as st
3
- import streamlit.components.v1 as components # For embedding custom HTML
4
- from generate_knowledge_graph import generate_knowledge_graph
5
 
6
- # Set up Streamlit page configuration
7
  st.set_page_config(
8
- page_icon=None,
9
- layout="wide", # Use wide layout for better graph display
10
- initial_sidebar_state="auto",
11
  menu_items=None
12
  )
13
 
14
- # Set the title of the app
15
  st.title("Knowledge Graph From Text")
16
 
17
- # Sidebar section for user input method
18
  st.sidebar.title("Input document")
19
  input_method = st.sidebar.radio(
20
  "Choose an input method:",
21
- ["Upload txt", "Input text"], # Options for uploading a file or manually inputting text
22
  )
23
 
24
- # Case 1: User chooses to upload a .txt file
25
- if input_method == "Upload txt":
26
- # File uploader widget in the sidebar
27
- uploaded_file = st.sidebar.file_uploader(label="Upload file", type=["txt"])
28
-
29
  if uploaded_file is not None:
30
- # Read the uploaded file content and decode it as UTF-8 text
31
  text = uploaded_file.read().decode("utf-8")
32
-
33
- # Button to generate the knowledge graph
34
- if st.sidebar.button("Generate Knowledge Graph"):
35
- with st.spinner("Generating knowledge graph..."):
36
- # Call the function to generate the graph from the text
37
- net = generate_knowledge_graph(text)
38
- st.success("Knowledge graph generated successfully!")
39
-
40
- # Save the graph to an HTML file
41
- output_file = "knowledge_graph.html"
42
- net.save_graph(output_file)
43
-
44
- # Open the HTML file and display it within the Streamlit app
45
- HtmlFile = open(output_file, 'r', encoding='utf-8')
46
- components.html(HtmlFile.read(), height=1000)
47
-
48
- # Case 2: User chooses to directly input text
49
  else:
50
- # Text area for manual input
51
  text = st.sidebar.text_area("Input text", height=300)
52
 
53
- if text: # Check if the text area is not empty
54
- if st.sidebar.button("Generate Knowledge Graph"):
55
- with st.spinner("Generating knowledge graph..."):
56
- # Call the function to generate the graph from the input text
57
- net = generate_knowledge_graph(text)
58
- st.success("Knowledge graph generated successfully!")
59
-
60
- # Save the graph to an HTML file
61
- output_file = "knowledge_graph.html"
62
- net.save_graph(output_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- # Open the HTML file and display it within the Streamlit app
65
- HtmlFile = open(output_file, 'r', encoding='utf-8')
66
- components.html(HtmlFile.read(), height=1000)
 
1
+
2
  import streamlit as st
3
+ import streamlit.components.v1 as components
4
+ from generate_knowledge_graph import generate_knowledge_graph, answer_question_with_graph
5
 
 
6
  st.set_page_config(
7
+ page_icon="None",
8
+ layout="wide",
9
+ initial_sidebar_state="auto",
10
  menu_items=None
11
  )
12
 
 
13
  st.title("Knowledge Graph From Text")
14
 
 
15
  st.sidebar.title("Input document")
16
  input_method = st.sidebar.radio(
17
  "Choose an input method:",
18
+ ("Upload .txt", "Input text")
19
  )
20
 
21
+ # Text extraction based on user choice
22
+ text = ""
23
+ if input_method == "Upload .txt":
24
+ uploaded_file = st.sidebar.file_uploader(label="Upload file", type="txt")
 
25
  if uploaded_file is not None:
 
26
  text = uploaded_file.read().decode("utf-8")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  else:
 
28
  text = st.sidebar.text_area("Input text", height=300)
29
 
30
+ if st.sidebar.button("1. Generate Knowledge Graph"):
31
+ if text:
32
+ with st.spinner("Generating knowledge graph..."):
33
+ net, graph_docs = generate_knowledge_graph(text)
34
+ st.session_state['graph_docs'] = graph_docs
35
+ st.success("Knowledge graph generated successfully!")
36
+
37
+ output_file = "knowledge_graph.html"
38
+ net.save_graph(output_file)
39
+ HtmlFile = open(output_file, 'r', encoding='utf-8')
40
+ components.html(HtmlFile.read(), height=600)
41
+ else:
42
+ st.sidebar.error("Please provide some text to generate the graph.")
43
+
44
+ # QA Section
45
+ if 'graph_docs' in st.session_state:
46
+ st.markdown("---")
47
+ st.subheader("Posez une question sur le document")
48
+
49
+ col1, col2 = st.columns([3, 1])
50
+ with col1:
51
+ question = st.text_input("Votre question :")
52
+ with col2:
53
+ k_value = st.slider("Relations à analyser (Top K)", min_value=1, max_value=20, value=5)
54
+
55
+ if st.button("2. Analyser") and question:
56
+ with st.spinner("Recherche sémantique dans le graphe en cours..."):
57
+ answer, filtered_net = answer_question_with_graph(
58
+ question,
59
+ st.session_state['graph_docs'],
60
+ k_relations=k_value
61
+ )
62
+
63
+ st.info(f"**Réponse :** {answer}")
64
 
65
+ st.markdown("**Sous-graphe des relations utilisées pour répondre :**")
66
+ HtmlFile = open("filtered_graph.html", 'r', encoding='utf-8')
67
+ components.html(HtmlFile.read(), height=450)
app_v1.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import necessary modules
2
+ import streamlit as st
3
+ import streamlit.components.v1 as components # For embedding custom HTML
4
+ from generate_knowledge_graph import generate_knowledge_graph
5
+
6
+ # Set up Streamlit page configuration
7
+ st.set_page_config(
8
+ page_icon=None,
9
+ layout="wide", # Use wide layout for better graph display
10
+ initial_sidebar_state="auto",
11
+ menu_items=None
12
+ )
13
+
14
+ # Set the title of the app
15
+ st.title("Knowledge Graph From Text")
16
+
17
+ # Sidebar section for user input method
18
+ st.sidebar.title("Input document")
19
+ input_method = st.sidebar.radio(
20
+ "Choose an input method:",
21
+ ["Upload txt", "Input text"], # Options for uploading a file or manually inputting text
22
+ )
23
+
24
+ # Case 1: User chooses to upload a .txt file
25
+ if input_method == "Upload txt":
26
+ # File uploader widget in the sidebar
27
+ uploaded_file = st.sidebar.file_uploader(label="Upload file", type=["txt"])
28
+
29
+ if uploaded_file is not None:
30
+ # Read the uploaded file content and decode it as UTF-8 text
31
+ text = uploaded_file.read().decode("utf-8")
32
+
33
+ # Button to generate the knowledge graph
34
+ if st.sidebar.button("Generate Knowledge Graph"):
35
+ with st.spinner("Generating knowledge graph..."):
36
+ # Call the function to generate the graph from the text
37
+ net = generate_knowledge_graph(text)
38
+ st.success("Knowledge graph generated successfully!")
39
+
40
+ # Save the graph to an HTML file
41
+ output_file = "knowledge_graph.html"
42
+ net.save_graph(output_file)
43
+
44
+ # Open the HTML file and display it within the Streamlit app
45
+ HtmlFile = open(output_file, 'r', encoding='utf-8')
46
+ components.html(HtmlFile.read(), height=1000)
47
+
48
+ # Case 2: User chooses to directly input text
49
+ else:
50
+ # Text area for manual input
51
+ text = st.sidebar.text_area("Input text", height=300)
52
+
53
+ if text: # Check if the text area is not empty
54
+ if st.sidebar.button("Generate Knowledge Graph"):
55
+ with st.spinner("Generating knowledge graph..."):
56
+ # Call the function to generate the graph from the input text
57
+ net = generate_knowledge_graph(text)
58
+ st.success("Knowledge graph generated successfully!")
59
+
60
+ # Save the graph to an HTML file
61
+ output_file = "knowledge_graph.html"
62
+ net.save_graph(output_file)
63
+
64
+ # Open the HTML file and display it within the Streamlit app
65
+ HtmlFile = open(output_file, 'r', encoding='utf-8')
66
+ components.html(HtmlFile.read(), height=1000)
generate_knowledge_graph.py CHANGED
@@ -1,127 +1,104 @@
 
1
  from langchain_experimental.graph_transformers import LLMGraphTransformer
2
  from langchain_core.documents import Document
3
- from langchain_openai import ChatOpenAI
 
 
4
  from pyvis.network import Network
5
-
6
  from dotenv import load_dotenv
7
  import os
8
  import asyncio
9
 
10
-
11
- # Load the .env file
12
  load_dotenv()
13
- # Get API key from environment variable
14
  api_key = os.getenv("OPENAI_API_KEY")
15
 
16
  llm = ChatOpenAI(temperature=0, model_name="gpt-4o")
17
-
18
  graph_transformer = LLMGraphTransformer(llm=llm)
19
 
20
-
21
- # Extract graph data from input text
22
  async def extract_graph_data(text):
23
- """
24
- Asynchronously extracts graph data from input text using a graph transformer.
25
-
26
- Args:
27
- text (str): Input text to be processed into graph format.
28
-
29
- Returns:
30
- list: A list of GraphDocument objects containing nodes and relationships.
31
- """
32
  documents = [Document(page_content=text)]
33
  graph_documents = await graph_transformer.aconvert_to_graph_documents(documents)
34
  return graph_documents
35
 
36
-
37
  def visualize_graph(graph_documents):
38
- """
39
- Visualizes a knowledge graph using PyVis based on the extracted graph documents.
40
-
41
- Args:
42
- graph_documents (list): A list of GraphDocument objects with nodes and relationships.
43
-
44
- Returns:
45
- pyvis.network.Network: The visualized network graph object.
46
- """
47
- # Create network
48
- net = Network(height="1200px", width="100%", directed=True,
49
- notebook=False, bgcolor="#222222", font_color="white", filter_menu=True, cdn_resources='remote')
50
-
51
  nodes = graph_documents[0].nodes
52
  relationships = graph_documents[0].relationships
53
 
54
- # Build lookup for valid nodes
55
  node_dict = {node.id: node for node in nodes}
56
-
57
- # Filter out invalid edges and collect valid node IDs
58
  valid_edges = []
59
  valid_node_ids = set()
 
60
  for rel in relationships:
61
  if rel.source.id in node_dict and rel.target.id in node_dict:
62
  valid_edges.append(rel)
63
  valid_node_ids.update([rel.source.id, rel.target.id])
64
 
65
- # Track which nodes are part of any relationship
66
- connected_node_ids = set()
67
- for rel in relationships:
68
- connected_node_ids.add(rel.source.id)
69
- connected_node_ids.add(rel.target.id)
70
-
71
- # Add valid nodes to the graph
72
  for node_id in valid_node_ids:
73
  node = node_dict[node_id]
74
  try:
75
  net.add_node(node.id, label=node.id, title=node.type, group=node.type)
76
  except:
77
- continue # Skip node if error occurs
78
 
79
- # Add valid edges to the graph
80
  for rel in valid_edges:
81
  try:
82
  net.add_edge(rel.source.id, rel.target.id, label=rel.type.lower())
83
  except:
84
- continue # Skip edge if error occurs
85
-
86
- # Configure graph layout and physics
87
- net.set_options("""
88
- {
89
- "physics": {
90
- "forceAtlas2Based": {
91
- "gravitationalConstant": -100,
92
- "centralGravity": 0.01,
93
- "springLength": 200,
94
- "springConstant": 0.08
95
- },
96
- "minVelocity": 0.75,
97
- "solver": "forceAtlas2Based"
98
- }
99
- }
100
- """)
101
-
102
- output_file = "knowledge_graph.html"
103
- try:
104
- net.save_graph(output_file)
105
- print(f"Graph saved to {os.path.abspath(output_file)}")
106
- return net
107
- except Exception as e:
108
- print(f"Error saving graph: {e}")
109
- return None
110
 
 
 
111
 
112
  def generate_knowledge_graph(text):
113
- """
114
- Generates and visualizes a knowledge graph from input text.
115
-
116
- This function runs the graph extraction asynchronously and then visualizes
117
- the resulting graph using PyVis.
118
-
119
- Args:
120
- text (str): Input text to convert into a knowledge graph.
121
-
122
- Returns:
123
- pyvis.network.Network: The visualized network graph object.
124
- """
125
  graph_documents = asyncio.run(extract_graph_data(text))
126
  net = visualize_graph(graph_documents)
127
- return net
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
  from langchain_experimental.graph_transformers import LLMGraphTransformer
3
  from langchain_core.documents import Document
4
+ from langchain_openai import ChatOpenAI, OpenAIEmbeddings
5
+ from langchain_community.vectorstores import FAISS
6
+ from langchain_core.prompts import PromptTemplate
7
  from pyvis.network import Network
 
8
  from dotenv import load_dotenv
9
  import os
10
  import asyncio
11
 
 
 
12
  load_dotenv()
 
13
  api_key = os.getenv("OPENAI_API_KEY")
14
 
15
  llm = ChatOpenAI(temperature=0, model_name="gpt-4o")
 
16
  graph_transformer = LLMGraphTransformer(llm=llm)
17
 
 
 
18
  async def extract_graph_data(text):
 
 
 
 
 
 
 
 
 
19
  documents = [Document(page_content=text)]
20
  graph_documents = await graph_transformer.aconvert_to_graph_documents(documents)
21
  return graph_documents
22
 
 
23
  def visualize_graph(graph_documents):
24
+ net = Network(height="600px", width="100%", directed=True, notebook=False, bgcolor="#222222", font_color="white", filter_menu=True, cdn_resources='remote')
 
 
 
 
 
 
 
 
 
 
 
 
25
  nodes = graph_documents[0].nodes
26
  relationships = graph_documents[0].relationships
27
 
 
28
  node_dict = {node.id: node for node in nodes}
 
 
29
  valid_edges = []
30
  valid_node_ids = set()
31
+
32
  for rel in relationships:
33
  if rel.source.id in node_dict and rel.target.id in node_dict:
34
  valid_edges.append(rel)
35
  valid_node_ids.update([rel.source.id, rel.target.id])
36
 
 
 
 
 
 
 
 
37
  for node_id in valid_node_ids:
38
  node = node_dict[node_id]
39
  try:
40
  net.add_node(node.id, label=node.id, title=node.type, group=node.type)
41
  except:
42
+ continue
43
 
 
44
  for rel in valid_edges:
45
  try:
46
  net.add_edge(rel.source.id, rel.target.id, label=rel.type.lower())
47
  except:
48
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ net.set_options('{"physics": {"forceAtlas2Based": {"gravitationalConstant": -100, "centralGravity": 0.01, "springLength": 200, "springConstant": 0.08}, "minVelocity": 0.75, "solver": "forceAtlas2Based"}}')
51
+ return net
52
 
53
  def generate_knowledge_graph(text):
 
 
 
 
 
 
 
 
 
 
 
 
54
  graph_documents = asyncio.run(extract_graph_data(text))
55
  net = visualize_graph(graph_documents)
56
+ return net, graph_documents
57
+
58
+ def answer_question_with_graph(question, graph_documents, k_relations=5):
59
+ all_relationships = []
60
+ for doc in graph_documents:
61
+ all_relationships.extend(doc.relationships)
62
+
63
+ if not all_relationships:
64
+ return "Aucune relation trouvée dans le graphe.", visualize_graph(graph_documents)
65
+
66
+ rel_docs = []
67
+ for i, rel in enumerate(all_relationships):
68
+ text_rep = f"L'entité '{rel.source.id}' a pour relation '{rel.type}' avec l'entité '{rel.target.id}'."
69
+ rel_docs.append(Document(page_content=text_rep, metadata={"rel_index": i}))
70
+
71
+ embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
72
+ vectorstore = FAISS.from_documents(rel_docs, embeddings)
73
+ retrieved_docs = vectorstore.similarity_search(question, k=k_relations)
74
+
75
+ used_relationships = [all_relationships[doc.metadata["rel_index"]] for doc in retrieved_docs]
76
+ context = "\n".join([doc.page_content for doc in retrieved_docs])
77
+
78
+ prompt = PromptTemplate(
79
+ template=""""Tu es un assistant expert qui répond aux questions en se basant UNIQUEMENT sur ce sous-ensemble de relations extraites d'un graphe de connaissances.\n\nContexte (Relations pertinentes trouvées) :\n{context}\n\nQuestion : {question}\n\nRéponds de manière claire et concise en français. Si la réponse n'est pas dans le contexte fourni, dis-le explicitement."""",
80
+ input_variables=["context", "question"]
81
+ )
82
+
83
+ chain = prompt | llm
84
+ answer = chain.invoke({"context": context, "question": question}).content
85
+
86
+ net = Network(height="450px", width="100%", directed=True, bgcolor="#222222", font_color="white")
87
+
88
+ nodes_added = set()
89
+ for rel in used_relationships:
90
+ if rel.source.id not in nodes_added:
91
+ net.add_node(rel.source.id, label=rel.source.id, title=rel.source.type, group=rel.source.type)
92
+ nodes_added.add(rel.source.id)
93
+ if rel.target.id not in nodes_added:
94
+ net.add_node(rel.target.id, label=rel.target.id, title=rel.target.type, group=rel.target.type)
95
+ nodes_added.add(rel.target.id)
96
+ try:
97
+ net.add_edge(rel.source.id, rel.target.id, label=rel.type)
98
+ except:
99
+ pass
100
+
101
+ net.set_options('{"physics": {"forceAtlas2Based": {"gravitationalConstant": -50}}}')
102
+ net.save_graph("filtered_graph.html")
103
+
104
+ return answer, net
generate_knowledge_graph_v1.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_experimental.graph_transformers import LLMGraphTransformer
2
+ from langchain_core.documents import Document
3
+ from langchain_openai import ChatOpenAI
4
+ from pyvis.network import Network
5
+
6
+ from dotenv import load_dotenv
7
+ import os
8
+ import asyncio
9
+
10
+
11
+ # Load the .env file
12
+ load_dotenv()
13
+ # Get API key from environment variable
14
+ api_key = os.getenv("OPENAI_API_KEY")
15
+
16
+ llm = ChatOpenAI(temperature=0, model_name="gpt-4o")
17
+
18
+ graph_transformer = LLMGraphTransformer(llm=llm)
19
+
20
+
21
+ # Extract graph data from input text
22
+ async def extract_graph_data(text):
23
+ """
24
+ Asynchronously extracts graph data from input text using a graph transformer.
25
+
26
+ Args:
27
+ text (str): Input text to be processed into graph format.
28
+
29
+ Returns:
30
+ list: A list of GraphDocument objects containing nodes and relationships.
31
+ """
32
+ documents = [Document(page_content=text)]
33
+ graph_documents = await graph_transformer.aconvert_to_graph_documents(documents)
34
+ return graph_documents
35
+
36
+
37
+ def visualize_graph(graph_documents):
38
+ """
39
+ Visualizes a knowledge graph using PyVis based on the extracted graph documents.
40
+
41
+ Args:
42
+ graph_documents (list): A list of GraphDocument objects with nodes and relationships.
43
+
44
+ Returns:
45
+ pyvis.network.Network: The visualized network graph object.
46
+ """
47
+ # Create network
48
+ net = Network(height="1200px", width="100%", directed=True,
49
+ notebook=False, bgcolor="#222222", font_color="white", filter_menu=True, cdn_resources='remote')
50
+
51
+ nodes = graph_documents[0].nodes
52
+ relationships = graph_documents[0].relationships
53
+
54
+ # Build lookup for valid nodes
55
+ node_dict = {node.id: node for node in nodes}
56
+
57
+ # Filter out invalid edges and collect valid node IDs
58
+ valid_edges = []
59
+ valid_node_ids = set()
60
+ for rel in relationships:
61
+ if rel.source.id in node_dict and rel.target.id in node_dict:
62
+ valid_edges.append(rel)
63
+ valid_node_ids.update([rel.source.id, rel.target.id])
64
+
65
+ # Track which nodes are part of any relationship
66
+ connected_node_ids = set()
67
+ for rel in relationships:
68
+ connected_node_ids.add(rel.source.id)
69
+ connected_node_ids.add(rel.target.id)
70
+
71
+ # Add valid nodes to the graph
72
+ for node_id in valid_node_ids:
73
+ node = node_dict[node_id]
74
+ try:
75
+ net.add_node(node.id, label=node.id, title=node.type, group=node.type)
76
+ except:
77
+ continue # Skip node if error occurs
78
+
79
+ # Add valid edges to the graph
80
+ for rel in valid_edges:
81
+ try:
82
+ net.add_edge(rel.source.id, rel.target.id, label=rel.type.lower())
83
+ except:
84
+ continue # Skip edge if error occurs
85
+
86
+ # Configure graph layout and physics
87
+ net.set_options("""
88
+ {
89
+ "physics": {
90
+ "forceAtlas2Based": {
91
+ "gravitationalConstant": -100,
92
+ "centralGravity": 0.01,
93
+ "springLength": 200,
94
+ "springConstant": 0.08
95
+ },
96
+ "minVelocity": 0.75,
97
+ "solver": "forceAtlas2Based"
98
+ }
99
+ }
100
+ """)
101
+
102
+ output_file = "knowledge_graph.html"
103
+ try:
104
+ net.save_graph(output_file)
105
+ print(f"Graph saved to {os.path.abspath(output_file)}")
106
+ return net
107
+ except Exception as e:
108
+ print(f"Error saving graph: {e}")
109
+ return None
110
+
111
+
112
+ def generate_knowledge_graph(text):
113
+ """
114
+ Generates and visualizes a knowledge graph from input text.
115
+
116
+ This function runs the graph extraction asynchronously and then visualizes
117
+ the resulting graph using PyVis.
118
+
119
+ Args:
120
+ text (str): Input text to convert into a knowledge graph.
121
+
122
+ Returns:
123
+ pyvis.network.Network: The visualized network graph object.
124
+ """
125
+ graph_documents = asyncio.run(extract_graph_data(text))
126
+ net = visualize_graph(graph_documents)
127
+ return net
requirements.txt CHANGED
@@ -11,3 +11,6 @@ pyvis>=0.3.2
11
 
12
  # Web UI
13
  streamlit>=1.32.0
 
 
 
 
11
 
12
  # Web UI
13
  streamlit>=1.32.0
14
+
15
+ faiss-cpu
16
+ tiktoken