ChatBot / app.py
Goated121's picture
Update app.py
d35835a verified
import gradio as gr
import faiss
import pickle
import numpy as np
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
import os
print("Files in current directory:", os.listdir())
# -----------------------------
# Load RAG components
# -----------------------------
embed_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
index = faiss.read_index("faiss_index.bin")
chunks = pickle.load(open("chunks.pkl", "rb"))
metadata = pickle.load(open("metadata.pkl", "rb"))
# -----------------------------
# Intent detection
# -----------------------------
def detect_query(query):
query = query.lower()
animal = None
topic = None
if "goat" in query:
animal = "goat"
elif "cow" in query:
animal = "cow"
if any(word in query for word in ["feed", "diet", "khilana"]):
topic = "feeding"
elif any(word in query for word in ["disease", "bimari"]):
topic = "disease"
return animal, topic
# -----------------------------
# Retrieve context (RAG)
# -----------------------------
def retrieve_context(query):
animal, topic = detect_query(query)
filtered_indices = []
for i, meta in enumerate(metadata):
if animal and meta["animal"] != animal:
continue
if topic and meta["topic"] != topic:
continue
filtered_indices.append(i)
if not filtered_indices:
filtered_indices = list(range(len(chunks)))
query_embedding = embed_model.encode([query])
filtered_embeddings = np.array([index.reconstruct(i) for i in filtered_indices])
distances = np.linalg.norm(filtered_embeddings - query_embedding, axis=1)
top_indices = distances.argsort()[:2]
context = ""
for idx in top_indices:
real_index = filtered_indices[idx]
context += chunks[real_index] + "\n"
return context.strip()
# -----------------------------
# Load Qwen model (CPU only, no accelerate)
# -----------------------------
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32 # CPU only
)
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=150,
do_sample=True,
temperature=0.6,
device=-1 # ensures CPU is used
)
print("LLM loaded successfully!")
# -----------------------------
# Chat function
# -----------------------------
def chat(user_input):
context = retrieve_context(user_input)
if not context:
return "I don't know."
prompt = f"""
You are a livestock expert assistant for goat and cows.
Use ONLY the information below to answer.
If answer is not present, say "I don't know".
Context:
{context}
Question:
{user_input}
Answer in short and clear sentences.
"""
response = generator(prompt, max_new_tokens=150, do_sample=True, temperature=0.6)
text = response[0]["generated_text"]
# Remove prompt if repeated
if prompt.strip() in text:
text = text.split(prompt.strip())[-1].strip()
return text
# -----------------------------
# Gradio UI
# -----------------------------
gr.Interface(
fn=chat,
inputs="text",
outputs="text",
title="Livestock Chatbot (RAG + Qwen)",
description="This chatbot answers livestock questions using RAG retrieval and Qwen model generation."
).launch()