microcore.embedding_db.chromadb

  1from dataclasses import dataclass
  2import uuid
  3import chromadb
  4from chromadb.config import Settings
  5from chromadb.utils import embedding_functions
  6from ..configuration import Config
  7from .. import SearchResult, SearchResults, AbstractEmbeddingDB
  8
  9
 10@dataclass
 11class ChromaEmbeddingDB(AbstractEmbeddingDB):
 12    config: Config
 13    embedding_function: embedding_functions.EmbeddingFunction = None
 14    client: chromadb.Client = None
 15
 16    def __post_init__(self):
 17        self.client = chromadb.PersistentClient(
 18            path=f"{self.config.STORAGE_PATH}/{self.config.EMBEDDING_DB_FOLDER}",
 19            settings=Settings(anonymized_telemetry=False),
 20        )
 21        self.embedding_function = (
 22            self.config.EMBEDDING_DB_FUNCTION
 23            or embedding_functions.DefaultEmbeddingFunction()
 24        )
 25
 26    @classmethod
 27    def _wrap_results(cls, results) -> list[str | SearchResult]:
 28        return SearchResults([
 29            SearchResult(
 30                results["documents"][0][i],
 31                dict(
 32                    metadata=results["metadatas"][0][i] or {},
 33                    id=results["ids"][0][i],
 34                    distance=results["distances"][0][i],
 35                ),
 36            )
 37            for i in range(len(results["documents"][0]))
 38        ])
 39
 40    def search(
 41        self,
 42        collection: str,
 43        query: str | list,
 44        n_results: int = 5,
 45        where: dict = None,
 46        **kwargs,
 47    ) -> list[str | SearchResult]:
 48        try:
 49            chroma_collection = self.client.get_collection(
 50                collection, embedding_function=self.embedding_function
 51            )
 52        except ValueError:
 53            return SearchResults([])
 54
 55        if isinstance(query, str):
 56            query = [query]
 57
 58        d = chroma_collection.query(
 59            query_texts=query, n_results=n_results, where=where, **kwargs
 60        )
 61        return (
 62            self._wrap_results(d)
 63            if d and d.get("documents") and d["documents"][0]
 64            else SearchResults([])
 65        )
 66
 67    def save_many(self, collection: str, items: list[tuple[str, dict] | str]):
 68        chroma_collection = self.client.get_or_create_collection(
 69            name=collection, embedding_function=self.embedding_function
 70        )
 71        unique = not self.config.EMBEDDING_DB_ALLOW_DUPLICATES
 72        texts, ids, metadatas = [], [], []
 73        for i in items:
 74            if isinstance(i, str):
 75                text = i
 76                metadata = None
 77            else:
 78                text = i[0]
 79                metadata = i[1] or None
 80            if unique and text in texts:
 81                continue
 82            texts.append(text)
 83            metadatas.append(metadata)
 84            ids.append(str(hash(text)) if unique else str(uuid.uuid4()))
 85        chroma_collection.upsert(documents=texts, ids=ids, metadatas=metadatas)
 86
 87    def clear(self, collection: str):
 88        try:
 89            self.client.delete_collection(collection)
 90        except ValueError:
 91            pass
 92
 93    def count(self, collection: str) -> int:
 94        try:
 95            chroma_collection = self.client.get_collection(
 96                collection, embedding_function=self.embedding_function
 97            )
 98        except ValueError:
 99            return 0
