Spaces:
Running
Running
| """Response generation using local or remote LLMs.""" | |
| from typing import Optional | |
| from coderag.config import get_settings | |
| from coderag.generation.citations import CitationParser | |
| from coderag.generation.prompts import SYSTEM_PROMPT, build_prompt, build_no_context_response | |
| from coderag.logging import get_logger | |
| from coderag.models.response import Response | |
| from coderag.models.query import Query | |
| from coderag.retrieval.retriever import Retriever | |
| logger = get_logger(__name__) | |
| class ResponseGenerator: | |
| """Generates grounded responses using local or remote LLMs.""" | |
| def __init__( | |
| self, | |
| retriever: Optional[Retriever] = None, | |
| ) -> None: | |
| self.settings = get_settings() | |
| self.retriever = retriever or Retriever() | |
| self.citation_parser = CitationParser() | |
| self.provider = self.settings.models.llm_provider.lower() | |
| self._client = None | |
| self._local_model = None | |
| self._local_tokenizer = None | |
| logger.info("ResponseGenerator initialized", provider=self.provider) | |
| def _get_api_client(self): | |
| """Get or create API client for remote providers.""" | |
| if self._client is not None: | |
| return self._client | |
| import httpx | |
| from openai import OpenAI | |
| api_key = self.settings.models.llm_api_key | |
| if not api_key: | |
| raise ValueError(f"API key required for provider: {self.provider}") | |
| # Provider-specific configurations | |
| provider_configs = { | |
| "openai": { | |
| "base_url": "https://api.openai.com/v1", | |
| "default_model": "gpt-4o-mini", | |
| }, | |
| "groq": { | |
| "base_url": "https://api.groq.com/openai/v1", | |
| "default_model": "llama-3.3-70b-versatile", | |
| }, | |
| "anthropic": { | |
| "base_url": "https://api.anthropic.com/v1", | |
| "default_model": "claude-3-5-sonnet-20241022", | |
| }, | |
| "openrouter": { | |
| "base_url": "https://openrouter.ai/api/v1", | |
| "default_model": "anthropic/claude-3.5-sonnet", | |
| }, | |
| "together": { | |
| "base_url": "https://api.together.xyz/v1", | |
| "default_model": "meta-llama/Llama-3.3-70B-Instruct-Turbo", | |
| }, | |
| } | |
| config = provider_configs.get(self.provider, {}) | |
| base_url = self.settings.models.llm_api_base or config.get("base_url") | |
| if not base_url: | |
| raise ValueError(f"Unknown provider: {self.provider}") | |
| # Set default model if not specified and it's a known provider | |
| if self.settings.models.llm_name.startswith("Qwen/"): | |
| self.model_name = config.get("default_model", self.settings.models.llm_name) | |
| else: | |
| self.model_name = self.settings.models.llm_name | |
| self._client = OpenAI( | |
| api_key=api_key, | |
| base_url=base_url, | |
| http_client=httpx.Client(timeout=120.0), | |
| ) | |
| logger.info("API client created", provider=self.provider, model=self.model_name) | |
| return self._client | |
| def _load_local_model(self): | |
| """Load local model with transformers.""" | |
| if self._local_model is not None: | |
| return | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError( | |
| "Local LLM requires a CUDA-capable GPU. Options:\n" | |
| " 1. Use a cloud provider (free): MODEL_LLM_PROVIDER=groq\n" | |
| " Get API key at: https://console.groq.com/keys\n" | |
| " 2. Install CUDA and a compatible GPU" | |
| ) | |
| logger.info("Loading local LLM", model=self.settings.models.llm_name) | |
| if self.settings.models.llm_use_4bit: | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| else: | |
| bnb_config = None | |
| self._local_tokenizer = AutoTokenizer.from_pretrained( | |
| self.settings.models.llm_name, | |
| trust_remote_code=True, | |
| ) | |
| self._local_model = AutoModelForCausalLM.from_pretrained( | |
| self.settings.models.llm_name, | |
| quantization_config=bnb_config, | |
| device_map=self.settings.models.llm_device_map, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16, | |
| ) | |
| logger.info("Local LLM loaded successfully") | |
| def generate(self, query: Query) -> Response: | |
| """Generate a response for a query.""" | |
| # Retrieve relevant chunks | |
| chunks, context = self.retriever.retrieve_with_context( | |
| query.question, | |
| query.repo_id, | |
| query.top_k, | |
| ) | |
| # Handle no results | |
| if not chunks: | |
| return Response( | |
| answer=build_no_context_response(), | |
| citations=[], | |
| retrieved_chunks=[], | |
| grounded=False, | |
| query_id=query.id, | |
| ) | |
| # Build prompt and generate | |
| prompt = build_prompt(query.question, context) | |
| if self.provider == "local": | |
| answer = self._generate_local(prompt) | |
| else: | |
| answer = self._generate_api(prompt) | |
| # Parse citations from answer | |
| citations = self.citation_parser.parse_citations(answer) | |
| # Determine if response is grounded | |
| grounded = len(citations) > 0 and len(chunks) > 0 | |
| return Response( | |
| answer=answer, | |
| citations=citations, | |
| retrieved_chunks=chunks, | |
| grounded=grounded, | |
| query_id=query.id, | |
| ) | |
| def _generate_api(self, prompt: str) -> str: | |
| """Generate using remote API.""" | |
| client = self._get_api_client() | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": prompt}, | |
| ] | |
| response = client.chat.completions.create( | |
| model=self.model_name, | |
| messages=messages, | |
| max_tokens=self.settings.models.llm_max_new_tokens, | |
| temperature=self.settings.models.llm_temperature, | |
| top_p=self.settings.models.llm_top_p, | |
| ) | |
| return response.choices[0].message.content.strip() | |
| def _generate_local(self, prompt: str) -> str: | |
| """Generate using local model.""" | |
| import torch | |
| self._load_local_model() | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": prompt}, | |
| ] | |
| text = self._local_tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| inputs = self._local_tokenizer(text, return_tensors="pt").to(self._local_model.device) | |
| with torch.no_grad(): | |
| outputs = self._local_model.generate( | |
| **inputs, | |
| max_new_tokens=self.settings.models.llm_max_new_tokens, | |
| temperature=self.settings.models.llm_temperature, | |
| top_p=self.settings.models.llm_top_p, | |
| do_sample=True, | |
| pad_token_id=self._local_tokenizer.eos_token_id, | |
| ) | |
| generated = outputs[0][inputs["input_ids"].shape[1]:] | |
| response = self._local_tokenizer.decode(generated, skip_special_tokens=True) | |
| return response.strip() | |
| def unload(self) -> None: | |
| """Unload models from memory.""" | |
| if self._local_model is not None: | |
| del self._local_model | |
| self._local_model = None | |
| if self._local_tokenizer is not None: | |
| del self._local_tokenizer | |
| self._local_tokenizer = None | |
| import torch | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.info("Models unloaded") | |