From c632f7fc4e438fb8a9230fdd2cc8fb868aeac8bc Mon Sep 17 00:00:00 2001 From: Filip Haltmayer <81822489+filip-halt@users.noreply.github.com> Date: Mon, 15 May 2023 21:22:54 -0700 Subject: [PATCH] Add Milvus and Zilliz Retrievals (#4416) Adds the basic retrievers for Milvus and Zilliz. Hybrid search support will be added in the future. Signed-off-by: Filip Haltmayer --- langchain/retrievers/milvus.py | 43 ++++++++++++++++++++++++++++++++++ langchain/retrievers/zilliz.py | 43 ++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+) create mode 100644 langchain/retrievers/milvus.py create mode 100644 langchain/retrievers/zilliz.py diff --git a/langchain/retrievers/milvus.py b/langchain/retrievers/milvus.py new file mode 100644 index 0000000000..915d61d989 --- /dev/null +++ b/langchain/retrievers/milvus.py @@ -0,0 +1,43 @@ +"""Milvus Retriever""" +from typing import Any, Dict, List, Optional + +from langchain.embeddings.base import Embeddings +from langchain.schema import BaseRetriever, Document +from langchain.vectorstores.milvus import Milvus + +# TODO: Update to MilvusClient + Hybrid Search when available + + +class MilvusRetreiver(BaseRetriever): + def __init__( + self, + embedding_function: Embeddings, + collection_name: str = "LangChainCollection", + connection_args: Optional[Dict[str, Any]] = None, + consistency_level: str = "Session", + search_params: Optional[dict] = None, + ): + self.store = Milvus( + embedding_function, + collection_name, + connection_args, + consistency_level, + ) + self.retriever = self.store.as_retriever(search_kwargs={"param": search_params}) + + def add_texts( + self, texts: List[str], metadatas: Optional[List[dict]] = None + ) -> None: + """Add text to the Milvus store + + Args: + texts (List[str]): The text + metadatas (List[dict]): Metadata dicts, must line up with existing store + """ + self.store.add_texts(texts, metadatas) + + def get_relevant_documents(self, query: str) -> List[Document]: + return self.retriever.get_relevant_documents(query) + + async def aget_relevant_documents(self, query: str) -> List[Document]: + raise NotImplementedError diff --git a/langchain/retrievers/zilliz.py b/langchain/retrievers/zilliz.py new file mode 100644 index 0000000000..6b39a3a022 --- /dev/null +++ b/langchain/retrievers/zilliz.py @@ -0,0 +1,43 @@ +"""Zilliz Retriever""" +from typing import Any, Dict, List, Optional + +from langchain.embeddings.base import Embeddings +from langchain.schema import BaseRetriever, Document +from langchain.vectorstores.zilliz import Zilliz + +# TODO: Update to ZillizClient + Hybrid Search when available + + +class ZillizRetreiver(BaseRetriever): + def __init__( + self, + embedding_function: Embeddings, + collection_name: str = "LangChainCollection", + connection_args: Optional[Dict[str, Any]] = None, + consistency_level: str = "Session", + search_params: Optional[dict] = None, + ): + self.store = Zilliz( + embedding_function, + collection_name, + connection_args, + consistency_level, + ) + self.retriever = self.store.as_retriever(search_kwargs={"param": search_params}) + + def add_texts( + self, texts: List[str], metadatas: Optional[List[dict]] = None + ) -> None: + """Add text to the Zilliz store + + Args: + texts (List[str]): The text + metadatas (List[dict]): Metadata dicts, must line up with existing store + """ + self.store.add_texts(texts, metadatas) + + def get_relevant_documents(self, query: str) -> List[Document]: + return self.retriever.get_relevant_documents(query) + + async def aget_relevant_documents(self, query: str) -> List[Document]: + raise NotImplementedError