Diego Adame commited on
Commit
84fc33b
·
1 Parent(s): 2062bc6

Final_Working

Browse files
Files changed (1) hide show
  1. server.py +23 -28
server.py CHANGED
@@ -11,7 +11,6 @@ import torch, uvicorn, os, subprocess, threading, shutil, time
11
  # =====================================================
12
  app = FastAPI(title="AI Chat + Summarization API")
13
 
14
- # Allow frontend requests
15
  app.add_middleware(
16
  CORSMiddleware,
17
  allow_origins=["*"],
@@ -21,35 +20,34 @@ app.add_middleware(
21
  )
22
 
23
  # =====================================================
24
- # Automatic Disk Cleanup (safety for Codespaces)
25
  # =====================================================
26
  def check_disk_space(min_gb=2):
27
  stat = shutil.disk_usage("/")
28
  free_gb = stat.free / (1024 ** 3)
29
  if free_gb < min_gb:
30
- print(f"⚠️ Low disk space ({free_gb:.2f} GB). Clearing Hugging Face cache...")
31
  os.system("rm -rf ~/.cache/huggingface/*")
32
 
33
  def background_health_monitor():
34
  while True:
35
  check_disk_space()
36
- time.sleep(600) # every 10 minutes
37
 
38
  threading.Thread(target=background_health_monitor, daemon=True).start()
39
 
40
  # =====================================================
41
- # Load Chat Model (Lightweight Qwen)
42
  # =====================================================
43
- print("Loading lightweight chat model (Qwen 1.5 0.5B Chat)…")
44
- chat_model_name = "Qwen/Qwen1.5-0.1B-Chat"
45
  chat_tokenizer = AutoTokenizer.from_pretrained(chat_model_name)
46
  chat_model = AutoModelForCausalLM.from_pretrained(
47
  chat_model_name,
48
  torch_dtype=torch.bfloat16,
49
- low_cpu_mem_usage=True,
50
  ).eval()
51
 
52
-
53
  # =====================================================
54
  # Load Summarization Model
55
  # =====================================================
@@ -61,7 +59,7 @@ summary_pipe = pipeline(
61
  )
62
 
63
  # =====================================================
64
- # Request Models
65
  # =====================================================
66
  class ChatRequest(BaseModel):
67
  message: str
@@ -74,19 +72,16 @@ class SummaryRequest(BaseModel):
74
  min_length: int = 25
75
 
76
  # =====================================================
77
- # Chat Endpoint (Fixed for Qwen 1.5 Chat)
78
  # =====================================================
79
  @app.post("/api/chat")
80
  def chat_generate(req: ChatRequest):
81
  try:
82
- # Proper message template for Qwen 1.5 Chat
83
  prompt = (
84
  "<|im_start|>system\nYou are a helpful AI assistant.<|im_end|>\n"
85
  f"<|im_start|>user\n{req.message}<|im_end|>\n"
86
  "<|im_start|>assistant\n"
87
  )
88
-
89
- # Tokenize and run inference
90
  inputs = chat_tokenizer(prompt, return_tensors="pt").to(chat_model.device)
91
  outputs = chat_model.generate(
92
  **inputs,
@@ -97,17 +92,11 @@ def chat_generate(req: ChatRequest):
97
  eos_token_id=chat_tokenizer.eos_token_id,
98
  pad_token_id=chat_tokenizer.eos_token_id,
99
  )
100
-
101
- # Decode only newly generated tokens
102
  new_tokens = outputs[0][inputs["input_ids"].size(1):]
103
  reply = chat_tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
104
-
105
- # Fallback in case of empty output
106
  if not reply:
107
  reply = chat_tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
108
-
109
  return {"success": True, "response": reply}
110
-
111
  except Exception as e:
112
  return {"success": False, "error": str(e)}
113
 
@@ -129,11 +118,11 @@ def summarize_text(req: SummaryRequest):
129
  return {"success": False, "error": str(e)}
130
 
131
  # =====================================================
132
- # Health + Static Routes
133
  # =====================================================
134
  @app.get("/api/health")
135
  def health_check():
136
- return {"status": "healthy", "models": ["chat: Qwen-0.5B-Chat", "summarization: DistilBART-6-6"]}
137
 
138
  if os.path.exists("static"):
139
  app.mount("/static", StaticFiles(directory="static"), name="static")
@@ -145,11 +134,17 @@ def read_root():
145
  return {"message": "AI Chat & Summarization API running!"}
146
 
147
  # =====================================================
148
- # Run FastAPI Server
149
  # =====================================================
150
  if __name__ == "__main__":
151
- # Get port from environment variable (Render provides this) or default to 8000
152
- port = int(os.environ.get("PORT", 8000))
153
-
154
- print(f"🚀 Starting FastAPI server on http://0.0.0.0:{port}")
155
- uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")
 
 
 
 
 
 
 
11
  # =====================================================
12
  app = FastAPI(title="AI Chat + Summarization API")
13
 
 
14
  app.add_middleware(
15
  CORSMiddleware,
16
  allow_origins=["*"],
 
20
  )
21
 
22
  # =====================================================
23
+ # Auto Disk Cleanup (for Codespaces)
24
  # =====================================================
25
  def check_disk_space(min_gb=2):
26
  stat = shutil.disk_usage("/")
27
  free_gb = stat.free / (1024 ** 3)
28
  if free_gb < min_gb:
29
+ print(f"⚠️ Low disk space ({free_gb:.2f} GB). Clearing HuggingFace cache...")
30
  os.system("rm -rf ~/.cache/huggingface/*")
31
 
32
  def background_health_monitor():
33
  while True:
34
  check_disk_space()
35
+ time.sleep(600)
36
 
37
  threading.Thread(target=background_health_monitor, daemon=True).start()
38
 
39
  # =====================================================
40
+ # Load Chat Model (Qwen 1.5-0.5B-Chat)
41
  # =====================================================
42
+ print("Loading Qwen 1.5-0.5B-Chat...")
43
+ chat_model_name = "Qwen/Qwen1.5-0.5B-Chat"
44
  chat_tokenizer = AutoTokenizer.from_pretrained(chat_model_name)
45
  chat_model = AutoModelForCausalLM.from_pretrained(
46
  chat_model_name,
47
  torch_dtype=torch.bfloat16,
48
+ device_map="auto",
49
  ).eval()
50
 
 
51
  # =====================================================
52
  # Load Summarization Model
53
  # =====================================================
 
59
  )
