saadkhi commited on
Commit
cc1250f
Β·
verified Β·
1 Parent(s): 119ad27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -30
app.py CHANGED
@@ -5,19 +5,17 @@ import torch
5
  import gradio as gr
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
 
 
8
  torch.set_num_threads(1)
9
 
10
- # BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
11
- BASE_MODEL = "tiiuae/falcon-rw-1b"
12
 
13
  print("Loading model...")
14
 
15
- model = AutoModelForCausalLM.from_pretrained(
16
- BASE_MODEL,
17
- torch_dtype=torch.float32
18
- )
19
-
20
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
 
 
21
  model.eval()
22
 
23
  print("Model ready")
@@ -36,52 +34,44 @@ def is_sql_related(text):
36
  return any(k in text for k in SQL_KEYWORDS)
37
 
38
  # ─────────────────────────
39
- # GENERATION
40
  # ─────────────────────────
41
  SYSTEM_PROMPT = """
42
  You are an expert SQL generator.
43
-
44
  Rules:
45
  - Only respond to SQL or database related questions.
46
- - If the question is not about SQL or databases, refuse.
47
  - Output ONLY SQL query.
48
- - Do not explain.
49
  """
50
 
 
 
 
51
  def generate_sql(user_input):
52
 
53
  if not user_input.strip():
54
  return "Enter SQL question."
55
 
56
- # HARD GUARD
57
  if not is_sql_related(user_input):
58
- return "I only respond to SQL and database related questions. If you want, I can craft helpful database queries for you."
59
 
60
- prompt = f"""
61
- {SYSTEM_PROMPT}
62
-
63
- User request: {user_input}
64
- SQL:
65
- """
66
 
67
  inputs = tokenizer(prompt, return_tensors="pt")
68
 
69
  with torch.no_grad():
70
  output = model.generate(
71
  **inputs,
72
- max_new_tokens=120,
73
- temperature=0.1,
74
- do_sample=False,
75
  pad_token_id=tokenizer.eos_token_id,
76
  )
77
 
78
  text = tokenizer.decode(output[0], skip_special_tokens=True)
79
 
80
- # return only SQL part
81
  result = text.split("SQL:")[-1].strip()
82
-
83
- # extra safety: remove explanations
84
- result = result.split("\n\n")[0]
85
 
86
  return result
87
 
@@ -96,17 +86,17 @@ demo = gr.Interface(
96
  placeholder="Find duplicate emails in users table"
97
  ),
98
  outputs=gr.Textbox(
99
- lines=8,
100
  label="Generated SQL"
101
  ),
102
  title="AI SQL Generator (Portfolio Project)",
103
- description="This model ONLY responds to SQL/database queries.",
104
  examples=[
105
  ["Find duplicate emails in users table"],
106
  ["Top 5 highest paid employees"],
107
  ["Count orders per customer last month"],
108
- ["Write a joke about cats"] # will be blocked
109
  ],
110
  )
111
 
112
- demo.launch(server_name="0.0.0.0")
 
5
  import gradio as gr
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
 
8
+ # Reduce CPU pressure
9
  torch.set_num_threads(1)
10
 
11
+ # βœ… Use lightweight model (IMPORTANT)
12
+ BASE_MODEL = "distilgpt2"
13
 
14
  print("Loading model...")
15
 
 
 
 
 
 
16
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
17
+ model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)
18
+
19
  model.eval()
20
 
21
  print("Model ready")
 
34
  return any(k in text for k in SQL_KEYWORDS)
35
 
36
  # ─────────────────────────
37
+ # PROMPT
38
  # ─────────────────────────
39
  SYSTEM_PROMPT = """
40
  You are an expert SQL generator.
 
41
  Rules:
42
  - Only respond to SQL or database related questions.
 
43
  - Output ONLY SQL query.
44
+ - No explanation.
45
  """
46
 
47
+ # ─────────────────────────
48
+ # GENERATION
49
+ # ─────────────────────────
50
  def generate_sql(user_input):
51
 
52
  if not user_input.strip():
53
  return "Enter SQL question."
54
 
 
55
  if not is_sql_related(user_input):
56
+ return "Only SQL/database questions are allowed."
57
 
58
+ prompt = f"{SYSTEM_PROMPT}\nUser: {user_input}\nSQL:"
 
 
 
 
 
59
 
60
  inputs = tokenizer(prompt, return_tensors="pt")
61
 
62
  with torch.no_grad():
63
  output = model.generate(
64
  **inputs,
65
+ max_new_tokens=80,
66
+ temperature=0.2,
67
+ do_sample=True,
68
  pad_token_id=tokenizer.eos_token_id,
69
  )
70
 
71
  text = tokenizer.decode(output[0], skip_special_tokens=True)
72
 
 
73
  result = text.split("SQL:")[-1].strip()
74
+ result = result.split("\n")[0]
 
 
75
 
76
  return result
77
 
 
86
  placeholder="Find duplicate emails in users table"
87
  ),
88
  outputs=gr.Textbox(
89
+ lines=6,
90
  label="Generated SQL"
91
  ),
92
  title="AI SQL Generator (Portfolio Project)",
93
+ description="Only SQL/database queries are supported.",
94
  examples=[
95
  ["Find duplicate emails in users table"],
96
  ["Top 5 highest paid employees"],
97
  ["Count orders per customer last month"],
98
+ ["Write a joke about cats"]
99
  ],
100
  )
101
 
102
+ demo.launch(server_name="0.0.0.0", server_port=7860)