Goated121 commited on
Commit
02bf677
·
verified ·
1 Parent(s): ab9c37e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -121
app.py CHANGED
@@ -1,95 +1,28 @@
1
- import gradio as gr
2
- import faiss
3
  import pickle
4
- import numpy as np
5
- from sentence_transformers import SentenceTransformer
6
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
  import torch
8
- import os
9
- from huggingface_hub import login
10
-
11
- # -----------------------------
12
- # Login with HF_TOKEN (secret)
13
- # -----------------------------
14
- HF_TOKEN = os.environ["HF_TOKEN"] # Must be set in Space secrets
15
- login(HF_TOKEN)
16
-
17
- print("Files in current directory:", os.listdir())
18
-
19
- # -----------------------------
20
- # Load RAG components
21
- # -----------------------------
22
- print("Loading embedding model...")
23
- embed_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
24
-
25
- print("Loading FAISS index and metadata...")
26
- index = faiss.read_index("faiss_index.bin")
27
- chunks = pickle.load(open("chunks.pkl", "rb"))
28
- metadata = pickle.load(open("metadata.pkl", "rb"))
29
-
30
- # -----------------------------
31
- # Intent detection
32
- # -----------------------------
33
- def detect_query(query):
34
- query = query.lower()
35
- animal = None
36
- topic = None
37
-
38
- if "goat" in query:
39
- animal = "goat"
40
- elif "cow" in query:
41
- animal = "cow"
42
-
43
- if any(word in query for word in ["feed", "diet", "khilana"]):
44
- topic = "feeding"
45
- elif any(word in query for word in ["disease", "bimari"]):
46
- topic = "disease"
47
-
48
- return animal, topic
49
-
50
- # -----------------------------
51
- # Retrieve context (RAG)
52
- # -----------------------------
53
- def retrieve_context(query):
54
- animal, topic = detect_query(query)
55
-
56
- filtered_indices = []
57
- for i, meta in enumerate(metadata):
58
- if animal and meta["animal"] != animal:
59
- continue
60
- if topic and meta["topic"] != topic:
61
- continue
62
- filtered_indices.append(i)
63
-
64
- if not filtered_indices:
65
- filtered_indices = list(range(len(chunks)))
66
 
67
- query_embedding = embed_model.encode([query])
68
- filtered_embeddings = np.array([index.reconstruct(i) for i in filtered_indices])
69
- distances = np.linalg.norm(filtered_embeddings - query_embedding, axis=1)
70
- top_indices = distances.argsort()[:2]
71
 
72
- context = ""
73
- for idx in top_indices:
74
- real_index = filtered_indices[idx]
75
- context += chunks[real_index] + "\n"
76
 
77
- return context.strip()
78
 
79
- # -----------------------------
80
- # Load Qwen3.5‑0.8B‑Base
81
- # -----------------------------
82
- model_name = "Qwen/Qwen3.5-0.8B-Base"
83
- print(f"Loading model {model_name} (may take a while on CPU)...")
84
 
85
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=HF_TOKEN)
86
  model = AutoModelForCausalLM.from_pretrained(
87
- model_name,
88
- torch_dtype=torch.float32,
89
- device_map={"": "cpu"},
90
  use_auth_token=HF_TOKEN
91
  )
