from fastapi import FastAPI, UploadFile, File from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse from pydantic import BaseModel from transformers import ( AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModelForVision2Seq, AutoProcessor ) import torch, uvicorn, os, subprocess, threading, shutil, time from typing import List import numpy as np from PIL import Image import io # ===================================================== # FastAPI App Setup # ===================================================== app = FastAPI(title="AI Chat + Summarization + Vision API") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ===================================================== # Auto Disk Cleanup (for Codespaces) # ===================================================== def check_disk_space(min_gb=2): stat = shutil.disk_usage("/") free_gb = stat.free / (1024 ** 3) if free_gb < min_gb: print(f"⚠️ Low disk space ({free_gb:.2f} GB). Clearing HuggingFace cache...") os.system("rm -rf ~/.cache/huggingface/*") def background_health_monitor(): while True: check_disk_space() time.sleep(600) threading.Thread(target=background_health_monitor, daemon=True).start() # ===================================================== # Model Loading (Lazy Initialization) # ===================================================== chat_model_name = "Qwen/Qwen1.5-0.5B-Chat" chat_tokenizer = None chat_model = None summary_pipe = None vision_model = None vision_processor = None def load_chat_model(): global chat_tokenizer, chat_model if chat_tokenizer is None or chat_model is None: print("Loading chat model...") chat_tokenizer = AutoTokenizer.from_pretrained(chat_model_name) chat_model = AutoModelForCausalLM.from_pretrained( chat_model_name, torch_dtype=torch.bfloat16, device_map="auto", low_cpu_mem_usage=True, offload_folder="offload", ).eval() def load_summary_model(): global summary_pipe if summary_pipe is None: print("Loading summarization model...") summary_pipe = pipeline( "summarization", model="sshleifer/distilbart-cnn-6-6", device=0 if torch.cuda.is_available() else -1 ) def load_vision_model(): global vision_model, vision_processor if vision_model is None or vision_processor is None: print("Loading vision model...") vision_model_name = "microsoft/git-base-coco" vision_model = AutoModelForVision2Seq.from_pretrained(vision_model_name).to("cuda" if torch.cuda.is_available() else "cpu") vision_processor = AutoProcessor.from_pretrained(vision_model_name) # ===================================================== # API Schemas # ===================================================== class ChatRequest(BaseModel): message: str max_new_tokens: int = 80 temperature: float = 0.7 class SummaryRequest(BaseModel): text: str max_length: int = 100 min_length: int = 25 class WordPredictionRequest(BaseModel): word: str num_predictions: int = 5 # ===================================================== # Chat Endpoint # ===================================================== @app.post("/api/chat") def chat_generate(req: ChatRequest): try: # Load models on first request load_chat_model() # Build prompt and run generation while requesting per-step scores prompt = ( "<|im_start|>system\nYou are a helpful AI assistant.<|im_end|>\n" f"<|im_start|>user\n{req.message}<|im_end|>\n" "<|im_start|>assistant\n" ) inputs = chat_tokenizer(prompt, return_tensors="pt").to(chat_model.device) # Generate deterministically (greedy) while returning scores for each generated step outputs = chat_model.generate( **inputs, max_new_tokens=req.max_new_tokens, temperature=req.temperature, do_sample=False, output_scores=True, return_dict_in_generate=True, eos_token_id=chat_tokenizer.eos_token_id, pad_token_id=chat_tokenizer.eos_token_id, ) # Full sequence and newly generated token ids sequence = outputs.sequences[0] start_idx = inputs["input_ids"].size(1) generated_ids = sequence[start_idx:].tolist() # Decode the full reply reply = chat_tokenizer.decode(generated_ids, skip_special_tokens=True).strip() # Prepare per-token alternatives using the per-step logits/scores tokens_info = [] # outputs.scores is a tuple with one entry per generated step if hasattr(outputs, "scores") and outputs.scores is not None: for i, logits in enumerate(outputs.scores): # logits shape: (batch_size, vocab_size) probs = torch.softmax(logits[0], dim=-1) chosen_id = generated_ids[i] # Get top-k (we ask for 6 and drop the chosen token if present) topk = torch.topk(probs, k=6) alts = [] for idx, val in zip(topk.indices.tolist(), topk.values.tolist()): if idx == chosen_id: continue alts.append({ "id": idx, "token": chat_tokenizer.decode([idx], skip_special_tokens=True).strip(), "probability": float(val) }) if len(alts) >= 5: break # Fallback: if not enough alts, sample additional highest-prob tokens if len(alts) < 5: # get full topk of vocab (expensive but rare for short max_new_tokens) fallback_topk = torch.topk(probs, k=10) for idx, val in zip(fallback_topk.indices.tolist(), fallback_topk.values.tolist()): if idx == chosen_id: continue if any(a["id"] == idx for a in alts): continue alts.append({ "id": idx, "token": chat_tokenizer.decode([idx], skip_special_tokens=True).strip(), "probability": float(val) }) if len(alts) >= 5: break tokens_info.append({ "id": chosen_id, "token": chat_tokenizer.decode([chosen_id], skip_special_tokens=True).strip(), "alternatives": alts }) return {"success": True, "response": reply, "tokens": tokens_info} except Exception as e: return {"success": False, "error": str(e)} # ===================================================== # Word Prediction Endpoint # ===================================================== @app.post("/predict_words") def predict_words(req: WordPredictionRequest): try: # Load models on first request load_chat_model() input_ids = chat_tokenizer.encode(req.word, return_tensors="pt") with torch.no_grad(): outputs = chat_model(input_ids) predictions = outputs.logits[0, -1, :] top_k = torch.topk(predictions, k=req.num_predictions) words = [] for i in range(req.num_predictions): token = top_k.indices[i].item() prob = float(torch.softmax(top_k.values, dim=0)[i].item()) predicted_word = chat_tokenizer.decode([token]) words.append({"word": predicted_word, "probability": prob}) return words except Exception as e: return {"success": False, "error": str(e)} # ===================================================== # Summarization Endpoint # ===================================================== @app.post("/api/summarize") def summarize_text(req: SummaryRequest): try: # Load models on first request load_summary_model() # Get word count word_count = len(req.text.split()) # Adjust max_length to be ~30-50% of input length adjusted_max = min(req.max_length, max(20, word_count // 2)) # Adjust min_length to be ~10-20% of input length adjusted_min = min(req.min_length, max(10, word_count // 5)) result = summary_pipe( req.text, max_length=adjusted_max, min_length=min(adjusted_min, adjusted_max // 2), truncation=True, ) key = "summary_text" if "summary_text" in result[0] else "generated_text" return {"success": True, "summary": result[0][key].strip()} except Exception as e: return {"success": False, "error": str(e)} # ===================================================== # Image Processing Endpoint # ===================================================== @app.post("/process_image") async def process_image(image: UploadFile = File(...)): try: # Load models on first request load_vision_model() contents = await image.read() img = Image.open(io.BytesIO(contents)).convert('RGB') # Process image with vision model inputs = vision_processor(images=img, return_tensors="pt") inputs = {k: v.to(vision_model.device) for k, v in inputs.items()} # Generate description with torch.no_grad(): outputs = vision_model.generate( **inputs, max_length=50, num_beams=5, temperature=0.8, do_sample=True ) description = vision_processor.batch_decode(outputs, skip_special_tokens=True)[0] return { "success": True, "description": description } except Exception as e: print(f"Error processing image: {str(e)}") return {"success": False, "error": str(e)} # ===================================================== # Health + Static # ===================================================== @app.get("/api/health") def health_check(): return { "status": "healthy", "models": [ "Qwen-1.5-0.5B-Chat", "DistilBART-6-6", "microsoft/git-base-coco" ] } if os.path.exists("static"): app.mount("/static", StaticFiles(directory="static"), name="static") @app.get("/") def read_root(): if os.path.exists("static/index.html"): return FileResponse("static/index.html") return {"message": "AI Chat & Summarization API running!"} # ===================================================== # Run API # ===================================================== if __name__ == "__main__": port = int(os.getenv("PORT", 8000)) uvicorn.run(app, host="0.0.0.0", port=port)