zm623 / main.py
zming623's picture
Update main.py
eb22929 verified
import sys
import base64
import uuid
import tempfile
from pathlib import Path
from fastapi import FastAPI, Request, Response
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
import uvicorn
import zipfile
import requests
sys.path.insert(0, str(Path(__file__).parent))
from onnx_inference import OnnxAsrPipeline
app = FastAPI()
# 静态页面目录
STATIC_DIR = Path(__file__).parent / "static"
app.mount("/static", StaticFiles(directory="static"), name="static")
# 上传目录
UPLOAD_DIR = Path(tempfile.gettempdir()) / "qwen3_asr_uploads"
UPLOAD_DIR.mkdir(exist_ok=True)
# ======================
# 你只需要改这里!
# ======================
MODEL_DIR = Path("./official_models/onnx_models")
ZIP_URL = "http://f.zm66.top:9041/models/onnx_models.zip" # 👈 放你的 zip 下载地址
ZIP_PATH = Path("./onnx_models.zip")
# 检查模型是否已存在
if not (MODEL_DIR / "tokenizer.json").exists():
print("🔍 模型不存在,开始从服务器下载...")
# 1. 创建目录
MODEL_DIR.mkdir(parents=True, exist_ok=True)
# 2. 下载 zip
with requests.get(ZIP_URL, stream=True, timeout=300) as r:
r.raise_for_status()
with open(ZIP_PATH, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
print("✅ 下载完成,开始解压...")
# 3. 解压到目标目录
with zipfile.ZipFile(ZIP_PATH, "r") as zip_ref:
zip_ref.extractall(MODEL_DIR)
# 4. 删除压缩包
ZIP_PATH.unlink()
print("✅ 模型解压完成!")
else:
print("✅ 模型已存在,跳过下载")
# 加载模型
print("Loading Qwen3-ASR model...")
pipeline = OnnxAsrPipeline(
onnx_dir="official_models/onnx_models",
num_threads=4,
quantize="int8",
providers=["CPUExecutionProvider"]
)
print("✅ 模型加载成功!")
# 首页
@app.get("/")
async def root():
return FileResponse(STATIC_DIR / "index.html")
# 状态接口
@app.get("/status")
async def status():
is_gpu = "CUDAExecutionProvider" in pipeline.providers
return {
"provider": "cuda" if is_gpu else "cpu",
"label": "🚀 GPU" if is_gpu else "💻 CPU"
}
# 识别接口
@app.post("/transcribe")
async def transcribe(request: Request):
try:
data = await request.json()
audio_b64 = data.get("audio_data")
lang = data.get("language", "Chinese")
hotwords = data.get("hotwords", "")
audio_bytes = base64.b64decode(audio_b64)
temp_path = UPLOAD_DIR / f"{uuid.uuid4().hex}.wav"
temp_path.write_bytes(audio_bytes)
result = pipeline.transcribe(
str(temp_path),
language=lang if lang != "auto" else None,
context=hotwords
)
temp_path.unlink()
return {
"success": True,
"text": result.get("text", ""),
"language": result.get("language", lang),
"timing": result.get("timing", {})
}
except Exception as e:
return {"success": False, "error": str(e)}
# ==========================
# 你要的启动方式(完全对应)
# ==========================
if __name__ == "__main__":
PORT = 7860
print(f"\n✅ 服务启动:http://127.0.0.1:{PORT}")
print(f"✅ 支持:录音 + 上传 + 热词")
print(f"✅ 服务运行中...\n")
uvicorn.run(
"main:app",
host="0.0.0.0",
port=PORT,
reload=False # 打包exe必须关闭reload
)