Rm retriever kwargs (#7013)

Doesn't actually limit the Retriever interface but hopefully in practice
it does
pull/6876/head
Bagatur 1 year ago committed by GitHub
parent 9dc77614e3
commit 7acd524210
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -14,9 +14,10 @@ from langchain.callbacks.manager import (
)
class BaseRetriever(ABC):
@abstractmethod
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant to a query.
Args:
@ -32,7 +33,6 @@ class BaseRetriever(ABC):
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
"""Asynchronously get documents relevant to a query.
Args:
@ -110,7 +110,7 @@ class NumpyRetriever(BaseRetriever):
]
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant to a query.
Args:
@ -127,7 +127,6 @@ class NumpyRetriever(BaseRetriever):
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
"""Asynchronously get documents relevant to a query.
Args:

@ -1,4 +1,4 @@
from typing import Any, List
from typing import List
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
@ -16,19 +16,11 @@ class ArxivRetriever(BaseRetriever, ArxivAPIWrapper):
"""
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
return self.load(query=query)
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

@ -3,7 +3,7 @@
from __future__ import annotations
import json
from typing import Any, Dict, List, Optional
from typing import Dict, List, Optional
import aiohttp
import requests
@ -87,11 +87,7 @@ class AzureCognitiveSearchRetriever(BaseRetriever, BaseModel):
return response_json["value"]
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
search_results = self._search(query)
@ -101,11 +97,7 @@ class AzureCognitiveSearchRetriever(BaseRetriever, BaseModel):
]
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
search_results = await self._asearch(query)

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any, List, Optional
from typing import List, Optional
import aiohttp
import requests
@ -26,11 +26,7 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
arbitrary_types_allowed = True
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
url, json, headers = self._create_request(query)
response = requests.post(url, json=json, headers=headers)
@ -45,11 +41,7 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
return docs
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
url, json, headers = self._create_request(query)

@ -1,4 +1,4 @@
from typing import Any, List, Optional
from typing import List, Optional
import aiohttp
import requests
@ -28,11 +28,7 @@ class DataberryRetriever(BaseRetriever):
self.top_k = top_k
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
response = requests.post(
self.datastore_url,
@ -59,11 +55,7 @@ class DataberryRetriever(BaseRetriever):
]
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
async with aiohttp.ClientSession() as session:
async with session.request(

@ -56,8 +56,8 @@ class DocArrayRetriever(BaseRetriever, BaseModel):
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
"""Get documents relevant for a query.
@ -213,7 +213,7 @@ class DocArrayRetriever(BaseRetriever, BaseModel):
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError

@ -117,11 +117,7 @@ class ElasticSearchBM25Retriever(BaseRetriever):
return ids
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
query_dict = {"query": {"match": {"content": query}}}
res = self.client.search(index=self.index_name, body=query_dict)
@ -132,10 +128,6 @@ class ElasticSearchBM25Retriever(BaseRetriever):
return docs
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

@ -264,8 +264,8 @@ class AmazonKendraRetriever(BaseRetriever):
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
"""Run search on Kendra index and get top k documents
@ -281,7 +281,7 @@ class AmazonKendraRetriever(BaseRetriever):
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError("Async version is not implemented for Kendra yet.")

@ -56,11 +56,7 @@ class KNNRetriever(BaseRetriever, BaseModel):
return cls(embeddings=embeddings, index=index, texts=texts, **kwargs)
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
query_embeds = np.array(self.embeddings.embed_query(query))
# calc L2 norm
@ -84,10 +80,6 @@ class KNNRetriever(BaseRetriever, BaseModel):
return top_k_results
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

@ -16,11 +16,7 @@ class LlamaIndexRetriever(BaseRetriever, BaseModel):
query_kwargs: Dict = Field(default_factory=dict)
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant for a query."""
try:
@ -44,11 +40,7 @@ class LlamaIndexRetriever(BaseRetriever, BaseModel):
return docs
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
self, query: str, *, run_manager: Optional[AsyncCallbackManagerForRetrieverRun]
) -> List[Document]:
raise NotImplementedError("LlamaIndexRetriever does not support async")
@ -60,11 +52,7 @@ class LlamaIndexGraphRetriever(BaseRetriever, BaseModel):
query_configs: List[Dict] = Field(default_factory=list)
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant for a query."""
try:
@ -96,10 +84,6 @@ class LlamaIndexGraphRetriever(BaseRetriever, BaseModel):
return docs
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError("LlamaIndexGraphRetriever does not support async")

@ -1,4 +1,4 @@
from typing import Any, List
from typing import List
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
@ -31,8 +31,8 @@ class MergerRetriever(BaseRetriever):
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
"""
Get the relevant documents for a given query.
@ -52,8 +52,8 @@ class MergerRetriever(BaseRetriever):
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
"""
Asynchronously get the relevant documents for a given query.

@ -22,11 +22,7 @@ class MetalRetriever(BaseRetriever):
self.params = params or {}
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
results = self.client.search({"text": query}, **self.params)
final_results = []
@ -36,10 +32,6 @@ class MetalRetriever(BaseRetriever):
return final_results
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

@ -1,5 +1,5 @@
import logging
from typing import Any, List
from typing import List
from pydantic import BaseModel, Field
@ -98,8 +98,8 @@ class MultiQueryRetriever(BaseRetriever):
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
"""Get relevated documents given a user query.
@ -117,8 +117,8 @@ class MultiQueryRetriever(BaseRetriever):
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError

