|
|
|
|
|
""" |
|
|
ChatGPT-Premium-like open-source Gradio app with: |
|
|
- multi-image upload (practical "unlimited" via disk+queue) |
|
|
- OCR (PaddleOCR preferred, fallback to pytesseract) |
|
|
- Visual reasoning (LLaVA/MiniGPT-style if model available) |
|
|
- Math/aptitude pipeline (OCR -> math-specialized LLM) |
|
|
- Caching of processed images & embeddings |
|
|
- Simple in-process queue & streaming text output |
|
|
- Rate-limiting per-client (token-bucket) |
|
|
h |
|
|
NOTES: |
|
|
- Replace model IDs with ones that match your hardware/quotas. |
|
|
- For production, swap the in-process queue with Redis/Celery and use S3/MinIO for storage. |
|
|
- Achieving strictly "better than ChatGPT" across the board is unrealistic; this app aims to be the best open-source approximation. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import time |
|
|
import uuid |
|
|
import threading |
|
|
import queue |
|
|
import json |
|
|
import math |
|
|
from pathlib import Path |
|
|
from typing import List, Dict, Tuple, Optional |
|
|
from collections import defaultdict, deque |
|
|
|
|
|
import gradio as gr |
|
|
from PIL import Image |
|
|
import torch |
|
|
from transformers import ( |
|
|
AutoProcessor, AutoModelForCausalLM, |
|
|
AutoTokenizer, TextIteratorStreamer |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
from paddleocr import PaddleOCR |
|
|
PADDLE_AVAILABLE = True |
|
|
except Exception: |
|
|
PADDLE_AVAILABLE = False |
|
|
|
|
|
try: |
|
|
import pytesseract |
|
|
TESSERACT_AVAILABLE = True |
|
|
except Exception: |
|
|
TESSERACT_AVAILABLE = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DATA_DIR = Path("data") |
|
|
IMAGES_DIR = DATA_DIR / "images" |
|
|
CACHE_DIR = DATA_DIR / "cache" |
|
|
IMAGES_DIR.mkdir(parents=True, exist_ok=True) |
|
|
CACHE_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
VISUAL_MODEL_ID = "liuhaotian/llava-v1.5-7b" |
|
|
VISUAL_USE = True |
|
|
|
|
|
|
|
|
MATH_LLM_ID = "mistralai/Mistral-7B-Instruct-v0.2" |
|
|
|
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
MAX_IMAGES_PER_REQUEST = 64 |
|
|
BATCH_SIZE = 4 |
|
|
MAX_HISTORY_TOKENS = 2048 |
|
|
STREAM_CHUNK_SECONDS = 0.12 |
|
|
|
|
|
|
|
|
RATE_TOKENS = 40 |
|
|
RATE_INTERVAL = 60 |
|
|
TOKENS_PER_REQUEST = 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_uploaded_image(tempfile) -> Path: |
|
|
|
|
|
uid = uuid.uuid4().hex |
|
|
ext = Path(tempfile.name).suffix or ".png" |
|
|
dest = IMAGES_DIR / f"{int(time.time())}_{uid}{ext}" |
|
|
|
|
|
with open(tempfile.name, "rb") as src, open(dest, "wb") as dst: |
|
|
dst.write(src.read()) |
|
|
return dest |
|
|
|
|
|
|
|
|
def cache_get(key: str) -> Optional[str]: |
|
|
p = CACHE_DIR / f"{key}.json" |
|
|
if p.exists(): |
|
|
try: |
|
|
return json.loads(p.read_text())["value"] |
|
|
except Exception: |
|
|
return None |
|
|
return None |
|
|
|
|
|
def cache_set(key: str, value: str): |
|
|
p = CACHE_DIR / f"{key}.json" |
|
|
p.write_text(json.dumps({"value": value})) |
|
|
|
|
|
def path_hash(p: Path) -> str: |
|
|
|
|
|
st = p.stat() |
|
|
return f"{p.name}-{st.st_size}-{int(st.st_mtime)}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TokenBucket: |
|
|
def __init__(self, rate=RATE_TOKENS, per=RATE_INTERVAL): |
|
|
self.rate = rate |
|
|
self.per = per |
|
|
self.allowance = rate |
|
|
self.last_check = time.time() |
|
|
|
|
|
def consume(self, tokens=1) -> bool: |
|
|
now = time.time() |
|
|
elapsed = now - self.last_check |
|
|
self.last_check = now |
|
|
self.allowance += elapsed * (self.rate / self.per) |
|
|
if self.allowance > self.rate: |
|
|
self.allowance = self.rate |
|
|
if self.allowance >= tokens: |
|
|
self.allowance -= tokens |
|
|
return True |
|
|
return False |
|
|
|
|
|
rate_buckets = defaultdict(lambda: TokenBucket()) |
|
|
|
|
|
def rate_ok(client_id: str) -> bool: |
|
|
return rate_buckets[client_id].consume(TOKENS_PER_REQUEST) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
paddle_ocr = None |
|
|
if PADDLE_AVAILABLE: |
|
|
paddle_ocr = PaddleOCR(use_angle_cls=True, lang="en") |
|
|
|
|
|
def run_ocr(path: Path) -> str: |
|
|
""" |
|
|
High-quality OCR pipeline: PaddleOCR -> pytesseract fallback |
|
|
""" |
|
|
key = f"ocr-{path_hash(path)}" |
|
|
cached = cache_get(key) |
|
|
if cached: |
|
|
return cached |
|
|
|
|
|
text = "" |
|
|
try: |
|
|
if paddle_ocr: |
|
|
result = paddle_ocr.ocr(str(path), cls=True) |
|
|
lines = [] |
|
|
for rec in result: |
|
|
for box, rec_res in rec: |
|
|
txt = rec_res[0] |
|
|
lines.append(txt) |
|
|
text = "\n".join(lines).strip() |
|
|
except Exception as e: |
|
|
|
|
|
text = "" |
|
|
|
|
|
if not text and TESSERACT_AVAILABLE: |
|
|
try: |
|
|
pil = Image.open(path).convert("RGB") |
|
|
text = pytesseract.image_to_string(pil) |
|
|
text = text.strip() |
|
|
except Exception: |
|
|
text = "" |
|
|
|
|
|
if not text: |
|
|
text = "" |
|
|
|
|
|
cache_set(key, text or "") |
|
|
return text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
visual_processor = None |
|
|
visual_model = None |
|
|
visual_tokenizer = None |
|
|
|
|
|
def init_visual_model(): |
|
|
global visual_processor, visual_model, visual_tokenizer |
|
|
if not VISUAL_USE: |
|
|
return |
|
|
try: |
|
|
visual_processor = AutoProcessor.from_pretrained(VISUAL_MODEL_ID) |
|
|
visual_model = AutoModelForCausalLM.from_pretrained( |
|
|
VISUAL_MODEL_ID, |
|
|
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, |
|
|
device_map="auto" |
|
|
) |
|
|
|
|
|
visual_tokenizer = AutoTokenizer.from_pretrained(VISUAL_MODEL_ID, use_fast=False) |
|
|
print("Visual model loaded.") |
|
|
except Exception as e: |
|
|
print("Could not load visual model:", e) |
|
|
|
|
|
visual_processor = visual_model = visual_tokenizer = None |
|
|
|
|
|
|
|
|
def run_visual_reasoning(image_path: Path, question: str, max_new_tokens=256) -> str: |
|
|
if visual_processor is None or visual_model is None: |
|
|
return "" |
|
|
key = f"visual-{path_hash(image_path)}-{question[:96]}" |
|
|
cached = cache_get(key) |
|
|
if cached: |
|
|
return cached |
|
|
|
|
|
try: |
|
|
image = Image.open(image_path).convert("RGB") |
|
|
inputs = visual_processor(images=image, text=question, return_tensors="pt").to(DEVICE) |
|
|
with torch.no_grad(): |
|
|
outs = visual_model.generate(**inputs, max_new_tokens=max_new_tokens) |
|
|
ans = visual_tokenizer.decode(outs[0], skip_special_tokens=True) |
|
|
cache_set(key, ans) |
|
|
return ans |
|
|
except Exception as e: |
|
|
print("Visual reasoning error:", e) |
|
|
return "" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
math_tokenizer = None |
|
|
math_model = None |
|
|
|
|
|
def init_math_model(): |
|
|
global math_tokenizer, math_model |
|
|
try: |
|
|
math_tokenizer = AutoTokenizer.from_pretrained(MATH_LLM_ID, use_fast=False) |
|
|
math_model = AutoModelForCausalLM.from_pretrained( |
|
|
MATH_LLM_ID, |
|
|
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, |
|
|
device_map="auto" |
|
|
) |
|
|
print("Math LLM loaded.") |
|
|
except Exception as e: |
|
|
print("Could not load math model:", e) |
|
|
math_model = None |
|
|
|
|
|
def ask_math_llm(prompt: str, stream=False): |
|
|
""" |
|
|
If stream=True, return a generator which yields partial text as generated. |
|
|
Otherwise, return final string. |
|
|
""" |
|
|
if math_model is None: |
|
|
return "Math model not available." |
|
|
|
|
|
inputs = math_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_HISTORY_TOKENS).to(DEVICE) |
|
|
|
|
|
if not stream: |
|
|
with torch.no_grad(): |
|
|
out_ids = math_model.generate(**inputs, max_new_tokens=512) |
|
|
return math_tokenizer.decode(out_ids[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
streamer = TextIteratorStreamer(math_tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
generation_kwargs = dict( |
|
|
**inputs, |
|
|
streamer=streamer, |
|
|
max_new_tokens=512, |
|
|
do_sample=True, |
|
|
temperature=0.7, |
|
|
top_p=0.9 |
|
|
) |
|
|
thread = threading.Thread(target=math_model.generate, kwargs=generation_kwargs) |
|
|
thread.start() |
|
|
|
|
|
buffer = "" |
|
|
for new_text in streamer: |
|
|
buffer += new_text |
|
|
yield buffer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
work_q = queue.Queue(maxsize=256) |
|
|
results_cache = {} |
|
|
|
|
|
def worker_loop(): |
|
|
while True: |
|
|
job = work_q.get() |
|
|
if job is None: |
|
|
break |
|
|
job_id, image_paths, question = job |
|
|
try: |
|
|
ocr_texts = [run_ocr(p) for p in image_paths] |
|
|
visual_texts = [] |
|
|
if visual_processor and visual_model: |
|
|
for p in image_paths: |
|
|
v = run_visual_reasoning(p, question) |
|
|
visual_texts.append(v) |
|
|
|
|
|
combined = { |
|
|
"ocr": ocr_texts, |
|
|
"visual": visual_texts |
|
|
} |
|
|
results_cache[job_id] = combined |
|
|
except Exception as e: |
|
|
results_cache[job_id] = {"error": str(e)} |
|
|
finally: |
|
|
work_q.task_done() |
|
|
|
|
|
|
|
|
NUM_WORKERS = max(1, min(4, (os.cpu_count() or 2)//2)) |
|
|
for _ in range(NUM_WORKERS): |
|
|
t = threading.Thread(target=worker_loop, daemon=True) |
|
|
t.start() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_prompt(system_prompt: str, chat_history: List[Tuple[str,str]], extracted_texts: List[str], user_question: str) -> str: |
|
|
|
|
|
history_text = "" |
|
|
for role, text in chat_history[-8:]: |
|
|
history_text += f"{role}: {text}\n" |
|
|
img_ctx = "" |
|
|
if extracted_texts: |
|
|
img_ctx = "\n\nEXTRACTED_FROM_IMAGES:\n" + "\n---\n".join(extracted_texts) |
|
|
prompt = f"""{system_prompt} |
|
|
|
|
|
Conversation: |
|
|
{history_text} |
|
|
|
|
|
User question: |
|
|
{user_question} |
|
|
|
|
|
{img_ctx} |
|
|
|
|
|
Assistant (explain step-by-step, show calculations if any):""" |
|
|
return prompt |
|
|
|
|
|
SYSTEM_PROMPT = "You are a helpful assistant that solves aptitude, math, and image-based questions. Be precise, show steps, and if images contain diagrams refer to them." |
|
|
|
|
|
|
|
|
SESSION_MEMORY = defaultdict(lambda: {"history": [], "embeddings": []}) |
|
|
|
|
|
def process_request(client_id: str, uploaded_files, user_question: str, stream=True): |
|
|
|
|
|
if not rate_ok(client_id): |
|
|
return ["Rate limit exceeded. Try again later."] |
|
|
|
|
|
|
|
|
image_paths = [] |
|
|
for f in (uploaded_files or []): |
|
|
p = save_uploaded_image(f) |
|
|
image_paths.append(p) |
|
|
if len(image_paths) > MAX_IMAGES_PER_REQUEST: |
|
|
return [f"Too many images - max {MAX_IMAGES_PER_REQUEST}"] |
|
|
|
|
|
|
|
|
job_id = uuid.uuid4().hex |
|
|
work_q.put((job_id, image_paths, user_question)) |
|
|
|
|
|
|
|
|
wait_seconds = 0 |
|
|
while job_id not in results_cache and wait_seconds < 12: |
|
|
time.sleep(0.25) |
|
|
wait_seconds += 0.25 |
|
|
|
|
|
if job_id not in results_cache: |
|
|
|
|
|
ocr_texts = [run_ocr(p) for p in image_paths] |
|
|
visual_texts = [] |
|
|
if visual_processor and visual_model: |
|
|
for p in image_paths: |
|
|
visual_texts.append(run_visual_reasoning(p, user_question)) |
|
|
results = {"ocr": ocr_texts, "visual": visual_texts} |
|
|
else: |
|
|
results = results_cache.pop(job_id, {"ocr": [], "visual": []}) |
|
|
|
|
|
|
|
|
extracted_texts = [] |
|
|
for o, v in zip(results.get("ocr", []), results.get("visual", [])): |
|
|
parts = [] |
|
|
if o: |
|
|
parts.append("OCR: " + o) |
|
|
if v: |
|
|
parts.append("Visual: " + v) |
|
|
combined = "\n".join(parts).strip() |
|
|
if combined: |
|
|
extracted_texts.append(combined) |
|
|
|
|
|
|
|
|
sess = SESSION_MEMORY[client_id] |
|
|
sess["history"].append(("User", user_question)) |
|
|
|
|
|
prompt = build_prompt(SYSTEM_PROMPT, sess["history"], extracted_texts, user_question) |
|
|
|
|
|
|
|
|
if stream: |
|
|
|
|
|
yield from _stream_llm_response_generator(prompt, client_id) |
|
|
else: |
|
|
answer = ask_math_llm(prompt, stream=False) |
|
|
sess["history"].append(("Assistant", answer)) |
|
|
return [answer] |
|
|
|
|
|
def _stream_llm_response_generator(prompt: str, client_id: str): |
|
|
|
|
|
|
|
|
session = SESSION_MEMORY[client_id] |
|
|
|
|
|
gen = ask_math_llm(prompt, stream=True) |
|
|
partial = "" |
|
|
for chunk in gen: |
|
|
|
|
|
partial = chunk |
|
|
|
|
|
yield partial |
|
|
|
|
|
session["history"].append(("Assistant", partial)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(css=""" |
|
|
/* small CSS to make chat look nicer */ |
|
|
.chat-column { max-width: 900px; margin-left: auto; margin-right: auto; } |
|
|
""") as demo: |
|
|
|
|
|
gr.Markdown("# 🚀 Open-Source ChatGPT-like (Multimodal)") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=8, elem_classes="chat-column"): |
|
|
chatbot = gr.Chatbot( |
|
|
label="Assistant", |
|
|
elem_id="chatbot", |
|
|
show_label=False, |
|
|
type="messages", |
|
|
height=600 |
|
|
) |
|
|
with gr.Row(): |
|
|
txt = gr.Textbox( |
|
|
label="Type a message...", |
|
|
placeholder="Ask a question or upload images", |
|
|
show_label=False |
|
|
) |
|
|
submit = gr.Button("Send") |
|
|
with gr.Row(): |
|
|
img_in = gr.File( |
|
|
label="Upload images (multiple)", |
|
|
file_count="multiple", |
|
|
file_types=["image"] |
|
|
) |
|
|
clear_btn = gr.Button("New Chat") |
|
|
client_id_state = gr.State(str(uuid.uuid4())) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def handle_send(message, client_state, files): |
|
|
client_id = client_state or str(uuid.uuid4()) |
|
|
gen = process_request(client_id, files, message, stream=True) |
|
|
collected = "" |
|
|
for part in gen: |
|
|
collected = part |
|
|
|
|
|
yield "", [ |
|
|
{"role": "user", "content": message}, |
|
|
{"role": "assistant", "content": collected} |
|
|
] |
|
|
|
|
|
yield "", [ |
|
|
{"role": "user", "content": message}, |
|
|
{"role": "assistant", "content": collected} |
|
|
] |
|
|
|
|
|
|
|
|
submit.click(handle_send, inputs=[txt, client_id_state, img_in], outputs=[txt, chatbot]) |
|
|
txt.submit(handle_send, inputs=[txt, client_id_state, img_in], outputs=[txt, chatbot]) |
|
|
|
|
|
|
|
|
def clear_chat(): |
|
|
client_id_state.value = str(uuid.uuid4()) |
|
|
return [], "" |
|
|
clear_btn.click(clear_chat, None, [chatbot, txt]) |