100        return chroma_collection.count()
101
102    def delete(self, collection: str, what: str | list[str] | dict):
103        try:
104            chroma_collection = self.client.get_collection(
105                collection, embedding_function=self.embedding_function
106            )
107        except ValueError:
108            return
109        if isinstance(what, str):
110            ids, where = [what], None
111        elif isinstance(what, list):
112            ids, where = what, None
113        elif isinstance(what, dict):
114            ids, where = None, what
115        else:
116            raise ValueError("Invalid `what` argument")
117        chroma_collection.delete(ids=ids, where=where)
118
119    def get_all(self, collection: str) -> list[str | SearchResult]:
120        try:
121            chroma_collection = self.client.get_collection(
122                collection, embedding_function=self.embedding_function
123            )
124        except ValueError:
125            return SearchResults([])
126        results = chroma_collection.get()
127        return SearchResults([
128            SearchResult(
129                results["documents"][i],
130                {"metadata": results["metadatas"][i] or {}, "id": results["ids"][i]},
131            )
132            for i in range(len(results["documents"]))
133        ])
@dataclass
class ChromaEmbeddingDB(microcore.embedding_db.AbstractEmbeddingDB):
 11@dataclass
 12class ChromaEmbeddingDB(AbstractEmbeddingDB):
 13    config: Config
 14    embedding_function: embedding_functions.EmbeddingFunction = None
 15    client: chromadb.Client = None
 16
 17    def __post_init__(self):
 18        self.client = chromadb.PersistentClient(
 19            path=f"{self.config.STORAGE_PATH}/{self.config.EMBEDDING_DB_FOLDER}",
 20            settings=Settings(anonymized_telemetry=False),
 21        )
 22        self.embedding_function = (
 23            self.config.EMBEDDING_DB_FUNCTION
 24            or embedding_functions.DefaultEmbeddingFunction()
 25        )
 26
 27    @classmethod
 28    def _wrap_results(cls, results) -> list[str | SearchResult]:
 29        return SearchResults([
 30            SearchResult(
 31                results["documents"][0][i],
 32                dict(
 33                    metadata=results["metadatas"][0][i] or {},
 34                    id=results["ids"][0][i],
 35                    distance=results["distances"][0][i],
 36                ),
 37            )
 38            for i in range(len(results["documents"][0]))
 39        ])
 40
 41    def search(
 42        self,
 43        collection: str,
 44        query: str | list,
 45        n_results: int = 5,
 46        where: dict = None,
 47        **kwargs,
 48    ) -> list[str | SearchResult]:
 49        try:
 50            chroma_collection = self.client.get_collection(
 51                collection, embedding_function=self.embedding_function
 52            )
 53        except ValueError:
 54            return SearchResults([])
 55
 56        if isinstance(query, str):
 57            query = [query]
 58
 59        d = chroma_collection.query(
 60            query_texts=query, n_results=n_results, where=where, **kwargs
 61        )
 62        return (
 63            self._wrap_results(d)
 64            if d and d.get("documents") and d["documents"][0]
 65            else SearchResults([])
 66        )
 67
 68    def save_many(self, collection: str, items: list[tuple[str, dict] | str]):
 69        chroma_collection = self.client.get_or_create_collection(
 70            name=collection, embedding_function=self.embedding_function
 71        )
 72        unique = not self.config.EMBEDDING_DB_ALLOW_DUPLICATES
 73        texts, ids, metadatas = [], [], []
 74        for i in items:
 75            if isinstance(i, str):
 76                text = i
 77                metadata = None
 78            else:
 79                text = i[0]
 80                metadata = i[1] or None
 81            if unique and text in texts:
 82                continue
 83            texts.append(text)
 84            metadatas.append(metadata)
 85            ids.append(str(hash(text)) if unique else str(uuid.uuid4()))
 86        chroma_collection.upsert(documents=texts, ids=ids, metadatas=metadatas)
 87
 88    def clear(self, collection: str):
 89        try:
 90            self.client.delete_collection(collection)
 91        except ValueError:
 92            pass
 93
 94    def count(self, collection: str) -> int:
 95        try:
 96            chroma_collection = self.client.get_collection(
 97                collection, embedding_function=self.embedding_function
 98            )
 99        except ValueError:
