textai-v2 / core /models.py
rbt2025's picture
Deploy TextAI v2 - Clean architecture
de7d69a verified
"""
Model Service - All model operations
Handles: loading, inference, downloading, management
"""
import re
import json
import requests
import threading
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Optional, Tuple, Any
from core.config import MODELS_DIR, HF_API_URL, HF_TOKEN, RECOMMENDED_QUANTS, MAX_PARAMS_BILLION
from core.state import get_state, InstalledModel
from core.logger import logger
# Lazy imports for heavy libraries
_llama_cpp = None
_transformers = None
_torch = None
def _get_llama_cpp():
"""Lazy load llama-cpp-python"""
global _llama_cpp
if _llama_cpp is None:
try:
from llama_cpp import Llama
_llama_cpp = Llama
logger.info("Models", "llama-cpp-python loaded")
except ImportError as e:
logger.warn("Models", f"llama-cpp-python not available: {e}")
_llama_cpp = False
return _llama_cpp if _llama_cpp else None
def _get_transformers():
"""Lazy load transformers"""
global _transformers, _torch
if _transformers is None:
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
_transformers = {"model": AutoModelForCausalLM, "tokenizer": AutoTokenizer}
_torch = torch
logger.info("Models", "transformers loaded")
except ImportError as e:
logger.warn("Models", f"transformers not available: {e}")
_transformers = False
return _transformers if _transformers else None
class ModelService:
"""
Service for all model operations.
Uses StateManager for persistence.
"""
def __init__(self):
self._current_model = None
self._current_tokenizer = None
self._lock = threading.Lock()
self._state = get_state()
# ══════════════════════════════════════════════════════════════════
# MODEL LISTING & INFO
# ══════════════════════════════════════════════════════════════════
def get_installed_models(self) -> List[Dict]:
"""Get all installed models from state"""
return self._state.get_installed_models()
def get_loaded_model(self) -> Optional[Dict]:
"""Get currently loaded model info"""
model_id = self._state.get_loaded_model_id()
if model_id:
return self._state.get_model_by_id(model_id)
return None
def is_model_loaded(self) -> bool:
"""Check if any model is loaded"""
return self._current_model is not None
# ══════════════════════════════════════════════════════════════════
# MODEL LOADING
# ══════════════════════════════════════════════════════════════════
def load_model(self, model_id: str) -> Dict[str, Any]:
"""Load a model by ID"""
with self._lock:
logger.info("Models", f"Loading model: {model_id}")
# Get model info from state
model_info = self._state.get_model_by_id(model_id)
if not model_info:
return {"success": False, "error": f"Model not found: {model_id}"}
# Unload current model first
if self._current_model is not None:
self.unload_model()
# Load based on type
model_path = MODELS_DIR / model_info["filename"]
if not model_path.exists():
return {"success": False, "error": f"Model file not found: {model_path}"}
try:
if model_info["model_type"] == "gguf":
result = self._load_gguf(model_path)
else:
result = self._load_transformers(model_path)
if result["success"]:
self._state.set_loaded_model(model_id)
return result
except Exception as e:
logger.error("Models", f"Load failed: {e}")
return {"success": False, "error": str(e)}
def _load_gguf(self, model_path: Path) -> Dict:
"""Load GGUF model"""
Llama = _get_llama_cpp()
if Llama is None:
return {"success": False, "error": "llama-cpp-python not installed"}
try:
self._current_model = Llama(
model_path=str(model_path),
n_ctx=4096,
n_threads=4,
n_gpu_layers=0,
verbose=False
)
logger.info("Models", f"GGUF loaded: {model_path.name}")
return {"success": True, "type": "gguf", "name": model_path.stem}
except Exception as e:
return {"success": False, "error": str(e)}
def _load_transformers(self, model_path: Path) -> Dict:
"""Load transformers model"""
tf = _get_transformers()
if tf is None:
return {"success": False, "error": "transformers not installed"}
try:
self._current_tokenizer = tf["tokenizer"].from_pretrained(str(model_path))
self._current_model = tf["model"].from_pretrained(
str(model_path),
torch_dtype=_torch.float32,
device_map="cpu",
low_cpu_mem_usage=True
)
logger.info("Models", f"Transformers loaded: {model_path.name}")
return {"success": True, "type": "transformers", "name": model_path.name}
except Exception as e:
return {"success": False, "error": str(e)}
def unload_model(self):
"""Unload current model"""
with self._lock:
if self._current_model:
del self._current_model
self._current_model = None
if self._current_tokenizer:
del self._current_tokenizer
self._current_tokenizer = None
self._state.set_loaded_model(None)
import gc
gc.collect()
logger.info("Models", "Model unloaded")
# ══════════════════════════════════════════════════════════════════
# INFERENCE
# ══════════════════════════════════════════════════════════════════
def generate(
self,
messages: List[Dict],
max_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9
) -> str:
"""Generate response from loaded model"""
if self._current_model is None:
return "[Error: No model loaded]"
model_info = self.get_loaded_model()
if not model_info:
return "[Error: Model info not found]"
try:
if model_info["model_type"] == "gguf":
return self._generate_gguf(messages, max_tokens, temperature, top_p)
else:
return self._generate_transformers(messages, max_tokens, temperature, top_p)
except Exception as e:
logger.error("Models", f"Generation error: {e}")
return f"[Error: {e}]"
def _generate_gguf(self, messages: List[Dict], max_tokens: int, temperature: float, top_p: float) -> str:
"""Generate with GGUF model"""
response = self._current_model.create_chat_completion(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stream=False
)
return response["choices"][0]["message"]["content"]
def _generate_transformers(self, messages: List[Dict], max_tokens: int, temperature: float, top_p: float) -> str:
"""Generate with transformers model"""
# Build prompt from messages
prompt = ""
for msg in messages:
role = msg["role"].capitalize()
prompt += f"{role}: {msg['content']}\n\n"
prompt += "Assistant: "
inputs = self._current_tokenizer(prompt, return_tensors="pt")
with _torch.no_grad():
outputs = self._current_model.generate(
inputs.input_ids,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=self._current_tokenizer.eos_token_id
)
response = self._current_tokenizer.decode(
outputs[0][inputs.input_ids.shape[1]:],
skip_special_tokens=True
)
return response.strip()
# ══════════════════════════════════════════════════════════════════
# HUGGINGFACE SEARCH & DOWNLOAD
# ══════════════════════════════════════════════════════════════════
def search_hf_models(
self,
query: str = "",
max_params: float = MAX_PARAMS_BILLION,
limit: int = 20
) -> Tuple[List[Dict], str]:
"""
Search HuggingFace for GGUF models.
Returns: (results, status_message)
"""
logger.info("Models", f"HF search: {query}")
try:
params = {
"search": query,
"library": "gguf",
"pipeline_tag": "text-generation",
"sort": "downloads",
"direction": -1,
"limit": limit + 20 # Extra for filtering
}
headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
resp = requests.get(f"{HF_API_URL}/models", params=params, headers=headers, timeout=30)
resp.raise_for_status()
results = []
for m in resp.json():
model_id = m.get("id", "")
params_b = self._estimate_params(model_id)
# Filter by params
if max_params and params_b and params_b > max_params:
continue
# Check compatibility
compat = self._check_compatibility(params_b)
results.append({
"id": model_id,
"downloads": m.get("downloads", 0),
"params_b": params_b,
"est_size_gb": round(params_b * 0.55, 1) if params_b else None,
"compatibility": compat,
"is_installed": self._is_repo_installed(model_id)
})
if len(results) >= limit:
break
logger.info("Models", f"HF search found {len(results)} models")
return results, f"Found {len(results)} models"
except Exception as e:
logger.error("Models", f"HF search error: {e}")
return [], f"Search failed: {e}"
def get_hf_model_files(self, repo_id: str) -> List[Dict]:
"""Get GGUF files available for a HF model"""
logger.info("Models", f"Getting files for: {repo_id}")
try:
headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
resp = requests.get(f"{HF_API_URL}/models/{repo_id}", headers=headers, timeout=30)
resp.raise_for_status()
files = []
for s in resp.json().get("siblings", []):
filename = s.get("rfilename", "")
if filename.endswith(".gguf"):
quant = self._extract_quant(filename)
files.append({
"filename": filename,
"quant": quant,
"recommended": quant in RECOMMENDED_QUANTS,
"is_installed": self._state.is_model_installed(repo_id, filename)
})
# Sort: recommended first, then by name
files.sort(key=lambda x: (not x["recommended"], x["filename"]))
return files
except Exception as e:
logger.error("Models", f"Get files error: {e}")
return []
def download_model(self, repo_id: str, filename: str) -> Dict[str, Any]:
"""
Download a model from HuggingFace.
Returns: {success, message, model_id}
"""
logger.info("Models", f"Downloading: {repo_id}/{filename}")
# Check for duplicate
if self._state.is_model_installed(repo_id, filename):
return {
"success": False,
"error": f"Model already installed: {filename}",
"duplicate": True
}
try:
# Download
url = f"https://huggingface.co/{repo_id}/resolve/main/{filename}"
headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
dest_path = MODELS_DIR / filename
resp = requests.get(url, headers=headers, stream=True, timeout=600)
resp.raise_for_status()
total_size = int(resp.headers.get('content-length', 0))
downloaded = 0
with open(dest_path, 'wb') as f:
for chunk in resp.iter_content(chunk_size=8192):
f.write(chunk)
downloaded += len(chunk)
# Create model entry
params_b = self._estimate_params(repo_id)
model = InstalledModel(
id=Path(filename).stem,
name=self._make_display_name(repo_id, filename),
hf_repo=repo_id,
filename=filename,
model_type="gguf" if filename.endswith(".gguf") else "transformers",
size_bytes=dest_path.stat().st_size,
quant=self._extract_quant(filename),
installed_at=datetime.now().isoformat(),
params_b=params_b or 0.0
)
# Add to state
self._state.add_model(model)
size_mb = dest_path.stat().st_size / (1024 * 1024)
logger.info("Models", f"Downloaded: {filename} ({size_mb:.1f} MB)")
return {
"success": True,
"message": f"Downloaded: {filename} ({size_mb:.1f} MB)",
"model_id": model.id
}
except Exception as e:
logger.error("Models", f"Download failed: {e}")
# Clean up partial download
if dest_path.exists():
dest_path.unlink()
return {"success": False, "error": str(e)}
def delete_model(self, model_id: str) -> Dict[str, Any]:
"""Delete an installed model"""
logger.info("Models", f"Deleting: {model_id}")
model_info = self._state.get_model_by_id(model_id)
if not model_info:
return {"success": False, "error": "Model not found"}
# Unload if currently loaded
if self._state.get_loaded_model_id() == model_id:
self.unload_model()
# Delete file
try:
model_path = MODELS_DIR / model_info["filename"]
if model_path.exists():
model_path.unlink()
except Exception as e:
logger.error("Models", f"File delete error: {e}")
# Remove from state
self._state.remove_model(model_id)
return {"success": True, "message": f"Deleted: {model_info['name']}"}
# ══════════════════════════════════════════════════════════════════
# UTILITY METHODS
# ══════════════════════════════════════════════════════════════════
def _estimate_params(self, model_id: str) -> Optional[float]:
"""Extract parameter count from model name"""
name = model_id.lower()
patterns = [
r'(\d+\.?\d*)b(?:illion)?',
r'(\d+\.?\d*)-?b(?:illion)?',
]
for pattern in patterns:
match = re.search(pattern, name)
if match:
try:
return float(match.group(1))
except:
pass
return None
def _extract_quant(self, filename: str) -> str:
"""Extract quantization type from filename"""
quants = ["Q2_K", "Q3_K_S", "Q3_K_M", "Q3_K_L", "Q4_0", "Q4_K_S", "Q4_K_M",
"Q5_0", "Q5_K_S", "Q5_K_M", "Q6_K", "Q8_0", "F16", "F32"]
upper = filename.upper()
for q in quants:
if q in upper:
return q
return "unknown"
def _check_compatibility(self, params_b: Optional[float]) -> Dict:
"""Check if model is compatible with free tier"""
if params_b is None:
return {"status": "unknown", "label": "❓ Unknown", "ok": True}
if params_b <= 1.5:
return {"status": "best", "label": "βœ… Best", "ok": True}
elif params_b <= 3:
return {"status": "good", "label": "βœ… Good", "ok": True}
elif params_b <= 7:
return {"status": "ok", "label": "⚠️ OK", "ok": True}
elif params_b <= 13:
return {"status": "slow", "label": "⚠️ Slow", "ok": False}
else:
return {"status": "too_large", "label": "❌ Too Large", "ok": False}
def _make_display_name(self, repo_id: str, filename: str) -> str:
"""Create a nice display name"""
# Extract meaningful part from repo or filename
name = Path(filename).stem
# Clean up common patterns
name = re.sub(r'[-_]gguf$', '', name, flags=re.IGNORECASE)
name = re.sub(r'[-_]q\d.*$', '', name, flags=re.IGNORECASE)
return name.replace('-', ' ').replace('_', ' ').title()
def _is_repo_installed(self, repo_id: str) -> bool:
"""Check if any model from this repo is installed"""
for m in self._state.get_installed_models():
if m["hf_repo"] == repo_id:
return True
return False
# Singleton
_model_service: Optional[ModelService] = None
def get_model_service() -> ModelService:
"""Get singleton model service"""
global _model_service
if _model_service is None:
_model_service = ModelService()
return _model_service