hhy-huang commited on
Commit
19838e5
·
1 Parent(s): b6983f9
Files changed (1) hide show
  1. README.md +145 -3
README.md CHANGED
@@ -9,8 +9,150 @@ base_model:
9
 
10
  # DeepRefine-v1-8B
11
 
12
- A general LLM-based reasoning model for agent-compiled knowledge refinement that improves the quality of any pre-constructed knowledge bases with user queries to make it more9suitable for the downstream tasks.
13
 
14
- ## Quick Start
 
15
 
16
- Todo.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # DeepRefine-v1-8B
11
 
12
+ A general LLM-based reasoning model for agent-compiled knowledge refinement that improves the quality of any pre-constructed knowledge bases with user queries to make it more suitable for the downstream tasks.
13
 
14
+ ## Simple Demo
15
+ Below is a single refinement process with respect to a specfic query and pre-defined results from its interactions with any databases.
16
 
17
+ ```python
18
+ DEEPREFINE_JUDGEMENT_SYSTEM_PROMPT = """
19
+ As an advanced judgement assistant, your task is to judge whether the given question is answerable based on the provided KG context.
20
+
21
+ Evaluate whether the given question is answerable based on the provided KG context. Output your judgment in the following format:
22
+ <judge>Yes</judge> or <judge>No</judge>
23
+
24
+ **Important:** You must think carefully about the question and the KG context before making your judgment. And output your judgment result directly in the specified format.
25
+ """
26
+
27
+ DEEPREFINE_JUDGEMENT_USER_PROMPT = """
28
+ Question: {question}
29
+ Knowledge Graph (KG) context: {triples_string}
30
+ """
31
+
32
+ DEEPREFINE_ERROR_ABDUCTION_SYSTEM_PROMPT = """
33
+ As an advanced error abduction assistant, your task is to analyze the error reasons based on the given interaction history.
34
+
35
+ Analyze the reasons of the unanswerable questions based on the given interaction history from the incompleteness, incorrectness, and redundancy perspectives. Output your analysis in the following format:
36
+ <abduction>...</abduction>
37
+
38
+ **Important:** You must think carefully about the interaction history before making your analysis. And output your analysis result directly in the specified format.
39
+ """
40
+
41
+ DEEPREFINE_ERROR_ABDUCTION_USER_PROMPT = """
42
+ Interaction history: {interaction_history}
43
+ """
44
+
45
+ DEEPREFINE_ACTION_SYSTEM_PROMPT = """
46
+ As an advanced knowledge graph refinement assistant, your task is to generate a series of actions (**within 10 actions**) to refine the given KG to make it more suitable for answering the given question.
47
+
48
+ Based on the given KG and the analysed error reasons, refine the given KG to make it more easily for retrieval and answering the given question. You have the following three types of actions to conduct:
49
+
50
+ - insert_edge(subject, relation, object): Insert a new edge into the KG to complete the missing information.
51
+ - delete_edge(subject, relation, object): Delete an edge from the KG to remove the redundant information or conflicting information.
52
+ - replace_node(old_entity, new_entity): Replace an entity in the KG to correct the errors or deal with disambiguation.
53
+
54
+ Output a series of actions (**within 10 actions**) in the following format:
55
+ <refinement>insert_edge("...", "...", "...")|delete_edge("...", "...", "...")|replace_node("...", "...")|...</refinement>
56
+
57
+ **Important:** You must think carefully about the given KG and the analysed error reasons before making your refinement. DO NOT DELETE ANY IRRELEVANT TRIPLES FROM THE ORIGINAL KG. TRY TO KEEP THE ORIGINAL KG AS MUCH AS POSSIBLE. DO NOT GENERATE TOO MANY ACTIONS. And output your refinement result directly in the specified format.
58
+ """
59
+
60
+ DEEPREFINE_ACTION_USER_PROMPT = """
61
+ Original Text: {original_text}
62
+ KG: {triples_string}
63
+ Question: {question}
64
+ Error reasons: {error_reasons}
65
+ """
66
+
67
+ import torch
68
+ from transformers import AutoTokenizer, AutoModelForCausalLM
69
+
70
+ model_name = "HaoyuHuang/DeepRefine-v1-8B"
71
+ device = "cuda" if torch.cuda.is_available() else "cpu"
72
+
73
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
74
+ model = AutoModelForCausalLM.from_pretrained(
75
+ model_name,
76
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
77
+ device_map="auto" if device == "cuda" else None,
78
+ trust_remote_code=True,
79
+ )
80
+ if device != "cuda":
81
+ model = model.to(device)
82
+
83
+ def call_model(system_prompt: str, user_prompt: str, max_new_tokens: int = 8192) -> str:
84
+ messages = [
85
+ {"role": "system", "content": system_prompt},
86
+ {"role": "user", "content": user_prompt},
87
+ ]
88
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
89
+ inputs = tokenizer(text, return_tensors="pt").to(model.device)
90
+
91
+ with torch.no_grad():
92
+ outputs = model.generate(
93
+ **inputs,
94
+ max_new_tokens=max_new_tokens,
95
+ do_sample=False,
96
+ temperature=0.0,
97
+ )
98
+
99
+ generated_ids = outputs[0, inputs["input_ids"].shape[-1]:]
100
+ return tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
101
+
102
+ @dataclass
103
+ class RetrievalStepResult:
104
+ """
105
+ For single step inference result, for debugging / analysis.
106
+ """
107
+ num_hops: int
108
+ base_top_k: int
109
+ query: str
110
+ retrieved_subgraph: List[Dict[str, str]]
111
+ raw_response: str
112
+ answerable: bool
113
+ answer: Optional[str] = None
114
+
115
+ question = "your_question"
116
+ triples_string = "retrieved_triples:
117
+ interaction_history = "list_of_RetrievalStepResult"
118
+ original_text = "triple_related_original_text"
119
+
120
+ ## Answerability Judgement Call
121
+ judgement_user_prompt = DEEPREFINE_JUDGEMENT_USER_PROMPT.format(
122
+ question=question,
123
+ triples_string=triples_string,
124
+ )
125
+ judgement_result = call_model(
126
+ DEEPREFINE_JUDGEMENT_SYSTEM_PROMPT,
127
+ judgement_user_prompt
128
+ )
129
+ print("Answerability Judgement:", judgement_result)
130
+
131
+
132
+ ## Error Abduction Call
133
+ abduction_user_prompt = DEEPREFINE_ERROR_ABDUCTION_USER_PROMPT.format(
134
+ interaction_history=interaction_history,
135
+ )
136
+ abduction_result = call_model(
137
+ DEEPREFINE_ERROR_ABDUCTION_SYSTEM_PROMPT,
138
+ abduction_user_prompt,
139
+ )
140
+ print("Error Abduction:", abduction_result)
141
+
142
+
143
+ ## Refinement Actions Generation Call
144
+ actions_user_prompt = DEEPREFINE_ACTION_USER_PROMPT.format(
145
+ original_text=original_text,
146
+ triples_string=triples_string,
147
+ question=question,
148
+ error_reasons=abduction_result,
149
+ )
150
+ actions_result = call_model(
151
+ DEEPREFINE_ACTION_SYSTEM_PROMPT,
152
+ actions_user_prompt,
153
+ )
154
+ print("Refinement Actions:", actions_result)
155
+
156
+ ## parse the refinement actions and make an interface with the operators of your own databases
157
+ # ...
158
+ ```