saadkhi commited on
Commit
5d261f7
Β·
verified Β·
1 Parent(s): ac4a697

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -21
app.py CHANGED
@@ -6,12 +6,14 @@ from fastapi import FastAPI
6
  from pydantic import BaseModel
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
8
 
9
- # Optimize CPU
 
 
10
  torch.set_num_threads(1)
11
 
12
- app = FastAPI(title="SQL Generator API")
13
 
14
- BASE_MODEL = "distilgpt2"
15
 
16
  print("Loading model...")
17
 
@@ -22,12 +24,6 @@ model.eval()
22
 
23
  print("Model ready")
24
 
25
- # ─────────────────────────
26
- # Request Schema
27
- # ─────────────────────────
28
- class Query(BaseModel):
29
- text: str
30
-
31
  # ─────────────────────────
32
  # SQL FILTER
33
  # ─────────────────────────
@@ -40,9 +36,6 @@ SQL_KEYWORDS = [
40
  def is_sql_related(text):
41
  return any(k in text.lower() for k in SQL_KEYWORDS)
42
 
43
- # ─────────────────────────
44
- # Generator
45
- # ─────────────────────────
46
  SYSTEM_PROMPT = """
47
  You are an expert SQL generator.
48
  Only output SQL query.
@@ -50,10 +43,10 @@ Only output SQL query.
50
 
51
  def generate_sql(user_input: str):
52
  if not user_input.strip():
53
- return "Empty input."
54
 
55
  if not is_sql_related(user_input):
56
- return "Only SQL-related queries allowed."
57
 
58
  prompt = f"{SYSTEM_PROMPT}\nUser: {user_input}\nSQL:"
59
 
@@ -62,9 +55,9 @@ def generate_sql(user_input: str):
62
  with torch.no_grad():
63
  output = model.generate(
64
  **inputs,
65
- max_new_tokens=80,
66
- temperature=0.2,
67
- do_sample=True,
68
  pad_token_id=tokenizer.eos_token_id,
69
  )
70
 
@@ -74,13 +67,30 @@ def generate_sql(user_input: str):
74
  return result
75
 
76
  # ─────────────────────────
77
- # Routes
78
  # ─────────────────────────
 
 
 
79
  @app.get("/")
80
  def root():
81
- return {"status": "API is running"}
82
 
83
  @app.post("/generate")
84
  def generate(query: Query):
85
- result = generate_sql(query.text)
86
- return {"result": result}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from pydantic import BaseModel
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
8
 
9
+ import gradio as gr
10
+ import threading
11
+
12
  torch.set_num_threads(1)
13
 
14
+ app = FastAPI()
15
 
16
+ BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
17
 
18
  print("Loading model...")
19
 
 
24
 
25
  print("Model ready")
26
 
 
 
 
 
 
 
27
  # ─────────────────────────
28
  # SQL FILTER
29
  # ─────────────────────────
 
36
  def is_sql_related(text):
37
  return any(k in text.lower() for k in SQL_KEYWORDS)
38
 
 
 
 
39
  SYSTEM_PROMPT = """
40
  You are an expert SQL generator.
41
  Only output SQL query.
 
43
 
44
  def generate_sql(user_input: str):
45
  if not user_input.strip():
46
+ return "Enter SQL question."
47
 
48
  if not is_sql_related(user_input):
49
+ return "Only SQL/database questions allowed."
50
 
51
  prompt = f"{SYSTEM_PROMPT}\nUser: {user_input}\nSQL:"
52
 
 
55
  with torch.no_grad():
56
  output = model.generate(
57
  **inputs,
58
+ max_new_tokens=120,
59
+ temperature=0.1,
60
+ do_sample=False,
61
  pad_token_id=tokenizer.eos_token_id,
62
  )
63
 
 
67
  return result
68
 
69
  # ─────────────────────────
70
+ # FastAPI Routes
71
  # ─────────────────────────
72
+ class Query(BaseModel):
73
+ text: str
74
+
75
  @app.get("/")
76
  def root():
77
+ return {"status": "API running"}
78
 
79
  @app.post("/generate")
80
  def generate(query: Query):
81
+ return {"result": generate_sql(query.text)}
82
+
83
+ # ─────────────────────────
84
+ # Gradio UI (for testing)
85
+ # ─────────────────────────
86
+ def launch_gradio():
87
+ demo = gr.Interface(
88
+ fn=generate_sql,
89
+ inputs=gr.Textbox(lines=3, label="SQL Question"),
90
+ outputs=gr.Textbox(lines=6, label="Generated SQL"),
91
+ title="SQL Generator Test UI"
92
+ )
93
+ demo.launch(server_name="0.0.0.0", server_port=7861)
94
+
95
+ # Run Gradio in parallel thread
96
+ threading.Thread(target=launch_gradio).start()