@ -143,11 +143,7 @@ class PineconeHybridSearchRetriever(BaseRetriever, BaseModel):
return values
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
from pinecone_text.hybrid import hybrid_convex_scale
@ -174,10 +170,6 @@ class PineconeHybridSearchRetriever(BaseRetriever, BaseModel):
return final_result
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

@ -1,5 +1,5 @@
"""A retriever that uses PubMed API to retrieve documents."""
from typing import Any, List
from typing import List
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
@ -17,19 +17,11 @@ class PubMedRetriever(BaseRetriever, PubMedAPIWrapper):
"""
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
return self.load_docs(query=query)
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

@ -1,4 +1,4 @@
from typing import Any, List, Optional
from typing import List, Optional
import aiohttp
import requests
@ -20,11 +20,7 @@ class RemoteLangChainRetriever(BaseRetriever, BaseModel):
metadata_key: str = "metadata"
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
response = requests.post(
self.url, json={self.input_key: query}, headers=self.headers
@ -38,11 +34,7 @@ class RemoteLangChainRetriever(BaseRetriever, BaseModel):
]
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
async with aiohttp.ClientSession() as session:
async with session.request(

@ -84,11 +84,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
return values
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant for a query.
@ -121,11 +117,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
return docs
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun],
**kwargs: Any,
self, query: str, *, run_manager: Optional[AsyncCallbackManagerForRetrieverRun]
) -> List[Document]:
raise NotImplementedError

@ -55,11 +55,7 @@ class SVMRetriever(BaseRetriever, BaseModel):
return cls(embeddings=embeddings, index=index, texts=texts, **kwargs)
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
from sklearn import svm
@ -98,10 +94,6 @@ class SVMRetriever(BaseRetriever, BaseModel):
return top_k_results
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

@ -64,11 +64,7 @@ class TFIDFRetriever(BaseRetriever, BaseModel):
)
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
from sklearn.metrics.pairwise import cosine_similarity
@ -82,10 +78,6 @@ class TFIDFRetriever(BaseRetriever, BaseModel):
return return_docs
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

@ -86,11 +86,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
return results
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Return documents that are relevant to the query."""
current_time = datetime.datetime.now()
@ -115,11 +111,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
return result
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
"""Return documents that are relevant to the query."""
raise NotImplementedError

@ -65,22 +65,14 @@ class VespaRetriever(BaseRetriever):
return docs
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
body = self._query_body.copy()
body["query"] = query
return self._query(body)
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