100            return 0
101        return chroma_collection.count()
102
103    def delete(self, collection: str, what: str | list[str] | dict):
104        try:
105            chroma_collection = self.client.get_collection(
106                collection, embedding_function=self.embedding_function
107            )
108        except ValueError:
109            return
110        if isinstance(what, str):
111            ids, where = [what], None
112        elif isinstance(what, list):
113            ids, where = what, None
114        elif isinstance(what, dict):
115            ids, where = None, what
116        else:
117            raise ValueError("Invalid `what` argument")
118        chroma_collection.delete(ids=ids, where=where)
119
120    def get_all(self, collection: str) -> list[str | SearchResult]:
121        try:
122            chroma_collection = self.client.get_collection(
123                collection, embedding_function=self.embedding_function
124            )
125        except ValueError:
126            return SearchResults([])
127        results = chroma_collection.get()
128        return SearchResults([
129            SearchResult(
130                results["documents"][i],
131                {"metadata": results["metadatas"][i] or {}, "id": results["ids"][i]},
132            )
133            for i in range(len(results["documents"]))
134        ])
ChromaEmbeddingDB( config: microcore.Config, embedding_function: chromadb.api.types.EmbeddingFunction = None, client: <function Client> = None)
embedding_function: chromadb.api.types.EmbeddingFunction = None
client: <function Client at 0x7fb3982ed260> = None
def search( self, collection: str, query: str | list, n_results: int = 5, where: dict = None, **kwargs) -> list[str | microcore.SearchResult]:
41    def search(
42        self,
43        collection: str,
44        query: str | list,
45        n_results: int = 5,
46        where: dict = None,
47        **kwargs,
48    ) -> list[str | SearchResult]:
49        try:
50            chroma_collection = self.client.get_collection(
51                collection, embedding_function=self.embedding_function
52            )
53        except ValueError:
54            return SearchResults([])
55
56        if isinstance(query, str):
57            query = [query]
58
59        d = chroma_collection.query(
60            query_texts=query, n_results=n_results, where=where, **kwargs
61        )
62        return (
63            self._wrap_results(d)
64            if d and d.get("documents") and d["documents"][0]
65            else SearchResults([])
66        )

Similarity search

Arguments:
  • collection (str): collection name
  • query (str | list): query string or list of query strings
  • n_results (int): number of results to return
  • where (dict): filter results by metadata
  • **kwargs: additional arguments
def save_many(self, collection: str, items: list[tuple[str, dict] | str]):
68    def save_many(self, collection: str, items: list[tuple[str, dict] | str]):
69        chroma_collection = self.client.get_or_create_collection(
70            name=collection, embedding_function=self.embedding_function
71        )
72        unique = not self.config.EMBEDDING_DB_ALLOW_DUPLICATES
73        texts, ids, metadatas = [], [], []
74        for i in items:
75            if isinstance(i, str):
76                text = i
77                metadata = None
78            else:
79                text = i[0]
80                metadata = i[1] or None
81            if unique and text in texts:
82                continue
83            texts.append(text)
84            metadatas.append(metadata)
85            ids.append(str(hash(text)) if unique else str(uuid.uuid4()))
86        chroma_collection.upsert(documents=texts, ids=ids, metadatas=metadatas)

Save multiple documents in the collection

def clear(self, collection: str):
88    def clear(self, collection: str):
89        try:
90            self.client.delete_collection(collection)
91        except ValueError:
92            pass

Clear the collection

def count(self, collection: str) -> int:
 94    def count(self, collection: str) -> int:
 95        try:
 96            chroma_collection = self.client.get_collection(
 97                collection, embedding_function=self.embedding_function
 98            )
 99        except ValueError:
100            return 0
101        return chroma_collection.count()

Count the number of documents in the collection

Returns:

Number of documents in the collection

def delete(self, collection: str, what: str | list[str] | dict):
103    def delete(self, collection: str, what: str | list[str] | dict):
104        try:
105            chroma_collection = self.client.get_collection(
106                collection, embedding_function=self.embedding_function
107            )
108        except ValueError:
109            return
110        if isinstance(what, str):
111            ids, where = [what], None
112        elif isinstance(what, list):
113            ids, where = what, None
114        elif isinstance(what, dict):
115            ids, where = None, what
116        else:
117            raise ValueError("Invalid `what` argument")
118        chroma_collection.delete(ids=ids, where=where)

Delete documents from the collection

Arguments:
  • collection (str): collection name
  • what (str | list[str] | dict): id, list ids or metadata query
def get_all(self, collection: str) -> list[str | microcore.SearchResult]:
120    def get_all(self, collection: str) -> list[str | SearchResult]:
121        try:
122            chroma_collection = self.client.get_collection(
123                collection, embedding_function=self.embedding_function
124            )
125        except ValueError:
126            return SearchResults([])
127        results = chroma_collection.get()
128        return SearchResults([
129            SearchResult(
130                results["documents"][i],
131                {"metadata": results["metadatas"][i] or {}, "id": results["ids"][i]},
132            )
133            for i in range(len(results["documents"]))
134        ])

Return all documents in the collection