Goated121 commited on
Commit
91b4de2
·
verified ·
1 Parent(s): 81026e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -57
app.py CHANGED
@@ -5,15 +5,15 @@ import numpy as np
5
  from sentence_transformers import SentenceTransformer
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
  import torch
 
 
 
8
 
9
  # -----------------------------
10
- # Load embedding model (for RAG)
11
  # -----------------------------
12
  embed_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
13
 
14
- # -----------------------------
15
- # Load FAISS + data
16
- # -----------------------------
17
  index = faiss.read_index("faiss_index.bin")
18
  chunks = pickle.load(open("chunks.pkl", "rb"))
19
  metadata = pickle.load(open("metadata.pkl", "rb"))
@@ -56,10 +56,8 @@ def retrieve_context(query):
56
  filtered_indices = list(range(len(chunks)))
57
 
58
  query_embedding = embed_model.encode([query])
59
-
60
  filtered_embeddings = np.array([index.reconstruct(i) for i in filtered_indices])
61
  distances = np.linalg.norm(filtered_embeddings - query_embedding, axis=1)
62
-
63
  top_indices = distances.argsort()[:2]
64
 
65
  context = ""
@@ -70,52 +68,49 @@ def retrieve_context(query):
70
  return context.strip()
71
 
72
  # -----------------------------
73
- # Load Qwen model (CPU SAFE)
74
  # -----------------------------
75
- MODEL_NAME = "Qwen/Qwen3.5-0.8B"
 
 
76
 
77
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
78
 
79
  model = AutoModelForCausalLM.from_pretrained(
80
- MODEL_NAME,
81
- torch_dtype=torch.float32 # CPU safe
82
  )
83
 
84
  generator = pipeline(
85
  "text-generation",
86
  model=model,
87
  tokenizer=tokenizer,
88
- max_new_tokens=150,
89
  do_sample=True,
90
- temperature=0.6
 
91
  )
92
 
93
- print("Model loaded successfully!")
94
 
95
  # -----------------------------
96
- # Chat function (RAG + LLM)
97
  # -----------------------------
98
- def chat_fn(user_input, history):
99
- if history is None:
100
- history = []
101
-
102
  context = retrieve_context(user_input)
103
 
104
- # No context
105
- if not context:
106
- response = "I don't know."
107
 
108
- else:
109
- # Small context → direct RAG
110
- if len(context.split()) < 50:
111
- response = context.strip()
112
 
113
- else:
114
- prompt = f"""
115
- You are a livestock expert assistant for goats and cows.
116
 
117
  Use ONLY the information below to answer.
118
- If the answer is not present, say "I don't know".
119
 
120
  Context:
121
  {context}
@@ -126,36 +121,22 @@ Question:
126
  Answer in short and clear sentences.
127
  """
128
 
129
- output = generator(
130
- prompt,
131
- max_new_tokens=120, # ✅ remove max_length warning
132
- do_sample=True,
133
- temperature=0.6
134
- )
135
-
136
- text = output[0]["generated_text"]
137
-
138
- if prompt.strip() in text:
139
- text = text.split(prompt.strip())[-1].strip()
140
-
141
- response = text
142
 
143
- # FIXED FORMAT (IMPORTANT)
144
- history.append({"role": "user", "content": user_input})
145
- history.append({"role": "assistant", "content": response})
146
 
147
- return history
148
 
149
  # -----------------------------
150
  # Gradio UI
151
  # -----------------------------
152
- with gr.Blocks() as demo:
153
- gr.Markdown("## 🐐 Livestock Chatbot (RAG + Qwen)")
154
-
155
- chatbot = gr.Chatbot(type="messages") # ✅ REQUIRED
156
- msg = gr.Textbox()
157
- btn = gr.Button("Send")
158
-
159
- btn.click(chat_fn, [msg, chatbot], chatbot)
160
-
161
- demo.launch()
 
5
  from sentence_transformers import SentenceTransformer
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
  import torch
8
+ import os
9
+
10
+ print("Files in current directory:", os.listdir())
11
 
12
  # -----------------------------
13
+ # Load RAG components
14
  # -----------------------------
15
  embed_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
16
 
 
 
 
17
  index = faiss.read_index("faiss_index.bin")
18
  chunks = pickle.load(open("chunks.pkl", "rb"))
19
  metadata = pickle.load(open("metadata.pkl", "rb"))
 
56
  filtered_indices = list(range(len(chunks)))
57
 
58
  query_embedding = embed_model.encode([query])
 
59
  filtered_embeddings = np.array([index.reconstruct(i) for i in filtered_indices])
60
  distances = np.linalg.norm(filtered_embeddings - query_embedding, axis=1)
 
61
  top_indices = distances.argsort()[:2]
62
 
63
  context = ""
 
68
  return context.strip()
69
 
70
  # -----------------------------
71
+ # Load FAST model (CPU friendly)
72
  # -----------------------------
73
+ model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
74
+
75
+ print("Loading fast model...")
76
 
77
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
78
 
79
  model = AutoModelForCausalLM.from_pretrained(
80
+ model_name,
81
+ torch_dtype=torch.float32
82
  )
83
 
84
  generator = pipeline(
85
  "text-generation",
86
  model=model,
87
  tokenizer=tokenizer,
88
+ max_new_tokens=120,
89
  do_sample=True,
90
+ temperature=0.6,
91
+ device=-1 # CPU
92
  )
93
 
94
+ print("Fast LLM loaded successfully!")
95
 
96
  # -----------------------------
97
+ # Chat function
98
  # -----------------------------
99
+ def chat(user_input):
 
 
 
100
  context = retrieve_context(user_input)
101
 
102
+ # Instant response if context is already short
103
+ if context and len(context.split()) < 50:
104
+ return context.strip()
105
 
106
+ if not context:
107
+ return "I don't know."
 
 
108
 
109
+ prompt = f"""
110
+ You are a livestock expert assistant for goat and cows.
 
111
 
112
  Use ONLY the information below to answer.
113
+ If answer is not present, say "I don't know".
114
 
115
  Context:
116
  {context}
 
121
  Answer in short and clear sentences.
122
  """
123
 
124
+ response = generator(prompt)
125
+ text = response[0]["generated_text"]
 
 
 
 
 
 
 
 
 
 
 
126
 
127
+ # Remove prompt if repeated
128
+ if prompt.strip() in text:
129
+ text = text.split(prompt.strip())[-1].strip()
130
 
131
+ return text
132
 
133
  # -----------------------------
134
  # Gradio UI
135
  # -----------------------------
136
+ gr.Interface(
137
+ fn=chat,
138
+ inputs="text",
139
+ outputs="text",
140
+ title="Livestock Chatbot (RAG + Fast LLM)",
141
+ description="Fast chatbot using RAG + TinyLlama (optimized for CPU)"
142
+ ).launch()