microcore.embedding_db.chromadb
1from dataclasses import dataclass 2import uuid 3import chromadb 4from chromadb.config import Settings 5from chromadb.utils import embedding_functions 6from ..configuration import Config 7from .. import SearchResult, SearchResults, AbstractEmbeddingDB 8 9 10@dataclass 11class ChromaEmbeddingDB(AbstractEmbeddingDB): 12 config: Config 13 embedding_function: embedding_functions.EmbeddingFunction = None 14 client: chromadb.Client = None 15 16 def __post_init__(self): 17 self.client = chromadb.PersistentClient( 18 path=f"{self.config.STORAGE_PATH}/{self.config.EMBEDDING_DB_FOLDER}", 19 settings=Settings(anonymized_telemetry=False), 20 ) 21 self.embedding_function = ( 22 self.config.EMBEDDING_DB_FUNCTION 23 or embedding_functions.DefaultEmbeddingFunction() 24 ) 25 26 @classmethod 27 def _wrap_results(cls, results) -> list[str | SearchResult]: 28 return SearchResults([ 29 SearchResult( 30 results["documents"][0][i], 31 dict( 32 metadata=results["metadatas"][0][i] or {}, 33 id=results["ids"][0][i], 34 distance=results["distances"][0][i], 35 ), 36 ) 37 for i in range(len(results["documents"][0])) 38 ]) 39 40 def search( 41 self, 42 collection: str, 43 query: str | list, 44 n_results: int = 5, 45 where: dict = None, 46 **kwargs, 47 ) -> list[str | SearchResult]: 48 try: 49 chroma_collection = self.client.get_collection( 50 collection, embedding_function=self.embedding_function 51 ) 52 except ValueError: 53 return SearchResults([]) 54 55 if isinstance(query, str): 56 query = [query] 57 58 d = chroma_collection.query( 59 query_texts=query, n_results=n_results, where=where, **kwargs 60 ) 61 return ( 62 self._wrap_results(d) 63 if d and d.get("documents") and d["documents"][0] 64 else SearchResults([]) 65 ) 66 67 def save_many(self, collection: str, items: list[tuple[str, dict] | str]): 68 chroma_collection = self.client.get_or_create_collection( 69 name=collection, embedding_function=self.embedding_function 70 ) 71 unique = not self.config.EMBEDDING_DB_ALLOW_DUPLICATES 72 texts, ids, metadatas = [], [], [] 73 for i in items: 74 if isinstance(i, str): 75 text = i 76 metadata = None 77 else: 78 text = i[0] 79 metadata = i[1] or None 80 if unique and text in texts: 81 continue 82 texts.append(text) 83 metadatas.append(metadata) 84 ids.append(str(hash(text)) if unique else str(uuid.uuid4())) 85 chroma_collection.upsert(documents=texts, ids=ids, metadatas=metadatas) 86 87 def clear(self, collection: str): 88 try: 89 self.client.delete_collection(collection) 90 except ValueError: 91 pass 92 93 def count(self, collection: str) -> int: 94 try: 95 chroma_collection = self.client.get_collection( 96 collection, embedding_function=self.embedding_function 97 ) 98 except ValueError: 99 return 0 100 return chroma_collection.count() 101 102 def delete(self, collection: str, what: str | list[str] | dict): 103 try: 104 chroma_collection = self.client.get_collection( 105 collection, embedding_function=self.embedding_function 106 ) 107 except ValueError: 108 return 109 if isinstance(what, str): 110 ids, where = [what], None 111 elif isinstance(what, list): 112 ids, where = what, None 113 elif isinstance(what, dict): 114 ids, where = None, what 115 else: 116 raise ValueError("Invalid `what` argument") 117 chroma_collection.delete(ids=ids, where=where) 118 119 def get_all(self, collection: str) -> list[str | SearchResult]: 120 try: 121 chroma_collection = self.client.get_collection( 122 collection, embedding_function=self.embedding_function 123 ) 124 except ValueError: 125 return SearchResults([]) 126 results = chroma_collection.get() 127 return SearchResults([ 128 SearchResult( 129 results["documents"][i], 130 {"metadata": results["metadatas"][i] or {}, "id": results["ids"][i]}, 131 ) 132 for i in range(len(results["documents"])) 133 ])
11@dataclass 12class ChromaEmbeddingDB(AbstractEmbeddingDB): 13 config: Config 14 embedding_function: embedding_functions.EmbeddingFunction = None 15 client: chromadb.Client = None 16 17 def __post_init__(self): 18 self.client = chromadb.PersistentClient( 19 path=f"{self.config.STORAGE_PATH}/{self.config.EMBEDDING_DB_FOLDER}", 20 settings=Settings(anonymized_telemetry=False), 21 ) 22 self.embedding_function = ( 23 self.config.EMBEDDING_DB_FUNCTION 24 or embedding_functions.DefaultEmbeddingFunction() 25 ) 26 27 @classmethod 28 def _wrap_results(cls, results) -> list[str | SearchResult]: 29 return SearchResults([ 30 SearchResult( 31 results["documents"][0][i], 32 dict( 33 metadata=results["metadatas"][0][i] or {}, 34 id=results["ids"][0][i], 35 distance=results["distances"][0][i], 36 ), 37 ) 38 for i in range(len(results["documents"][0])) 39 ]) 40 41 def search( 42 self, 43 collection: str, 44 query: str | list, 45 n_results: int = 5, 46 where: dict = None, 47 **kwargs, 48 ) -> list[str | SearchResult]: 49 try: 50 chroma_collection = self.client.get_collection( 51 collection, embedding_function=self.embedding_function 52 ) 53 except ValueError: 54 return SearchResults([]) 55 56 if isinstance(query, str): 57 query = [query] 58 59 d = chroma_collection.query( 60 query_texts=query, n_results=n_results, where=where, **kwargs 61 ) 62 return ( 63 self._wrap_results(d) 64 if d and d.get("documents") and d["documents"][0] 65 else SearchResults([]) 66 ) 67 68 def save_many(self, collection: str, items: list[tuple[str, dict] | str]): 69 chroma_collection = self.client.get_or_create_collection( 70 name=collection, embedding_function=self.embedding_function 71 ) 72 unique = not self.config.EMBEDDING_DB_ALLOW_DUPLICATES 73 texts, ids, metadatas = [], [], [] 74 for i in items: 75 if isinstance(i, str): 76 text = i 77 metadata = None 78 else: 79 text = i[0] 80 metadata = i[1] or None 81 if unique and text in texts: 82 continue 83 texts.append(text) 84 metadatas.append(metadata) 85 ids.append(str(hash(text)) if unique else str(uuid.uuid4())) 86 chroma_collection.upsert(documents=texts, ids=ids, metadatas=metadatas) 87 88 def clear(self, collection: str): 89 try: 90 self.client.delete_collection(collection) 91 except ValueError: 92 pass 93 94 def count(self, collection: str) -> int: 95 try: 96 chroma_collection = self.client.get_collection( 97 collection, embedding_function=self.embedding_function 98 ) 99 except ValueError: 100 return 0 101 return chroma_collection.count() 102 103 def delete(self, collection: str, what: str | list[str] | dict): 104 try: 105 chroma_collection = self.client.get_collection( 106 collection, embedding_function=self.embedding_function 107 ) 108 except ValueError: 109 return 110 if isinstance(what, str): 111 ids, where = [what], None 112 elif isinstance(what, list): 113 ids, where = what, None 114 elif isinstance(what, dict): 115 ids, where = None, what 116 else: 117 raise ValueError("Invalid `what` argument") 118 chroma_collection.delete(ids=ids, where=where) 119 120 def get_all(self, collection: str) -> list[str | SearchResult]: 121 try: 122 chroma_collection = self.client.get_collection( 123 collection, embedding_function=self.embedding_function 124 ) 125 except ValueError: 126 return SearchResults([]) 127 results = chroma_collection.get() 128 return SearchResults([ 129 SearchResult( 130 results["documents"][i], 131 {"metadata": results["metadatas"][i] or {}, "id": results["ids"][i]}, 132 ) 133 for i in range(len(results["documents"])) 134 ])
ChromaEmbeddingDB( config: microcore.Config, embedding_function: chromadb.api.types.EmbeddingFunction = None, client: <function Client> = None)
config: microcore.Config
def
search( self, collection: str, query: str | list, n_results: int = 5, where: dict = None, **kwargs) -> list[str | microcore.SearchResult]:
41 def search( 42 self, 43 collection: str, 44 query: str | list, 45 n_results: int = 5, 46 where: dict = None, 47 **kwargs, 48 ) -> list[str | SearchResult]: 49 try: 50 chroma_collection = self.client.get_collection( 51 collection, embedding_function=self.embedding_function 52 ) 53 except ValueError: 54 return SearchResults([]) 55 56 if isinstance(query, str): 57 query = [query] 58 59 d = chroma_collection.query( 60 query_texts=query, n_results=n_results, where=where, **kwargs 61 ) 62 return ( 63 self._wrap_results(d) 64 if d and d.get("documents") and d["documents"][0] 65 else SearchResults([]) 66 )
Similarity search
Arguments:
- collection (str): collection name
- query (str | list): query string or list of query strings
- n_results (int): number of results to return
- where (dict): filter results by metadata
- **kwargs: additional arguments
def
save_many(self, collection: str, items: list[tuple[str, dict] | str]):
68 def save_many(self, collection: str, items: list[tuple[str, dict] | str]): 69 chroma_collection = self.client.get_or_create_collection( 70 name=collection, embedding_function=self.embedding_function 71 ) 72 unique = not self.config.EMBEDDING_DB_ALLOW_DUPLICATES 73 texts, ids, metadatas = [], [], [] 74 for i in items: 75 if isinstance(i, str): 76 text = i 77 metadata = None 78 else: 79 text = i[0] 80 metadata = i[1] or None 81 if unique and text in texts: 82 continue 83 texts.append(text) 84 metadatas.append(metadata) 85 ids.append(str(hash(text)) if unique else str(uuid.uuid4())) 86 chroma_collection.upsert(documents=texts, ids=ids, metadatas=metadatas)
Save multiple documents in the collection
def
clear(self, collection: str):
88 def clear(self, collection: str): 89 try: 90 self.client.delete_collection(collection) 91 except ValueError: 92 pass
Clear the collection
def
count(self, collection: str) -> int:
94 def count(self, collection: str) -> int: 95 try: 96 chroma_collection = self.client.get_collection( 97 collection, embedding_function=self.embedding_function 98 ) 99 except ValueError: 100 return 0 101 return chroma_collection.count()
Count the number of documents in the collection
Returns:
Number of documents in the collection
def
delete(self, collection: str, what: str | list[str] | dict):
103 def delete(self, collection: str, what: str | list[str] | dict): 104 try: 105 chroma_collection = self.client.get_collection( 106 collection, embedding_function=self.embedding_function 107 ) 108 except ValueError: 109 return 110 if isinstance(what, str): 111 ids, where = [what], None 112 elif isinstance(what, list): 113 ids, where = what, None 114 elif isinstance(what, dict): 115 ids, where = None, what 116 else: 117 raise ValueError("Invalid `what` argument") 118 chroma_collection.delete(ids=ids, where=where)
Delete documents from the collection
Arguments:
- collection (str): collection name
- what (str | list[str] | dict): id, list ids or metadata query
120 def get_all(self, collection: str) -> list[str | SearchResult]: 121 try: 122 chroma_collection = self.client.get_collection( 123 collection, embedding_function=self.embedding_function 124 ) 125 except ValueError: 126 return SearchResults([]) 127 results = chroma_collection.get() 128 return SearchResults([ 129 SearchResult( 130 results["documents"][i], 131 {"metadata": results["metadatas"][i] or {}, "id": results["ids"][i]}, 132 ) 133 for i in range(len(results["documents"])) 134 ])
Return all documents in the collection