@ -93,7 +93,6 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
*,
run_manager: CallbackManagerForRetrieverRun,
where_filter: Optional[Dict[str, object]] = None,
**kwargs: Any,
) -> List[Document]:
"""Look up similar documents in Weaviate."""
query_obj = self._client.query.get(self._index_name, self._query_attrs)
@ -112,11 +111,6 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
return docs
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
where_filter: Optional[Dict[str, object]] = None,
**kwargs: Any,
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

@ -1,4 +1,4 @@
from typing import Any, List
from typing import List
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
@ -16,19 +16,11 @@ class WikipediaRetriever(BaseRetriever, WikipediaAPIWrapper):
"""
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
return self.load(query=query)
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Dict, List, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
@ -64,7 +64,6 @@ class ZepRetriever(BaseRetriever):
*,
run_manager: CallbackManagerForRetrieverRun,
metadata: Optional[Dict] = None,
**kwargs: Any,
) -> List[Document]:
from zep_python import MemorySearchPayload
@ -84,7 +83,6 @@ class ZepRetriever(BaseRetriever):
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
metadata: Optional[Dict] = None,
**kwargs: Any,
) -> List[Document]:
from zep_python import MemorySearchPayload

@ -44,7 +44,6 @@ class BaseRetriever(ABC):
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError
""" # noqa: E501
_new_arg_supported: bool = False
@ -86,27 +85,23 @@ class BaseRetriever(ABC):
@abstractmethod
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant to a query.
Args:
query: String to find relevant documents for.
run_manager: The callbacks handler to use.
query: String to find relevant documents for
run_manager: The callbacks handler to use
Returns:
List of relevant documents
"""
@abstractmethod
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
"""Asynchronously get documents relevant to a query.
Args:
query: string to find relevant documents for
query: String to find relevant documents for
run_manager: The callbacks handler to use
Returns:
List of relevant documents
@ -117,8 +112,8 @@ class BaseRetriever(ABC):
) -> List[Document]:
"""Retrieve documents relevant to a query.
Args:
query: String to find relevant documents for.
callbacks: Callback manager or list of callbacks.
query: string to find relevant documents for
callbacks: Callback manager or list of callbacks
Returns:
List of relevant documents
"""
@ -132,14 +127,13 @@ class BaseRetriever(ABC):
**kwargs,
)
try:
_kwargs = kwargs if self._expects_other_args else {}
if self._new_arg_supported:
result = self._get_relevant_documents(
query, run_manager=run_manager, **kwargs
query, run_manager=run_manager, **_kwargs
)
elif self._expects_other_args:
result = self._get_relevant_documents(query, **kwargs)
else:
result = self._get_relevant_documents(query) # type: ignore[call-arg]
result = self._get_relevant_documents(query, **_kwargs)
except Exception as e:
run_manager.on_retriever_error(e)
raise e
@ -170,16 +164,13 @@ class BaseRetriever(ABC):
**kwargs,
)
try:
_kwargs = kwargs if self._expects_other_args else {}
if self._new_arg_supported:
result = await self._aget_relevant_documents(
query, run_manager=run_manager, **kwargs
query, run_manager=run_manager, **_kwargs
)
elif self._expects_other_args:
result = await self._aget_relevant_documents(query, **kwargs)
else:
result = await self._aget_relevant_documents(
query, # type: ignore[call-arg]
)
result = await self._aget_relevant_documents(query, **_kwargs)
except Exception as e:
await run_manager.on_retriever_error(e)
raise e