92
 
 
93
  generator = pipeline(
94
  "text-generation",
95
  model=model,
@@ -97,45 +30,54 @@ generator = pipeline(
97
  max_new_tokens=150,
98
  do_sample=True,
99
  temperature=0.7,
100
- device=-1 # CPU only
101
  )
102
 
103
- print("LLM loaded successfully!")
104
-
105
- # -----------------------------
106
- # Chat function
107
- # -----------------------------
108
- def chat(user_input):
109
- context = retrieve_context(user_input)
110
- if not context:
111
- return "I don't know."
112
-
113
- prompt = f"""
114
- You are a livestock expert assistant for goats and cows.
115
- Use ONLY the information below to answer.
116
- If answer is not present, say "I don't know".
117
-
118
- Context:
119
- {context}
120
-
121
- Question:
122
- {user_input}
123
-
124
- Answer in short and clear sentences.
125
- """
126
- response = generator(prompt)
127
- text = response[0]["generated_text"]
128
- if prompt.strip() in text:
129
- text = text.split(prompt.strip())[-1].strip()
130
- return text
131
-
132
- # -----------------------------
133
- # Gradio UI
134
- # -----------------------------
135
- gr.Interface(
136
- fn=chat,
137
- inputs="text",
138
- outputs="text",
139
- title="Livestock Chatbot (RAG + Qwen3.5‑0.8B‑Base)",
140
- description="Answers livestock questions using RAG retrieval and the Qwen3.5‑0.8B base model."
141
- ).launch()
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
  import pickle
 
 
 
4
  import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
6
+ import faiss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ # --- Load HF token from Space secrets ---
9
+ HF_TOKEN = os.environ["HF_TOKEN"]
 
 
10
 
11
+ # --- Model configuration ---
12
+ MODEL_NAME = "Qwen/Qwen3.5-0.8B-Base"
 
 
13
 
14
+ print(f"Loading model {MODEL_NAME} on CPU...")
15
 
16
+ # --- Load tokenizer ---
17
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN)
 
 
 
18
 
19
+ # --- Load model (CPU only) ---
20
  model = AutoModelForCausalLM.from_pretrained(
21
+ MODEL_NAME,
 
 
22
  use_auth_token=HF_TOKEN
23
  )
24
 
25
+ # --- Setup text-generation pipeline ---
26
  generator = pipeline(
27
  "text-generation",
28
  model=model,
 
30
  max_new_tokens=150,
31
  do_sample=True,
32
  temperature=0.7,
33
+ device=-1 # CPU
34
  )
35
 
36
+ print("Model loaded successfully!")
37
+
38
+ # --- Load FAISS index and metadata ---
39
+ if os.path.exists("faiss_index.bin") and os.path.exists("metadata.pkl"):
40
+ print("Loading FAISS index and metadata...")
41
+ index = faiss.read_index("faiss_index.bin")
42
+ with open("metadata.pkl", "rb") as f:
43
+ metadata = pickle.load(f)
44
+ print("FAISS index loaded.")
45
+ else:
46
+ print("FAISS index or metadata not found. Make sure you uploaded faiss_index.bin and metadata.pkl")
47
+ index = None
48
+ metadata = None
49
+
50
+ # --- Embeddings model for query ---
51
+ from sentence_transformers import SentenceTransformer
52
+ embed_model = SentenceTransformer("paraphrase-multilingual-MiniLM-L12-v2")
53
+
54
+ # --- RAG retrieval function ---
55
+ def retrieve_docs(query, top_k=3):
56
+ if index is None or metadata is None:
57
+ return []
58
+ q_embed = embed_model.encode([query])
59
+ distances, idxs = index.search(q_embed, top_k)
60
+ docs = [metadata[i] for i in idxs[0]]
61
+ return docs
62
+
63
+ # --- Chatbot function ---
64
+ def chat(query):
65
+ # Retrieve relevant docs
66
+ retrieved_docs = retrieve_docs(query)
67
+ context = "\n".join(retrieved_docs) if retrieved_docs else ""
68
+
69
+ # Combine context with user query
70
+ prompt = f"Context:\n{context}\n\nQuestion: {query}\nAnswer:"
71
+
72
+ # Generate response
73
+ output = generator(prompt)
74
+ return output[0]["generated_text"]
75
+
76
+ # --- Example usage ---
77
+ if __name__ == "__main__":
78
+ while True:
79
+ query = input("You: ")
80
+ if query.lower() in ["exit", "quit"]:
81
+ break
82
+ answer = chat(query)
83
+ print("Bot:", answer)