Spaces:
Running
Running
| from fastapi import FastAPI, UploadFile, File | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| pipeline, | |
| AutoModelForVision2Seq, | |
| AutoProcessor | |
| ) | |
| import torch, uvicorn, os, subprocess, threading, shutil, time | |
| from typing import List | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| # ===================================================== | |
| # FastAPI App Setup | |
| # ===================================================== | |
| app = FastAPI(title="AI Chat + Summarization + Vision API") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ===================================================== | |
| # Auto Disk Cleanup (for Codespaces) | |
| # ===================================================== | |
| def check_disk_space(min_gb=2): | |
| stat = shutil.disk_usage("/") | |
| free_gb = stat.free / (1024 ** 3) | |
| if free_gb < min_gb: | |
| print(f"⚠️ Low disk space ({free_gb:.2f} GB). Clearing HuggingFace cache...") | |
| os.system("rm -rf ~/.cache/huggingface/*") | |
| def background_health_monitor(): | |
| while True: | |
| check_disk_space() | |
| time.sleep(600) | |
| threading.Thread(target=background_health_monitor, daemon=True).start() | |
| # ===================================================== | |
| # Model Loading (Lazy Initialization) | |
| # ===================================================== | |
| chat_model_name = "Qwen/Qwen1.5-0.5B-Chat" | |
| chat_tokenizer = None | |
| chat_model = None | |
| summary_pipe = None | |
| vision_model = None | |
| vision_processor = None | |
| def load_chat_model(): | |
| global chat_tokenizer, chat_model | |
| if chat_tokenizer is None or chat_model is None: | |
| print("Loading chat model...") | |
| chat_tokenizer = AutoTokenizer.from_pretrained(chat_model_name) | |
| chat_model = AutoModelForCausalLM.from_pretrained( | |
| chat_model_name, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| low_cpu_mem_usage=True, | |
| offload_folder="offload", | |
| ).eval() | |
| def load_summary_model(): | |
| global summary_pipe | |
| if summary_pipe is None: | |
| print("Loading summarization model...") | |
| summary_pipe = pipeline( | |
| "summarization", | |
| model="sshleifer/distilbart-cnn-6-6", | |
| device=0 if torch.cuda.is_available() else -1 | |
| ) | |
| def load_vision_model(): | |
| global vision_model, vision_processor | |
| if vision_model is None or vision_processor is None: | |
| print("Loading vision model...") | |
| vision_model_name = "microsoft/git-base-coco" | |
| vision_model = AutoModelForVision2Seq.from_pretrained(vision_model_name).to("cuda" if torch.cuda.is_available() else "cpu") | |
| vision_processor = AutoProcessor.from_pretrained(vision_model_name) | |
| # ===================================================== | |
| # API Schemas | |
| # ===================================================== | |
| class ChatRequest(BaseModel): | |
| message: str | |
| max_new_tokens: int = 80 | |
| temperature: float = 0.7 | |
| class SummaryRequest(BaseModel): | |
| text: str | |
| max_length: int = 100 | |
| min_length: int = 25 | |
| class WordPredictionRequest(BaseModel): | |
| word: str | |
| num_predictions: int = 5 | |
| # ===================================================== | |
| # Chat Endpoint | |
| # ===================================================== | |
| def chat_generate(req: ChatRequest): | |
| try: | |
| # Load models on first request | |
| load_chat_model() | |
| # Build prompt and run generation while requesting per-step scores | |
| prompt = ( | |
| "<|im_start|>system\nYou are a helpful AI assistant.<|im_end|>\n" | |
| f"<|im_start|>user\n{req.message}<|im_end|>\n" | |
| "<|im_start|>assistant\n" | |
| ) | |
| inputs = chat_tokenizer(prompt, return_tensors="pt").to(chat_model.device) | |
| # Generate deterministically (greedy) while returning scores for each generated step | |
| outputs = chat_model.generate( | |
| **inputs, | |
| max_new_tokens=req.max_new_tokens, | |
| temperature=req.temperature, | |
| do_sample=False, | |
| output_scores=True, | |
| return_dict_in_generate=True, | |
| eos_token_id=chat_tokenizer.eos_token_id, | |
| pad_token_id=chat_tokenizer.eos_token_id, | |
| ) | |
| # Full sequence and newly generated token ids | |
| sequence = outputs.sequences[0] | |
| start_idx = inputs["input_ids"].size(1) | |
| generated_ids = sequence[start_idx:].tolist() | |
| # Decode the full reply | |
| reply = chat_tokenizer.decode(generated_ids, skip_special_tokens=True).strip() | |
| # Prepare per-token alternatives using the per-step logits/scores | |
| tokens_info = [] | |
| # outputs.scores is a tuple with one entry per generated step | |
| if hasattr(outputs, "scores") and outputs.scores is not None: | |
| for i, logits in enumerate(outputs.scores): | |
| # logits shape: (batch_size, vocab_size) | |
| probs = torch.softmax(logits[0], dim=-1) | |
| chosen_id = generated_ids[i] | |
| # Get top-k (we ask for 6 and drop the chosen token if present) | |
| topk = torch.topk(probs, k=6) | |
| alts = [] | |
| for idx, val in zip(topk.indices.tolist(), topk.values.tolist()): | |
| if idx == chosen_id: | |
| continue | |
| alts.append({ | |
| "id": idx, | |
| "token": chat_tokenizer.decode([idx], skip_special_tokens=True).strip(), | |
| "probability": float(val) | |
| }) | |
| if len(alts) >= 5: | |
| break | |
| # Fallback: if not enough alts, sample additional highest-prob tokens | |
| if len(alts) < 5: | |
| # get full topk of vocab (expensive but rare for short max_new_tokens) | |
| fallback_topk = torch.topk(probs, k=10) | |
| for idx, val in zip(fallback_topk.indices.tolist(), fallback_topk.values.tolist()): | |
| if idx == chosen_id: | |
| continue | |
| if any(a["id"] == idx for a in alts): | |
| continue | |
| alts.append({ | |
| "id": idx, | |
| "token": chat_tokenizer.decode([idx], skip_special_tokens=True).strip(), | |
| "probability": float(val) | |
| }) | |
| if len(alts) >= 5: | |
| break | |
| tokens_info.append({ | |
| "id": chosen_id, | |
| "token": chat_tokenizer.decode([chosen_id], skip_special_tokens=True).strip(), | |
| "alternatives": alts | |
| }) | |
| return {"success": True, "response": reply, "tokens": tokens_info} | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| # ===================================================== | |
| # Word Prediction Endpoint | |
| # ===================================================== | |
| def predict_words(req: WordPredictionRequest): | |
| try: | |
| # Load models on first request | |
| load_chat_model() | |
| input_ids = chat_tokenizer.encode(req.word, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = chat_model(input_ids) | |
| predictions = outputs.logits[0, -1, :] | |
| top_k = torch.topk(predictions, k=req.num_predictions) | |
| words = [] | |
| for i in range(req.num_predictions): | |
| token = top_k.indices[i].item() | |
| prob = float(torch.softmax(top_k.values, dim=0)[i].item()) | |
| predicted_word = chat_tokenizer.decode([token]) | |
| words.append({"word": predicted_word, "probability": prob}) | |
| return words | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| # ===================================================== | |
| # Summarization Endpoint | |
| # ===================================================== | |
| def summarize_text(req: SummaryRequest): | |
| try: | |
| # Load models on first request | |
| load_summary_model() | |
| # Get word count | |
| word_count = len(req.text.split()) | |
| # Adjust max_length to be ~30-50% of input length | |
| adjusted_max = min(req.max_length, max(20, word_count // 2)) | |
| # Adjust min_length to be ~10-20% of input length | |
| adjusted_min = min(req.min_length, max(10, word_count // 5)) | |
| result = summary_pipe( | |
| req.text, | |
| max_length=adjusted_max, | |
| min_length=min(adjusted_min, adjusted_max // 2), | |
| truncation=True, | |
| ) | |
| key = "summary_text" if "summary_text" in result[0] else "generated_text" | |
| return {"success": True, "summary": result[0][key].strip()} | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| # ===================================================== | |
| # Image Processing Endpoint | |
| # ===================================================== | |
| async def process_image(image: UploadFile = File(...)): | |
| try: | |
| # Load models on first request | |
| load_vision_model() | |
| contents = await image.read() | |
| img = Image.open(io.BytesIO(contents)).convert('RGB') | |
| # Process image with vision model | |
| inputs = vision_processor(images=img, return_tensors="pt") | |
| inputs = {k: v.to(vision_model.device) for k, v in inputs.items()} | |
| # Generate description | |
| with torch.no_grad(): | |
| outputs = vision_model.generate( | |
| **inputs, | |
| max_length=50, | |
| num_beams=5, | |
| temperature=0.8, | |
| do_sample=True | |
| ) | |
| description = vision_processor.batch_decode(outputs, skip_special_tokens=True)[0] | |
| return { | |
| "success": True, | |
| "description": description | |
| } | |
| except Exception as e: | |
| print(f"Error processing image: {str(e)}") | |
| return {"success": False, "error": str(e)} | |
| # ===================================================== | |
| # Health + Static | |
| # ===================================================== | |
| def health_check(): | |
| return { | |
| "status": "healthy", | |
| "models": [ | |
| "Qwen-1.5-0.5B-Chat", | |
| "DistilBART-6-6", | |
| "microsoft/git-base-coco" | |
| ] | |
| } | |
| if os.path.exists("static"): | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| def read_root(): | |
| if os.path.exists("static/index.html"): | |
| return FileResponse("static/index.html") | |
| return {"message": "AI Chat & Summarization API running!"} | |
| # ===================================================== | |
| # Run API | |
| # ===================================================== | |
| if __name__ == "__main__": | |
| port = int(os.getenv("PORT", 8000)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |