CodeRAG / coderag /generation /generator.py
Sebastiangmz's picture
Update to v0.1.2
42f5b98
"""Response generation using local or remote LLMs."""
from typing import Optional
from coderag.config import get_settings
from coderag.generation.citations import CitationParser
from coderag.generation.prompts import SYSTEM_PROMPT, build_prompt, build_no_context_response
from coderag.logging import get_logger
from coderag.models.response import Response
from coderag.models.query import Query
from coderag.retrieval.retriever import Retriever
logger = get_logger(__name__)
class ResponseGenerator:
"""Generates grounded responses using local or remote LLMs."""
def __init__(
self,
retriever: Optional[Retriever] = None,
) -> None:
self.settings = get_settings()
self.retriever = retriever or Retriever()
self.citation_parser = CitationParser()
self.provider = self.settings.models.llm_provider.lower()
self._client = None
self._local_model = None
self._local_tokenizer = None
logger.info("ResponseGenerator initialized", provider=self.provider)
def _get_api_client(self):
"""Get or create API client for remote providers."""
if self._client is not None:
return self._client
import httpx
from openai import OpenAI
api_key = self.settings.models.llm_api_key
if not api_key:
raise ValueError(f"API key required for provider: {self.provider}")
# Provider-specific configurations
provider_configs = {
"openai": {
"base_url": "https://api.openai.com/v1",
"default_model": "gpt-4o-mini",
},
"groq": {
"base_url": "https://api.groq.com/openai/v1",
"default_model": "llama-3.3-70b-versatile",
},
"anthropic": {
"base_url": "https://api.anthropic.com/v1",
"default_model": "claude-3-5-sonnet-20241022",
},
"openrouter": {
"base_url": "https://openrouter.ai/api/v1",
"default_model": "anthropic/claude-3.5-sonnet",
},
"together": {
"base_url": "https://api.together.xyz/v1",
"default_model": "meta-llama/Llama-3.3-70B-Instruct-Turbo",
},
}
config = provider_configs.get(self.provider, {})
base_url = self.settings.models.llm_api_base or config.get("base_url")
if not base_url:
raise ValueError(f"Unknown provider: {self.provider}")
# Set default model if not specified and it's a known provider
if self.settings.models.llm_name.startswith("Qwen/"):
self.model_name = config.get("default_model", self.settings.models.llm_name)
else:
self.model_name = self.settings.models.llm_name
self._client = OpenAI(
api_key=api_key,
base_url=base_url,
http_client=httpx.Client(timeout=120.0),
)
logger.info("API client created", provider=self.provider, model=self.model_name)
return self._client
def _load_local_model(self):
"""Load local model with transformers."""
if self._local_model is not None:
return
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
if not torch.cuda.is_available():
raise RuntimeError(
"Local LLM requires a CUDA-capable GPU. Options:\n"
" 1. Use a cloud provider (free): MODEL_LLM_PROVIDER=groq\n"
" Get API key at: https://console.groq.com/keys\n"
" 2. Install CUDA and a compatible GPU"
)
logger.info("Loading local LLM", model=self.settings.models.llm_name)
if self.settings.models.llm_use_4bit:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
else:
bnb_config = None
self._local_tokenizer = AutoTokenizer.from_pretrained(
self.settings.models.llm_name,
trust_remote_code=True,
)
self._local_model = AutoModelForCausalLM.from_pretrained(
self.settings.models.llm_name,
quantization_config=bnb_config,
device_map=self.settings.models.llm_device_map,
trust_remote_code=True,
torch_dtype=torch.float16,
)
logger.info("Local LLM loaded successfully")
def generate(self, query: Query) -> Response:
"""Generate a response for a query."""
# Retrieve relevant chunks
chunks, context = self.retriever.retrieve_with_context(
query.question,
query.repo_id,
query.top_k,
)
# Handle no results
if not chunks:
return Response(
answer=build_no_context_response(),
citations=[],
retrieved_chunks=[],
grounded=False,
query_id=query.id,
)
# Build prompt and generate
prompt = build_prompt(query.question, context)
if self.provider == "local":
answer = self._generate_local(prompt)
else:
answer = self._generate_api(prompt)
# Parse citations from answer
citations = self.citation_parser.parse_citations(answer)
# Determine if response is grounded
grounded = len(citations) > 0 and len(chunks) > 0
return Response(
answer=answer,
citations=citations,
retrieved_chunks=chunks,
grounded=grounded,
query_id=query.id,
)
def _generate_api(self, prompt: str) -> str:
"""Generate using remote API."""
client = self._get_api_client()
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
]
response = client.chat.completions.create(
model=self.model_name,
messages=messages,
max_tokens=self.settings.models.llm_max_new_tokens,
temperature=self.settings.models.llm_temperature,
top_p=self.settings.models.llm_top_p,
)
return response.choices[0].message.content.strip()
def _generate_local(self, prompt: str) -> str:
"""Generate using local model."""
import torch
self._load_local_model()
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
]
text = self._local_tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = self._local_tokenizer(text, return_tensors="pt").to(self._local_model.device)
with torch.no_grad():
outputs = self._local_model.generate(
**inputs,
max_new_tokens=self.settings.models.llm_max_new_tokens,
temperature=self.settings.models.llm_temperature,
top_p=self.settings.models.llm_top_p,
do_sample=True,
pad_token_id=self._local_tokenizer.eos_token_id,
)
generated = outputs[0][inputs["input_ids"].shape[1]:]
response = self._local_tokenizer.decode(generated, skip_special_tokens=True)
return response.strip()
def unload(self) -> None:
"""Unload models from memory."""
if self._local_model is not None:
del self._local_model
self._local_model = None
if self._local_tokenizer is not None:
del self._local_tokenizer
self._local_tokenizer = None
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("Models unloaded")