Spaces:
Sleeping
Sleeping
| # ========================================================= | |
| # KB AI Challenge - Professional RAG System (Multilingual) | |
| # ========================================================= | |
| import os | |
| import sys | |
| import numpy as np | |
| import traceback | |
| import fitz # PyMuPDF | |
| from typing import List | |
| # --- ๋ผ์ด๋ธ๋ฌ๋ฆฌ ์ํฌํธ --- | |
| import gradio as gr | |
| import speech_recognition as sr | |
| from dotenv import load_dotenv | |
| # .env ๋ก๋ | |
| load_dotenv() | |
| from deep_translator import GoogleTranslator | |
| from sentence_transformers import SentenceTransformer | |
| from groq import Groq | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.models import Distance, VectorParams, PointStruct | |
| try: | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| except ImportError: | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| # ========================================================= | |
| # 1. ์ค์ ๋ฐ ์ด๊ธฐํ | |
| # ========================================================= | |
| GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "your_groq_api_key_here") | |
| EMBEDDING_MODEL_NAME = "jhgan/ko-sroberta-multitask" | |
| GROQ_MODEL_NAME = "llama-3.3-70b-versatile" | |
| COLLECTION_NAME = "local_kb" | |
| print("๐ ๏ธ ์์คํ ์ด๊ธฐํ ์ค... (System Init)") | |
| # ๋ชจ๋ธ ๋ก๋ | |
| embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME) | |
| embedding_model.max_seq_length = 512 | |
| # Qdrant (๋ฉ๋ชจ๋ฆฌ) | |
| qdrant_client = QdrantClient(":memory:") | |
| try: | |
| qdrant_client.recreate_collection( | |
| collection_name=COLLECTION_NAME, | |
| vectors_config=VectorParams(size=768, distance=Distance.COSINE), | |
| ) | |
| print(f"โ Qdrant Collection Ready.") | |
| except Exception as e: | |
| print(f"โ Qdrant Error: {e}") | |
| # Groq Init | |
| groq_client = None | |
| if GROQ_API_KEY and GROQ_API_KEY != "your_groq_api_key_here": | |
| try: | |
| groq_client = Groq(api_key=GROQ_API_KEY) | |
| except Exception as e: | |
| print(f"โ Groq Error: {e}") | |
| else: | |
| print("โ ๏ธ Groq API Key Missing.") | |
| doc_id_counter = 0 | |
| print("โ System Ready.") | |
| # ========================================================= | |
| # 2. ๋ค๊ตญ์ด ์ง์ ๋ก์ง (Translation & STT) | |
| # ========================================================= | |
| LANG_MAP = { | |
| "ํ๊ตญ์ด (Korean)": {"code": "ko", "stt": "ko-KR"}, | |
| "English (์์ด)": {"code": "en", "stt": "en-US"}, | |
| "ๆฅๆฌ่ช (Japanese)": {"code": "ja", "stt": "ja-JP"}, | |
| "ไธญๆ (Chinese)": {"code": "zh-CN", "stt": "zh-CN"} | |
| } | |
| def translate_text(text, target_lang_code): | |
| try: | |
| if target_lang_code == "ko": return text | |
| return GoogleTranslator(source='auto', target=target_lang_code).translate(text) | |
| except: | |
| return text | |
| def translate_to_korean(text): | |
| try: | |
| return GoogleTranslator(source='auto', target='ko').translate(text) | |
| except: | |
| return text | |
| # ========================================================= | |
| # 3. ํต์ฌ ๋ก์ง (RAG Pipeline) | |
| # ========================================================= | |
| def process_uploaded_files(files): | |
| """PDF ์ฒ๋ฆฌ ๋ฐ ์๋ฒ ๋ฉ""" | |
| global doc_id_counter | |
| if not files: return "ํ์ผ์ด ์ ํ๋์ง ์์์ต๋๋ค." | |
| total_chunks = 0 | |
| status_msg = "" | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50, length_function=len) | |
| for file in files: | |
| try: | |
| file_path = file.name if hasattr(file, 'name') else file | |
| doc = fitz.open(file_path) | |
| file_text = "" | |
| for page in doc: file_text += page.get_text() | |
| if not file_text.strip(): | |
| status_msg += f"โ ๏ธ {os.path.basename(file_path)}: ํ ์คํธ ์์.\n" | |
| continue | |
| chunks = text_splitter.split_text(file_text) | |
| points = [] | |
| for i, chunk in enumerate(chunks): | |
| vector = embedding_model.encode(chunk).tolist() | |
| payload = {"filename": os.path.basename(file_path), "text": chunk} | |
| points.append(PointStruct(id=doc_id_counter, vector=vector, payload=payload)) | |
| doc_id_counter += 1 | |
| if points: | |
| qdrant_client.upsert(collection_name=COLLECTION_NAME, points=points) | |
| total_chunks += len(points) | |
| status_msg += f"โ {os.path.basename(file_path)} ({len(points)} ๊ฐ ์ ์ฅ๋จ)\n" | |
| except Exception as e: | |
| status_msg += f"โ ์ค๋ฅ: {os.path.basename(file_path)} - {str(e)}\n" | |
| return f"์ด {total_chunks}๊ฐ ๋ฐ์ดํฐ ์ฒ๋ฆฌ ์๋ฃ.\n\n{status_msg}" | |
| def search_knowledge_base(query, top_k=5): | |
| try: | |
| query_vector = embedding_model.encode(query).tolist() | |
| res = qdrant_client.query_points( | |
| collection_name=COLLECTION_NAME, query=query_vector, limit=top_k, with_payload=True | |
| ) | |
| return res.points | |
| except: | |
| return [] | |
| def generate_answer_groq(query, context_text): | |
| if not groq_client: return "API ํค๊ฐ ํ์ํฉ๋๋ค." | |
| system_prompt = """ | |
| ๋น์ ์ KB ๊ธ์ต๊ทธ๋ฃน์ ์ ๋ฌธ AI ์ด์์คํดํธ์ ๋๋ค. | |
| ์ ๊ณต๋ [๋ฌธ๋งฅ]์ ๊ธฐ๋ฐํ์ฌ ์ง๋ฌธ์ ๋ํด ์ ํํ๊ณ ์ ๋ฌธ์ ์ธ ๋ต๋ณ์ ์์ฑํ์ธ์. | |
| ๋ชจ๋ฅด๋ ๋ด์ฉ์ ๋ชจ๋ฅธ๋ค๊ณ ๋ตํ๊ณ , ์ถ์ธกํ์ง ๋ง์ธ์. | |
| ๋ต๋ณ์ ํ๊ตญ์ด๋ก ์์ฑํ์ธ์. | |
| """ | |
| user_prompt = f"์ง๋ฌธ: {query}\n\n[๋ฌธ๋งฅ]\n{context_text}" | |
| try: | |
| response = groq_client.chat.completions.create( | |
| messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}], | |
| model=GROQ_MODEL_NAME, temperature=0.1 | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| return f"์๋ต ์์ฑ ์ค๋ฅ: {e}" | |
| def run_rag_chat(message, history, lang_selection): | |
| if not message: return "", history, "" | |
| target_lang = LANG_MAP[lang_selection]["code"] | |
| # 1. ์ ๋ ฅ ๋ฒ์ญ (Target -> Korean) | |
| korean_query = message | |
| if target_lang != "ko": | |
| korean_query = translate_to_korean(message) | |
| # 2. ๊ฒ์ & ๋ต๋ณ ์์ฑ (Korean) | |
| hits = search_knowledge_base(korean_query) | |
| if not hits: | |
| bot_response_ko = "์ฃ์กํฉ๋๋ค. ๊ด๋ จ ์ ๋ณด๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค." | |
| reference_text = "์ฐธ๊ณ ๋ฌธ์ ์์" | |
| else: | |
| context_text = "\n\n".join([h.payload['text'] for h in hits]) | |
| # ์ค๋ณต ์ ๊ฑฐ ๋ฐ ๊ทธ๋ฃนํ (File grouping) | |
| ref_data = {} | |
| for h in hits: | |
| fname = h.payload['filename'] | |
| if fname not in ref_data: | |
| ref_data[fname] = [] | |
| ref_data[fname].append(h.score) | |
| refs = [] | |
| for fname, scores in ref_data.items(): | |
| refs.append(f"- {fname} (๊ด๋ จ ๋ด์ฉ {len(scores)}๊ฑด, ์ต๊ณ ์ ์ฌ๋: {max(scores):.2f})") | |
| reference_text = "\n".join(refs) | |
| bot_response_ko = generate_answer_groq(korean_query, context_text) | |
| # 3. ๋ต๋ณ ๋ฒ์ญ (Korean -> Target) | |
| final_response = bot_response_ko | |
| if target_lang != "ko": | |
| translated_response = translate_text(bot_response_ko, target_lang) | |
| final_response = f"{translated_response}\n\n---\n[ํ๊ตญ์ด ์๋ฌธ]\n{bot_response_ko}" | |
| # ํ์คํ ๋ฆฌ์ ์ถ๊ฐ (Messages Format for Gradio 6.x) | |
| new_history = history + [ | |
| {"role": "user", "content": message}, | |
| {"role": "assistant", "content": final_response} | |
| ] | |
| return "", new_history, reference_text | |
| def voice_to_text_chat(audio, history, lang_selection): | |
| if audio is None: return "", history, "์์ฑ ์ ๋ ฅ ์์" | |
| stt_lang = LANG_MAP[lang_selection]["stt"] | |
| try: | |
| sample_rate, audio_numpy = audio | |
| if audio_numpy.dtype == np.float32: | |
| audio_numpy = (audio_numpy * 32767).astype(np.int16) | |
| if len(audio_numpy.shape) > 1: | |
| audio_numpy = audio_numpy.mean(axis=1).astype(np.int16) | |
| audio_data = sr.AudioData(audio_numpy.tobytes(), sample_rate, 2) | |
| r = sr.Recognizer() | |
| # ์ ํ๋ ์ธ์ด๋ก ์ธ์ | |
| text = r.recognize_google(audio_data, language=stt_lang) | |
| # ์ฑํ ํจ์ ํธ์ถ | |
| return run_rag_chat(text, history, lang_selection) | |
| except sr.UnknownValueError: | |
| return "", history, "์์ฑ์ ์ดํดํ ์ ์์ต๋๋ค." | |
| except Exception as e: | |
| return "", history, f"์ค๋ฅ: {e}" | |
| # ========================================================= | |
| # 4. UI Layout (Clean Professional Korean) | |
| # ========================================================= | |
| theme = gr.themes.Soft( | |
| primary_hue="amber", | |
| neutral_hue="slate", | |
| font=[gr.themes.GoogleFont("Noto Sans KR"), "sans-serif"] | |
| ) | |
| css = """ | |
| footer {visibility: hidden !important;} | |
| .gradio-container {min-height: 0px !important;} | |
| """ | |
| with gr.Blocks(theme=theme, title="KB AI Challenge", css=css) as demo: | |
| with gr.Row(): | |
| # --- LEFT SIDEBAR --- | |
| with gr.Column(scale=1, min_width=300, variant="panel"): | |
| gr.Markdown("## KB AI Challenge") | |
| gr.Markdown("**๋ค๊ตญ์ด ๊ธ์ต AI ์ด์์คํดํธ**") | |
| with gr.Group(): | |
| lang_dropdown = gr.Dropdown( | |
| choices=list(LANG_MAP.keys()), | |
| value="ํ๊ตญ์ด (Korean)", | |
| label="์ธ์ด ์ค์ ", | |
| interactive=True | |
| ) | |
| file_input = gr.File(label="์ง์ ๋ฒ ์ด์ค (PDF)", file_count="multiple", file_types=[".pdf"]) | |
| with gr.Row(): | |
| upload_btn = gr.Button("์ ๋ก๋ ๋ฐ ๋ถ์", variant="primary", size="sm") | |
| upload_status = gr.Textbox(show_label=False, placeholder="์ํ ๋๊ธฐ ์ค...", interactive=False, lines=1, max_lines=1) | |
| gr.Markdown("### ์์ฑ ๋ํ") | |
| audio_input = gr.Audio(sources=["microphone"], type="numpy", label="์์ฑ ์ ๋ ฅ", show_label=False) | |
| with gr.Accordion("์์คํ ์ํคํ ์ฒ", open=False): | |
| gr.Markdown( | |
| """ | |
| **์ต์ ํ ๋ด์ญ** | |
| 1. **STT**: Google Speech API | |
| 2. **๋ฒ์ญ**: Google Translate API | |
| 3. **LLM**: Groq LPU (Llama 3) | |
| """ | |
| ) | |
| # --- RIGHT MAIN --- | |
| with gr.Column(scale=3): | |
| # chatbot (Messages format) | |
| chatbot = gr.Chatbot(label="๋ํ", height=500, show_label=False) | |
| # References | |
| gr.Markdown("**์ฐธ๊ณ ๋ฌธ์**") | |
| ref_output = gr.Textbox(show_label=False, interactive=False, lines=3, max_lines=5, placeholder="๊ด๋ จ ๋ฌธ์๊ฐ ํ์๋ฉ๋๋ค.") | |
| # Input Area | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| scale=6, | |
| show_label=False, | |
| placeholder="์ง๋ฌธ์ ์ ๋ ฅํ์ธ์...", | |
| container=False | |
| ) | |
| submit_btn = gr.Button("์ ์ก", scale=1, variant="primary") | |
| # --- Event Handlers --- | |
| upload_btn.click(process_uploaded_files, inputs=[file_input], outputs=[upload_status]) | |
| msg.submit(run_rag_chat, [msg, chatbot, lang_dropdown], [msg, chatbot, ref_output]) | |
| submit_btn.click(run_rag_chat, [msg, chatbot, lang_dropdown], [msg, chatbot, ref_output]) | |
| audio_input.stop_recording(voice_to_text_chat, [audio_input, chatbot, lang_dropdown], [msg, chatbot, ref_output]) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |