File size: 3,495 Bytes
48a449e
02bf677
b30d6cf
 
 
 
 
91b4de2
 
 
b30d6cf
 
91b4de2
b30d6cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e540d02
b30d6cf
d35835a
625b503
d35835a
ce0c959
91b4de2
d35835a
ce0c959
625127c
b30d6cf
 
 
 
e540d02
b30d6cf
d35835a
 
b30d6cf
 
e540d02
b30d6cf
 
91b4de2
b30d6cf
91b4de2
b30d6cf
 
91b4de2
 
81026e1
91b4de2
 
b30d6cf
 
91b4de2
b30d6cf
 
 
 
 
 
02bf677
b30d6cf
 
e540d02
91b4de2
b30d6cf
91b4de2
 
 
81026e1
91b4de2
b30d6cf
 
48a449e
b30d6cf
91b4de2
 
 
 
e540d02
 
91b4de2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import gradio as gr
import faiss
import pickle
import numpy as np
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
import os

print("Files in current directory:", os.listdir())

# -----------------------------
# Load RAG components
# -----------------------------
embed_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')

index = faiss.read_index("faiss_index.bin")
chunks = pickle.load(open("chunks.pkl", "rb"))
metadata = pickle.load(open("metadata.pkl", "rb"))

# -----------------------------
# Intent detection
# -----------------------------
def detect_query(query):
    query = query.lower()
    animal = None
    topic = None

    if "goat" in query:
        animal = "goat"
    elif "cow" in query:
        animal = "cow"

    if any(word in query for word in ["feed", "diet", "khilana"]):
        topic = "feeding"
    elif any(word in query for word in ["disease", "bimari"]):
        topic = "disease"

    return animal, topic

# -----------------------------
# Retrieve context (RAG)
# -----------------------------
def retrieve_context(query):
    animal, topic = detect_query(query)

    filtered_indices = []
    for i, meta in enumerate(metadata):
        if animal and meta["animal"] != animal:
            continue
        if topic and meta["topic"] != topic:
            continue
        filtered_indices.append(i)

    if not filtered_indices:
        filtered_indices = list(range(len(chunks)))

    query_embedding = embed_model.encode([query])
    filtered_embeddings = np.array([index.reconstruct(i) for i in filtered_indices])
    distances = np.linalg.norm(filtered_embeddings - query_embedding, axis=1)
    top_indices = distances.argsort()[:2]

    context = ""
    for idx in top_indices:
        real_index = filtered_indices[idx]
        context += chunks[real_index] + "\n"

    return context.strip()

# -----------------------------
# Load Qwen model (CPU only, no accelerate)
# -----------------------------
model_name = "Qwen/Qwen2.5-1.5B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float32  # CPU only
)

generator = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=150,
    do_sample=True,
    temperature=0.6,
    device=-1  # ensures CPU is used
)

print("LLM loaded successfully!")

# -----------------------------
# Chat function
# -----------------------------
def chat(user_input):
    context = retrieve_context(user_input)

    if not context:
        return "I don't know."

    prompt = f"""
You are a livestock expert assistant for goat and cows.

Use ONLY the information below to answer.
If answer is not present, say "I don't know".

Context:
{context}

Question:
{user_input}

Answer in short and clear sentences.
"""
    response = generator(prompt, max_new_tokens=150, do_sample=True, temperature=0.6)
    text = response[0]["generated_text"]

    # Remove prompt if repeated
    if prompt.strip() in text:
        text = text.split(prompt.strip())[-1].strip()

    return text

# -----------------------------
# Gradio UI
# -----------------------------
gr.Interface(
    fn=chat,
    inputs="text",
    outputs="text",
    title="Livestock Chatbot (RAG + Qwen)",
    description="This chatbot answers livestock questions using RAG retrieval and Qwen model generation."
).launch()