diff --git a/docs/snippets/modules/data_connection/retrievers/how_to/custom_retriever.mdx b/docs/snippets/modules/data_connection/retrievers/how_to/custom_retriever.mdx index 911784d96c..863aaf163e 100644 --- a/docs/snippets/modules/data_connection/retrievers/how_to/custom_retriever.mdx +++ b/docs/snippets/modules/data_connection/retrievers/how_to/custom_retriever.mdx @@ -14,9 +14,10 @@ from langchain.callbacks.manager import ( ) class BaseRetriever(ABC): + @abstractmethod def _get_relevant_documents( - self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any + self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: """Get documents relevant to a query. Args: @@ -32,7 +33,6 @@ class BaseRetriever(ABC): query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun, - **kwargs: Any, ) -> List[Document]: """Asynchronously get documents relevant to a query. Args: @@ -110,7 +110,7 @@ class NumpyRetriever(BaseRetriever): ] def _get_relevant_documents( - self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any + self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: """Get documents relevant to a query. Args: @@ -127,7 +127,6 @@ class NumpyRetriever(BaseRetriever): query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun, - **kwargs: Any, ) -> List[Document]: """Asynchronously get documents relevant to a query. Args: diff --git a/langchain/retrievers/arxiv.py b/langchain/retrievers/arxiv.py index 46223397b7..89a1370271 100644 --- a/langchain/retrievers/arxiv.py +++ b/langchain/retrievers/arxiv.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import List from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, @@ -16,19 +16,11 @@ class ArxivRetriever(BaseRetriever, ArxivAPIWrapper): """ def _get_relevant_documents( - self, - query: str, - *, - run_manager: CallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: return self.load(query=query) async def _aget_relevant_documents( - self, - query: str, - *, - run_manager: AsyncCallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun ) -> List[Document]: raise NotImplementedError diff --git a/langchain/retrievers/azure_cognitive_search.py b/langchain/retrievers/azure_cognitive_search.py index d79053c7e9..518750d663 100644 --- a/langchain/retrievers/azure_cognitive_search.py +++ b/langchain/retrievers/azure_cognitive_search.py @@ -3,7 +3,7 @@ from __future__ import annotations import json -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional import aiohttp import requests @@ -87,11 +87,7 @@ class AzureCognitiveSearchRetriever(BaseRetriever, BaseModel): return response_json["value"] def _get_relevant_documents( - self, - query: str, - *, - run_manager: CallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: search_results = self._search(query) @@ -101,11 +97,7 @@ class AzureCognitiveSearchRetriever(BaseRetriever, BaseModel): ] async def _aget_relevant_documents( - self, - query: str, - *, - run_manager: AsyncCallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun ) -> List[Document]: search_results = await self._asearch(query) diff --git a/langchain/retrievers/chatgpt_plugin_retriever.py b/langchain/retrievers/chatgpt_plugin_retriever.py index 5da8d408d9..06a25735c9 100644 --- a/langchain/retrievers/chatgpt_plugin_retriever.py +++ b/langchain/retrievers/chatgpt_plugin_retriever.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, List, Optional +from typing import List, Optional import aiohttp import requests @@ -26,11 +26,7 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel): arbitrary_types_allowed = True def _get_relevant_documents( - self, - query: str, - *, - run_manager: CallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: url, json, headers = self._create_request(query) response = requests.post(url, json=json, headers=headers) @@ -45,11 +41,7 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel): return docs async def _aget_relevant_documents( - self, - query: str, - *, - run_manager: AsyncCallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun ) -> List[Document]: url, json, headers = self._create_request(query) diff --git a/langchain/retrievers/databerry.py b/langchain/retrievers/databerry.py index e124d503b4..753823bdb4 100644 --- a/langchain/retrievers/databerry.py +++ b/langchain/retrievers/databerry.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional +from typing import List, Optional import aiohttp import requests @@ -28,11 +28,7 @@ class DataberryRetriever(BaseRetriever): self.top_k = top_k def _get_relevant_documents( - self, - query: str, - *, - run_manager: CallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: response = requests.post( self.datastore_url, @@ -59,11 +55,7 @@ class DataberryRetriever(BaseRetriever): ] async def _aget_relevant_documents( - self, - query: str, - *, - run_manager: AsyncCallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun ) -> List[Document]: async with aiohttp.ClientSession() as session: async with session.request( diff --git a/langchain/retrievers/docarray.py b/langchain/retrievers/docarray.py index 37cf155593..93b35d1c40 100644 --- a/langchain/retrievers/docarray.py +++ b/langchain/retrievers/docarray.py @@ -56,8 +56,8 @@ class DocArrayRetriever(BaseRetriever, BaseModel): def _get_relevant_documents( self, query: str, + *, run_manager: CallbackManagerForRetrieverRun, - **kwargs: Any, ) -> List[Document]: """Get documents relevant for a query. @@ -213,7 +213,7 @@ class DocArrayRetriever(BaseRetriever, BaseModel): async def _aget_relevant_documents( self, query: str, + *, run_manager: AsyncCallbackManagerForRetrieverRun, - **kwargs: Any, ) -> List[Document]: raise NotImplementedError diff --git a/langchain/retrievers/elastic_search_bm25.py b/langchain/retrievers/elastic_search_bm25.py index bdcfcaa116..81d0192554 100644 --- a/langchain/retrievers/elastic_search_bm25.py +++ b/langchain/retrievers/elastic_search_bm25.py @@ -117,11 +117,7 @@ class ElasticSearchBM25Retriever(BaseRetriever): return ids def _get_relevant_documents( - self, - query: str, - *, - run_manager: CallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: query_dict = {"query": {"match": {"content": query}}} res = self.client.search(index=self.index_name, body=query_dict) @@ -132,10 +128,6 @@ class ElasticSearchBM25Retriever(BaseRetriever): return docs async def _aget_relevant_documents( - self, - query: str, - *, - run_manager: AsyncCallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun ) -> List[Document]: raise NotImplementedError diff --git a/langchain/retrievers/kendra.py b/langchain/retrievers/kendra.py index b748c11ad3..47d2e321ba 100644 --- a/langchain/retrievers/kendra.py +++ b/langchain/retrievers/kendra.py @@ -264,8 +264,8 @@ class AmazonKendraRetriever(BaseRetriever): def _get_relevant_documents( self, query: str, + *, run_manager: CallbackManagerForRetrieverRun, - **kwargs: Any, ) -> List[Document]: """Run search on Kendra index and get top k documents @@ -281,7 +281,7 @@ class AmazonKendraRetriever(BaseRetriever): async def _aget_relevant_documents( self, query: str, + *, run_manager: AsyncCallbackManagerForRetrieverRun, - **kwargs: Any, ) -> List[Document]: raise NotImplementedError("Async version is not implemented for Kendra yet.") diff --git a/langchain/retrievers/knn.py b/langchain/retrievers/knn.py index 924f991601..d41f5ab32f 100644 --- a/langchain/retrievers/knn.py +++ b/langchain/retrievers/knn.py @@ -56,11 +56,7 @@ class KNNRetriever(BaseRetriever, BaseModel): return cls(embeddings=embeddings, index=index, texts=texts, **kwargs) def _get_relevant_documents( - self, - query: str, - *, - run_manager: CallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: query_embeds = np.array(self.embeddings.embed_query(query)) # calc L2 norm @@ -84,10 +80,6 @@ class KNNRetriever(BaseRetriever, BaseModel): return top_k_results async def _aget_relevant_documents( - self, - query: str, - *, - run_manager: AsyncCallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun ) -> List[Document]: raise NotImplementedError diff --git a/langchain/retrievers/llama_index.py b/langchain/retrievers/llama_index.py index f9cf3b36ad..5ede25e7b3 100644 --- a/langchain/retrievers/llama_index.py +++ b/langchain/retrievers/llama_index.py @@ -16,11 +16,7 @@ class LlamaIndexRetriever(BaseRetriever, BaseModel): query_kwargs: Dict = Field(default_factory=dict) def _get_relevant_documents( - self, - query: str, - *, - run_manager: CallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: """Get documents relevant for a query.""" try: @@ -44,11 +40,7 @@ class LlamaIndexRetriever(BaseRetriever, BaseModel): return docs async def _aget_relevant_documents( - self, - query: str, - *, - run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None, - **kwargs: Any, + self, query: str, *, run_manager: Optional[AsyncCallbackManagerForRetrieverRun] ) -> List[Document]: raise NotImplementedError("LlamaIndexRetriever does not support async") @@ -60,11 +52,7 @@ class LlamaIndexGraphRetriever(BaseRetriever, BaseModel): query_configs: List[Dict] = Field(default_factory=list) def _get_relevant_documents( - self, - query: str, - *, - run_manager: CallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: """Get documents relevant for a query.""" try: @@ -96,10 +84,6 @@ class LlamaIndexGraphRetriever(BaseRetriever, BaseModel): return docs async def _aget_relevant_documents( - self, - query: str, - *, - run_manager: AsyncCallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun ) -> List[Document]: raise NotImplementedError("LlamaIndexGraphRetriever does not support async") diff --git a/langchain/retrievers/merger_retriever.py b/langchain/retrievers/merger_retriever.py index ec8dce44f1..59d217c31c 100644 --- a/langchain/retrievers/merger_retriever.py +++ b/langchain/retrievers/merger_retriever.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import List from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, @@ -31,8 +31,8 @@ class MergerRetriever(BaseRetriever): def _get_relevant_documents( self, query: str, + *, run_manager: CallbackManagerForRetrieverRun, - **kwargs: Any, ) -> List[Document]: """ Get the relevant documents for a given query. @@ -52,8 +52,8 @@ class MergerRetriever(BaseRetriever): async def _aget_relevant_documents( self, query: str, + *, run_manager: AsyncCallbackManagerForRetrieverRun, - **kwargs: Any, ) -> List[Document]: """ Asynchronously get the relevant documents for a given query. diff --git a/langchain/retrievers/metal.py b/langchain/retrievers/metal.py index ec6824eb79..e5e3b6fae8 100644 --- a/langchain/retrievers/metal.py +++ b/langchain/retrievers/metal.py @@ -22,11 +22,7 @@ class MetalRetriever(BaseRetriever): self.params = params or {} def _get_relevant_documents( - self, - query: str, - *, - run_manager: CallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: results = self.client.search({"text": query}, **self.params) final_results = [] @@ -36,10 +32,6 @@ class MetalRetriever(BaseRetriever): return final_results async def _aget_relevant_documents( - self, - query: str, - *, - run_manager: AsyncCallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun ) -> List[Document]: raise NotImplementedError diff --git a/langchain/retrievers/multi_query.py b/langchain/retrievers/multi_query.py index bf330284cc..10c540d94c 100644 --- a/langchain/retrievers/multi_query.py +++ b/langchain/retrievers/multi_query.py @@ -1,5 +1,5 @@ import logging -from typing import Any, List +from typing import List from pydantic import BaseModel, Field @@ -98,8 +98,8 @@ class MultiQueryRetriever(BaseRetriever): def _get_relevant_documents( self, query: str, + *, run_manager: CallbackManagerForRetrieverRun, - **kwargs: Any, ) -> List[Document]: """Get relevated documents given a user query. @@ -117,8 +117,8 @@ class MultiQueryRetriever(BaseRetriever): async def _aget_relevant_documents( self, query: str, + *, run_manager: AsyncCallbackManagerForRetrieverRun, - **kwargs: Any, ) -> List[Document]: raise NotImplementedError diff --git a/langchain/retrievers/pinecone_hybrid_search.py b/langchain/retrievers/pinecone_hybrid_search.py index 481856d65f..eca21de810 100644 --- a/langchain/retrievers/pinecone_hybrid_search.py +++ b/langchain/retrievers/pinecone_hybrid_search.py @@ -143,11 +143,7 @@ class PineconeHybridSearchRetriever(BaseRetriever, BaseModel): return values def _get_relevant_documents( - self, - query: str, - *, - run_manager: CallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: from pinecone_text.hybrid import hybrid_convex_scale @@ -174,10 +170,6 @@ class PineconeHybridSearchRetriever(BaseRetriever, BaseModel): return final_result async def _aget_relevant_documents( - self, - query: str, - *, - run_manager: AsyncCallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun ) -> List[Document]: raise NotImplementedError diff --git a/langchain/retrievers/pubmed.py b/langchain/retrievers/pubmed.py index edb05063fb..573a9f2c10 100644 --- a/langchain/retrievers/pubmed.py +++ b/langchain/retrievers/pubmed.py @@ -1,5 +1,5 @@ """A retriever that uses PubMed API to retrieve documents.""" -from typing import Any, List +from typing import List from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, @@ -17,19 +17,11 @@ class PubMedRetriever(BaseRetriever, PubMedAPIWrapper): """ def _get_relevant_documents( - self, - query: str, - *, - run_manager: CallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: return self.load_docs(query=query) async def _aget_relevant_documents( - self, - query: str, - *, - run_manager: AsyncCallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun ) -> List[Document]: raise NotImplementedError diff --git a/langchain/retrievers/remote_retriever.py b/langchain/retrievers/remote_retriever.py index f6d8758d02..a75d87f9e0 100644 --- a/langchain/retrievers/remote_retriever.py +++ b/langchain/retrievers/remote_retriever.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional +from typing import List, Optional import aiohttp import requests @@ -20,11 +20,7 @@ class RemoteLangChainRetriever(BaseRetriever, BaseModel): metadata_key: str = "metadata" def _get_relevant_documents( - self, - query: str, - *, - run_manager: CallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: response = requests.post( self.url, json={self.input_key: query}, headers=self.headers @@ -38,11 +34,7 @@ class RemoteLangChainRetriever(BaseRetriever, BaseModel): ] async def _aget_relevant_documents( - self, - query: str, - *, - run_manager: AsyncCallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun ) -> List[Document]: async with aiohttp.ClientSession() as session: async with session.request( diff --git a/langchain/retrievers/self_query/base.py b/langchain/retrievers/self_query/base.py index 859d4d1d92..5e07a10d9e 100644 --- a/langchain/retrievers/self_query/base.py +++ b/langchain/retrievers/self_query/base.py @@ -84,11 +84,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel): return values def _get_relevant_documents( - self, - query: str, - *, - run_manager: CallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: """Get documents relevant for a query. @@ -121,11 +117,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel): return docs async def _aget_relevant_documents( - self, - query: str, - *, - run_manager: Optional[AsyncCallbackManagerForRetrieverRun], - **kwargs: Any, + self, query: str, *, run_manager: Optional[AsyncCallbackManagerForRetrieverRun] ) -> List[Document]: raise NotImplementedError diff --git a/langchain/retrievers/svm.py b/langchain/retrievers/svm.py index 2e96781220..5fcc3cd665 100644 --- a/langchain/retrievers/svm.py +++ b/langchain/retrievers/svm.py @@ -55,11 +55,7 @@ class SVMRetriever(BaseRetriever, BaseModel): return cls(embeddings=embeddings, index=index, texts=texts, **kwargs) def _get_relevant_documents( - self, - query: str, - *, - run_manager: CallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: from sklearn import svm @@ -98,10 +94,6 @@ class SVMRetriever(BaseRetriever, BaseModel): return top_k_results async def _aget_relevant_documents( - self, - query: str, - *, - run_manager: AsyncCallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun ) -> List[Document]: raise NotImplementedError diff --git a/langchain/retrievers/tfidf.py b/langchain/retrievers/tfidf.py index 71b1afb473..818541f09e 100644 --- a/langchain/retrievers/tfidf.py +++ b/langchain/retrievers/tfidf.py @@ -64,11 +64,7 @@ class TFIDFRetriever(BaseRetriever, BaseModel): ) def _get_relevant_documents( - self, - query: str, - *, - run_manager: CallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: from sklearn.metrics.pairwise import cosine_similarity @@ -82,10 +78,6 @@ class TFIDFRetriever(BaseRetriever, BaseModel): return return_docs async def _aget_relevant_documents( - self, - query: str, - *, - run_manager: AsyncCallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun ) -> List[Document]: raise NotImplementedError diff --git a/langchain/retrievers/time_weighted_retriever.py b/langchain/retrievers/time_weighted_retriever.py index 0f0681ad5b..0340785844 100644 --- a/langchain/retrievers/time_weighted_retriever.py +++ b/langchain/retrievers/time_weighted_retriever.py @@ -86,11 +86,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel): return results def _get_relevant_documents( - self, - query: str, - *, - run_manager: CallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: """Return documents that are relevant to the query.""" current_time = datetime.datetime.now() @@ -115,11 +111,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel): return result async def _aget_relevant_documents( - self, - query: str, - *, - run_manager: AsyncCallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun ) -> List[Document]: """Return documents that are relevant to the query.""" raise NotImplementedError diff --git a/langchain/retrievers/vespa_retriever.py b/langchain/retrievers/vespa_retriever.py index b95ce59c42..c8eb04b8b7 100644 --- a/langchain/retrievers/vespa_retriever.py +++ b/langchain/retrievers/vespa_retriever.py @@ -65,22 +65,14 @@ class VespaRetriever(BaseRetriever): return docs def _get_relevant_documents( - self, - query: str, - *, - run_manager: CallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: body = self._query_body.copy() body["query"] = query return self._query(body) async def _aget_relevant_documents( - self, - query: str, - *, - run_manager: AsyncCallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun ) -> List[Document]: raise NotImplementedError diff --git a/langchain/retrievers/weaviate_hybrid_search.py b/langchain/retrievers/weaviate_hybrid_search.py index b34edd9440..f8ad46fa44 100644 --- a/langchain/retrievers/weaviate_hybrid_search.py +++ b/langchain/retrievers/weaviate_hybrid_search.py @@ -93,7 +93,6 @@ class WeaviateHybridSearchRetriever(BaseRetriever): *, run_manager: CallbackManagerForRetrieverRun, where_filter: Optional[Dict[str, object]] = None, - **kwargs: Any, ) -> List[Document]: """Look up similar documents in Weaviate.""" query_obj = self._client.query.get(self._index_name, self._query_attrs) @@ -112,11 +111,6 @@ class WeaviateHybridSearchRetriever(BaseRetriever): return docs async def _aget_relevant_documents( - self, - query: str, - *, - run_manager: AsyncCallbackManagerForRetrieverRun, - where_filter: Optional[Dict[str, object]] = None, - **kwargs: Any, + self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun ) -> List[Document]: raise NotImplementedError diff --git a/langchain/retrievers/wikipedia.py b/langchain/retrievers/wikipedia.py index 013728396a..fe43099cf9 100644 --- a/langchain/retrievers/wikipedia.py +++ b/langchain/retrievers/wikipedia.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import List from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, @@ -16,19 +16,11 @@ class WikipediaRetriever(BaseRetriever, WikipediaAPIWrapper): """ def _get_relevant_documents( - self, - query: str, - *, - run_manager: CallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: return self.load(query=query) async def _aget_relevant_documents( - self, - query: str, - *, - run_manager: AsyncCallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun ) -> List[Document]: raise NotImplementedError diff --git a/langchain/retrievers/zep.py b/langchain/retrievers/zep.py index d6d43bf1b8..9062a5a982 100644 --- a/langchain/retrievers/zep.py +++ b/langchain/retrievers/zep.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, @@ -64,7 +64,6 @@ class ZepRetriever(BaseRetriever): *, run_manager: CallbackManagerForRetrieverRun, metadata: Optional[Dict] = None, - **kwargs: Any, ) -> List[Document]: from zep_python import MemorySearchPayload @@ -84,7 +83,6 @@ class ZepRetriever(BaseRetriever): *, run_manager: AsyncCallbackManagerForRetrieverRun, metadata: Optional[Dict] = None, - **kwargs: Any, ) -> List[Document]: from zep_python import MemorySearchPayload diff --git a/langchain/schema/retriever.py b/langchain/schema/retriever.py index d3e59fb75c..cd4ab75813 100644 --- a/langchain/schema/retriever.py +++ b/langchain/schema/retriever.py @@ -44,7 +44,6 @@ class BaseRetriever(ABC): async def aget_relevant_documents(self, query: str) -> List[Document]: raise NotImplementedError - """ # noqa: E501 _new_arg_supported: bool = False @@ -86,27 +85,23 @@ class BaseRetriever(ABC): @abstractmethod def _get_relevant_documents( - self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any + self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: """Get documents relevant to a query. Args: - query: String to find relevant documents for. - run_manager: The callbacks handler to use. + query: String to find relevant documents for + run_manager: The callbacks handler to use Returns: List of relevant documents """ @abstractmethod async def _aget_relevant_documents( - self, - query: str, - *, - run_manager: AsyncCallbackManagerForRetrieverRun, - **kwargs: Any, + self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun ) -> List[Document]: """Asynchronously get documents relevant to a query. Args: - query: string to find relevant documents for + query: String to find relevant documents for run_manager: The callbacks handler to use Returns: List of relevant documents @@ -117,8 +112,8 @@ class BaseRetriever(ABC): ) -> List[Document]: """Retrieve documents relevant to a query. Args: - query: String to find relevant documents for. - callbacks: Callback manager or list of callbacks. + query: string to find relevant documents for + callbacks: Callback manager or list of callbacks Returns: List of relevant documents """ @@ -132,14 +127,13 @@ class BaseRetriever(ABC): **kwargs, ) try: + _kwargs = kwargs if self._expects_other_args else {} if self._new_arg_supported: result = self._get_relevant_documents( - query, run_manager=run_manager, **kwargs + query, run_manager=run_manager, **_kwargs ) - elif self._expects_other_args: - result = self._get_relevant_documents(query, **kwargs) else: - result = self._get_relevant_documents(query) # type: ignore[call-arg] + result = self._get_relevant_documents(query, **_kwargs) except Exception as e: run_manager.on_retriever_error(e) raise e @@ -170,16 +164,13 @@ class BaseRetriever(ABC): **kwargs, ) try: + _kwargs = kwargs if self._expects_other_args else {} if self._new_arg_supported: result = await self._aget_relevant_documents( - query, run_manager=run_manager, **kwargs + query, run_manager=run_manager, **_kwargs ) - elif self._expects_other_args: - result = await self._aget_relevant_documents(query, **kwargs) else: - result = await self._aget_relevant_documents( - query, # type: ignore[call-arg] - ) + result = await self._aget_relevant_documents(query, **_kwargs) except Exception as e: await run_manager.on_retriever_error(e) raise e diff --git a/langchain/vectorstores/azuresearch.py b/langchain/vectorstores/azuresearch.py index e97c65ed68..2404d31431 100644 --- a/langchain/vectorstores/azuresearch.py +++ b/langchain/vectorstores/azuresearch.py @@ -497,8 +497,8 @@ class AzureSearchVectorStoreRetriever(BaseRetriever, BaseModel): def _get_relevant_documents( self, query: str, + *, run_manager: CallbackManagerForRetrieverRun, - **kwargs: Any, ) -> List[Document]: if self.search_type == "similarity": docs = self.vectorstore.vector_search(query, k=self.k) @@ -513,8 +513,8 @@ class AzureSearchVectorStoreRetriever(BaseRetriever, BaseModel): async def _aget_relevant_documents( self, query: str, + *, run_manager: AsyncCallbackManagerForRetrieverRun, - **kwargs: Any, ) -> List[Document]: raise NotImplementedError( "AzureSearchVectorStoreRetriever does not support async" diff --git a/langchain/vectorstores/base.py b/langchain/vectorstores/base.py index 4bbe7c8d84..1e574af5d3 100644 --- a/langchain/vectorstores/base.py +++ b/langchain/vectorstores/base.py @@ -408,11 +408,7 @@ class VectorStoreRetriever(BaseRetriever, BaseModel): return values def _get_relevant_documents( - self, - query: str, - *, - run_manager: Optional[CallbackManagerForRetrieverRun] = None, - **kwargs: Any, + self, query: str, *, run_manager: Optional[CallbackManagerForRetrieverRun] ) -> List[Document]: if self.search_type == "similarity": docs = self.vectorstore.similarity_search(query, **self.search_kwargs) @@ -432,11 +428,7 @@ class VectorStoreRetriever(BaseRetriever, BaseModel): return docs async def _aget_relevant_documents( - self, - query: str, - *, - run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None, - **kwargs: Any, + self, query: str, *, run_manager: Optional[AsyncCallbackManagerForRetrieverRun] ) -> List[Document]: if self.search_type == "similarity": docs = await self.vectorstore.asimilarity_search( diff --git a/langchain/vectorstores/redis.py b/langchain/vectorstores/redis.py index dbed8cb969..31c95e1f98 100644 --- a/langchain/vectorstores/redis.py +++ b/langchain/vectorstores/redis.py @@ -620,11 +620,7 @@ class RedisVectorStoreRetriever(VectorStoreRetriever, BaseModel): return values def _get_relevant_documents( - self, - query: str, - *, - run_manager: Optional[CallbackManagerForRetrieverRun] = None, - **kwargs: Any, + self, query: str, *, run_manager: Optional[CallbackManagerForRetrieverRun] ) -> List[Document]: if self.search_type == "similarity": docs = self.vectorstore.similarity_search(query, k=self.k) @@ -637,11 +633,7 @@ class RedisVectorStoreRetriever(VectorStoreRetriever, BaseModel): return docs async def _aget_relevant_documents( - self, - query: str, - *, - run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None, - **kwargs: Any, + self, query: str, *, run_manager: Optional[AsyncCallbackManagerForRetrieverRun] ) -> List[Document]: raise NotImplementedError("RedisVectorStoreRetriever does not support async") diff --git a/langchain/vectorstores/singlestoredb.py b/langchain/vectorstores/singlestoredb.py index b16b6cedb7..90f1693092 100644 --- a/langchain/vectorstores/singlestoredb.py +++ b/langchain/vectorstores/singlestoredb.py @@ -451,11 +451,7 @@ class SingleStoreDBRetriever(VectorStoreRetriever): allowed_search_types: ClassVar[Collection[str]] = ("similarity",) def _get_relevant_documents( - self, - query: str, - *, - run_manager: Optional[CallbackManagerForRetrieverRun] = None, - **kwargs: Any, + self, query: str, *, run_manager: Optional[CallbackManagerForRetrieverRun] ) -> List[Document]: if self.search_type == "similarity": docs = self.vectorstore.similarity_search(query, k=self.k) @@ -464,11 +460,7 @@ class SingleStoreDBRetriever(VectorStoreRetriever): return docs async def _aget_relevant_documents( - self, - query: str, - *, - run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None, - **kwargs: Any, + self, query: str, *, run_manager: Optional[AsyncCallbackManagerForRetrieverRun] ) -> List[Document]: raise NotImplementedError( "SingleStoreDBVectorStoreRetriever does not support async" diff --git a/tests/integration_tests/chains/test_retrieval_qa.py b/tests/integration_tests/chains/test_retrieval_qa.py index 7ce6fdff1e..35aaffefaa 100644 --- a/tests/integration_tests/chains/test_retrieval_qa.py +++ b/tests/integration_tests/chains/test_retrieval_qa.py @@ -7,18 +7,19 @@ from langchain.document_loaders import TextLoader from langchain.embeddings.openai import OpenAIEmbeddings from langchain.llms import OpenAI from langchain.text_splitter import CharacterTextSplitter -from langchain.vectorstores import Chroma +from langchain.vectorstores import FAISS def test_retrieval_qa_saving_loading(tmp_path: Path) -> None: """Test saving and loading.""" - loader = TextLoader("docs/modules/state_of_the_union.txt") + loader = TextLoader("docs/extras/modules/state_of_the_union.txt") documents = loader.load() text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) texts = text_splitter.split_documents(documents) embeddings = OpenAIEmbeddings() - docsearch = Chroma.from_documents(texts, embeddings) + docsearch = FAISS.from_documents(texts, embeddings) qa = RetrievalQA.from_llm(llm=OpenAI(), retriever=docsearch.as_retriever()) + qa.run("What did the president say about Ketanji Brown Jackson?") file_path = tmp_path / "RetrievalQA_chain.yaml" qa.save(file_path=file_path) diff --git a/tests/integration_tests/retrievers/test_contextual_compression.py b/tests/integration_tests/retrievers/test_contextual_compression.py index 60eb206b8b..df74e94cac 100644 --- a/tests/integration_tests/retrievers/test_contextual_compression.py +++ b/tests/integration_tests/retrievers/test_contextual_compression.py @@ -1,7 +1,7 @@ from langchain.embeddings import OpenAIEmbeddings from langchain.retrievers.contextual_compression import ContextualCompressionRetriever from langchain.retrievers.document_compressors import EmbeddingsFilter -from langchain.vectorstores import Chroma +from langchain.vectorstores import FAISS def test_contextual_compression_retriever_get_relevant_docs() -> None: @@ -13,7 +13,7 @@ def test_contextual_compression_retriever_get_relevant_docs() -> None: ] embeddings = OpenAIEmbeddings() base_compressor = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.75) - base_retriever = Chroma.from_texts(texts, embedding=embeddings).as_retriever( + base_retriever = FAISS.from_texts(texts, embedding=embeddings).as_retriever( search_kwargs={"k": len(texts)} ) retriever = ContextualCompressionRetriever( diff --git a/tests/unit_tests/retrievers/test_base.py b/tests/unit_tests/retrievers/test_base.py index e3e15ff4e9..28295d1c56 100644 --- a/tests/unit_tests/retrievers/test_base.py +++ b/tests/unit_tests/retrievers/test_base.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional import pytest @@ -141,45 +141,48 @@ async def test_fake_retriever_v1_with_kwargs_upgrade_async( assert callbacks.retriever_errors == 0 +class FakeRetrieverV2(BaseRetriever): + def __init__(self, throw_error: bool = False) -> None: + self.throw_error = throw_error + + def _get_relevant_documents( + self, query: str, *, run_manager: CallbackManagerForRetrieverRun | None + ) -> List[Document]: + assert isinstance(self, FakeRetrieverV2) + assert run_manager is not None + assert isinstance(run_manager, CallbackManagerForRetrieverRun) + if self.throw_error: + raise ValueError("Test error") + return [ + Document(page_content=query), + ] + + async def _aget_relevant_documents( + self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun | None + ) -> List[Document]: + assert isinstance(self, FakeRetrieverV2) + assert run_manager is not None + assert isinstance(run_manager, AsyncCallbackManagerForRetrieverRun) + if self.throw_error: + raise ValueError("Test error") + return [ + Document(page_content=f"Async query {query}"), + ] + + @pytest.fixture def fake_retriever_v2() -> BaseRetriever: - class FakeRetrieverV2(BaseRetriever): - def _get_relevant_documents( - self, - query: str, - *, - run_manager: CallbackManagerForRetrieverRun | None = None, - **kwargs: Any, - ) -> List[Document]: - assert isinstance(self, FakeRetrieverV2) - assert run_manager is not None - assert isinstance(run_manager, CallbackManagerForRetrieverRun) - if "throw_error" in kwargs: - raise ValueError("Test error") - return [ - Document(page_content=query, metadata=kwargs), - ] - - async def _aget_relevant_documents( - self, - query: str, - *, - run_manager: AsyncCallbackManagerForRetrieverRun | None = None, - **kwargs: Any, - ) -> List[Document]: - assert isinstance(self, FakeRetrieverV2) - assert run_manager is not None - assert isinstance(run_manager, AsyncCallbackManagerForRetrieverRun) - if "throw_error" in kwargs: - raise ValueError("Test error") - return [ - Document(page_content=f"Async query {query}", metadata=kwargs), - ] - return FakeRetrieverV2() # type: ignore[abstract] -def test_fake_retriever_v2(fake_retriever_v2: BaseRetriever) -> None: +@pytest.fixture +def fake_erroring_retriever_v2() -> BaseRetriever: + return FakeRetrieverV2(throw_error=True) # type: ignore[abstract] + + +def test_fake_retriever_v2( + fake_retriever_v2: BaseRetriever, fake_erroring_retriever_v2: BaseRetriever +) -> None: callbacks = FakeCallbackHandler() assert fake_retriever_v2._new_arg_supported is True results = fake_retriever_v2.get_relevant_documents("Foo", callbacks=[callbacks]) @@ -187,20 +190,17 @@ def test_fake_retriever_v2(fake_retriever_v2: BaseRetriever) -> None: assert callbacks.retriever_starts == 1 assert callbacks.retriever_ends == 1 assert callbacks.retriever_errors == 0 - results2 = fake_retriever_v2.get_relevant_documents( - "Foo", callbacks=[callbacks], foo="bar" - ) - assert results2[0].metadata == {"foo": "bar"} + fake_retriever_v2.get_relevant_documents("Foo", callbacks=[callbacks]) with pytest.raises(ValueError, match="Test error"): - fake_retriever_v2.get_relevant_documents( - "Foo", callbacks=[callbacks], throw_error=True - ) + fake_erroring_retriever_v2.get_relevant_documents("Foo", callbacks=[callbacks]) assert callbacks.retriever_errors == 1 @pytest.mark.asyncio -async def test_fake_retriever_v2_async(fake_retriever_v2: BaseRetriever) -> None: +async def test_fake_retriever_v2_async( + fake_retriever_v2: BaseRetriever, fake_erroring_retriever_v2: BaseRetriever +) -> None: callbacks = FakeCallbackHandler() assert fake_retriever_v2._new_arg_supported is True results = await fake_retriever_v2.aget_relevant_documents( @@ -210,11 +210,8 @@ async def test_fake_retriever_v2_async(fake_retriever_v2: BaseRetriever) -> None assert callbacks.retriever_starts == 1 assert callbacks.retriever_ends == 1 assert callbacks.retriever_errors == 0 - results2 = await fake_retriever_v2.aget_relevant_documents( - "Foo", callbacks=[callbacks], foo="bar" - ) - assert results2[0].metadata == {"foo": "bar"} + await fake_retriever_v2.aget_relevant_documents("Foo", callbacks=[callbacks]) with pytest.raises(ValueError, match="Test error"): - await fake_retriever_v2.aget_relevant_documents( - "Foo", callbacks=[callbacks], throw_error=True + await fake_erroring_retriever_v2.aget_relevant_documents( + "Foo", callbacks=[callbacks] )