ButterM40 commited on
Commit
b9ed0c9
·
1 Parent(s): a828cd4

Optimize build: lazy model loading + CPU torch wheel

Browse files
Files changed (3) hide show
  1. Dockerfile +5 -2
  2. requirements.txt +1 -2
  3. server.py +47 -22
Dockerfile CHANGED
@@ -32,9 +32,12 @@ RUN apt-get update && \
32
  # Copy requirements first for better caching
33
  COPY requirements.txt .
34
 
35
- # Upgrade pip and install dependencies preferring binary wheels
36
  RUN pip install --no-cache-dir --upgrade pip setuptools wheel && \
37
- pip install --no-cache-dir --prefer-binary -r requirements.txt || \
 
 
 
38
  (echo "Initial pip install failed, retrying without --prefer-binary" && pip install --no-cache-dir -r requirements.txt)
39
 
40
  # Copy the rest of the application
 
32
  # Copy requirements first for better caching
33
  COPY requirements.txt .
34
 
35
+ # Upgrade pip and install torch CPU wheel first (faster than compiling)
36
  RUN pip install --no-cache-dir --upgrade pip setuptools wheel && \
37
+ pip install --no-cache-dir torch==2.5.1+cpu --index-url https://download.pytorch.org/whl/cpu
38
+
39
+ # Install remaining dependencies preferring binary wheels
40
+ RUN pip install --no-cache-dir --prefer-binary -r requirements.txt || \
41
  (echo "Initial pip install failed, retrying without --prefer-binary" && pip install --no-cache-dir -r requirements.txt)
42
 
43
  # Copy the rest of the application
requirements.txt CHANGED
@@ -3,9 +3,8 @@ fastapi==0.115.5
3
  uvicorn[standard]==0.32.1
4
  pydantic==2.10.2
5
 
6
- # Transformers and ML
7
  transformers==4.46.3
8
- torch==2.5.1
9
  accelerate>=0.26.0
10
 
11
  # Tokenizers
 
3
  uvicorn[standard]==0.32.1
4
  pydantic==2.10.2
5
 
6
+ # Transformers and ML (torch installed separately in Dockerfile)
7
  transformers==4.46.3
 
8
  accelerate>=0.26.0
9
 
10
  # Tokenizers
server.py CHANGED
@@ -47,32 +47,45 @@ def background_health_monitor():
47
  threading.Thread(target=background_health_monitor, daemon=True).start()
48
 
49
  # =====================================================
50
- # Load Models
51
  # =====================================================
52
- print("Loading models...")
53
-
54
- # Chat Model
55
  chat_model_name = "Qwen/Qwen1.5-0.5B-Chat"
56
- chat_tokenizer = AutoTokenizer.from_pretrained(chat_model_name)
57
- chat_model = AutoModelForCausalLM.from_pretrained(
58
- chat_model_name,
59
- torch_dtype=torch.bfloat16,
60
- device_map="auto",
61
- low_cpu_mem_usage=True,
62
- offload_folder="offload",
63
- ).eval()
64
 
