| |
| """ |
| ChatGPT-Premium-like open-source Gradio app with: |
| - multi-image upload (practical "unlimited" via disk+queue) |
| - OCR (PaddleOCR preferred, fallback to pytesseract) |
| - Visual reasoning (LLaVA/MiniGPT-style if model available) |
| - Math/aptitude pipeline (OCR -> math-specialized LLM) |
| - Caching of processed images & embeddings |
| - Simple in-process queue & streaming text output |
| - Rate-limiting per-client (token-bucket) |
| h |
| NOTES: |
| - Replace model IDs with ones that match your hardware/quotas. |
| - For production, swap the in-process queue with Redis/Celery and use S3/MinIO for storage. |
| - Achieving strictly "better than ChatGPT" across the board is unrealistic; this app aims to be the best open-source approximation. |
| """ |
|
|
| import os |
| import time |
| import uuid |
| import threading |
| import queue |
| import json |
| import math |
| from pathlib import Path |
| from typing import List, Dict, Tuple, Optional |
| from collections import defaultdict, deque |
|
|
| import gradio as gr |
| from PIL import Image |
| import torch |
| from transformers import ( |
| AutoProcessor, AutoModelForCausalLM, |
| AutoTokenizer, TextIteratorStreamer |
| ) |
|
|
| |
| try: |
| from paddleocr import PaddleOCR |
| PADDLE_AVAILABLE = True |
| except Exception: |
| PADDLE_AVAILABLE = False |
|
|
| try: |
| import pytesseract |
| TESSERACT_AVAILABLE = True |
| except Exception: |
| TESSERACT_AVAILABLE = False |
|
|
| |
| |
| |
| |
| DATA_DIR = Path("data") |
| IMAGES_DIR = DATA_DIR / "images" |
| CACHE_DIR = DATA_DIR / "cache" |
| IMAGES_DIR.mkdir(parents=True, exist_ok=True) |
| CACHE_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| |
| |
| VISUAL_MODEL_ID = "liuhaotian/llava-v1.5-7b" |
| VISUAL_USE = True |
|
|
| |
| MATH_LLM_ID = "mistralai/Mistral-7B-Instruct-v0.2" |
|
|
| |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| MAX_IMAGES_PER_REQUEST = 64 |
| BATCH_SIZE = 4 |
| MAX_HISTORY_TOKENS = 2048 |
| STREAM_CHUNK_SECONDS = 0.12 |
|
|
| |
| RATE_TOKENS = 40 |
| RATE_INTERVAL = 60 |
| TOKENS_PER_REQUEST = 1 |
|
|
| |
| |
| |
| def save_uploaded_image(tempfile) -> Path: |
| |
| uid = uuid.uuid4().hex |
| ext = Path(tempfile.name).suffix or ".png" |
| dest = IMAGES_DIR / f"{int(time.time())}_{uid}{ext}" |
| |
| with open(tempfile.name, "rb") as src, open(dest, "wb") as dst: |
| dst.write(src.read()) |
| return dest |
|
|
| |
| def cache_get(key: str) -> Optional[str]: |
| p = CACHE_DIR / f"{key}.json" |
| if p.exists(): |
| try: |
| return json.loads(p.read_text())["value"] |
| except Exception: |
| return None |
| return None |
|
|
| def cache_set(key: str, value: str): |
| p = CACHE_DIR / f"{key}.json" |
| p.write_text(json.dumps({"value": value})) |
|
|
| def path_hash(p: Path) -> str: |
| |
| st = p.stat() |
| return f"{p.name}-{st.st_size}-{int(st.st_mtime)}" |
|
|
| |
| |
| |
| class TokenBucket: |
| def __init__(self, rate=RATE_TOKENS, per=RATE_INTERVAL): |
| self.rate = rate |
| self.per = per |
| self.allowance = rate |
| self.last_check = time.time() |
|
|
| def consume(self, tokens=1) -> bool: |
| now = time.time() |
| elapsed = now - self.last_check |
| self.last_check = now |
| self.allowance += elapsed * (self.rate / self.per) |
| if self.allowance > self.rate: |
| self.allowance = self.rate |
| if self.allowance >= tokens: |
| self.allowance -= tokens |
| return True |
| return False |
|
|
| rate_buckets = defaultdict(lambda: TokenBucket()) |
|
|
| def rate_ok(client_id: str) -> bool: |
| return rate_buckets[client_id].consume(TOKENS_PER_REQUEST) |
|
|
| |
| |
| |
| paddle_ocr = None |
| if PADDLE_AVAILABLE: |
| paddle_ocr = PaddleOCR(use_angle_cls=True, lang="en") |
|
|
| def run_ocr(path: Path) -> str: |
| """ |
| High-quality OCR pipeline: PaddleOCR -> pytesseract fallback |
| """ |
| key = f"ocr-{path_hash(path)}" |
| cached = cache_get(key) |
| if cached: |
| return cached |
|
|
| text = "" |
| try: |
| if paddle_ocr: |
| result = paddle_ocr.ocr(str(path), cls=True) |
| lines = [] |
| for rec in result: |
| for box, rec_res in rec: |
| txt = rec_res[0] |
| lines.append(txt) |
| text = "\n".join(lines).strip() |
| except Exception as e: |
| |
| text = "" |
|
|
| if not text and TESSERACT_AVAILABLE: |
| try: |
| pil = Image.open(path).convert("RGB") |
| text = pytesseract.image_to_string(pil) |
| text = text.strip() |
| except Exception: |
| text = "" |
|
|
| if not text: |
| text = "" |
|
|
| cache_set(key, text or "") |
| return text |
|
|
| |
| |
| |
| visual_processor = None |
| visual_model = None |
| visual_tokenizer = None |
|
|
| def init_visual_model(): |
| global visual_processor, visual_model, visual_tokenizer |
| if not VISUAL_USE: |
| return |
| try: |
| visual_processor = AutoProcessor.from_pretrained(VISUAL_MODEL_ID) |
| visual_model = AutoModelForCausalLM.from_pretrained( |
| VISUAL_MODEL_ID, |
| torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, |
| device_map="auto" |
| ) |
| |
| visual_tokenizer = AutoTokenizer.from_pretrained(VISUAL_MODEL_ID, use_fast=False) |
| print("Visual model loaded.") |
| except Exception as e: |
| print("Could not load visual model:", e) |
| |
| visual_processor = visual_model = visual_tokenizer = None |
|
|
| |
| def run_visual_reasoning(image_path: Path, question: str, max_new_tokens=256) -> str: |
| if visual_processor is None or visual_model is None: |
| return "" |
| key = f"visual-{path_hash(image_path)}-{question[:96]}" |
| cached = cache_get(key) |
| if cached: |
| return cached |
|
|
| try: |
| image = Image.open(image_path).convert("RGB") |
| inputs = visual_processor(images=image, text=question, return_tensors="pt").to(DEVICE) |
| with torch.no_grad(): |
| outs = visual_model.generate(**inputs, max_new_tokens=max_new_tokens) |
| ans = visual_tokenizer.decode(outs[0], skip_special_tokens=True) |
| cache_set(key, ans) |
| return ans |
| except Exception as e: |
| print("Visual reasoning error:", e) |
| return "" |
|
|
| |
| |
| |
| math_tokenizer = None |
| math_model = None |
|
|
| def init_math_model(): |
| global math_tokenizer, math_model |
| try: |
| math_tokenizer = AutoTokenizer.from_pretrained(MATH_LLM_ID, use_fast=False) |
| math_model = AutoModelForCausalLM.from_pretrained( |
| MATH_LLM_ID, |
| torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, |
| device_map="auto" |
| ) |
| print("Math LLM loaded.") |
| except Exception as e: |
| print("Could not load math model:", e) |
| math_model = None |
|
|
| def ask_math_llm(prompt: str, stream=False): |
| """ |
| If stream=True, return a generator which yields partial text as generated. |
| Otherwise, return final string. |
| """ |
| if math_model is None: |
| return "Math model not available." |
|
|
| inputs = math_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_HISTORY_TOKENS).to(DEVICE) |
|
|
| if not stream: |
| with torch.no_grad(): |
| out_ids = math_model.generate(**inputs, max_new_tokens=512) |
| return math_tokenizer.decode(out_ids[0], skip_special_tokens=True) |
|
|
| |
| streamer = TextIteratorStreamer(math_tokenizer, skip_prompt=True, skip_special_tokens=True) |
| generation_kwargs = dict( |
| **inputs, |
| streamer=streamer, |
| max_new_tokens=512, |
| do_sample=True, |
| temperature=0.7, |
| top_p=0.9 |
| ) |
| thread = threading.Thread(target=math_model.generate, kwargs=generation_kwargs) |
| thread.start() |
| |
| buffer = "" |
| for new_text in streamer: |
| buffer += new_text |
| yield buffer |
|
|
| |
| |
| |
| work_q = queue.Queue(maxsize=256) |
| results_cache = {} |
|
|
| def worker_loop(): |
| while True: |
| job = work_q.get() |
| if job is None: |
| break |
| job_id, image_paths, question = job |
| try: |
| ocr_texts = [run_ocr(p) for p in image_paths] |
| visual_texts = [] |
| if visual_processor and visual_model: |
| for p in image_paths: |
| v = run_visual_reasoning(p, question) |
| visual_texts.append(v) |
| |
| combined = { |
| "ocr": ocr_texts, |
| "visual": visual_texts |
| } |
| results_cache[job_id] = combined |
| except Exception as e: |
| results_cache[job_id] = {"error": str(e)} |
| finally: |
| work_q.task_done() |
|
|
| |
| NUM_WORKERS = max(1, min(4, (os.cpu_count() or 2)//2)) |
| for _ in range(NUM_WORKERS): |
| t = threading.Thread(target=worker_loop, daemon=True) |
| t.start() |
|
|
| |
| |
| |
| def build_prompt(system_prompt: str, chat_history: List[Tuple[str,str]], extracted_texts: List[str], user_question: str) -> str: |
| |
| history_text = "" |
| for role, text in chat_history[-8:]: |
| history_text += f"{role}: {text}\n" |
| img_ctx = "" |
| if extracted_texts: |
| img_ctx = "\n\nEXTRACTED_FROM_IMAGES:\n" + "\n---\n".join(extracted_texts) |
| prompt = f"""{system_prompt} |
| |
| Conversation: |
| {history_text} |
| |
| User question: |
| {user_question} |
| |
| {img_ctx} |
| |
| Assistant (explain step-by-step, show calculations if any):""" |
| return prompt |
|
|
| SYSTEM_PROMPT = "You are a helpful assistant that solves aptitude, math, and image-based questions. Be precise, show steps, and if images contain diagrams refer to them." |
|
|
| |
| SESSION_MEMORY = defaultdict(lambda: {"history": [], "embeddings": []}) |
|
|
| def process_request(client_id: str, uploaded_files, user_question: str, stream=True): |
| |
| if not rate_ok(client_id): |
| return ["Rate limit exceeded. Try again later."] |
|
|
| |
| image_paths = [] |
| for f in (uploaded_files or []): |
| p = save_uploaded_image(f) |
| image_paths.append(p) |
| if len(image_paths) > MAX_IMAGES_PER_REQUEST: |
| return [f"Too many images - max {MAX_IMAGES_PER_REQUEST}"] |
|
|
| |
| job_id = uuid.uuid4().hex |
| work_q.put((job_id, image_paths, user_question)) |
|
|
| |
| wait_seconds = 0 |
| while job_id not in results_cache and wait_seconds < 12: |
| time.sleep(0.25) |
| wait_seconds += 0.25 |
|
|
| if job_id not in results_cache: |
| |
| ocr_texts = [run_ocr(p) for p in image_paths] |
| visual_texts = [] |
| if visual_processor and visual_model: |
| for p in image_paths: |
| visual_texts.append(run_visual_reasoning(p, user_question)) |
| results = {"ocr": ocr_texts, "visual": visual_texts} |
| else: |
| results = results_cache.pop(job_id, {"ocr": [], "visual": []}) |
|
|
| |
| extracted_texts = [] |
| for o, v in zip(results.get("ocr", []), results.get("visual", [])): |
| parts = [] |
| if o: |
| parts.append("OCR: " + o) |
| if v: |
| parts.append("Visual: " + v) |
| combined = "\n".join(parts).strip() |
| if combined: |
| extracted_texts.append(combined) |
|
|
| |
| sess = SESSION_MEMORY[client_id] |
| sess["history"].append(("User", user_question)) |
| |
| prompt = build_prompt(SYSTEM_PROMPT, sess["history"], extracted_texts, user_question) |
|
|
| |
| if stream: |
| |
| yield from _stream_llm_response_generator(prompt, client_id) |
| else: |
| answer = ask_math_llm(prompt, stream=False) |
| sess["history"].append(("Assistant", answer)) |
| return [answer] |
|
|
| def _stream_llm_response_generator(prompt: str, client_id: str): |
| |
| |
| session = SESSION_MEMORY[client_id] |
| |
| gen = ask_math_llm(prompt, stream=True) |
| partial = "" |
| for chunk in gen: |
| |
| partial = chunk |
| |
| yield partial |
| |
| session["history"].append(("Assistant", partial)) |
|
|
| |
| |
| |
| |
| |
| |
| with gr.Blocks(css=""" |
| /* small CSS to make chat look nicer */ |
| .chat-column { max-width: 900px; margin-left: auto; margin-right: auto; } |
| """) as demo: |
|
|
| gr.Markdown("# 🚀 Open-Source ChatGPT-like (Multimodal)") |
|
|
| with gr.Row(): |
| with gr.Column(scale=8, elem_classes="chat-column"): |
| chatbot = gr.Chatbot( |
| label="Assistant", |
| elem_id="chatbot", |
| show_label=False, |
| type="messages", |
| height=600 |
| ) |
| with gr.Row(): |
| txt = gr.Textbox( |
| label="Type a message...", |
| placeholder="Ask a question or upload images", |
| show_label=False |
| ) |
| submit = gr.Button("Send") |
| with gr.Row(): |
| img_in = gr.File( |
| label="Upload images (multiple)", |
| file_count="multiple", |
| file_types=["image"] |
| ) |
| clear_btn = gr.Button("New Chat") |
| client_id_state = gr.State(str(uuid.uuid4())) |
|
|
| |
| |
| |
| def handle_send(message, client_state, files): |
| client_id = client_state or str(uuid.uuid4()) |
| gen = process_request(client_id, files, message, stream=True) |
| collected = "" |
| for part in gen: |
| collected = part |
| |
| yield "", [ |
| {"role": "user", "content": message}, |
| {"role": "assistant", "content": collected} |
| ] |
| |
| yield "", [ |
| {"role": "user", "content": message}, |
| {"role": "assistant", "content": collected} |
| ] |
|
|
| |
| submit.click(handle_send, inputs=[txt, client_id_state, img_in], outputs=[txt, chatbot]) |
| txt.submit(handle_send, inputs=[txt, client_id_state, img_in], outputs=[txt, chatbot]) |
|
|
| |
| def clear_chat(): |
| client_id_state.value = str(uuid.uuid4()) |
| return [], "" |
| clear_btn.click(clear_chat, None, [chatbot, txt]) |