Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import requests | |
| import streamlit as st | |
| API_URL = st.secrets.get("API_URL", "http://localhost:7860") if hasattr(st, "secrets") else "http://localhost:7860" | |
| st.set_page_config(page_title="rag-context-optimizer", page_icon="R", layout="wide") | |
| st.title("RAG Context Optimizer") | |
| st.caption("Use any prompt, keep the token budget tight, and let the optimizer pick the best evidence per token.") | |
| def api_get(path: str): | |
| response = requests.get(f"{API_URL}{path}", timeout=20) | |
| response.raise_for_status() | |
| return response.json() | |
| def api_post(path: str, payload: dict | None = None): | |
| response = requests.post(f"{API_URL}{path}", json=payload or {}, timeout=20) | |
| response.raise_for_status() | |
| return response.json() | |
| def start_episode(task_name: str, query: str, token_budget: int, max_steps: int): | |
| st.session_state["payload"] = api_post( | |
| "/reset", | |
| { | |
| "task_name": task_name, | |
| "custom_query": query, | |
| "token_budget": token_budget, | |
| "max_steps": max_steps, | |
| }, | |
| ) | |
| def do_step(payload: dict): | |
| st.session_state["payload"] = api_post("/step", payload) | |
| tasks = api_get("/tasks") | |
| task_map = {task["name"]: task for task in tasks} | |
| selected_task = st.sidebar.selectbox("Task preset", list(task_map)) | |
| task_meta = task_map[selected_task] | |
| default_query = st.session_state.get("custom_query", "") | |
| custom_query = st.sidebar.text_area( | |
| "Custom prompt", | |
| value=default_query, | |
| height=180, | |
| placeholder="Enter any prompt you want to optimize for minimal token usage.", | |
| ) | |
| token_budget = st.sidebar.number_input( | |
| "Token budget", | |
| min_value=50, | |
| value=int(task_meta["token_budget"]), | |
| step=10, | |
| ) | |
| max_steps = st.sidebar.number_input( | |
| "Max steps", | |
| min_value=1, | |
| value=int(task_meta["max_steps"]), | |
| step=1, | |
| ) | |
| st.session_state["custom_query"] = custom_query | |
| sidebar_cols = st.sidebar.columns(2) | |
| if sidebar_cols[0].button("Start / Reset", use_container_width=True): | |
| if not custom_query.strip(): | |
| st.sidebar.error("Enter a custom prompt first.") | |
| else: | |
| start_episode(selected_task, custom_query.strip(), int(token_budget), int(max_steps)) | |
| st.rerun() | |
| if sidebar_cols[1].button("Refresh", use_container_width=True): | |
| st.rerun() | |
| if "payload" not in st.session_state: | |
| st.info("Add your prompt in the sidebar and press Start / Reset.") | |
| st.stop() | |
| payload = st.session_state["payload"] | |
| observation = payload["observation"] | |
| col1, col2, col3, col4 = st.columns(4) | |
| col1.metric("Task", observation["task_name"]) | |
| col2.metric("Budget", observation["token_budget"]) | |
| col3.metric("Used", observation["total_tokens_used"]) | |
| col4.metric("Step", observation["step_number"]) | |
| st.subheader("Active Query") | |
| st.info(observation["query"]) | |
| feedback = observation.get("last_action_feedback") | |
| if feedback: | |
| st.warning(feedback) | |
| if payload.get("info", {}).get("grader_breakdown"): | |
| st.success(f"Final score: {payload.get('reward', 0):.4f}") | |
| st.json(payload["info"]["grader_breakdown"]) | |
| action_cols = st.columns(3) | |
| if action_cols[0].button("Auto Optimize Step", use_container_width=True): | |
| suggestion = api_post("/optimize-step") | |
| do_step(suggestion) | |
| st.rerun() | |
| if action_cols[1].button("Auto Run", use_container_width=True): | |
| for _ in range(20): | |
| suggestion = api_post("/optimize-step") | |
| do_step(suggestion) | |
| if suggestion["action_type"] == "submit_answer" or st.session_state["payload"]["done"]: | |
| break | |
| st.rerun() | |
| manual_answer = action_cols[2].text_input("Manual answer", value="") | |
| if st.button("Submit Manual Answer", type="primary", use_container_width=True): | |
| do_step( | |
| { | |
| "action_type": "submit_answer", | |
| "answer": manual_answer.strip() or "Concise answer synthesized from the selected evidence.", | |
| } | |
| ) | |
| st.rerun() | |
| st.subheader("Available Chunks") | |
| chunk_columns = st.columns(2) | |
| for index, chunk in enumerate(observation["available_chunks"]): | |
| selected = chunk["chunk_id"] in set(observation["selected_chunks"]) | |
| container = chunk_columns[index % 2].container(border=True) | |
| container.markdown(f"**{chunk['chunk_id']}**") | |
| container.caption(f"{chunk['domain']} | {chunk['tokens']} tokens") | |
| container.write(", ".join(chunk["keywords"])) | |
| c1, c2 = container.columns(2) | |
| if selected: | |
| if c1.button("Deselect", key=f"deselect-{chunk['chunk_id']}", use_container_width=True): | |
| do_step({"action_type": "deselect_chunk", "chunk_id": chunk["chunk_id"]}) | |
| st.rerun() | |
| else: | |
| if c1.button("Select", key=f"select-{chunk['chunk_id']}", use_container_width=True): | |
| do_step({"action_type": "select_chunk", "chunk_id": chunk["chunk_id"]}) | |
| st.rerun() | |
| if c2.button("Compress 50%", key=f"compress-{chunk['chunk_id']}", use_container_width=True): | |
| do_step( | |
| { | |
| "action_type": "compress_chunk", | |
| "chunk_id": chunk["chunk_id"], | |
| "compression_ratio": 0.5, | |
| } | |
| ) | |
| st.rerun() | |
| st.subheader("Observation") | |
| st.json(payload) | |
| st.subheader("State") | |
| st.json(api_get("/state")) | |