tinyehr-sql / app.py
vidulpanickan's picture
Update app.py
4d71586 verified
import os
import re
import gradio as gr
from huggingface_hub import InferenceClient
HF_TOKEN = os.environ.get("HF_TOKEN")
client = InferenceClient(
model="Qwen/Qwen3-Coder-30B-A3B-Instruct:ovhcloud",
token=HF_TOKEN,
)
# ─────────────────────────────────────────────
# SYSTEM PROMPT (strict, few-shot)
# ─────────────────────────────────────────────
SYSTEM_PROMPT = """\
You are a strict SQL code generator for DuckDB.
YOUR ONLY JOB is to output a single, valid DuckDB SQL query.
ABSOLUTE OUTPUT RULES β€” violating any rule makes the output wrong:
1. Output ONLY raw SQL. No markdown, no code fences, no backticks, no explanations.
2. Never prefix with "sql", "SQL:", "Here is", or any natural language.
3. Never output anything after the semicolon.
4. If the question cannot be answered from the schema, output exactly: NOT_A_DATA_QUESTION
5. NOT_A_DATA_QUESTION also applies to: greetings, general knowledge, math unrelated to the schema, anything not about querying the provided tables.
SQL RULES:
- Use ONLY table and column names that appear in the schema β€” never invent names.
- Use DuckDB syntax exclusively. Never use SQLite or MySQL syntax.
- Text matching: always use ILIKE '%term%'. Never use LOWER() or UPPER() for comparison.
- Prefer the fewest JOINs and subqueries needed to answer the question.
- Never use SELECT * β€” always name the columns you need.
- Age filters: use a numeric comparison on the age column directly (e.g. age > 50).
- Counts: use COUNT(*) or COUNT(column). Alias it clearly, e.g. AS num_patients.
- Date arithmetic: NEVER use julianday(). Use datediff('day', start_col, end_col) for days between two timestamps. Use epoch(end_col - start_col) / 86400 for interval-to-days.
- Identifier quoting: wrap table and column names in double quotes if they start with a digit or contain special characters (e.g. "2b_concept", "my-column").
- String concatenation: use || operator, never CONCAT().
- Current date/time: use current_date or current_timestamp, never NOW().
FEW-SHOT EXAMPLES:
Schema:
CREATE TABLE patients (patient_id INT, age INT, gender VARCHAR, diagnosis VARCHAR, died BOOLEAN);
CREATE TABLE admissions (subject_id INT, admittime TIMESTAMP, dischtime TIMESTAMP, admission_type VARCHAR);
Q: How many patients above 50 have asthma?
A: SELECT COUNT(*) AS num_patients FROM patients WHERE age > 50 AND diagnosis ILIKE '%asthma%';
Q: Show me all patients who died during their hospital stay.
A: SELECT patient_id, age, gender, diagnosis FROM patients WHERE died = true;
Q: What is the average age of female patients?
A: SELECT AVG(age) AS avg_age FROM patients WHERE gender ILIKE '%female%';
Q: Who are the top 10 patients with the longest hospital stay?
A: SELECT a.subject_id, datediff('day', a.admittime, a.dischtime) AS stay_days FROM admissions a WHERE a.dischtime IS NOT NULL ORDER BY stay_days DESC LIMIT 10;
Q: Hello, how are you?
A: NOT_A_DATA_QUESTION
Q: What is the capital of France?
A: NOT_A_DATA_QUESTION
Now answer the user's question using ONLY the schema they provide."""
# ─────────────────────────────────────────────
# HELPERS
# ─────────────────────────────────────────────
VALID_SQL_STARTS = ("SELECT", "WITH", "INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER")
def clean_sql(raw: str) -> str:
"""Remove markdown fences, leading 'sql' keyword, thinking tags, and extra whitespace."""
sql = raw.strip()
# Strip <think>...</think> blocks (some models emit these)
sql = re.sub(r"<think>.*?</think>", "", sql, flags=re.DOTALL)
# Strip markdown code fences
sql = re.sub(r"^```[a-zA-Z]*\n?", "", sql)
sql = re.sub(r"```$", "", sql)
# Strip leading "sql" keyword
sql = re.sub(r"(?i)^sql\s+", "", sql)
# Strip any trailing text after the semicolon
semi_match = re.search(r";", sql)
if semi_match:
sql = sql[: semi_match.end()]
return sql.strip()
def validate_sql(sql: str) -> str:
"""
Light sanity check on the generated SQL.
Returns the SQL unchanged if it looks valid, or an error string.
"""
upper = sql.upper().strip()
if upper == "NOT_A_DATA_QUESTION":
return (
"⚠️ That question doesn't appear to be about the database. "
"Try asking something that can be answered by querying the schema."
)
if not upper.startswith(VALID_SQL_STARTS):
return (
f"⚠️ The model returned an unexpected response instead of SQL:\n\n{sql}\n\n"
"Try rephrasing your question to be more specific about the data."
)
return sql # looks good
# ─────────────────────────────────────────────
# MAIN GENERATOR
# ─────────────────────────────────────────────
def generate_sql(question: str, schema_ddl: str):
if not question.strip():
yield "⚠️ Please enter a question."
return
if not schema_ddl.strip():
yield "⚠️ Please provide your schema DDL."
return
prompt = (
f"Database schema:\n{schema_ddl.strip()}\n\n"
f"Question: {question.strip()}\n\n"
"SQL:"
)
accumulated = ""
try:
for token in client.chat_completion(
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
],
max_tokens=500,
temperature=0.0,
stream=True,
):
chunk = token.choices[0].delta.content or ""
accumulated += chunk
yield accumulated # stream raw while typing
except Exception as e:
yield f"❌ Error calling model: {e}"
return
# Final: clean then validate
final = validate_sql(clean_sql(accumulated))
yield final
# ─────────────────────────────────────────────
# GRADIO UI
# ─────────────────────────────────────────────
demo = gr.Interface(
fn=generate_sql,
inputs=[
gr.Textbox(
label="Question",
placeholder="Show me how many patients aged above 50 have asthma",
),
gr.Textbox(
label="Schema DDL",
lines=10,
placeholder="CREATE TABLE patients (...)",
),
],
outputs=gr.Textbox(label="Generated SQL"),
title="TinyEHR Text-to-SQL",
description="Generate SQL queries for the TinyEHR dataset from natural language.",
flagging_mode="never",
)
demo.launch()