import os import re import gradio as gr from huggingface_hub import InferenceClient HF_TOKEN = os.environ.get("HF_TOKEN") client = InferenceClient( model="Qwen/Qwen3-Coder-30B-A3B-Instruct:ovhcloud", token=HF_TOKEN, ) # ───────────────────────────────────────────── # SYSTEM PROMPT (strict, few-shot) # ───────────────────────────────────────────── SYSTEM_PROMPT = """\ You are a strict SQL code generator for DuckDB. YOUR ONLY JOB is to output a single, valid DuckDB SQL query. ABSOLUTE OUTPUT RULES — violating any rule makes the output wrong: 1. Output ONLY raw SQL. No markdown, no code fences, no backticks, no explanations. 2. Never prefix with "sql", "SQL:", "Here is", or any natural language. 3. Never output anything after the semicolon. 4. If the question cannot be answered from the schema, output exactly: NOT_A_DATA_QUESTION 5. NOT_A_DATA_QUESTION also applies to: greetings, general knowledge, math unrelated to the schema, anything not about querying the provided tables. SQL RULES: - Use ONLY table and column names that appear in the schema — never invent names. - Use DuckDB syntax exclusively. Never use SQLite or MySQL syntax. - Text matching: always use ILIKE '%term%'. Never use LOWER() or UPPER() for comparison. - Prefer the fewest JOINs and subqueries needed to answer the question. - Never use SELECT * — always name the columns you need. - Age filters: use a numeric comparison on the age column directly (e.g. age > 50). - Counts: use COUNT(*) or COUNT(column). Alias it clearly, e.g. AS num_patients. - Date arithmetic: NEVER use julianday(). Use datediff('day', start_col, end_col) for days between two timestamps. Use epoch(end_col - start_col) / 86400 for interval-to-days. - Identifier quoting: wrap table and column names in double quotes if they start with a digit or contain special characters (e.g. "2b_concept", "my-column"). - String concatenation: use || operator, never CONCAT(). - Current date/time: use current_date or current_timestamp, never NOW(). FEW-SHOT EXAMPLES: Schema: CREATE TABLE patients (patient_id INT, age INT, gender VARCHAR, diagnosis VARCHAR, died BOOLEAN); CREATE TABLE admissions (subject_id INT, admittime TIMESTAMP, dischtime TIMESTAMP, admission_type VARCHAR); Q: How many patients above 50 have asthma? A: SELECT COUNT(*) AS num_patients FROM patients WHERE age > 50 AND diagnosis ILIKE '%asthma%'; Q: Show me all patients who died during their hospital stay. A: SELECT patient_id, age, gender, diagnosis FROM patients WHERE died = true; Q: What is the average age of female patients? A: SELECT AVG(age) AS avg_age FROM patients WHERE gender ILIKE '%female%'; Q: Who are the top 10 patients with the longest hospital stay? A: SELECT a.subject_id, datediff('day', a.admittime, a.dischtime) AS stay_days FROM admissions a WHERE a.dischtime IS NOT NULL ORDER BY stay_days DESC LIMIT 10; Q: Hello, how are you? A: NOT_A_DATA_QUESTION Q: What is the capital of France? A: NOT_A_DATA_QUESTION Now answer the user's question using ONLY the schema they provide.""" # ───────────────────────────────────────────── # HELPERS # ───────────────────────────────────────────── VALID_SQL_STARTS = ("SELECT", "WITH", "INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER") def clean_sql(raw: str) -> str: """Remove markdown fences, leading 'sql' keyword, thinking tags, and extra whitespace.""" sql = raw.strip() # Strip ... blocks (some models emit these) sql = re.sub(r".*?", "", sql, flags=re.DOTALL) # Strip markdown code fences sql = re.sub(r"^```[a-zA-Z]*\n?", "", sql) sql = re.sub(r"```$", "", sql) # Strip leading "sql" keyword sql = re.sub(r"(?i)^sql\s+", "", sql) # Strip any trailing text after the semicolon semi_match = re.search(r";", sql) if semi_match: sql = sql[: semi_match.end()] return sql.strip() def validate_sql(sql: str) -> str: """ Light sanity check on the generated SQL. Returns the SQL unchanged if it looks valid, or an error string. """ upper = sql.upper().strip() if upper == "NOT_A_DATA_QUESTION": return ( "⚠️ That question doesn't appear to be about the database. " "Try asking something that can be answered by querying the schema." ) if not upper.startswith(VALID_SQL_STARTS): return ( f"⚠️ The model returned an unexpected response instead of SQL:\n\n{sql}\n\n" "Try rephrasing your question to be more specific about the data." ) return sql # looks good # ───────────────────────────────────────────── # MAIN GENERATOR # ───────────────────────────────────────────── def generate_sql(question: str, schema_ddl: str): if not question.strip(): yield "⚠️ Please enter a question." return if not schema_ddl.strip(): yield "⚠️ Please provide your schema DDL." return prompt = ( f"Database schema:\n{schema_ddl.strip()}\n\n" f"Question: {question.strip()}\n\n" "SQL:" ) accumulated = "" try: for token in client.chat_completion( messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": prompt}, ], max_tokens=500, temperature=0.0, stream=True, ): chunk = token.choices[0].delta.content or "" accumulated += chunk yield accumulated # stream raw while typing except Exception as e: yield f"❌ Error calling model: {e}" return # Final: clean then validate final = validate_sql(clean_sql(accumulated)) yield final # ───────────────────────────────────────────── # GRADIO UI # ───────────────────────────────────────────── demo = gr.Interface( fn=generate_sql, inputs=[ gr.Textbox( label="Question", placeholder="Show me how many patients aged above 50 have asthma", ), gr.Textbox( label="Schema DDL", lines=10, placeholder="CREATE TABLE patients (...)", ), ], outputs=gr.Textbox(label="Generated SQL"), title="TinyEHR Text-to-SQL", description="Generate SQL queries for the TinyEHR dataset from natural language.", flagging_mode="never", ) demo.launch()