Alexander Sanchez commited on
Commit
c64e5be
Β·
1 Parent(s): b07c8a2

added mixtral-8x7b

Browse files
Files changed (2) hide show
  1. app.py +12 -4
  2. rag_corrector.py +5 -5
app.py CHANGED
@@ -56,15 +56,15 @@ DEMO_EXAMPLES = [
56
 
57
  # ── FunciΓ³n principal ─────────────────────────────────────────────────────────
58
 
59
- def corregir(htr_text: str, top_k: int, mostrar_prompt: bool):
60
  if not htr_text.strip():
61
- return "", "", "", "", "⚠ Introduce un texto HTR para corregir."
62
 
63
  if not os.getenv("OPENAI_API_KEY"):
64
  return "", "", "", "", " Falta OPENAI_API_KEY en el fichero .env"
65
 
66
  try:
67
- result = corrector.correct(htr_text, top_k=int(top_k))
68
  except Exception as e:
69
  return "", "", "", "", f" Error al llamar a la API: {e}"
70
 
@@ -235,6 +235,14 @@ with gr.Blocks(
235
  minimum=1, maximum=10, value=5, step=1,
236
  label="Documents retrieved (k)",
237
  )
 
 
 
 
 
 
 
 
238
  show_prompt = gr.Checkbox(label="Show RAG prompt", value=False)
239
  btn_corregir = gr.Button("✦ Correct with RAG", variant="primary")
240
 
@@ -263,7 +271,7 @@ with gr.Blocks(
263
 
264
  btn_corregir.click(
265
  fn=corregir,
266
- inputs=[htr_input, top_k_slider, show_prompt],
267
  outputs=[corrected_out, docs_out, analysis_out, diff_out, status_out, prompt_out],
268
  )
269
 
 
56
 
57
  # ── FunciΓ³n principal ─────────────────────────────────────────────────────────
58
 
59
+ def corregir(htr_text: str, top_k: int, mostrar_prompt: bool, model: str):
60
  if not htr_text.strip():
61
+ return "", "", "", "", " Introduce un texto HTR para corregir."
62
 
63
  if not os.getenv("OPENAI_API_KEY"):
64
  return "", "", "", "", " Falta OPENAI_API_KEY en el fichero .env"
65
 
66
  try:
67
+ result = corrector.correct(htr_text, top_k=int(top_k), model= model)
68
  except Exception as e:
69
  return "", "", "", "", f" Error al llamar a la API: {e}"
70
 
 
235
  minimum=1, maximum=10, value=5, step=1,
236
  label="Documents retrieved (k)",
237
  )
238
+ model_selector = gr.Dropdown(
239
+ label="Modelo LLM",
240
+ choices=[
241
+ "llama-3.3-70b-versatile",
242
+ "mixtral-8x7b-32768",
243
+ ],
244
+ value="llama-3.3-70b-versatile",
245
+ )
246
  show_prompt = gr.Checkbox(label="Show RAG prompt", value=False)
247
  btn_corregir = gr.Button("✦ Correct with RAG", variant="primary")
248
 
 
271
 
272
  btn_corregir.click(
273
  fn=corregir,
274
+ inputs=[htr_input, top_k_slider, show_prompt, model_selector],
275
  outputs=[corrected_out, docs_out, analysis_out, diff_out, status_out, prompt_out],
276
  )
277
 
rag_corrector.py CHANGED
@@ -53,7 +53,7 @@ class RAGCorrector:
53
  self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_BASE_URL", "https://api.x.ai/v1"),)
54
  # ── API pΓΊblica ──────────────────────────────────────────────────────────
55
 
56
- def correct(self, htr_text: str, top_k: int = TOP_K) -> Dict:
57
  """
58
  Corrige un texto HTR usando RAG.
59
 
@@ -71,7 +71,7 @@ class RAGCorrector:
71
 
72
  prompt = self._build_prompt(htr_text, retrieved, htr_errors, grafia_warns)
73
 
74
- corrected = self._call_llm(prompt)
75
 
76
  return {
77
  "corrected": corrected,
@@ -79,7 +79,7 @@ class RAGCorrector:
79
  "retrieved": retrieved,
80
  "htr_errors": htr_errors,
81
  "grafia_warns": grafia_warns,
82
- "model": MODEL,
83
  }
84
 
85
  # ── DetecciΓ³n de patrones ────────────────────────────────────────────────
@@ -159,9 +159,9 @@ class RAGCorrector:
159
 
160
  # ── Llamada al LLM ───────────────────────────────────────────────────────
161
 
162
- def _call_llm(self, user_prompt: str) -> str:
163
  response = self.client.chat.completions.create(
164
- model=MODEL,
165
  temperature=0.1, # baja temperatura: reproducibilidad
166
  max_tokens=1024,
167
  messages=[
 
53
  self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_BASE_URL", "https://api.x.ai/v1"),)
54
  # ── API pΓΊblica ──────────────────────────────────────────────────────────
55
 
56
+ def correct(self, htr_text: str, top_k: int = TOP_K, model: str = None) -> Dict:
57
  """
58
  Corrige un texto HTR usando RAG.
59
 
 
71
 
72
  prompt = self._build_prompt(htr_text, retrieved, htr_errors, grafia_warns)
73
 
74
+ corrected = self._call_llm(prompt, model=model or MODEL)
75
 
76
  return {
77
  "corrected": corrected,
 
79
  "retrieved": retrieved,
80
  "htr_errors": htr_errors,
81
  "grafia_warns": grafia_warns,
82
+ "model": model or MODEL,
83
  }
84
 
85
  # ── DetecciΓ³n de patrones ────────────────────────────────────────────────
 
159
 
160
  # ── Llamada al LLM ───────────────────────────────────────────────────────
161
 
162
+ def _call_llm(self, user_prompt: str, model: str = MODEL) -> str:
163
  response = self.client.chat.completions.create(
164
+ model=model, # usa el modelo que llega, no el de .env
165
  temperature=0.1, # baja temperatura: reproducibilidad
166
  max_tokens=1024,
167
  messages=[