@ -497,8 +497,8 @@ class AzureSearchVectorStoreRetriever(BaseRetriever, BaseModel):
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
if self.search_type == "similarity":
docs = self.vectorstore.vector_search(query, k=self.k)
@ -513,8 +513,8 @@ class AzureSearchVectorStoreRetriever(BaseRetriever, BaseModel):
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError(
"AzureSearchVectorStoreRetriever does not support async"

@ -408,11 +408,7 @@ class VectorStoreRetriever(BaseRetriever, BaseModel):
return values
def _get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
self, query: str, *, run_manager: Optional[CallbackManagerForRetrieverRun]
) -> List[Document]:
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
@ -432,11 +428,7 @@ class VectorStoreRetriever(BaseRetriever, BaseModel):
return docs
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
self, query: str, *, run_manager: Optional[AsyncCallbackManagerForRetrieverRun]
) -> List[Document]:
if self.search_type == "similarity":
docs = await self.vectorstore.asimilarity_search(

@ -620,11 +620,7 @@ class RedisVectorStoreRetriever(VectorStoreRetriever, BaseModel):
return values
def _get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
self, query: str, *, run_manager: Optional[CallbackManagerForRetrieverRun]
) -> List[Document]:
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, k=self.k)
@ -637,11 +633,7 @@ class RedisVectorStoreRetriever(VectorStoreRetriever, BaseModel):
return docs
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
self, query: str, *, run_manager: Optional[AsyncCallbackManagerForRetrieverRun]
) -> List[Document]:
raise NotImplementedError("RedisVectorStoreRetriever does not support async")

