ar9avg commited on
Commit
44ef33f
·
1 Parent(s): 2f89522
backend/api/demo.py CHANGED
@@ -48,7 +48,7 @@ from rl.types import RepairAction, REPAIR_ACTION_NAMES, REPAIR_ACTION_BY_NAME
48
  from rl.error_classifier import classify_error, extract_offending_token
49
  from rl.grader import GraderInput, compute_reward, compute_episode_reward
50
  from rl.types import RLState, EpisodeStep, featurize, ERROR_CLASS_NAMES
51
- from gepa.optimizer import get_gepa, QueryResult
52
 
53
  router = APIRouter()
54
 
@@ -136,12 +136,14 @@ async def get_prompt_history():
136
  }
137
  for c in sorted(pareto, key=lambda x: x.generation)
138
  ]
 
139
  return {
140
  "prompt": gepa.get_current_prompt(),
141
  "generation": gepa.current_generation,
142
  "history": history,
143
- "queryCount": len(gepa.get_history()),
144
- "nextOptimizeAt": (len(gepa.get_history()) // 4 + 1) * 4,
 
145
  }
146
 
147
 
 
48
  from rl.error_classifier import classify_error, extract_offending_token
49
  from rl.grader import GraderInput, compute_reward, compute_episode_reward
50
  from rl.types import RLState, EpisodeStep, featurize, ERROR_CLASS_NAMES
51
+ from gepa.optimizer import get_gepa, QueryResult, GEPA_OPTIMIZE_EVERY
52
 
53
  router = APIRouter()
54
 
 
136
  }
137
  for c in sorted(pareto, key=lambda x: x.generation)
138
  ]
139
+ query_count = len(gepa.get_history())
140
  return {
141
  "prompt": gepa.get_current_prompt(),
142
  "generation": gepa.current_generation,
143
  "history": history,
144
+ "queryCount": query_count,
145
+ "optimizeEvery": GEPA_OPTIMIZE_EVERY,
146
+ "cycleProgress": query_count % GEPA_OPTIMIZE_EVERY,
147
  }
148
 
149
 
backend/gepa/optimizer.py CHANGED
@@ -26,6 +26,10 @@ GEPA_PATH = _DATA_DIR / "gepa_prompt.json"
26
 
27
  _MODEL = os.environ.get("MODEL_NAME", "gpt-4o-mini")
28
 
 
 
 
 
29
  SEED_SYSTEM_PROMPT = """You are a SQL expert. Given a natural language question and a SQLite database schema, write a correct SQL query.
30
 
31
  Rules:
@@ -165,7 +169,7 @@ class GEPAOptimizer:
165
  return max(c.generation for c in self._pareto_front)
166
 
167
  def should_optimize(self) -> bool:
168
- return len(self._history) > 0 and len(self._history) % 4 == 0
169
 
170
  def reset(self) -> None:
171
  self._history.clear()
 
26
 
27
  _MODEL = os.environ.get("MODEL_NAME", "gpt-4o-mini")
28
 
29
+ # How many queries between each GEPA optimization cycle.
30
+ # Override with the GEPA_OPTIMIZE_EVERY environment variable.
31
+ GEPA_OPTIMIZE_EVERY: int = int(os.environ.get("GEPA_OPTIMIZE_EVERY", "4"))
32
+
33
  SEED_SYSTEM_PROMPT = """You are a SQL expert. Given a natural language question and a SQLite database schema, write a correct SQL query.
34
 
35
  Rules:
 
169
  return max(c.generation for c in self._pareto_front)
170
 
171
  def should_optimize(self) -> bool:
172
+ return len(self._history) > 0 and len(self._history) % GEPA_OPTIMIZE_EVERY == 0
173
 
174
  def reset(self) -> None:
175
  self._history.clear()
frontend/src/components/PromptEvolution.tsx CHANGED
@@ -22,17 +22,18 @@ export function PromptEvolution() {
22
  const generation = promptGeneration
23
 
24
  const [queryCount, setQueryCount] = useState(0)
25
- const [nextAt, setNextAt] = useState(4)
 
26
 
27
  const loadHistory = async () => {
28
  setLoading(true)
29
  try {
30
  const data = await fetchPromptHistory()
31
  setPromptData(data)
32
- if ((data as Record<string, unknown>).queryCount !== undefined) {
33
- setQueryCount((data as Record<string, unknown>).queryCount as number)
34
- setNextAt((data as Record<string, unknown>).nextOptimizeAt as number)
35
- }
36
  } catch {
37
  // noop
38
  } finally {
@@ -80,12 +81,14 @@ export function PromptEvolution() {
80
  <div className="flex flex-col gap-1">
81
  <div className="flex items-center justify-between text-[9px] text-gray-600">
82
  <span>{queryCount} queries processed</span>
83
- <span>{nextAt - queryCount} until next optimization</span>
 
 
84
  </div>
85
  <div className="h-1 bg-white/5 rounded-full overflow-hidden">
86
  <div
87
  className="h-full rounded-full bg-violet-500/50 transition-all duration-500"
88
- style={{ width: `${((queryCount % 4) / 4) * 100}%` }}
89
  />
90
  </div>
91
  </div>
 
22
  const generation = promptGeneration
23
 
24
  const [queryCount, setQueryCount] = useState(0)
25
+ const [optimizeEvery, setOptimizeEvery] = useState(4)
26
+ const [cycleProgress, setCycleProgress] = useState(0)
27
 
28
  const loadHistory = async () => {
29
  setLoading(true)
30
  try {
31
  const data = await fetchPromptHistory()
32
  setPromptData(data)
33
+ const d = data as Record<string, unknown>
34
+ if (d.queryCount !== undefined) setQueryCount(d.queryCount as number)
35
+ if (d.optimizeEvery !== undefined) setOptimizeEvery(d.optimizeEvery as number)
36
+ if (d.cycleProgress !== undefined) setCycleProgress(d.cycleProgress as number)
37
  } catch {
38
  // noop
39
  } finally {
 
81
  <div className="flex flex-col gap-1">
82
  <div className="flex items-center justify-between text-[9px] text-gray-600">
83
  <span>{queryCount} queries processed</span>
84
+ <span className="text-gray-700">
85
+ {cycleProgress}/{optimizeEvery} · optimizes every {optimizeEvery}
86
+ </span>
87
  </div>
88
  <div className="h-1 bg-white/5 rounded-full overflow-hidden">
89
  <div
90
  className="h-full rounded-full bg-violet-500/50 transition-all duration-500"
91
+ style={{ width: `${(cycleProgress / optimizeEvery) * 100}%` }}
92
  />
93
  </div>
94
  </div>