Goated121 commited on
Commit
b30d6cf
·
verified ·
1 Parent(s): e3c60ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -47
app.py CHANGED
@@ -1,69 +1,152 @@
1
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
2
- import torch
3
  import gradio as gr
4
- import pickle
5
  import faiss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- # ----------------------------
8
- # Model Setup
9
- # ----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
10
  MODEL_NAME = "Qwen/Qwen3.5-0.8B"
11
 
12
- # Load tokenizer and model
13
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
14
 
15
- # CPU mode
16
  model = AutoModelForCausalLM.from_pretrained(
17
  MODEL_NAME,
18
- device_map="cpu", # force CPU
19
- torch_dtype=torch.float32 # use float32 for CPU
20
  )
21
 
22
- # Text generation pipeline
23
- generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
24
-
25
- # ----------------------------
26
- # Load FAISS Index (optional)
27
- # ----------------------------
28
- try:
29
- index = faiss.read_index("faiss_index.bin")
30
- with open("metadata.pkl", "rb") as f:
31
- metadata = pickle.load(f)
32
- except Exception as e:
33
- print("FAISS index or metadata not found:", e)
34
- index = None
35
- metadata = None
36
-
37
- # ----------------------------
38
- # Chat Function
39
- # ----------------------------
40
- def chat_fn(user_input, chat_history=[]):
41
- # If retrieval is enabled
42
- if index:
43
- query_vector = tokenizer(user_input, return_tensors="pt")["input_ids"].float().mean(dim=1).detach().numpy()
44
- D, I = index.search(query_vector, k=3)
45
- retrieved_texts = [metadata[i] for i in I[0]]
46
- context = " ".join(retrieved_texts)
47
- prompt = f"Context: {context}\nUser: {user_input}\nAI:"
48
  else:
49
- prompt = f"User: {user_input}\nAI:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- # Generate response
52
- output = generator(prompt, max_length=300, do_sample=True, top_p=0.9, temperature=0.7)
53
- response = output[0]["generated_text"].split("AI:")[-1].strip()
54
 
55
- # Update chat history
56
- chat_history.append((user_input, response))
57
- return chat_history, chat_history
58
 
59
- # ----------------------------
 
 
 
 
 
 
 
 
 
60
  # Gradio UI
61
- # ----------------------------
62
  with gr.Blocks() as demo:
 
 
63
  chatbot = gr.Chatbot()
64
- msg = gr.Textbox(label="Your Message")
65
  btn = gr.Button("Send")
66
 
67
- btn.click(chat_fn, [msg, chatbot], [chatbot, chatbot])
68
 
69
  demo.launch()
 
 
 
1
  import gradio as gr
 
2
  import faiss
3
+ import pickle
4
+ import numpy as np
5
+ from sentence_transformers import SentenceTransformer
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
+ import torch
8
+
9
+ # -----------------------------
10
+ # Load embedding model (for RAG)
11
+ # -----------------------------
12
+ embed_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
13
+
14
+ # -----------------------------
15
+ # Load FAISS + data
16
+ # -----------------------------
17
+ index = faiss.read_index("faiss_index.bin")
18
+ chunks = pickle.load(open("chunks.pkl", "rb"))
19
+ metadata = pickle.load(open("metadata.pkl", "rb"))
20
+
21
+ # -----------------------------
22
+ # Intent detection
23
+ # -----------------------------
24
+ def detect_query(query):
25
+ query = query.lower()
26
+ animal = None
27
+ topic = None
28
+
29
+ if "goat" in query:
30
+ animal = "goat"
31
+ elif "cow" in query:
32
+ animal = "cow"
33
+
34
+ if any(word in query for word in ["feed", "diet", "khilana"]):
35
+ topic = "feeding"
36
+ elif any(word in query for word in ["disease", "bimari"]):
37
+ topic = "disease"
38
+
39
+ return animal, topic
40
+
41
+ # -----------------------------
42
+ # Retrieve context (RAG)
43
+ # -----------------------------
44
+ def retrieve_context(query):
45
+ animal, topic = detect_query(query)
46
+
47
+ filtered_indices = []
48
+ for i, meta in enumerate(metadata):
49
+ if animal and meta["animal"] != animal:
50
+ continue
51
+ if topic and meta["topic"] != topic:
52
+ continue
53
+ filtered_indices.append(i)
54
+
55
+ if not filtered_indices:
56
+ filtered_indices = list(range(len(chunks)))
57
+
58
+ query_embedding = embed_model.encode([query])
59
 
60
+ filtered_embeddings = np.array([index.reconstruct(i) for i in filtered_indices])
61
+ distances = np.linalg.norm(filtered_embeddings - query_embedding, axis=1)
62
+
63
+ top_indices = distances.argsort()[:2]
64
+
65
+ context = ""
66
+ for idx in top_indices:
67
+ real_index = filtered_indices[idx]
68
+ context += chunks[real_index] + "\n"
69
+
70
+ return context.strip()
71
+
72
+ # -----------------------------
73
+ # Load Qwen model (CPU SAFE)
74
+ # -----------------------------
75
  MODEL_NAME = "Qwen/Qwen3.5-0.8B"
76
 
 
77
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
78
 
 
79
  model = AutoModelForCausalLM.from_pretrained(
80
  MODEL_NAME,
81
+ torch_dtype=torch.float32 # CPU safe
 
82
  )
83
 
84
+ generator = pipeline(
85
+ "text-generation",
86
+ model=model,
87
+ tokenizer=tokenizer,
88
+ max_new_tokens=150,
89
+ do_sample=True,
90
+ temperature=0.6
91
+ )
92
+
93
+ print("Model loaded successfully!")
94
+
95
+ # -----------------------------
96
+ # Chat function (RAG + LLM)
97
+ # -----------------------------
98
+ def chat_fn(user_input, history):
99
+ if history is None:
100
+ history = []
101
+
102
+ context = retrieve_context(user_input)
103
+
104
+ # 🔥 If no context → strict fallback
105
+ if not context:
106
+ response = "I don't know."
 
 
 
107
  else:
108
+ # 🔥 If context is small → return directly (FAST RAG)
109
+ if len(context.split()) < 50:
110
+ response = context.strip()
111
+ else:
112
+ # 🔥 Use LLM with strict instruction
113
+ prompt = f"""
114
+ You are a livestock expert assistant for goats and cows.
115
+
116
+ Use ONLY the information below to answer.
117
+ If the answer is not present, say "I don't know".
118
+
119
+ Context:
120
+ {context}
121
+
122
+ Question:
123
+ {user_input}
124
 
125
+ Answer in short and clear sentences.
126
+ """
 
127
 
128
+ output = generator(prompt)
129
+ text = output[0]["generated_text"]
 
130
 
131
+ # Clean output
132
+ if prompt.strip() in text:
133
+ text = text.split(prompt.strip())[-1].strip()
134
+
135
+ response = text
136
+
137
+ history.append((user_input, response))
138
+ return history
139
+
140
+ # -----------------------------
141
  # Gradio UI
142
+ # -----------------------------
143
  with gr.Blocks() as demo:
144
+ gr.Markdown("## 🐐 Livestock Chatbot (RAG + Qwen)")
145
+
146
  chatbot = gr.Chatbot()
147
+ msg = gr.Textbox(label="Ask about goats or cows")
148
  btn = gr.Button("Send")
149
 
150
+ btn.click(chat_fn, [msg, chatbot], chatbot)
151
 
152
  demo.launch()