Spaces:
Running
Running
| """ChromaDB vector store operations.""" | |
| from pathlib import Path | |
| from typing import Optional | |
| import chromadb | |
| from chromadb.config import Settings | |
| from coderag.config import get_settings | |
| from coderag.logging import get_logger | |
| from coderag.models.chunk import Chunk | |
| logger = get_logger(__name__) | |
| class VectorStore: | |
| """ChromaDB vector store for chunk storage and retrieval.""" | |
| def __init__( | |
| self, | |
| persist_directory: Optional[Path] = None, | |
| collection_name: Optional[str] = None, | |
| ) -> None: | |
| settings = get_settings() | |
| self.persist_directory = persist_directory or settings.vectorstore.persist_directory | |
| self.collection_name = collection_name or settings.vectorstore.collection_name | |
| self._client: Optional[chromadb.PersistentClient] = None | |
| self._collection: Optional[chromadb.Collection] = None | |
| def client(self) -> chromadb.PersistentClient: | |
| if self._client is None: | |
| self._init_client() | |
| return self._client | |
| def collection(self) -> chromadb.Collection: | |
| if self._collection is None: | |
| self._init_collection() | |
| return self._collection | |
| def _init_client(self) -> None: | |
| logger.info("Initializing ChromaDB", path=str(self.persist_directory)) | |
| self.persist_directory.mkdir(parents=True, exist_ok=True) | |
| self._client = chromadb.PersistentClient( | |
| path=str(self.persist_directory), | |
| settings=Settings(anonymized_telemetry=False), | |
| ) | |
| def _init_collection(self) -> None: | |
| self._collection = self.client.get_or_create_collection( | |
| name=self.collection_name, | |
| metadata={"hnsw:space": "cosine"}, | |
| ) | |
| logger.info("Collection initialized", name=self.collection_name) | |
| def add_chunks(self, chunks: list[Chunk]) -> int: | |
| if not chunks: | |
| return 0 | |
| ids = [chunk.id for chunk in chunks] | |
| embeddings = [chunk.embedding for chunk in chunks if chunk.embedding] | |
| documents = [chunk.content for chunk in chunks] | |
| metadatas = [chunk.to_dict() for chunk in chunks] | |
| # Remove embedding and filter None values (ChromaDB doesn't accept None) | |
| cleaned_metadatas = [] | |
| for m in metadatas: | |
| m.pop("embedding", None) | |
| m.pop("content", None) # Already stored in documents | |
| # Filter out None values - ChromaDB only accepts str, int, float, bool | |
| cleaned = {k: v for k, v in m.items() if v is not None} | |
| cleaned_metadatas.append(cleaned) | |
| self.collection.add( | |
| ids=ids, | |
| embeddings=embeddings, | |
| documents=documents, | |
| metadatas=cleaned_metadatas, | |
| ) | |
| logger.info("Chunks added to vector store", count=len(chunks)) | |
| return len(chunks) | |
| def query( | |
| self, | |
| query_embedding: list[float], | |
| repo_id: str, | |
| top_k: int = 5, | |
| similarity_threshold: float = 0.0, | |
| ) -> list[tuple[Chunk, float]]: | |
| results = self.collection.query( | |
| query_embeddings=[query_embedding], | |
| n_results=top_k, | |
| where={"repo_id": repo_id}, | |
| include=["documents", "metadatas", "distances"], | |
| ) | |
| chunks_with_scores = [] | |
| if results["ids"] and results["ids"][0]: | |
| for i, chunk_id in enumerate(results["ids"][0]): | |
| # ChromaDB returns distances, convert to similarity for cosine | |
| distance = results["distances"][0][i] | |
| similarity = 1 - distance | |
| if similarity >= similarity_threshold: | |
| metadata = results["metadatas"][0][i] | |
| metadata["id"] = chunk_id | |
| metadata["content"] = results["documents"][0][i] | |
| chunk = Chunk.from_dict(metadata) | |
| chunks_with_scores.append((chunk, similarity)) | |
| return chunks_with_scores | |
| def delete_repo_chunks(self, repo_id: str) -> int: | |
| # Get all chunks for this repo | |
| results = self.collection.get(where={"repo_id": repo_id}, include=[]) | |
| if results["ids"]: | |
| self.collection.delete(ids=results["ids"]) | |
| count = len(results["ids"]) | |
| logger.info("Deleted repo chunks", repo_id=repo_id, count=count) | |
| return count | |
| return 0 | |
| def delete_file_chunks(self, repo_id: str, file_path: str) -> int: | |
| """Delete chunks for a specific file in a repository (for incremental updates).""" | |
| results = self.collection.get( | |
| where={"$and": [{"repo_id": repo_id}, {"file_path": file_path}]}, | |
| include=[], | |
| ) | |
| if results["ids"]: | |
| self.collection.delete(ids=results["ids"]) | |
| count = len(results["ids"]) | |
| logger.info("Deleted file chunks", repo_id=repo_id, file_path=file_path, count=count) | |
| return count | |
| return 0 | |
| def get_indexed_files(self, repo_id: str) -> set[str]: | |
| """Get set of file paths indexed for a repository.""" | |
| results = self.collection.get( | |
| where={"repo_id": repo_id}, | |
| include=["metadatas"], | |
| ) | |
| files = set() | |
| if results["metadatas"]: | |
| for metadata in results["metadatas"]: | |
| if "file_path" in metadata: | |
| files.add(metadata["file_path"]) | |
| return files | |
| def get_repo_chunk_count(self, repo_id: str) -> int: | |
| results = self.collection.get(where={"repo_id": repo_id}, include=[]) | |
| return len(results["ids"]) if results["ids"] else 0 | |
| def get_all_repo_ids(self) -> list[str]: | |
| results = self.collection.get(include=["metadatas"]) | |
| repo_ids = set() | |
| if results["metadatas"]: | |
| for metadata in results["metadatas"]: | |
| if "repo_id" in metadata: | |
| repo_ids.add(metadata["repo_id"]) | |
| return list(repo_ids) | |
| def clear(self) -> None: | |
| self.client.delete_collection(self.collection_name) | |
| self._collection = None | |
| logger.info("Collection cleared", name=self.collection_name) | |