KB_AI_Challenge / app.py
nneans's picture
Update app.py
f438fbf verified
# =========================================================
# KB AI Challenge - Professional RAG System (Multilingual)
# =========================================================
import os
import sys
import numpy as np
import traceback
import fitz # PyMuPDF
from typing import List
# --- ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์ž„ํฌํŠธ ---
import gradio as gr
import speech_recognition as sr
from dotenv import load_dotenv
# .env ๋กœ๋“œ
load_dotenv()
from deep_translator import GoogleTranslator
from sentence_transformers import SentenceTransformer
from groq import Groq
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct
try:
from langchain.text_splitter import RecursiveCharacterTextSplitter
except ImportError:
from langchain_text_splitters import RecursiveCharacterTextSplitter
# =========================================================
# 1. ์„ค์ • ๋ฐ ์ดˆ๊ธฐํ™”
# =========================================================
GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "your_groq_api_key_here")
EMBEDDING_MODEL_NAME = "jhgan/ko-sroberta-multitask"
GROQ_MODEL_NAME = "llama-3.3-70b-versatile"
COLLECTION_NAME = "local_kb"
print("๐Ÿ› ๏ธ ์‹œ์Šคํ…œ ์ดˆ๊ธฐํ™” ์ค‘... (System Init)")
# ๋ชจ๋ธ ๋กœ๋“œ
embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
embedding_model.max_seq_length = 512
# Qdrant (๋ฉ”๋ชจ๋ฆฌ)
qdrant_client = QdrantClient(":memory:")
try:
qdrant_client.recreate_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(size=768, distance=Distance.COSINE),
)
print(f"โœ… Qdrant Collection Ready.")
except Exception as e:
print(f"โŒ Qdrant Error: {e}")
# Groq Init
groq_client = None
if GROQ_API_KEY and GROQ_API_KEY != "your_groq_api_key_here":
try:
groq_client = Groq(api_key=GROQ_API_KEY)
except Exception as e:
print(f"โŒ Groq Error: {e}")
else:
print("โš ๏ธ Groq API Key Missing.")
doc_id_counter = 0
print("โœ… System Ready.")
# =========================================================
# 2. ๋‹ค๊ตญ์–ด ์ง€์› ๋กœ์ง (Translation & STT)
# =========================================================
LANG_MAP = {
"ํ•œ๊ตญ์–ด (Korean)": {"code": "ko", "stt": "ko-KR"},
"English (์˜์–ด)": {"code": "en", "stt": "en-US"},
"ๆ—ฅๆœฌ่ชž (Japanese)": {"code": "ja", "stt": "ja-JP"},
"ไธญๆ–‡ (Chinese)": {"code": "zh-CN", "stt": "zh-CN"}
}
def translate_text(text, target_lang_code):
try:
if target_lang_code == "ko": return text
return GoogleTranslator(source='auto', target=target_lang_code).translate(text)
except:
return text
def translate_to_korean(text):
try:
return GoogleTranslator(source='auto', target='ko').translate(text)
except:
return text
# =========================================================
# 3. ํ•ต์‹ฌ ๋กœ์ง (RAG Pipeline)
# =========================================================
def process_uploaded_files(files):
"""PDF ์ฒ˜๋ฆฌ ๋ฐ ์ž„๋ฒ ๋”ฉ"""
global doc_id_counter
if not files: return "ํŒŒ์ผ์ด ์„ ํƒ๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค."
total_chunks = 0
status_msg = ""
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50, length_function=len)
for file in files:
try:
file_path = file.name if hasattr(file, 'name') else file
doc = fitz.open(file_path)
file_text = ""
for page in doc: file_text += page.get_text()
if not file_text.strip():
status_msg += f"โš ๏ธ {os.path.basename(file_path)}: ํ…์ŠคํŠธ ์—†์Œ.\n"
continue
chunks = text_splitter.split_text(file_text)
points = []
for i, chunk in enumerate(chunks):
vector = embedding_model.encode(chunk).tolist()
payload = {"filename": os.path.basename(file_path), "text": chunk}
points.append(PointStruct(id=doc_id_counter, vector=vector, payload=payload))
doc_id_counter += 1
if points:
qdrant_client.upsert(collection_name=COLLECTION_NAME, points=points)
total_chunks += len(points)
status_msg += f"โœ… {os.path.basename(file_path)} ({len(points)} ๊ฐœ ์ €์žฅ๋จ)\n"
except Exception as e:
status_msg += f"โŒ ์˜ค๋ฅ˜: {os.path.basename(file_path)} - {str(e)}\n"
return f"์ด {total_chunks}๊ฐœ ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ ์™„๋ฃŒ.\n\n{status_msg}"
def search_knowledge_base(query, top_k=5):
try:
query_vector = embedding_model.encode(query).tolist()
res = qdrant_client.query_points(
collection_name=COLLECTION_NAME, query=query_vector, limit=top_k, with_payload=True
)
return res.points
except:
return []
def generate_answer_groq(query, context_text):
if not groq_client: return "API ํ‚ค๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค."
system_prompt = """
๋‹น์‹ ์€ KB ๊ธˆ์œต๊ทธ๋ฃน์˜ ์ „๋ฌธ AI ์–ด์‹œ์Šคํ„ดํŠธ์ž…๋‹ˆ๋‹ค.
์ œ๊ณต๋œ [๋ฌธ๋งฅ]์— ๊ธฐ๋ฐ˜ํ•˜์—ฌ ์งˆ๋ฌธ์— ๋Œ€ํ•ด ์ •ํ™•ํ•˜๊ณ  ์ „๋ฌธ์ ์ธ ๋‹ต๋ณ€์„ ์ž‘์„ฑํ•˜์„ธ์š”.
๋ชจ๋ฅด๋Š” ๋‚ด์šฉ์€ ๋ชจ๋ฅธ๋‹ค๊ณ  ๋‹ตํ•˜๊ณ , ์ถ”์ธกํ•˜์ง€ ๋งˆ์„ธ์š”.
๋‹ต๋ณ€์€ ํ•œ๊ตญ์–ด๋กœ ์ž‘์„ฑํ•˜์„ธ์š”.
"""
user_prompt = f"์งˆ๋ฌธ: {query}\n\n[๋ฌธ๋งฅ]\n{context_text}"
try:
response = groq_client.chat.completions.create(
messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}],
model=GROQ_MODEL_NAME, temperature=0.1
)
return response.choices[0].message.content
except Exception as e:
return f"์‘๋‹ต ์ƒ์„ฑ ์˜ค๋ฅ˜: {e}"
def run_rag_chat(message, history, lang_selection):
if not message: return "", history, ""
target_lang = LANG_MAP[lang_selection]["code"]
# 1. ์ž…๋ ฅ ๋ฒˆ์—ญ (Target -> Korean)
korean_query = message
if target_lang != "ko":
korean_query = translate_to_korean(message)
# 2. ๊ฒ€์ƒ‰ & ๋‹ต๋ณ€ ์ƒ์„ฑ (Korean)
hits = search_knowledge_base(korean_query)
if not hits:
bot_response_ko = "์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ๊ด€๋ จ ์ •๋ณด๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
reference_text = "์ฐธ๊ณ  ๋ฌธ์„œ ์—†์Œ"
else:
context_text = "\n\n".join([h.payload['text'] for h in hits])
# ์ค‘๋ณต ์ œ๊ฑฐ ๋ฐ ๊ทธ๋ฃนํ™” (File grouping)
ref_data = {}
for h in hits:
fname = h.payload['filename']
if fname not in ref_data:
ref_data[fname] = []
ref_data[fname].append(h.score)
refs = []
for fname, scores in ref_data.items():
refs.append(f"- {fname} (๊ด€๋ จ ๋‚ด์šฉ {len(scores)}๊ฑด, ์ตœ๊ณ  ์œ ์‚ฌ๋„: {max(scores):.2f})")
reference_text = "\n".join(refs)
bot_response_ko = generate_answer_groq(korean_query, context_text)
# 3. ๋‹ต๋ณ€ ๋ฒˆ์—ญ (Korean -> Target)
final_response = bot_response_ko
if target_lang != "ko":
translated_response = translate_text(bot_response_ko, target_lang)
final_response = f"{translated_response}\n\n---\n[ํ•œ๊ตญ์–ด ์›๋ฌธ]\n{bot_response_ko}"
# ํžˆ์Šคํ† ๋ฆฌ์— ์ถ”๊ฐ€ (Messages Format for Gradio 6.x)
new_history = history + [
{"role": "user", "content": message},
{"role": "assistant", "content": final_response}
]
return "", new_history, reference_text
def voice_to_text_chat(audio, history, lang_selection):
if audio is None: return "", history, "์Œ์„ฑ ์ž…๋ ฅ ์—†์Œ"
stt_lang = LANG_MAP[lang_selection]["stt"]
try:
sample_rate, audio_numpy = audio
if audio_numpy.dtype == np.float32:
audio_numpy = (audio_numpy * 32767).astype(np.int16)
if len(audio_numpy.shape) > 1:
audio_numpy = audio_numpy.mean(axis=1).astype(np.int16)
audio_data = sr.AudioData(audio_numpy.tobytes(), sample_rate, 2)
r = sr.Recognizer()
# ์„ ํƒ๋œ ์–ธ์–ด๋กœ ์ธ์‹
text = r.recognize_google(audio_data, language=stt_lang)
# ์ฑ„ํŒ… ํ•จ์ˆ˜ ํ˜ธ์ถœ
return run_rag_chat(text, history, lang_selection)
except sr.UnknownValueError:
return "", history, "์Œ์„ฑ์„ ์ดํ•ดํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
except Exception as e:
return "", history, f"์˜ค๋ฅ˜: {e}"
# =========================================================
# 4. UI Layout (Clean Professional Korean)
# =========================================================
theme = gr.themes.Soft(
primary_hue="amber",
neutral_hue="slate",
font=[gr.themes.GoogleFont("Noto Sans KR"), "sans-serif"]
)
css = """
footer {visibility: hidden !important;}
.gradio-container {min-height: 0px !important;}
"""
with gr.Blocks(theme=theme, title="KB AI Challenge", css=css) as demo:
with gr.Row():
# --- LEFT SIDEBAR ---
with gr.Column(scale=1, min_width=300, variant="panel"):
gr.Markdown("## KB AI Challenge")
gr.Markdown("**๋‹ค๊ตญ์–ด ๊ธˆ์œต AI ์–ด์‹œ์Šคํ„ดํŠธ**")
with gr.Group():
lang_dropdown = gr.Dropdown(
choices=list(LANG_MAP.keys()),
value="ํ•œ๊ตญ์–ด (Korean)",
label="์–ธ์–ด ์„ค์ •",
interactive=True
)
file_input = gr.File(label="์ง€์‹ ๋ฒ ์ด์Šค (PDF)", file_count="multiple", file_types=[".pdf"])
with gr.Row():
upload_btn = gr.Button("์—…๋กœ๋“œ ๋ฐ ๋ถ„์„", variant="primary", size="sm")
upload_status = gr.Textbox(show_label=False, placeholder="์ƒํƒœ ๋Œ€๊ธฐ ์ค‘...", interactive=False, lines=1, max_lines=1)
gr.Markdown("### ์Œ์„ฑ ๋Œ€ํ™”")
audio_input = gr.Audio(sources=["microphone"], type="numpy", label="์Œ์„ฑ ์ž…๋ ฅ", show_label=False)
with gr.Accordion("์‹œ์Šคํ…œ ์•„ํ‚คํ…์ฒ˜", open=False):
gr.Markdown(
"""
**์ตœ์ ํ™” ๋‚ด์—ญ**
1. **STT**: Google Speech API
2. **๋ฒˆ์—ญ**: Google Translate API
3. **LLM**: Groq LPU (Llama 3)
"""
)
# --- RIGHT MAIN ---
with gr.Column(scale=3):
# chatbot (Messages format)
chatbot = gr.Chatbot(label="๋Œ€ํ™”", height=500, show_label=False)
# References
gr.Markdown("**์ฐธ๊ณ  ๋ฌธ์„œ**")
ref_output = gr.Textbox(show_label=False, interactive=False, lines=3, max_lines=5, placeholder="๊ด€๋ จ ๋ฌธ์„œ๊ฐ€ ํ‘œ์‹œ๋ฉ๋‹ˆ๋‹ค.")
# Input Area
with gr.Row():
msg = gr.Textbox(
scale=6,
show_label=False,
placeholder="์งˆ๋ฌธ์„ ์ž…๋ ฅํ•˜์„ธ์š”...",
container=False
)
submit_btn = gr.Button("์ „์†ก", scale=1, variant="primary")
# --- Event Handlers ---
upload_btn.click(process_uploaded_files, inputs=[file_input], outputs=[upload_status])
msg.submit(run_rag_chat, [msg, chatbot, lang_dropdown], [msg, chatbot, ref_output])
submit_btn.click(run_rag_chat, [msg, chatbot, lang_dropdown], [msg, chatbot, ref_output])
audio_input.stop_recording(voice_to_text_chat, [audio_input, chatbot, lang_dropdown], [msg, chatbot, ref_output])
if __name__ == "__main__":
demo.launch(share=True)