Spaces:
Running
Running
| import os | |
| import re | |
| import gradio as gr | |
| from huggingface_hub import InferenceClient | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| client = InferenceClient( | |
| model="Qwen/Qwen3-Coder-30B-A3B-Instruct:ovhcloud", | |
| token=HF_TOKEN, | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # SYSTEM PROMPT (strict, few-shot) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| SYSTEM_PROMPT = """\ | |
| You are a strict SQL code generator for DuckDB. | |
| YOUR ONLY JOB is to output a single, valid DuckDB SQL query. | |
| ABSOLUTE OUTPUT RULES β violating any rule makes the output wrong: | |
| 1. Output ONLY raw SQL. No markdown, no code fences, no backticks, no explanations. | |
| 2. Never prefix with "sql", "SQL:", "Here is", or any natural language. | |
| 3. Never output anything after the semicolon. | |
| 4. If the question cannot be answered from the schema, output exactly: NOT_A_DATA_QUESTION | |
| 5. NOT_A_DATA_QUESTION also applies to: greetings, general knowledge, math unrelated to the schema, anything not about querying the provided tables. | |
| SQL RULES: | |
| - Use ONLY table and column names that appear in the schema β never invent names. | |
| - Use DuckDB syntax exclusively. Never use SQLite or MySQL syntax. | |
| - Text matching: always use ILIKE '%term%'. Never use LOWER() or UPPER() for comparison. | |
| - Prefer the fewest JOINs and subqueries needed to answer the question. | |
| - Never use SELECT * β always name the columns you need. | |
| - Age filters: use a numeric comparison on the age column directly (e.g. age > 50). | |
| - Counts: use COUNT(*) or COUNT(column). Alias it clearly, e.g. AS num_patients. | |
| - Date arithmetic: NEVER use julianday(). Use datediff('day', start_col, end_col) for days between two timestamps. Use epoch(end_col - start_col) / 86400 for interval-to-days. | |
| - Identifier quoting: wrap table and column names in double quotes if they start with a digit or contain special characters (e.g. "2b_concept", "my-column"). | |
| - String concatenation: use || operator, never CONCAT(). | |
| - Current date/time: use current_date or current_timestamp, never NOW(). | |
| FEW-SHOT EXAMPLES: | |
| Schema: | |
| CREATE TABLE patients (patient_id INT, age INT, gender VARCHAR, diagnosis VARCHAR, died BOOLEAN); | |
| CREATE TABLE admissions (subject_id INT, admittime TIMESTAMP, dischtime TIMESTAMP, admission_type VARCHAR); | |
| Q: How many patients above 50 have asthma? | |
| A: SELECT COUNT(*) AS num_patients FROM patients WHERE age > 50 AND diagnosis ILIKE '%asthma%'; | |
| Q: Show me all patients who died during their hospital stay. | |
| A: SELECT patient_id, age, gender, diagnosis FROM patients WHERE died = true; | |
| Q: What is the average age of female patients? | |
| A: SELECT AVG(age) AS avg_age FROM patients WHERE gender ILIKE '%female%'; | |
| Q: Who are the top 10 patients with the longest hospital stay? | |
| A: SELECT a.subject_id, datediff('day', a.admittime, a.dischtime) AS stay_days FROM admissions a WHERE a.dischtime IS NOT NULL ORDER BY stay_days DESC LIMIT 10; | |
| Q: Hello, how are you? | |
| A: NOT_A_DATA_QUESTION | |
| Q: What is the capital of France? | |
| A: NOT_A_DATA_QUESTION | |
| Now answer the user's question using ONLY the schema they provide.""" | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # HELPERS | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| VALID_SQL_STARTS = ("SELECT", "WITH", "INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER") | |
| def clean_sql(raw: str) -> str: | |
| """Remove markdown fences, leading 'sql' keyword, thinking tags, and extra whitespace.""" | |
| sql = raw.strip() | |
| # Strip <think>...</think> blocks (some models emit these) | |
| sql = re.sub(r"<think>.*?</think>", "", sql, flags=re.DOTALL) | |
| # Strip markdown code fences | |
| sql = re.sub(r"^```[a-zA-Z]*\n?", "", sql) | |
| sql = re.sub(r"```$", "", sql) | |
| # Strip leading "sql" keyword | |
| sql = re.sub(r"(?i)^sql\s+", "", sql) | |
| # Strip any trailing text after the semicolon | |
| semi_match = re.search(r";", sql) | |
| if semi_match: | |
| sql = sql[: semi_match.end()] | |
| return sql.strip() | |
| def validate_sql(sql: str) -> str: | |
| """ | |
| Light sanity check on the generated SQL. | |
| Returns the SQL unchanged if it looks valid, or an error string. | |
| """ | |
| upper = sql.upper().strip() | |
| if upper == "NOT_A_DATA_QUESTION": | |
| return ( | |
| "β οΈ That question doesn't appear to be about the database. " | |
| "Try asking something that can be answered by querying the schema." | |
| ) | |
| if not upper.startswith(VALID_SQL_STARTS): | |
| return ( | |
| f"β οΈ The model returned an unexpected response instead of SQL:\n\n{sql}\n\n" | |
| "Try rephrasing your question to be more specific about the data." | |
| ) | |
| return sql # looks good | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # MAIN GENERATOR | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def generate_sql(question: str, schema_ddl: str): | |
| if not question.strip(): | |
| yield "β οΈ Please enter a question." | |
| return | |
| if not schema_ddl.strip(): | |
| yield "β οΈ Please provide your schema DDL." | |
| return | |
| prompt = ( | |
| f"Database schema:\n{schema_ddl.strip()}\n\n" | |
| f"Question: {question.strip()}\n\n" | |
| "SQL:" | |
| ) | |
| accumulated = "" | |
| try: | |
| for token in client.chat_completion( | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| max_tokens=500, | |
| temperature=0.0, | |
| stream=True, | |
| ): | |
| chunk = token.choices[0].delta.content or "" | |
| accumulated += chunk | |
| yield accumulated # stream raw while typing | |
| except Exception as e: | |
| yield f"β Error calling model: {e}" | |
| return | |
| # Final: clean then validate | |
| final = validate_sql(clean_sql(accumulated)) | |
| yield final | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # GRADIO UI | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| demo = gr.Interface( | |
| fn=generate_sql, | |
| inputs=[ | |
| gr.Textbox( | |
| label="Question", | |
| placeholder="Show me how many patients aged above 50 have asthma", | |
| ), | |
| gr.Textbox( | |
| label="Schema DDL", | |
| lines=10, | |
| placeholder="CREATE TABLE patients (...)", | |
| ), | |
| ], | |
| outputs=gr.Textbox(label="Generated SQL"), | |
| title="TinyEHR Text-to-SQL", | |
| description="Generate SQL queries for the TinyEHR dataset from natural language.", | |
| flagging_mode="never", | |
| ) | |
| demo.launch() |