mayzinoo commited on
Commit
f017e89
·
verified ·
1 Parent(s): d3e89bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -17
app.py CHANGED
@@ -96,32 +96,27 @@ Do not summarize. Copy the official description exactly.
96
 
97
  }
98
 
99
- def generate_output(prompt_type, query):
100
- import re
101
-
102
- # Check for exact SOL code in query
103
  sol_match = re.search(r"\bG\.[A-Z]+\.\d+\b", query)
104
  matched_code = sol_match.group(0) if sol_match else None
105
 
106
  if matched_code:
107
- # Find matching document by metadata
108
- docs_with_code = vectordb.similarity_search_with_score(query)
109
- filtered_docs = [doc for doc, _ in docs_with_code if doc.metadata.get("standard") == matched_code]
110
-
111
- if filtered_docs:
112
- context = "\n\n".join([doc.page_content for doc in filtered_docs])
113
- else:
114
- # fallback to regular retrieval if not found
115
- docs = retriever.get_relevant_documents(query)
116
- context = "\n\n".join([doc.page_content for doc in docs])
117
  else:
118
- # regular query
119
  docs = retriever.get_relevant_documents(query)
120
  context = "\n\n".join([doc.page_content for doc in docs])
121
 
122
- # Run the prompt
123
  chain = LLMChain(llm=llm, prompt=templates[prompt_type])
124
- return chain.run({"context": context, "query": query}).strip()
125
 
126
 
127
 
 
96
 
97
  }
98
 
99
+ def generate_prompt_output(prompt_type, query, retriever, llm):
100
+ # Try to extract SOL code
 
 
101
  sol_match = re.search(r"\bG\.[A-Z]+\.\d+\b", query)
102
  matched_code = sol_match.group(0) if sol_match else None
103
 
104
  if matched_code:
105
+ # Retrieve and filter by metadata
106
+ all_docs = retriever.vectorstore._collection.get(include=['documents', 'metadatas'])
107
+ filtered = []
108
+ for doc_text, metadata in zip(all_docs['documents'], all_docs['metadatas']):
109
+ if metadata.get('standard') == matched_code:
110
+ filtered.append(doc_text)
111
+
112
+ context = "\n\n".join(filtered)
 
 
113
  else:
114
+ # fallback to semantic retrieval
115
  docs = retriever.get_relevant_documents(query)
116
  context = "\n\n".join([doc.page_content for doc in docs])
117
 
 
118
  chain = LLMChain(llm=llm, prompt=templates[prompt_type])
119
+ return chain.run({"context": context, "query": query})
120
 
121
 
122