nneans commited on
Commit
b819706
ยท
verified ยท
1 Parent(s): 7f7971a

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +298 -0
  2. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =========================================================
2
+ # KB ๊ธˆ์œต RAG ์ฑ—๋ด‡ (Local Self-Contained Version)
3
+ # =========================================================
4
+ # ์ด ์ฝ”๋“œ๋Š” ์„œ๋ฒ„๋‚˜ ํด๋ผ์šฐ๋“œ DB ์—†์ด, ์‚ฌ์šฉ์ž๊ฐ€ ์ง์ ‘ PDF๋ฅผ ์—…๋กœ๋“œํ•˜์—ฌ
5
+ # ๋กœ์ปฌ์—์„œ ์ง€์‹ ๋ฒ ์ด์Šค๋ฅผ ๊ตฌ์ถ•ํ•˜๊ณ  ์งˆ๋ฌธํ•  ์ˆ˜ ์žˆ๋Š” ๊ตฌ์กฐ์ž…๋‹ˆ๋‹ค.
6
+ # Groq(LLM), Google(Voice/Translate) API๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ฌด๋ฃŒ๋กœ ๋™์ž‘ํ•ฉ๋‹ˆ๋‹ค.
7
+ # =========================================================
8
+
9
+ import os
10
+ import sys
11
+ import numpy as np
12
+ import traceback
13
+ import fitz # PyMuPDF (PDF ์ฒ˜๋ฆฌ)
14
+ from typing import List
15
+
16
+ # --- ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์ž„ํฌํŠธ ---
17
+ import gradio as gr
18
+ import speech_recognition as sr
19
+ from deep_translator import GoogleTranslator
20
+ from sentence_transformers import SentenceTransformer
21
+ from groq import Groq
22
+ from qdrant_client import QdrantClient
23
+ from qdrant_client.models import Distance, VectorParams, PointStruct
24
+ try:
25
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
26
+ except ImportError:
27
+ # langchain 0.2.0 ์ด์ƒ์—์„œ ๊ตฌ์กฐ๊ฐ€ ๋ณ€๊ฒฝ๋œ ๊ฒฝ์šฐ
28
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
29
+
30
+ # =========================================================
31
+ # 1. ์„ค์ • ๋ฐ ์ดˆ๊ธฐํ™”
32
+ # =========================================================
33
+
34
+ # Groq API ํ‚ค (ํ•„์ˆ˜)
35
+ GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "your_groq_api_key_here")
36
+ if not GROQ_API_KEY or GROQ_API_KEY == "your_groq_api_key_here":
37
+ print("โš ๏ธ GROQ_API_KEY๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. RAG ๊ธฐ๋Šฅ ์‚ฌ์šฉ ์‹œ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.")
38
+
39
+ # ๋ชจ๋ธ ์„ค์ •
40
+ EMBEDDING_MODEL_NAME = "jhgan/ko-sroberta-multitask"
41
+ GROQ_MODEL_NAME = "llama-3.3-70b-versatile"
42
+ COLLECTION_NAME = "local_kb"
43
+
44
+ print("๐Ÿ› ๏ธ ๋ชจ๋ธ ๋ฐ ํด๋ผ์ด์–ธํŠธ ์ดˆ๊ธฐํ™” ์ค‘...")
45
+
46
+ # 1. ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋กœ๋“œ (๋กœ์ปฌ ์‹คํ–‰)
47
+ embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
48
+ embedding_model.max_seq_length = 512
49
+
50
+ # 2. Qdrant ํด๋ผ์ด์–ธํŠธ (๋กœ์ปฌ ๋ฉ”๋ชจ๋ฆฌ DB - ํ”„๋กœ๊ทธ๋žจ ์ข…๋ฃŒ ์‹œ ๋ฐ์ดํ„ฐ ์‚ญ์ œ๋จ)
51
+ # ์˜๊ตฌ ์ €์žฅ์„ ์›ํ•˜๋ฉด path="./local_qdrant_db" ๋กœ ๋ณ€๊ฒฝํ•˜์„ธ์š”.
52
+ # ์—ฌ๊ธฐ์„œ๋Š” ํฌํŠธํด๋ฆฌ์˜ค์šฉ ๋ฐ๋ชจ๋ฅผ ์œ„ํ•ด ๋งค๋ฒˆ ๊นจ๋—ํ•œ ์ƒํƒœ์ธ ':memory:'๋ฅผ ๊ธฐ๋ณธ์œผ๋กœ ํ•ฉ๋‹ˆ๋‹ค.
53
+ qdrant_client = QdrantClient(":memory:")
54
+
55
+ # ์ปฌ๋ ‰์…˜ ์ƒ์„ฑ (์ด๋ฏธ ์กด์žฌํ•˜๋ฉด ์‚ญ์ œ ํ›„ ์žฌ์ƒ์„ฑ)
56
+ try:
57
+ qdrant_client.recreate_collection(
58
+ collection_name=COLLECTION_NAME,
59
+ vectors_config=VectorParams(size=768, distance=Distance.COSINE),
60
+ )
61
+ print(f"โœ… ๋กœ์ปฌ Qdrant ์ปฌ๋ ‰์…˜ '{COLLECTION_NAME}' ์ƒ์„ฑ ์™„๋ฃŒ.")
62
+ except Exception as e:
63
+ print(f"โŒ Qdrant ์ปฌ๋ ‰์…˜ ์ƒ์„ฑ ์‹คํŒจ: {e}")
64
+
65
+ # 3. Groq ํด๋ผ์ด์–ธํŠธ
66
+ try:
67
+ groq_client = Groq(api_key=GROQ_API_KEY)
68
+ except Exception as e:
69
+ groq_client = None
70
+ print(f"โŒ Groq ํด๋ผ์ด์–ธํŠธ ์ดˆ๊ธฐํ™” ์‹คํŒจ: {e}")
71
+
72
+ #์ „์—ญ ๋ณ€์ˆ˜: ๋ฌธ์„œ ID ์นด์šดํ„ฐ
73
+ doc_id_counter = 0
74
+
75
+ print("โœ… ๋ชจ๋“  ์‹œ์Šคํ…œ ์ค€๋น„ ์™„๋ฃŒ!")
76
+
77
+
78
+ # =========================================================
79
+ # 2. ๋ฌธ์„œ ์ฒ˜๋ฆฌ ๋ฐ RAG ํ•ต์‹ฌ ๋กœ์ง
80
+ # =========================================================
81
+
82
+ def process_uploaded_files(files):
83
+ """PDF ํŒŒ์ผ์„ ์ฝ์–ด ํ…์ŠคํŠธ๋ฅผ ์ถ”์ถœํ•˜๊ณ  ๋ฒกํ„ฐ DB์— ์ €์žฅ"""
84
+ global doc_id_counter
85
+
86
+ if not files:
87
+ return "ํŒŒ์ผ์ด ์—…๋กœ๋“œ๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค."
88
+
89
+ total_chunks = 0
90
+ status_msg = ""
91
+
92
+ # ํ…์ŠคํŠธ ๋ถ„๋ฆฌ๊ธฐ ์„ค์ •
93
+ text_splitter = RecursiveCharacterTextSplitter(
94
+ chunk_size=500,
95
+ chunk_overlap=50,
96
+ length_function=len,
97
+ )
98
+
99
+ for file in files:
100
+ try:
101
+ # Gradio ๋ฒ„์ „/์„ค์ •์— ๋”ฐ๋ผ file์ด ๋ฌธ์ž์—ด(๊ฒฝ๋กœ)์ผ ์ˆ˜๋„ ์žˆ๊ณ  ๊ฐ์ฒด์ผ ์ˆ˜๋„ ์žˆ์Œ
102
+ file_path = file.name if hasattr(file, 'name') else file
103
+
104
+ # 1. PDF ํ…์ŠคํŠธ ์ถ”์ถœ
105
+ doc = fitz.open(file_path)
106
+ file_text = ""
107
+ for page in doc:
108
+ file_text += page.get_text()
109
+
110
+ if not file_text.strip():
111
+ status_msg += f"โš ๏ธ {os.path.basename(file_path)}: ํ…์ŠคํŠธ ์ถ”์ถœ ์‹คํŒจ (์ด๋ฏธ์ง€ PDF์ผ ์ˆ˜ ์žˆ์Œ)\n"
112
+ continue
113
+
114
+ # 2. ํ…์ŠคํŠธ ๋ถ„ํ•  (Chunking)
115
+ chunks = text_splitter.split_text(file_text)
116
+
117
+ # 3. ์ž„๋ฒ ๋”ฉ ๋ฐ ์ €์žฅ
118
+ points = []
119
+ for i, chunk in enumerate(chunks):
120
+ vector = embedding_model.encode(chunk).tolist()
121
+
122
+ payload = {
123
+ "filename": os.path.basename(file_path),
124
+ "text": chunk,
125
+ "chunk_id": i
126
+ }
127
+
128
+ points.append(PointStruct(id=doc_id_counter, vector=vector, payload=payload))
129
+ doc_id_counter += 1
130
+
131
+ # Qdrant์— ์ €์žฅ
132
+ if points:
133
+ qdrant_client.upsert(
134
+ collection_name=COLLECTION_NAME,
135
+ points=points
136
+ )
137
+ total_chunks += len(points)
138
+ status_msg += f"โœ… {os.path.basename(file_path)}: {len(points)}๊ฐœ ์ง€์‹ ์ €์žฅ ์™„๋ฃŒ.\n"
139
+
140
+ except Exception as e:
141
+ traceback.print_exc()
142
+ file_name_debug = getattr(file, 'name', str(file))
143
+ status_msg += f"โŒ {os.path.basename(file_name_debug)} ์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜: {str(e)}\n"
144
+
145
+ print(f"DEBUG: ์ด ์ €์žฅ๋œ ์ฒญํฌ ์ˆ˜: {total_chunks}")
146
+ if total_chunks == 0:
147
+ return status_msg + "\n(์ €์žฅ๋œ ๋ฐ์ดํ„ฐ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค. PDF๊ฐ€ ๋น„์–ด์žˆ๊ฑฐ๋‚˜ ์ด๋ฏธ์ง€์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.)"
148
+
149
+ return f"์ฒ˜๋ฆฌ ์™„๋ฃŒ! ์ด {total_chunks}๊ฐœ์˜ ์ง€์‹ ์กฐ๊ฐ์ด ์ €์žฅ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.\n\n{status_msg}"
150
+
151
+ def search_knowledge_base(query, top_k=5):
152
+ """๋กœ์ปฌ Qdrant์—์„œ ๊ด€๋ จ ๋ฌธ์„œ ๊ฒ€์ƒ‰"""
153
+ try:
154
+ query_vector = embedding_model.encode(query).tolist()
155
+ # qdrant-client ๋ฒ„์ „์— ๋”ฐ๋ผ .search()๊ฐ€ ์—†๊ฑฐ๋‚˜ ๋‹ค๋ฅด๊ฒŒ ๋™์ž‘ํ•  ์ˆ˜ ์žˆ์–ด .query_points() ์‚ฌ์šฉ
156
+ search_result = qdrant_client.query_points(
157
+ collection_name=COLLECTION_NAME,
158
+ query=query_vector,
159
+ limit=top_k,
160
+ with_payload=True
161
+ )
162
+ return search_result.points
163
+ except Exception as e:
164
+ print(f"๊ฒ€์ƒ‰ ์˜ค๋ฅ˜: {e}")
165
+ return []
166
+
167
+ def generate_answer_groq(query, context_text):
168
+ """Groq API๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋‹ต๋ณ€ ์ƒ์„ฑ"""
169
+ if not groq_client:
170
+ return "Groq API ์„ค์ • ์˜ค๋ฅ˜"
171
+
172
+ system_prompt = """
173
+ ๋‹น์‹ ์€ ์นœ์ ˆํ•˜๊ณ  ์ „๋ฌธ์ ์ธ ๊ธˆ์œต AI ์–ด์‹œ์Šคํ„ดํŠธ์ž…๋‹ˆ๋‹ค.
174
+ ๋ฐ˜๋“œ์‹œ ์•„๋ž˜ ์ œ๊ณต๋œ [์ฐธ๊ณ ์ž๋ฃŒ]๋งŒ์„ ๋ฐ”ํƒ•์œผ๋กœ ์งˆ๋ฌธ์— ๋‹ต๋ณ€ํ•˜์„ธ์š”.
175
+ ์ฐธ๊ณ ์ž๋ฃŒ์— ๋‚ด์šฉ์ด ์—†๋‹ค๋ฉด ์†”์งํ•˜๊ฒŒ ๋ชจ๋ฅธ๋‹ค๊ณ  ๋Œ€๋‹ตํ•˜์„ธ์š”.
176
+ ์ถœ์ฒ˜(ํŒŒ์ผ์ด๋ฆ„)๋ฅผ ๋‹ต๋ณ€ ๋์— ๋ช…์‹œํ•ด์ฃผ์„ธ์š”.
177
+ """
178
+
179
+ user_prompt = f"์งˆ๋ฌธ: {query}\n\n[์ฐธ๊ณ ์ž๋ฃŒ]\n{context_text}"
180
+
181
+ try:
182
+ response = groq_client.chat.completions.create(
183
+ messages=[
184
+ {"role": "system", "content": system_prompt},
185
+ {"role": "user", "content": user_prompt},
186
+ ],
187
+ model=GROQ_MODEL_NAME,
188
+ temperature=0.1,
189
+ )
190
+ return response.choices[0].message.content
191
+ except Exception as e:
192
+ return f"Groq ์ƒ์„ฑ ์˜ค๋ฅ˜: {e}"
193
+
194
+ # RAG ํŒŒ์ดํ”„๋ผ์ธ (ํ†ตํ•ฉ)
195
+ def run_rag_pipeline(text_input, detected_lang='ko'):
196
+ if not text_input:
197
+ return "", "", "", ""
198
+
199
+ # 1. ์งˆ๋ฌธ ๋ฒˆ์—ญ (ํ•„์š”์‹œ)
200
+ korean_query = text_input
201
+ if detected_lang != 'ko':
202
+ try:
203
+ korean_query = GoogleTranslator(source='auto', target='ko').translate(text_input)
204
+ except: pass
205
+
206
+ # 2. ๋ฌธ์„œ ๊ฒ€์ƒ‰
207
+ hits = search_knowledge_base(korean_query)
208
+
209
+ if not hits:
210
+ return korean_query, "์ €์žฅ๋œ ์ง€์‹์ด ๋ถ€์กฑํ•˜์—ฌ ๋‹ต๋ณ€ํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. PDF๋ฅผ ๋จผ์ € ์—…๋กœ๋“œํ•ด์ฃผ์„ธ์š”.", "", "์ฐธ๊ณ  ๋ฌธ์„œ ์—†์Œ"
211
+
212
+ # 3. ์ปจํ…์ŠคํŠธ ๊ตฌ์„ฑ
213
+ context_text = ""
214
+ references = []
215
+ for hit in hits:
216
+ context_text += f"{hit.payload['text']}\n\n"
217
+ references.append(f"- {hit.payload['filename']} (์œ ์‚ฌ๋„: {hit.score:.2f})")
218
+
219
+ ref_str = "\n".join(references)
220
+
221
+ # 4. ๋‹ต๋ณ€ ์ƒ์„ฑ
222
+ korean_answer = generate_answer_groq(korean_query, context_text)
223
+
224
+ # 5. ๋‹ต๋ณ€ ๋ฒˆ์—ญ (ํ•„์š”์‹œ)
225
+ final_answer = korean_answer
226
+ if detected_lang != 'ko':
227
+ try:
228
+ final_answer = GoogleTranslator(source='ko', target=detected_lang).translate(korean_answer)
229
+ except: pass
230
+
231
+ return korean_query, korean_answer, final_answer, ref_str
232
+
233
+
234
+ # =========================================================
235
+ # 3. ์Œ์„ฑ ๋ฐ UI ํ—ฌํผ ํ•จ์ˆ˜
236
+ # =========================================================
237
+
238
+ def voice_to_text(audio_input):
239
+ """์Œ์„ฑ ์ธ์‹ (Google API)"""
240
+ if audio_input is None: return "์Œ์„ฑ ์ž…๋ ฅ ์—†์Œ", None
241
+
242
+ try:
243
+ sample_rate, audio_numpy = audio_input
244
+ if audio_numpy.dtype == np.float32:
245
+ audio_numpy = (audio_numpy * 32767).astype(np.int16)
246
+ if len(audio_numpy.shape) > 1:
247
+ audio_numpy = audio_numpy.mean(axis=1).astype(np.int16)
248
+
249
+ audio_data = sr.AudioData(audio_numpy.tobytes(), sample_rate, 2)
250
+ r = sr.Recognizer()
251
+ text = r.recognize_google(audio_data, language='ko-KR')
252
+ return text, 'ko'
253
+ except sr.UnknownValueError:
254
+ return "์ธ์‹ ์‹คํŒจ (๋‹ค์‹œ ๋งํ•ด์ฃผ์„ธ์š”)", None
255
+ except Exception as e:
256
+ return f"์˜ค๋ฅ˜: {e}", None
257
+
258
+ # =========================================================
259
+ # 4. Gradio UI ๊ตฌ์„ฑ
260
+ # =========================================================
261
+
262
+ with gr.Blocks(theme=gr.themes.Soft(), title="KB AI Challenge") as demo:
263
+ gr.Markdown("# KB AI Challenge")
264
+ gr.Markdown("์„œ๋ฒ„ ์—†์ด ๋กœ์ปฌ์—์„œ ๋™์ž‘ํ•˜๋Š” **๊ฐœ์ธ์šฉ RAG ์‹œ์Šคํ…œ**์ž…๋‹ˆ๋‹ค. PDF๋ฅผ ์—…๋กœ๋“œํ•˜๋ฉด ์ฆ‰์‹œ ํ•™์Šตํ•˜์—ฌ ๋‹ต๋ณ€ํ•ฉ๋‹ˆ๋‹ค.")
265
+
266
+ with gr.Accordion("๐Ÿ“‚ 1. ์ง€์‹ ๋ฒ ์ด์Šค ๊ตฌ์ถ• (ํŒŒ์ผ ์—…๋กœ๋“œ)", open=True):
267
+ with gr.Row():
268
+ file_input = gr.File(label="PDF ์—…๋กœ๋“œ (์—ฌ๋Ÿฌ ๊ฐœ ๊ฐ€๋Šฅ)", file_count="multiple", file_types=[".pdf"])
269
+ upload_btn = gr.Button("์ €์žฅํ•˜๊ธฐ", variant="primary")
270
+ upload_status = gr.Textbox(label="์ฒ˜๋ฆฌ ์ƒํƒœ", interactive=False)
271
+
272
+ gr.Markdown("---")
273
+ gr.Markdown("### ๐ŸŽค 2. AI์™€ ๋Œ€ํ™”ํ•˜๊ธฐ")
274
+
275
+ with gr.Row():
276
+ with gr.Column(scale=1):
277
+ audio_in = gr.Audio(sources=["microphone", "upload"], type="numpy", label="์Œ์„ฑ ์งˆ๋ฌธ")
278
+ asr_btn = gr.Button("์Œ์„ฑ ์ธ์‹ ์‹œ์ž‘", variant="secondary")
279
+ text_in = gr.Textbox(label="์ธ์‹๋œ ํ…์ŠคํŠธ (์ง์ ‘ ์ž…๋ ฅ ๊ฐ€๋Šฅ)", lines=3)
280
+ chat_btn = gr.Button("์งˆ๋ฌธํ•˜๊ธฐ", variant="primary")
281
+
282
+ with gr.Column(scale=2):
283
+ answer_box = gr.Textbox(label="AI ๋‹ต๋ณ€ (ํ•œ๊ตญ์–ด)", lines=6, interactive=False)
284
+ ref_box = gr.Textbox(label="์ฐธ๊ณ  ๋ฌธํ—Œ", lines=4, interactive=False)
285
+
286
+ # ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ
287
+ upload_btn.click(process_uploaded_files, inputs=[file_input], outputs=[upload_status])
288
+
289
+ asr_btn.click(voice_to_text, inputs=[audio_in], outputs=[text_in, gr.State()])
290
+
291
+ chat_btn.click(
292
+ run_rag_pipeline,
293
+ inputs=[text_in, gr.State('ko')], # ์–ธ์–ด๋Š” ๊ธฐ๋ณธ ํ•œ๊ตญ์–ด๋กœ ๊ณ ์ • (๋‹จ์ˆœํ™”)
294
+ outputs=[gr.State(), answer_box, gr.State(), ref_box]
295
+ )
296
+
297
+ if __name__ == "__main__":
298
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ python-multipart
4
+ groq
5
+ qdrant-client
6
+ sentence-transformers
7
+ langchain
8
+ langchain-text-splitters
9
+ PyMuPDF
10
+ numpy