| import sys |
| import base64 |
| import uuid |
| import tempfile |
| from pathlib import Path |
| from fastapi import FastAPI, Request, Response |
| from fastapi.staticfiles import StaticFiles |
| from fastapi.responses import FileResponse |
| import uvicorn |
| import zipfile |
| import requests |
|
|
|
|
| sys.path.insert(0, str(Path(__file__).parent)) |
| from onnx_inference import OnnxAsrPipeline |
|
|
| app = FastAPI() |
|
|
| |
| STATIC_DIR = Path(__file__).parent / "static" |
| app.mount("/static", StaticFiles(directory="static"), name="static") |
|
|
| |
| UPLOAD_DIR = Path(tempfile.gettempdir()) / "qwen3_asr_uploads" |
| UPLOAD_DIR.mkdir(exist_ok=True) |
| |
| |
| |
| MODEL_DIR = Path("./official_models/onnx_models") |
| ZIP_URL = "http://f.zm66.top:9041/models/onnx_models.zip" |
| ZIP_PATH = Path("./onnx_models.zip") |
|
|
| |
| if not (MODEL_DIR / "tokenizer.json").exists(): |
| print("🔍 模型不存在,开始从服务器下载...") |
|
|
| |
| MODEL_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| |
| with requests.get(ZIP_URL, stream=True, timeout=300) as r: |
| r.raise_for_status() |
| with open(ZIP_PATH, "wb") as f: |
| for chunk in r.iter_content(chunk_size=8192): |
| f.write(chunk) |
| print("✅ 下载完成,开始解压...") |
|
|
| |
| with zipfile.ZipFile(ZIP_PATH, "r") as zip_ref: |
| zip_ref.extractall(MODEL_DIR) |
|
|
| |
| ZIP_PATH.unlink() |
| print("✅ 模型解压完成!") |
|
|
| else: |
| print("✅ 模型已存在,跳过下载") |
|
|
| |
| print("Loading Qwen3-ASR model...") |
| pipeline = OnnxAsrPipeline( |
| onnx_dir="official_models/onnx_models", |
| num_threads=4, |
| quantize="int8", |
| providers=["CPUExecutionProvider"] |
| ) |
| print("✅ 模型加载成功!") |
|
|
| |
| @app.get("/") |
| async def root(): |
| return FileResponse(STATIC_DIR / "index.html") |
|
|
| |
| @app.get("/status") |
| async def status(): |
| is_gpu = "CUDAExecutionProvider" in pipeline.providers |
| return { |
| "provider": "cuda" if is_gpu else "cpu", |
| "label": "🚀 GPU" if is_gpu else "💻 CPU" |
| } |
|
|
| |
| @app.post("/transcribe") |
| async def transcribe(request: Request): |
| try: |
| data = await request.json() |
| audio_b64 = data.get("audio_data") |
| lang = data.get("language", "Chinese") |
| hotwords = data.get("hotwords", "") |
|
|
| audio_bytes = base64.b64decode(audio_b64) |
| temp_path = UPLOAD_DIR / f"{uuid.uuid4().hex}.wav" |
| temp_path.write_bytes(audio_bytes) |
|
|
| result = pipeline.transcribe( |
| str(temp_path), |
| language=lang if lang != "auto" else None, |
| context=hotwords |
| ) |
| temp_path.unlink() |
|
|
| return { |
| "success": True, |
| "text": result.get("text", ""), |
| "language": result.get("language", lang), |
| "timing": result.get("timing", {}) |
| } |
| except Exception as e: |
| return {"success": False, "error": str(e)} |
|
|
| |
| |
| |
| if __name__ == "__main__": |
| PORT = 7860 |
| print(f"\n✅ 服务启动:http://127.0.0.1:{PORT}") |
| print(f"✅ 支持:录音 + 上传 + 热词") |
| print(f"✅ 服务运行中...\n") |
| |
| uvicorn.run( |
| "main:app", |
| host="0.0.0.0", |
| port=PORT, |
| reload=False |
| ) |