60
 
61
  # =====================================================
62
+ # API Schemas
63
  # =====================================================
64
  class ChatRequest(BaseModel):
65
  message: str
 
72
  min_length: int = 25
73
 
74
  # =====================================================
75
+ # Chat Endpoint
76
  # =====================================================
77
  @app.post("/api/chat")
78
  def chat_generate(req: ChatRequest):
79
  try:
 
80
  prompt = (
81
  "<|im_start|>system\nYou are a helpful AI assistant.<|im_end|>\n"
82
  f"<|im_start|>user\n{req.message}<|im_end|>\n"
83
  "<|im_start|>assistant\n"
84
  )
 
 
85
  inputs = chat_tokenizer(prompt, return_tensors="pt").to(chat_model.device)
86
  outputs = chat_model.generate(
87
  **inputs,
 
92
  eos_token_id=chat_tokenizer.eos_token_id,
93
  pad_token_id=chat_tokenizer.eos_token_id,
94
  )
 
 
95
  new_tokens = outputs[0][inputs["input_ids"].size(1):]
96
  reply = chat_tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
 
 
97
  if not reply:
98
  reply = chat_tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
 
99
  return {"success": True, "response": reply}
 
100
  except Exception as e:
101
  return {"success": False, "error": str(e)}
102
 
 
118
  return {"success": False, "error": str(e)}
119
 
120
  # =====================================================
121
+ # Health + Static
122
  # =====================================================
123
  @app.get("/api/health")
124
  def health_check():
125
+ return {"status": "healthy", "models": ["Qwen-1.5-0.5B-Chat", "DistilBART-6-6"]}
126
 
127
  if os.path.exists("static"):
128
  app.mount("/static", StaticFiles(directory="static"), name="static")
 
134
  return {"message": "AI Chat & Summarization API running!"}
135
 
136
  # =====================================================
137
+ # Run API + Cloudflare Tunnel
138
  # =====================================================
139
  if __name__ == "__main__":
140
+
141
+ def run_api():
142
+ print("🚀 Starting FastAPI server on http://0.0.0.0:8000")
143
+ uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")
144
+
145
+ threading.Thread(target=run_api, daemon=True).start()
146
+
147
+ # Start Cloudflare tunnel
148
+ time.sleep(3)
149
+ print("🌐 Starting Cloudflare Tunnel…")
150
+ subprocess.run(["cloudflared", "tunnel", "--url", "http://localhost:8000"])