Spaces:
Sleeping
Sleeping
Alexander Sanchez commited on
Commit Β·
c64e5be
1
Parent(s): b07c8a2
added mixtral-8x7b
Browse files- app.py +12 -4
- rag_corrector.py +5 -5
app.py
CHANGED
|
@@ -56,15 +56,15 @@ DEMO_EXAMPLES = [
|
|
| 56 |
|
| 57 |
# ββ FunciΓ³n principal βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 58 |
|
| 59 |
-
def corregir(htr_text: str, top_k: int, mostrar_prompt: bool):
|
| 60 |
if not htr_text.strip():
|
| 61 |
-
return "", "", "", "", "
|
| 62 |
|
| 63 |
if not os.getenv("OPENAI_API_KEY"):
|
| 64 |
return "", "", "", "", " Falta OPENAI_API_KEY en el fichero .env"
|
| 65 |
|
| 66 |
try:
|
| 67 |
-
result = corrector.correct(htr_text, top_k=int(top_k))
|
| 68 |
except Exception as e:
|
| 69 |
return "", "", "", "", f" Error al llamar a la API: {e}"
|
| 70 |
|
|
@@ -235,6 +235,14 @@ with gr.Blocks(
|
|
| 235 |
minimum=1, maximum=10, value=5, step=1,
|
| 236 |
label="Documents retrieved (k)",
|
| 237 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
show_prompt = gr.Checkbox(label="Show RAG prompt", value=False)
|
| 239 |
btn_corregir = gr.Button("β¦ Correct with RAG", variant="primary")
|
| 240 |
|
|
@@ -263,7 +271,7 @@ with gr.Blocks(
|
|
| 263 |
|
| 264 |
btn_corregir.click(
|
| 265 |
fn=corregir,
|
| 266 |
-
inputs=[htr_input, top_k_slider, show_prompt],
|
| 267 |
outputs=[corrected_out, docs_out, analysis_out, diff_out, status_out, prompt_out],
|
| 268 |
)
|
| 269 |
|
|
|
|
| 56 |
|
| 57 |
# ββ FunciΓ³n principal βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 58 |
|
| 59 |
+
def corregir(htr_text: str, top_k: int, mostrar_prompt: bool, model: str):
|
| 60 |
if not htr_text.strip():
|
| 61 |
+
return "", "", "", "", " Introduce un texto HTR para corregir."
|
| 62 |
|
| 63 |
if not os.getenv("OPENAI_API_KEY"):
|
| 64 |
return "", "", "", "", " Falta OPENAI_API_KEY en el fichero .env"
|
| 65 |
|
| 66 |
try:
|
| 67 |
+
result = corrector.correct(htr_text, top_k=int(top_k), model= model)
|
| 68 |
except Exception as e:
|
| 69 |
return "", "", "", "", f" Error al llamar a la API: {e}"
|
| 70 |
|
|
|
|
| 235 |
minimum=1, maximum=10, value=5, step=1,
|
| 236 |
label="Documents retrieved (k)",
|
| 237 |
)
|
| 238 |
+
model_selector = gr.Dropdown(
|
| 239 |
+
label="Modelo LLM",
|
| 240 |
+
choices=[
|
| 241 |
+
"llama-3.3-70b-versatile",
|
| 242 |
+
"mixtral-8x7b-32768",
|
| 243 |
+
],
|
| 244 |
+
value="llama-3.3-70b-versatile",
|
| 245 |
+
)
|
| 246 |
show_prompt = gr.Checkbox(label="Show RAG prompt", value=False)
|
| 247 |
btn_corregir = gr.Button("β¦ Correct with RAG", variant="primary")
|
| 248 |
|
|
|
|
| 271 |
|
| 272 |
btn_corregir.click(
|
| 273 |
fn=corregir,
|
| 274 |
+
inputs=[htr_input, top_k_slider, show_prompt, model_selector],
|
| 275 |
outputs=[corrected_out, docs_out, analysis_out, diff_out, status_out, prompt_out],
|
| 276 |
)
|
| 277 |
|
rag_corrector.py
CHANGED
|
@@ -53,7 +53,7 @@ class RAGCorrector:
|
|
| 53 |
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_BASE_URL", "https://api.x.ai/v1"),)
|
| 54 |
# ββ API pΓΊblica ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 55 |
|
| 56 |
-
def correct(self, htr_text: str, top_k: int = TOP_K) -> Dict:
|
| 57 |
"""
|
| 58 |
Corrige un texto HTR usando RAG.
|
| 59 |
|
|
@@ -71,7 +71,7 @@ class RAGCorrector:
|
|
| 71 |
|
| 72 |
prompt = self._build_prompt(htr_text, retrieved, htr_errors, grafia_warns)
|
| 73 |
|
| 74 |
-
corrected = self._call_llm(prompt)
|
| 75 |
|
| 76 |
return {
|
| 77 |
"corrected": corrected,
|
|
@@ -79,7 +79,7 @@ class RAGCorrector:
|
|
| 79 |
"retrieved": retrieved,
|
| 80 |
"htr_errors": htr_errors,
|
| 81 |
"grafia_warns": grafia_warns,
|
| 82 |
-
"model": MODEL,
|
| 83 |
}
|
| 84 |
|
| 85 |
# ββ DetecciΓ³n de patrones ββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -159,9 +159,9 @@ class RAGCorrector:
|
|
| 159 |
|
| 160 |
# ββ Llamada al LLM βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 161 |
|
| 162 |
-
def _call_llm(self, user_prompt: str) -> str:
|
| 163 |
response = self.client.chat.completions.create(
|
| 164 |
-
model=
|
| 165 |
temperature=0.1, # baja temperatura: reproducibilidad
|
| 166 |
max_tokens=1024,
|
| 167 |
messages=[
|
|
|
|
| 53 |
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_BASE_URL", "https://api.x.ai/v1"),)
|
| 54 |
# ββ API pΓΊblica ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 55 |
|
| 56 |
+
def correct(self, htr_text: str, top_k: int = TOP_K, model: str = None) -> Dict:
|
| 57 |
"""
|
| 58 |
Corrige un texto HTR usando RAG.
|
| 59 |
|
|
|
|
| 71 |
|
| 72 |
prompt = self._build_prompt(htr_text, retrieved, htr_errors, grafia_warns)
|
| 73 |
|
| 74 |
+
corrected = self._call_llm(prompt, model=model or MODEL)
|
| 75 |
|
| 76 |
return {
|
| 77 |
"corrected": corrected,
|
|
|
|
| 79 |
"retrieved": retrieved,
|
| 80 |
"htr_errors": htr_errors,
|
| 81 |
"grafia_warns": grafia_warns,
|
| 82 |
+
"model": model or MODEL,
|
| 83 |
}
|
| 84 |
|
| 85 |
# ββ DetecciΓ³n de patrones ββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 159 |
|
| 160 |
# ββ Llamada al LLM βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 161 |
|
| 162 |
+
def _call_llm(self, user_prompt: str, model: str = MODEL) -> str:
|
| 163 |
response = self.client.chat.completions.create(
|
| 164 |
+
model=model, # usa el modelo que llega, no el de .env
|
| 165 |
temperature=0.1, # baja temperatura: reproducibilidad
|
| 166 |
max_tokens=1024,
|
| 167 |
messages=[
|