@ -451,11 +451,7 @@ class SingleStoreDBRetriever(VectorStoreRetriever):
allowed_search_types: ClassVar[Collection[str]] = ("similarity",)
def _get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
self, query: str, *, run_manager: Optional[CallbackManagerForRetrieverRun]
) -> List[Document]:
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, k=self.k)
@ -464,11 +460,7 @@ class SingleStoreDBRetriever(VectorStoreRetriever):
return docs
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
self, query: str, *, run_manager: Optional[AsyncCallbackManagerForRetrieverRun]
) -> List[Document]:
raise NotImplementedError(
"SingleStoreDBVectorStoreRetriever does not support async"

@ -7,18 +7,19 @@ from langchain.document_loaders import TextLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms import OpenAI
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.vectorstores import FAISS
def test_retrieval_qa_saving_loading(tmp_path: Path) -> None:
"""Test saving and loading."""
loader = TextLoader("docs/modules/state_of_the_union.txt")
loader = TextLoader("docs/extras/modules/state_of_the_union.txt")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = text_splitter.split_documents(documents)
embeddings = OpenAIEmbeddings()
docsearch = Chroma.from_documents(texts, embeddings)
docsearch = FAISS.from_documents(texts, embeddings)
qa = RetrievalQA.from_llm(llm=OpenAI(), retriever=docsearch.as_retriever())
qa.run("What did the president say about Ketanji Brown Jackson?")
file_path = tmp_path / "RetrievalQA_chain.yaml"
qa.save(file_path=file_path)

@ -1,7 +1,7 @@
from langchain.embeddings import OpenAIEmbeddings
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.vectorstores import Chroma
from langchain.vectorstores import FAISS
def test_contextual_compression_retriever_get_relevant_docs() -> None:
@ -13,7 +13,7 @@ def test_contextual_compression_retriever_get_relevant_docs() -> None:
]
embeddings = OpenAIEmbeddings()
base_compressor = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.75)
base_retriever = Chroma.from_texts(texts, embedding=embeddings).as_retriever(
base_retriever = FAISS.from_texts(texts, embedding=embeddings).as_retriever(
search_kwargs={"k": len(texts)}
)
retriever = ContextualCompressionRetriever(

@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional
from typing import Dict, List, Optional
import pytest
@ -141,45 +141,48 @@ async def test_fake_retriever_v1_with_kwargs_upgrade_async(
assert callbacks.retriever_errors == 0
class FakeRetrieverV2(BaseRetriever):
def __init__(self, throw_error: bool = False) -> None:
self.throw_error = throw_error
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun | None
) -> List[Document]:
assert isinstance(self, FakeRetrieverV2)
assert run_manager is not None
assert isinstance(run_manager, CallbackManagerForRetrieverRun)
if self.throw_error:
raise ValueError("Test error")
return [
Document(page_content=query),
]
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun | None
) -> List[Document]:
assert isinstance(self, FakeRetrieverV2)
assert run_manager is not None
assert isinstance(run_manager, AsyncCallbackManagerForRetrieverRun)
if self.throw_error:
raise ValueError("Test error")
return [
Document(page_content=f"Async query {query}"),
]
@pytest.fixture
def fake_retriever_v2() -> BaseRetriever:
class FakeRetrieverV2(BaseRetriever):
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun | None = None,
**kwargs: Any,
) -> List[Document]:
assert isinstance(self, FakeRetrieverV2)
assert run_manager is not None
assert isinstance(run_manager, CallbackManagerForRetrieverRun)
if "throw_error" in kwargs:
raise ValueError("Test error")
return [
Document(page_content=query, metadata=kwargs),
]
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun | None = None,
**kwargs: Any,
) -> List[Document]:
assert isinstance(self, FakeRetrieverV2)
assert run_manager is not None
assert isinstance(run_manager, AsyncCallbackManagerForRetrieverRun)
if "throw_error" in kwargs:
raise ValueError("Test error")
return [
Document(page_content=f"Async query {query}", metadata=kwargs),
]
return FakeRetrieverV2() # type: ignore[abstract]
def test_fake_retriever_v2(fake_retriever_v2: BaseRetriever) -> None:
@pytest.fixture
def fake_erroring_retriever_v2() -> BaseRetriever:
return FakeRetrieverV2(throw_error=True) # type: ignore[abstract]
def test_fake_retriever_v2(
fake_retriever_v2: BaseRetriever, fake_erroring_retriever_v2: BaseRetriever
) -> None:
callbacks = FakeCallbackHandler()
assert fake_retriever_v2._new_arg_supported is True
results = fake_retriever_v2.get_relevant_documents("Foo", callbacks=[callbacks])
@ -187,20 +190,17 @@ def test_fake_retriever_v2(fake_retriever_v2: BaseRetriever) -> None:
assert callbacks.retriever_starts == 1
assert callbacks.retriever_ends == 1
assert callbacks.retriever_errors == 0
results2 = fake_retriever_v2.get_relevant_documents(
"Foo", callbacks=[callbacks], foo="bar"
)
assert results2[0].metadata == {"foo": "bar"}
fake_retriever_v2.get_relevant_documents("Foo", callbacks=[callbacks])
with pytest.raises(ValueError, match="Test error"):
fake_retriever_v2.get_relevant_documents(
"Foo", callbacks=[callbacks], throw_error=True
)
fake_erroring_retriever_v2.get_relevant_documents("Foo", callbacks=[callbacks])
assert callbacks.retriever_errors == 1
@pytest.mark.asyncio
async def test_fake_retriever_v2_async(fake_retriever_v2: BaseRetriever) -> None:
async def test_fake_retriever_v2_async(
fake_retriever_v2: BaseRetriever, fake_erroring_retriever_v2: BaseRetriever
) -> None:
callbacks = FakeCallbackHandler()
assert fake_retriever_v2._new_arg_supported is True
results = await fake_retriever_v2.aget_relevant_documents(
@ -210,11 +210,8 @@ async def test_fake_retriever_v2_async(fake_retriever_v2: BaseRetriever) -> None
assert callbacks.retriever_starts == 1
assert callbacks.retriever_ends == 1
assert callbacks.retriever_errors == 0
results2 = await fake_retriever_v2.aget_relevant_documents(
"Foo", callbacks=[callbacks], foo="bar"
)
assert results2[0].metadata == {"foo": "bar"}
await fake_retriever_v2.aget_relevant_documents("Foo", callbacks=[callbacks])
with pytest.raises(ValueError, match="Test error"):
await fake_retriever_v2.aget_relevant_documents(
"Foo", callbacks=[callbacks], throw_error=True
await fake_erroring_retriever_v2.aget_relevant_documents(
"Foo", callbacks=[callbacks]
)

Loading…
Cancel
Save