PhishingTest / llm_client.py
dungeon29's picture
Update llm_client.py
5c74992 verified
raw
history blame
7.92 kB
import os
import requests
import subprocess
import tarfile
import stat
from huggingface_hub import hf_hub_download
from langchain.llms.base import LLM
from langchain.chains import RetrievalQA
from langchain_core.prompts import PromptTemplate
from typing import Any, List, Optional, Mapping
# --- Helper to Setup llama-cli ---
def setup_llama_cli():
"""
Download and extract llama-cli binary and libs from official releases
"""
# Latest release URL for Linux x64 (b4991 equivalent or newer)
# Using the one found: b7312
CLI_URL = "https://github.com/ggml-org/llama.cpp/releases/download/b7312/llama-b7312-bin-ubuntu-x64.tar.gz"
LOCAL_TAR = "llama-cli.tar.gz"
BIN_DIR = "./llama_bin" # Extract to a subdirectory
CLI_BIN = os.path.join(BIN_DIR, "bin/llama-cli") # Standard structure usually has bin/
if os.path.exists(CLI_BIN):
return CLI_BIN, BIN_DIR
try:
print("⬇️ Downloading llama-cli binary...")
response = requests.get(CLI_URL, stream=True)
if response.status_code == 200:
with open(LOCAL_TAR, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print("πŸ“¦ Extracting llama-cli...")
# Create dir
os.makedirs(BIN_DIR, exist_ok=True)
with tarfile.open(LOCAL_TAR, "r:gz") as tar:
tar.extractall(path=BIN_DIR)
# Locate the binary (it might be in bin/ or root of tar)
# We search for it
found_bin = None
for root, dirs, files in os.walk(BIN_DIR):
if "llama-cli" in files:
found_bin = os.path.join(root, "llama-cli")
break
if not found_bin:
print("❌ Could not find llama-cli in extracted files.")
return None, None
# Make executable
st = os.stat(found_bin)
os.chmod(found_bin, st.st_mode | stat.S_IEXEC)
print(f"βœ… llama-cli binary ready at {found_bin}!")
return found_bin, BIN_DIR
else:
print(f"❌ Failed to download binary: {response.status_code}")
return None, None
except Exception as e:
print(f"❌ Error setting up llama-cli: {e}")
return None, None
# --- Custom LangChain LLM Wrapper for Hybrid Approach ---
class HybridLLM(LLM):
api_url: str = ""
model_path: str = ""
cli_path: str = ""
lib_path: str = "" # Path to folder containing .so files
@property
def _llm_type(self) -> str:
return "hybrid_llm"
def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
# 1. Try Colab API first
if self.api_url:
try:
print(f"🌐 Calling Colab API: {self.api_url}")
response = requests.post(
f"{self.api_url}/generate",
json={"prompt": prompt, "max_tokens": 512},
timeout=30
)
if response.status_code == 200:
return response.json()["response"]
else:
print(f"⚠️ API Error {response.status_code}: {response.text}")
except Exception as e:
print(f"⚠️ API Connection Failed: {e}")
# 2. Fallback to Local llama-cli
if self.model_path and self.cli_path and os.path.exists(self.cli_path):
print("πŸ’» Using Local llama-cli Fallback...")
try:
# Construct command
cmd = [
self.cli_path,
"-m", self.model_path,
"-p", prompt,
"-n", "512",
"--temp", "0.7",
"--no-display-prompt", # Don't echo prompt
"-c", "2048" # Context size
]
# Setup Environment with LD_LIBRARY_PATH
env = os.environ.copy()
# Add the directory containing the binary (and likely libs) to LD_LIBRARY_PATH
# Also check 'lib' subdir if it exists
lib_paths = [os.path.dirname(self.cli_path)]
lib_subdir = os.path.join(self.lib_path, "lib")
if os.path.exists(lib_subdir):
lib_paths.append(lib_subdir)
env["LD_LIBRARY_PATH"] = ":".join(lib_paths) + ":" + env.get("LD_LIBRARY_PATH", "")
# Run binary
result = subprocess.run(
cmd,
capture_output=True,
text=True,
encoding='utf-8',
errors='replace',
env=env
)
if result.returncode == 0:
return result.stdout.strip()
else:
return f"❌ llama-cli Error: {result.stderr}"
except Exception as e:
return f"❌ Local Inference Failed: {e}"
return "❌ Error: No working LLM available (API failed and no local model)."
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {"api_url": self.api_url, "model_path": self.model_path}
class LLMClient:
def __init__(self, vector_store=None):
"""
Initialize Hybrid LLM Client with Binary Wrapper
"""
self.vector_store = vector_store
self.api_url = os.environ.get("COLAB_API_URL", "")
self.model_path = None
self.cli_path = None
self.lib_path = None
# Setup Local Fallback
try:
# 1. Setup Binary
self.cli_path, self.lib_path = setup_llama_cli()
# 2. Download Model (Qwen3-0.6B)
print("πŸ“‚ Loading Local Qwen3-0.6B (GGUF)...")
model_repo = "Qwen/Qwen3-0.6B-GGUF"
filename = "Qwen3-0.6B-Q8_0.gguf"
self.model_path = hf_hub_download(
repo_id=model_repo,
filename=filename
)
print(f"βœ… Model downloaded to: {self.model_path}")
except Exception as e:
print(f"⚠️ Could not setup local fallback: {e}")
# Create Hybrid LangChain Wrapper
self.llm = HybridLLM(
api_url=self.api_url,
model_path=self.model_path,
cli_path=self.cli_path,
lib_path=self.lib_path
)
def analyze(self, text, context_chunks=None):
"""
Analyze text using LangChain RetrievalQA
"""
if not self.vector_store:
return "❌ Vector Store not initialized."
# Custom Prompt Template
template = """<|im_start|>system
You are a cybersecurity expert. Task: Determine whether the input is 'PHISHING' or 'BENIGN' (Safe).
Respond in the following format:
LABEL: [PHISHING or BENIGN]
EXPLANATION: [A brief Vietnamese explanation]
Context:
{context}
<|im_end|>
<|im_start|>user
Input:
{question}
Short Analysis:
<|im_end|>
<|im_start|>assistant
"""
PROMPT = PromptTemplate(
template=template,
input_variables=["context", "question"]
)
# Create QA Chain
qa_chain = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff",
retriever=self.vector_store.as_retriever(search_kwargs={"k": 3}),
chain_type_kwargs={"prompt": PROMPT}
)
try:
print("πŸ€– Generating response...")
response = qa_chain.invoke(text)
return response['result']
except Exception as e:
return f"❌ Error: {str(e)}"