context-prune / optimizer_ui.py
prithic07's picture
Upgrade RAG project with advanced Context Optimizer environment and baseline inference
0b89610
from __future__ import annotations
import requests
import streamlit as st
API_URL = st.secrets.get("API_URL", "http://localhost:7860") if hasattr(st, "secrets") else "http://localhost:7860"
st.set_page_config(page_title="rag-context-optimizer", page_icon="R", layout="wide")
st.title("RAG Context Optimizer")
st.caption("Use any prompt, keep the token budget tight, and let the optimizer pick the best evidence per token.")
def api_get(path: str):
response = requests.get(f"{API_URL}{path}", timeout=20)
response.raise_for_status()
return response.json()
def api_post(path: str, payload: dict | None = None):
response = requests.post(f"{API_URL}{path}", json=payload or {}, timeout=20)
response.raise_for_status()
return response.json()
def start_episode(task_name: str, query: str, token_budget: int, max_steps: int):
st.session_state["payload"] = api_post(
"/reset",
{
"task_name": task_name,
"custom_query": query,
"token_budget": token_budget,
"max_steps": max_steps,
},
)
def do_step(payload: dict):
st.session_state["payload"] = api_post("/step", payload)
tasks = api_get("/tasks")
task_map = {task["name"]: task for task in tasks}
selected_task = st.sidebar.selectbox("Task preset", list(task_map))
task_meta = task_map[selected_task]
default_query = st.session_state.get("custom_query", "")
custom_query = st.sidebar.text_area(
"Custom prompt",
value=default_query,
height=180,
placeholder="Enter any prompt you want to optimize for minimal token usage.",
)
token_budget = st.sidebar.number_input(
"Token budget",
min_value=50,
value=int(task_meta["token_budget"]),
step=10,
)
max_steps = st.sidebar.number_input(
"Max steps",
min_value=1,
value=int(task_meta["max_steps"]),
step=1,
)
st.session_state["custom_query"] = custom_query
sidebar_cols = st.sidebar.columns(2)
if sidebar_cols[0].button("Start / Reset", use_container_width=True):
if not custom_query.strip():
st.sidebar.error("Enter a custom prompt first.")
else:
start_episode(selected_task, custom_query.strip(), int(token_budget), int(max_steps))
st.rerun()
if sidebar_cols[1].button("Refresh", use_container_width=True):
st.rerun()
if "payload" not in st.session_state:
st.info("Add your prompt in the sidebar and press Start / Reset.")
st.stop()
payload = st.session_state["payload"]
observation = payload["observation"]
col1, col2, col3, col4 = st.columns(4)
col1.metric("Task", observation["task_name"])
col2.metric("Budget", observation["token_budget"])
col3.metric("Used", observation["total_tokens_used"])
col4.metric("Step", observation["step_number"])
st.subheader("Active Query")
st.info(observation["query"])
feedback = observation.get("last_action_feedback")
if feedback:
st.warning(feedback)
if payload.get("info", {}).get("grader_breakdown"):
st.success(f"Final score: {payload.get('reward', 0):.4f}")
st.json(payload["info"]["grader_breakdown"])
action_cols = st.columns(3)
if action_cols[0].button("Auto Optimize Step", use_container_width=True):
suggestion = api_post("/optimize-step")
do_step(suggestion)
st.rerun()
if action_cols[1].button("Auto Run", use_container_width=True):
for _ in range(20):
suggestion = api_post("/optimize-step")
do_step(suggestion)
if suggestion["action_type"] == "submit_answer" or st.session_state["payload"]["done"]:
break
st.rerun()
manual_answer = action_cols[2].text_input("Manual answer", value="")
if st.button("Submit Manual Answer", type="primary", use_container_width=True):
do_step(
{
"action_type": "submit_answer",
"answer": manual_answer.strip() or "Concise answer synthesized from the selected evidence.",
}
)
st.rerun()
st.subheader("Available Chunks")
chunk_columns = st.columns(2)
for index, chunk in enumerate(observation["available_chunks"]):
selected = chunk["chunk_id"] in set(observation["selected_chunks"])
container = chunk_columns[index % 2].container(border=True)
container.markdown(f"**{chunk['chunk_id']}**")
container.caption(f"{chunk['domain']} | {chunk['tokens']} tokens")
container.write(", ".join(chunk["keywords"]))
c1, c2 = container.columns(2)
if selected:
if c1.button("Deselect", key=f"deselect-{chunk['chunk_id']}", use_container_width=True):
do_step({"action_type": "deselect_chunk", "chunk_id": chunk["chunk_id"]})
st.rerun()
else:
if c1.button("Select", key=f"select-{chunk['chunk_id']}", use_container_width=True):
do_step({"action_type": "select_chunk", "chunk_id": chunk["chunk_id"]})
st.rerun()
if c2.button("Compress 50%", key=f"compress-{chunk['chunk_id']}", use_container_width=True):
do_step(
{
"action_type": "compress_chunk",
"chunk_id": chunk["chunk_id"],
"compression_ratio": 0.5,
}
)
st.rerun()
st.subheader("Observation")
st.json(payload)
st.subheader("State")
st.json(api_get("/state"))