local-inference / server.py
ButterM40's picture
Optimize build: lazy model loading + CPU torch wheel
b9ed0c9
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
pipeline,
AutoModelForVision2Seq,
AutoProcessor
)
import torch, uvicorn, os, subprocess, threading, shutil, time
from typing import List
import numpy as np
from PIL import Image
import io
# =====================================================
# FastAPI App Setup
# =====================================================
app = FastAPI(title="AI Chat + Summarization + Vision API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# =====================================================
# Auto Disk Cleanup (for Codespaces)
# =====================================================
def check_disk_space(min_gb=2):
stat = shutil.disk_usage("/")
free_gb = stat.free / (1024 ** 3)
if free_gb < min_gb:
print(f"⚠️ Low disk space ({free_gb:.2f} GB). Clearing HuggingFace cache...")
os.system("rm -rf ~/.cache/huggingface/*")
def background_health_monitor():
while True:
check_disk_space()
time.sleep(600)
threading.Thread(target=background_health_monitor, daemon=True).start()
# =====================================================
# Model Loading (Lazy Initialization)
# =====================================================
chat_model_name = "Qwen/Qwen1.5-0.5B-Chat"
chat_tokenizer = None
chat_model = None
summary_pipe = None
vision_model = None
vision_processor = None
def load_chat_model():
global chat_tokenizer, chat_model
if chat_tokenizer is None or chat_model is None:
print("Loading chat model...")
chat_tokenizer = AutoTokenizer.from_pretrained(chat_model_name)
chat_model = AutoModelForCausalLM.from_pretrained(
chat_model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
low_cpu_mem_usage=True,
offload_folder="offload",
).eval()
def load_summary_model():
global summary_pipe
if summary_pipe is None:
print("Loading summarization model...")
summary_pipe = pipeline(
"summarization",
model="sshleifer/distilbart-cnn-6-6",
device=0 if torch.cuda.is_available() else -1
)
def load_vision_model():
global vision_model, vision_processor
if vision_model is None or vision_processor is None:
print("Loading vision model...")
vision_model_name = "microsoft/git-base-coco"
vision_model = AutoModelForVision2Seq.from_pretrained(vision_model_name).to("cuda" if torch.cuda.is_available() else "cpu")
vision_processor = AutoProcessor.from_pretrained(vision_model_name)
# =====================================================
# API Schemas
# =====================================================
class ChatRequest(BaseModel):
message: str
max_new_tokens: int = 80
temperature: float = 0.7
class SummaryRequest(BaseModel):
text: str
max_length: int = 100
min_length: int = 25
class WordPredictionRequest(BaseModel):
word: str
num_predictions: int = 5
# =====================================================
# Chat Endpoint
# =====================================================
@app.post("/api/chat")
def chat_generate(req: ChatRequest):
try:
# Load models on first request
load_chat_model()
# Build prompt and run generation while requesting per-step scores
prompt = (
"<|im_start|>system\nYou are a helpful AI assistant.<|im_end|>\n"
f"<|im_start|>user\n{req.message}<|im_end|>\n"
"<|im_start|>assistant\n"
)
inputs = chat_tokenizer(prompt, return_tensors="pt").to(chat_model.device)
# Generate deterministically (greedy) while returning scores for each generated step
outputs = chat_model.generate(
**inputs,
max_new_tokens=req.max_new_tokens,
temperature=req.temperature,
do_sample=False,
output_scores=True,
return_dict_in_generate=True,
eos_token_id=chat_tokenizer.eos_token_id,
pad_token_id=chat_tokenizer.eos_token_id,
)
# Full sequence and newly generated token ids
sequence = outputs.sequences[0]
start_idx = inputs["input_ids"].size(1)
generated_ids = sequence[start_idx:].tolist()
# Decode the full reply
reply = chat_tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
# Prepare per-token alternatives using the per-step logits/scores
tokens_info = []
# outputs.scores is a tuple with one entry per generated step
if hasattr(outputs, "scores") and outputs.scores is not None:
for i, logits in enumerate(outputs.scores):
# logits shape: (batch_size, vocab_size)
probs = torch.softmax(logits[0], dim=-1)
chosen_id = generated_ids[i]
# Get top-k (we ask for 6 and drop the chosen token if present)
topk = torch.topk(probs, k=6)
alts = []
for idx, val in zip(topk.indices.tolist(), topk.values.tolist()):
if idx == chosen_id:
continue
alts.append({
"id": idx,
"token": chat_tokenizer.decode([idx], skip_special_tokens=True).strip(),
"probability": float(val)
})
if len(alts) >= 5:
break
# Fallback: if not enough alts, sample additional highest-prob tokens
if len(alts) < 5:
# get full topk of vocab (expensive but rare for short max_new_tokens)
fallback_topk = torch.topk(probs, k=10)
for idx, val in zip(fallback_topk.indices.tolist(), fallback_topk.values.tolist()):
if idx == chosen_id:
continue
if any(a["id"] == idx for a in alts):
continue
alts.append({
"id": idx,
"token": chat_tokenizer.decode([idx], skip_special_tokens=True).strip(),
"probability": float(val)
})
if len(alts) >= 5:
break
tokens_info.append({
"id": chosen_id,
"token": chat_tokenizer.decode([chosen_id], skip_special_tokens=True).strip(),
"alternatives": alts
})
return {"success": True, "response": reply, "tokens": tokens_info}
except Exception as e:
return {"success": False, "error": str(e)}
# =====================================================
# Word Prediction Endpoint
# =====================================================
@app.post("/predict_words")
def predict_words(req: WordPredictionRequest):
try:
# Load models on first request
load_chat_model()
input_ids = chat_tokenizer.encode(req.word, return_tensors="pt")
with torch.no_grad():
outputs = chat_model(input_ids)
predictions = outputs.logits[0, -1, :]
top_k = torch.topk(predictions, k=req.num_predictions)
words = []
for i in range(req.num_predictions):
token = top_k.indices[i].item()
prob = float(torch.softmax(top_k.values, dim=0)[i].item())
predicted_word = chat_tokenizer.decode([token])
words.append({"word": predicted_word, "probability": prob})
return words
except Exception as e:
return {"success": False, "error": str(e)}
# =====================================================
# Summarization Endpoint
# =====================================================
@app.post("/api/summarize")
def summarize_text(req: SummaryRequest):
try:
# Load models on first request
load_summary_model()
# Get word count
word_count = len(req.text.split())
# Adjust max_length to be ~30-50% of input length
adjusted_max = min(req.max_length, max(20, word_count // 2))
# Adjust min_length to be ~10-20% of input length
adjusted_min = min(req.min_length, max(10, word_count // 5))
result = summary_pipe(
req.text,
max_length=adjusted_max,
min_length=min(adjusted_min, adjusted_max // 2),
truncation=True,
)
key = "summary_text" if "summary_text" in result[0] else "generated_text"
return {"success": True, "summary": result[0][key].strip()}
except Exception as e:
return {"success": False, "error": str(e)}
# =====================================================
# Image Processing Endpoint
# =====================================================
@app.post("/process_image")
async def process_image(image: UploadFile = File(...)):
try:
# Load models on first request
load_vision_model()
contents = await image.read()
img = Image.open(io.BytesIO(contents)).convert('RGB')
# Process image with vision model
inputs = vision_processor(images=img, return_tensors="pt")
inputs = {k: v.to(vision_model.device) for k, v in inputs.items()}
# Generate description
with torch.no_grad():
outputs = vision_model.generate(
**inputs,
max_length=50,
num_beams=5,
temperature=0.8,
do_sample=True
)
description = vision_processor.batch_decode(outputs, skip_special_tokens=True)[0]
return {
"success": True,
"description": description
}
except Exception as e:
print(f"Error processing image: {str(e)}")
return {"success": False, "error": str(e)}
# =====================================================
# Health + Static
# =====================================================
@app.get("/api/health")
def health_check():
return {
"status": "healthy",
"models": [
"Qwen-1.5-0.5B-Chat",
"DistilBART-6-6",
"microsoft/git-base-coco"
]
}
if os.path.exists("static"):
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.get("/")
def read_root():
if os.path.exists("static/index.html"):
return FileResponse("static/index.html")
return {"message": "AI Chat & Summarization API running!"}
# =====================================================
# Run API
# =====================================================
if __name__ == "__main__":
port = int(os.getenv("PORT", 8000))
uvicorn.run(app, host="0.0.0.0", port=port)