65
- # Summarization Model
66
- summary_pipe = pipeline(
67
- "summarization",
68
- model="sshleifer/distilbart-cnn-6-6",
69
- device=0 if torch.cuda.is_available() else -1
70
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- # Vision Model
73
- vision_model_name = "microsoft/git-base-coco"
74
- vision_model = AutoModelForVision2Seq.from_pretrained(vision_model_name).to("cuda" if torch.cuda.is_available() else "cpu")
75
- vision_processor = AutoProcessor.from_pretrained(vision_model_name)
 
 
 
76
 
77
  # =====================================================
78
  # API Schemas
@@ -97,6 +110,9 @@ class WordPredictionRequest(BaseModel):
97
  @app.post("/api/chat")
98
  def chat_generate(req: ChatRequest):
99
  try:
 
 
 
100
  # Build prompt and run generation while requesting per-step scores
101
  prompt = (
102
  "<|im_start|>system\nYou are a helpful AI assistant.<|im_end|>\n"
@@ -181,6 +197,9 @@ def chat_generate(req: ChatRequest):
181
  @app.post("/predict_words")
182
  def predict_words(req: WordPredictionRequest):
183
  try:
 
 
 
184
  input_ids = chat_tokenizer.encode(req.word, return_tensors="pt")
185
  with torch.no_grad():
186
  outputs = chat_model(input_ids)
@@ -204,6 +223,9 @@ def predict_words(req: WordPredictionRequest):
204
  @app.post("/api/summarize")
205
  def summarize_text(req: SummaryRequest):
206
  try:
 
 
 
207
  # Get word count
208
  word_count = len(req.text.split())
209
  # Adjust max_length to be ~30-50% of input length
@@ -228,6 +250,9 @@ def summarize_text(req: SummaryRequest):
228
  @app.post("/process_image")
229
  async def process_image(image: UploadFile = File(...)):
230
  try:
 
 
 
231
  contents = await image.read()
232
  img = Image.open(io.BytesIO(contents)).convert('RGB')
233
 
 
47
  threading.Thread(target=background_health_monitor, daemon=True).start()
48
 
49
  # =====================================================
50
+ # Model Loading (Lazy Initialization)
51
  # =====================================================
 
 
 
52
  chat_model_name = "Qwen/Qwen1.5-0.5B-Chat"
53
+ chat_tokenizer = None
54
+ chat_model = None
55
+ summary_pipe = None
56
+ vision_model = None
57
+ vision_processor = None
 
 
 
58
 
59
+ def load_chat_model():
60
+ global chat_tokenizer, chat_model
61
+ if chat_tokenizer is None or chat_model is None:
62
+ print("Loading chat model...")
63
+ chat_tokenizer = AutoTokenizer.from_pretrained(chat_model_name)
64
+ chat_model = AutoModelForCausalLM.from_pretrained(
65
+ chat_model_name,
66
+ torch_dtype=torch.bfloat16,
67
+ device_map="auto",
68
+ low_cpu_mem_usage=True,
69
+ offload_folder="offload",
70
+ ).eval()
71
+
72
+ def load_summary_model():
73
+ global summary_pipe
74
+ if summary_pipe is None:
75
+ print("Loading summarization model...")
76
+ summary_pipe = pipeline(
77
+ "summarization",
78
+ model="sshleifer/distilbart-cnn-6-6",
79
+ device=0 if torch.cuda.is_available() else -1
80
+ )
81
 
82
+ def load_vision_model():
83
+ global vision_model, vision_processor
84
+ if vision_model is None or vision_processor is None:
85
+ print("Loading vision model...")
86
+ vision_model_name = "microsoft/git-base-coco"
87
+ vision_model = AutoModelForVision2Seq.from_pretrained(vision_model_name).to("cuda" if torch.cuda.is_available() else "cpu")
88
+ vision_processor = AutoProcessor.from_pretrained(vision_model_name)
89
 
90
  # =====================================================
91
  # API Schemas
 
110
  @app.post("/api/chat")
111
  def chat_generate(req: ChatRequest):
112
  try:
113
+ # Load models on first request
114
+ load_chat_model()
115
+
116
  # Build prompt and run generation while requesting per-step scores
117
  prompt = (
118
  "<|im_start|>system\nYou are a helpful AI assistant.<|im_end|>\n"
 
197
  @app.post("/predict_words")
198
  def predict_words(req: WordPredictionRequest):
199
  try:
200
+ # Load models on first request
201
+ load_chat_model()
202
+
203
  input_ids = chat_tokenizer.encode(req.word, return_tensors="pt")
204
  with torch.no_grad():
205
  outputs = chat_model(input_ids)
 
223
  @app.post("/api/summarize")
224
  def summarize_text(req: SummaryRequest):
225
  try:
226
+ # Load models on first request
227
+ load_summary_model()
228
+
229
  # Get word count
230
  word_count = len(req.text.split())
231
  # Adjust max_length to be ~30-50% of input length
 
250
  @app.post("/process_image")
251
  async def process_image(image: UploadFile = File(...)):
252
  try:
253
+ # Load models on first request
254
+ load_vision_model()
255
+
256
  contents = await image.read()
257
  img = Image.open(io.BytesIO(contents)).convert('RGB')
258