File size: 6,183 Bytes
42f5b98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""ChromaDB vector store operations."""

from pathlib import Path
from typing import Optional

import chromadb
from chromadb.config import Settings

from coderag.config import get_settings
from coderag.logging import get_logger
from coderag.models.chunk import Chunk

logger = get_logger(__name__)


class VectorStore:
    """ChromaDB vector store for chunk storage and retrieval."""

    def __init__(
        self,
        persist_directory: Optional[Path] = None,
        collection_name: Optional[str] = None,
    ) -> None:
        settings = get_settings()
        self.persist_directory = persist_directory or settings.vectorstore.persist_directory
        self.collection_name = collection_name or settings.vectorstore.collection_name
        self._client: Optional[chromadb.PersistentClient] = None
        self._collection: Optional[chromadb.Collection] = None

    @property
    def client(self) -> chromadb.PersistentClient:
        if self._client is None:
            self._init_client()
        return self._client

    @property
    def collection(self) -> chromadb.Collection:
        if self._collection is None:
            self._init_collection()
        return self._collection

    def _init_client(self) -> None:
        logger.info("Initializing ChromaDB", path=str(self.persist_directory))
        self.persist_directory.mkdir(parents=True, exist_ok=True)
        self._client = chromadb.PersistentClient(
            path=str(self.persist_directory),
            settings=Settings(anonymized_telemetry=False),
        )

    def _init_collection(self) -> None:
        self._collection = self.client.get_or_create_collection(
            name=self.collection_name,
            metadata={"hnsw:space": "cosine"},
        )
        logger.info("Collection initialized", name=self.collection_name)

    def add_chunks(self, chunks: list[Chunk]) -> int:
        if not chunks:
            return 0

        ids = [chunk.id for chunk in chunks]
        embeddings = [chunk.embedding for chunk in chunks if chunk.embedding]
        documents = [chunk.content for chunk in chunks]
        metadatas = [chunk.to_dict() for chunk in chunks]

        # Remove embedding and filter None values (ChromaDB doesn't accept None)
        cleaned_metadatas = []
        for m in metadatas:
            m.pop("embedding", None)
            m.pop("content", None)  # Already stored in documents
            # Filter out None values - ChromaDB only accepts str, int, float, bool
            cleaned = {k: v for k, v in m.items() if v is not None}
            cleaned_metadatas.append(cleaned)

        self.collection.add(
            ids=ids,
            embeddings=embeddings,
            documents=documents,
            metadatas=cleaned_metadatas,
        )

        logger.info("Chunks added to vector store", count=len(chunks))
        return len(chunks)

    def query(
        self,
        query_embedding: list[float],
        repo_id: str,
        top_k: int = 5,
        similarity_threshold: float = 0.0,
    ) -> list[tuple[Chunk, float]]:
        results = self.collection.query(
            query_embeddings=[query_embedding],
            n_results=top_k,
            where={"repo_id": repo_id},
            include=["documents", "metadatas", "distances"],
        )

        chunks_with_scores = []
        if results["ids"] and results["ids"][0]:
            for i, chunk_id in enumerate(results["ids"][0]):
                # ChromaDB returns distances, convert to similarity for cosine
                distance = results["distances"][0][i]
                similarity = 1 - distance

                if similarity >= similarity_threshold:
                    metadata = results["metadatas"][0][i]
                    metadata["id"] = chunk_id
                    metadata["content"] = results["documents"][0][i]
                    chunk = Chunk.from_dict(metadata)
                    chunks_with_scores.append((chunk, similarity))

        return chunks_with_scores

    def delete_repo_chunks(self, repo_id: str) -> int:
        # Get all chunks for this repo
        results = self.collection.get(where={"repo_id": repo_id}, include=[])

        if results["ids"]:
            self.collection.delete(ids=results["ids"])
            count = len(results["ids"])
            logger.info("Deleted repo chunks", repo_id=repo_id, count=count)
            return count
        return 0

    def delete_file_chunks(self, repo_id: str, file_path: str) -> int:
        """Delete chunks for a specific file in a repository (for incremental updates)."""
        results = self.collection.get(
            where={"$and": [{"repo_id": repo_id}, {"file_path": file_path}]},
            include=[],
        )

        if results["ids"]:
            self.collection.delete(ids=results["ids"])
            count = len(results["ids"])
            logger.info("Deleted file chunks", repo_id=repo_id, file_path=file_path, count=count)
            return count
        return 0

    def get_indexed_files(self, repo_id: str) -> set[str]:
        """Get set of file paths indexed for a repository."""
        results = self.collection.get(
            where={"repo_id": repo_id},
            include=["metadatas"],
        )

        files = set()
        if results["metadatas"]:
            for metadata in results["metadatas"]:
                if "file_path" in metadata:
                    files.add(metadata["file_path"])
        return files

    def get_repo_chunk_count(self, repo_id: str) -> int:
        results = self.collection.get(where={"repo_id": repo_id}, include=[])
        return len(results["ids"]) if results["ids"] else 0

    def get_all_repo_ids(self) -> list[str]:
        results = self.collection.get(include=["metadatas"])
        repo_ids = set()
        if results["metadatas"]:
            for metadata in results["metadatas"]:
                if "repo_id" in metadata:
                    repo_ids.add(metadata["repo_id"])
        return list(repo_ids)

    def clear(self) -> None:
        self.client.delete_collection(self.collection_name)
        self._collection = None
        logger.info("Collection cleared", name=self.collection_name)