Spaces:
Sleeping
Sleeping
fix
Browse files
backend/api/demo.py
CHANGED
|
@@ -48,7 +48,7 @@ from rl.types import RepairAction, REPAIR_ACTION_NAMES, REPAIR_ACTION_BY_NAME
|
|
| 48 |
from rl.error_classifier import classify_error, extract_offending_token
|
| 49 |
from rl.grader import GraderInput, compute_reward, compute_episode_reward
|
| 50 |
from rl.types import RLState, EpisodeStep, featurize, ERROR_CLASS_NAMES
|
| 51 |
-
from gepa.optimizer import get_gepa, QueryResult
|
| 52 |
|
| 53 |
router = APIRouter()
|
| 54 |
|
|
@@ -136,12 +136,14 @@ async def get_prompt_history():
|
|
| 136 |
}
|
| 137 |
for c in sorted(pareto, key=lambda x: x.generation)
|
| 138 |
]
|
|
|
|
| 139 |
return {
|
| 140 |
"prompt": gepa.get_current_prompt(),
|
| 141 |
"generation": gepa.current_generation,
|
| 142 |
"history": history,
|
| 143 |
-
"queryCount":
|
| 144 |
-
"
|
|
|
|
| 145 |
}
|
| 146 |
|
| 147 |
|
|
|
|
| 48 |
from rl.error_classifier import classify_error, extract_offending_token
|
| 49 |
from rl.grader import GraderInput, compute_reward, compute_episode_reward
|
| 50 |
from rl.types import RLState, EpisodeStep, featurize, ERROR_CLASS_NAMES
|
| 51 |
+
from gepa.optimizer import get_gepa, QueryResult, GEPA_OPTIMIZE_EVERY
|
| 52 |
|
| 53 |
router = APIRouter()
|
| 54 |
|
|
|
|
| 136 |
}
|
| 137 |
for c in sorted(pareto, key=lambda x: x.generation)
|
| 138 |
]
|
| 139 |
+
query_count = len(gepa.get_history())
|
| 140 |
return {
|
| 141 |
"prompt": gepa.get_current_prompt(),
|
| 142 |
"generation": gepa.current_generation,
|
| 143 |
"history": history,
|
| 144 |
+
"queryCount": query_count,
|
| 145 |
+
"optimizeEvery": GEPA_OPTIMIZE_EVERY,
|
| 146 |
+
"cycleProgress": query_count % GEPA_OPTIMIZE_EVERY,
|
| 147 |
}
|
| 148 |
|
| 149 |
|
backend/gepa/optimizer.py
CHANGED
|
@@ -26,6 +26,10 @@ GEPA_PATH = _DATA_DIR / "gepa_prompt.json"
|
|
| 26 |
|
| 27 |
_MODEL = os.environ.get("MODEL_NAME", "gpt-4o-mini")
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
SEED_SYSTEM_PROMPT = """You are a SQL expert. Given a natural language question and a SQLite database schema, write a correct SQL query.
|
| 30 |
|
| 31 |
Rules:
|
|
@@ -165,7 +169,7 @@ class GEPAOptimizer:
|
|
| 165 |
return max(c.generation for c in self._pareto_front)
|
| 166 |
|
| 167 |
def should_optimize(self) -> bool:
|
| 168 |
-
return len(self._history) > 0 and len(self._history) %
|
| 169 |
|
| 170 |
def reset(self) -> None:
|
| 171 |
self._history.clear()
|
|
|
|
| 26 |
|
| 27 |
_MODEL = os.environ.get("MODEL_NAME", "gpt-4o-mini")
|
| 28 |
|
| 29 |
+
# How many queries between each GEPA optimization cycle.
|
| 30 |
+
# Override with the GEPA_OPTIMIZE_EVERY environment variable.
|
| 31 |
+
GEPA_OPTIMIZE_EVERY: int = int(os.environ.get("GEPA_OPTIMIZE_EVERY", "4"))
|
| 32 |
+
|
| 33 |
SEED_SYSTEM_PROMPT = """You are a SQL expert. Given a natural language question and a SQLite database schema, write a correct SQL query.
|
| 34 |
|
| 35 |
Rules:
|
|
|
|
| 169 |
return max(c.generation for c in self._pareto_front)
|
| 170 |
|
| 171 |
def should_optimize(self) -> bool:
|
| 172 |
+
return len(self._history) > 0 and len(self._history) % GEPA_OPTIMIZE_EVERY == 0
|
| 173 |
|
| 174 |
def reset(self) -> None:
|
| 175 |
self._history.clear()
|
frontend/src/components/PromptEvolution.tsx
CHANGED
|
@@ -22,17 +22,18 @@ export function PromptEvolution() {
|
|
| 22 |
const generation = promptGeneration
|
| 23 |
|
| 24 |
const [queryCount, setQueryCount] = useState(0)
|
| 25 |
-
const [
|
|
|
|
| 26 |
|
| 27 |
const loadHistory = async () => {
|
| 28 |
setLoading(true)
|
| 29 |
try {
|
| 30 |
const data = await fetchPromptHistory()
|
| 31 |
setPromptData(data)
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
} catch {
|
| 37 |
// noop
|
| 38 |
} finally {
|
|
@@ -80,12 +81,14 @@ export function PromptEvolution() {
|
|
| 80 |
<div className="flex flex-col gap-1">
|
| 81 |
<div className="flex items-center justify-between text-[9px] text-gray-600">
|
| 82 |
<span>{queryCount} queries processed</span>
|
| 83 |
-
<span
|
|
|
|
|
|
|
| 84 |
</div>
|
| 85 |
<div className="h-1 bg-white/5 rounded-full overflow-hidden">
|
| 86 |
<div
|
| 87 |
className="h-full rounded-full bg-violet-500/50 transition-all duration-500"
|
| 88 |
-
style={{ width: `${(
|
| 89 |
/>
|
| 90 |
</div>
|
| 91 |
</div>
|
|
|
|
| 22 |
const generation = promptGeneration
|
| 23 |
|
| 24 |
const [queryCount, setQueryCount] = useState(0)
|
| 25 |
+
const [optimizeEvery, setOptimizeEvery] = useState(4)
|
| 26 |
+
const [cycleProgress, setCycleProgress] = useState(0)
|
| 27 |
|
| 28 |
const loadHistory = async () => {
|
| 29 |
setLoading(true)
|
| 30 |
try {
|
| 31 |
const data = await fetchPromptHistory()
|
| 32 |
setPromptData(data)
|
| 33 |
+
const d = data as Record<string, unknown>
|
| 34 |
+
if (d.queryCount !== undefined) setQueryCount(d.queryCount as number)
|
| 35 |
+
if (d.optimizeEvery !== undefined) setOptimizeEvery(d.optimizeEvery as number)
|
| 36 |
+
if (d.cycleProgress !== undefined) setCycleProgress(d.cycleProgress as number)
|
| 37 |
} catch {
|
| 38 |
// noop
|
| 39 |
} finally {
|
|
|
|
| 81 |
<div className="flex flex-col gap-1">
|
| 82 |
<div className="flex items-center justify-between text-[9px] text-gray-600">
|
| 83 |
<span>{queryCount} queries processed</span>
|
| 84 |
+
<span className="text-gray-700">
|
| 85 |
+
{cycleProgress}/{optimizeEvery} · optimizes every {optimizeEvery}
|
| 86 |
+
</span>
|
| 87 |
</div>
|
| 88 |
<div className="h-1 bg-white/5 rounded-full overflow-hidden">
|
| 89 |
<div
|
| 90 |
className="h-full rounded-full bg-violet-500/50 transition-all duration-500"
|
| 91 |
+
style={{ width: `${(cycleProgress / optimizeEvery) * 100}%` }}
|
| 92 |
/>
|
| 93 |
</div>
|
| 94 |
</div>
|