CodeRAG / coderag /indexing /vectorstore.py
Sebastiangmz's picture
Update to v0.1.2
42f5b98
"""ChromaDB vector store operations."""
from pathlib import Path
from typing import Optional
import chromadb
from chromadb.config import Settings
from coderag.config import get_settings
from coderag.logging import get_logger
from coderag.models.chunk import Chunk
logger = get_logger(__name__)
class VectorStore:
"""ChromaDB vector store for chunk storage and retrieval."""
def __init__(
self,
persist_directory: Optional[Path] = None,
collection_name: Optional[str] = None,
) -> None:
settings = get_settings()
self.persist_directory = persist_directory or settings.vectorstore.persist_directory
self.collection_name = collection_name or settings.vectorstore.collection_name
self._client: Optional[chromadb.PersistentClient] = None
self._collection: Optional[chromadb.Collection] = None
@property
def client(self) -> chromadb.PersistentClient:
if self._client is None:
self._init_client()
return self._client
@property
def collection(self) -> chromadb.Collection:
if self._collection is None:
self._init_collection()
return self._collection
def _init_client(self) -> None:
logger.info("Initializing ChromaDB", path=str(self.persist_directory))
self.persist_directory.mkdir(parents=True, exist_ok=True)
self._client = chromadb.PersistentClient(
path=str(self.persist_directory),
settings=Settings(anonymized_telemetry=False),
)
def _init_collection(self) -> None:
self._collection = self.client.get_or_create_collection(
name=self.collection_name,
metadata={"hnsw:space": "cosine"},
)
logger.info("Collection initialized", name=self.collection_name)
def add_chunks(self, chunks: list[Chunk]) -> int:
if not chunks:
return 0
ids = [chunk.id for chunk in chunks]
embeddings = [chunk.embedding for chunk in chunks if chunk.embedding]
documents = [chunk.content for chunk in chunks]
metadatas = [chunk.to_dict() for chunk in chunks]
# Remove embedding and filter None values (ChromaDB doesn't accept None)
cleaned_metadatas = []
for m in metadatas:
m.pop("embedding", None)
m.pop("content", None) # Already stored in documents
# Filter out None values - ChromaDB only accepts str, int, float, bool
cleaned = {k: v for k, v in m.items() if v is not None}
cleaned_metadatas.append(cleaned)
self.collection.add(
ids=ids,
embeddings=embeddings,
documents=documents,
metadatas=cleaned_metadatas,
)
logger.info("Chunks added to vector store", count=len(chunks))
return len(chunks)
def query(
self,
query_embedding: list[float],
repo_id: str,
top_k: int = 5,
similarity_threshold: float = 0.0,
) -> list[tuple[Chunk, float]]:
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
where={"repo_id": repo_id},
include=["documents", "metadatas", "distances"],
)
chunks_with_scores = []
if results["ids"] and results["ids"][0]:
for i, chunk_id in enumerate(results["ids"][0]):
# ChromaDB returns distances, convert to similarity for cosine
distance = results["distances"][0][i]
similarity = 1 - distance
if similarity >= similarity_threshold:
metadata = results["metadatas"][0][i]
metadata["id"] = chunk_id
metadata["content"] = results["documents"][0][i]
chunk = Chunk.from_dict(metadata)
chunks_with_scores.append((chunk, similarity))
return chunks_with_scores
def delete_repo_chunks(self, repo_id: str) -> int:
# Get all chunks for this repo
results = self.collection.get(where={"repo_id": repo_id}, include=[])
if results["ids"]:
self.collection.delete(ids=results["ids"])
count = len(results["ids"])
logger.info("Deleted repo chunks", repo_id=repo_id, count=count)
return count
return 0
def delete_file_chunks(self, repo_id: str, file_path: str) -> int:
"""Delete chunks for a specific file in a repository (for incremental updates)."""
results = self.collection.get(
where={"$and": [{"repo_id": repo_id}, {"file_path": file_path}]},
include=[],
)
if results["ids"]:
self.collection.delete(ids=results["ids"])
count = len(results["ids"])
logger.info("Deleted file chunks", repo_id=repo_id, file_path=file_path, count=count)
return count
return 0
def get_indexed_files(self, repo_id: str) -> set[str]:
"""Get set of file paths indexed for a repository."""
results = self.collection.get(
where={"repo_id": repo_id},
include=["metadatas"],
)
files = set()
if results["metadatas"]:
for metadata in results["metadatas"]:
if "file_path" in metadata:
files.add(metadata["file_path"])
return files
def get_repo_chunk_count(self, repo_id: str) -> int:
results = self.collection.get(where={"repo_id": repo_id}, include=[])
return len(results["ids"]) if results["ids"] else 0
def get_all_repo_ids(self) -> list[str]:
results = self.collection.get(include=["metadatas"])
repo_ids = set()
if results["metadatas"]:
for metadata in results["metadatas"]:
if "repo_id" in metadata:
repo_ids.add(metadata["repo_id"])
return list(repo_ids)
def clear(self) -> None:
self.client.delete_collection(self.collection_name)
self._collection = None
logger.info("Collection cleared", name=self.collection_name)