Rm retriever kwargs (#7013)

Doesn't actually limit the Retriever interface but hopefully in practice
it does
This commit is contained in:
Bagatur 2023-07-02 08:22:24 -06:00 committed by GitHub
parent 9dc77614e3
commit 7acd524210
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 131 additions and 319 deletions

View File

@ -14,9 +14,10 @@ from langchain.callbacks.manager import (
) )
class BaseRetriever(ABC): class BaseRetriever(ABC):
@abstractmethod @abstractmethod
def _get_relevant_documents( def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]: ) -> List[Document]:
"""Get documents relevant to a query. """Get documents relevant to a query.
Args: Args:
@ -32,7 +33,6 @@ class BaseRetriever(ABC):
query: str, query: str,
*, *,
run_manager: AsyncCallbackManagerForRetrieverRun, run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Asynchronously get documents relevant to a query. """Asynchronously get documents relevant to a query.
Args: Args:
@ -110,7 +110,7 @@ class NumpyRetriever(BaseRetriever):
] ]
def _get_relevant_documents( def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]: ) -> List[Document]:
"""Get documents relevant to a query. """Get documents relevant to a query.
Args: Args:
@ -127,7 +127,6 @@ class NumpyRetriever(BaseRetriever):
query: str, query: str,
*, *,
run_manager: AsyncCallbackManagerForRetrieverRun, run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Asynchronously get documents relevant to a query. """Asynchronously get documents relevant to a query.
Args: Args:

View File

@ -1,4 +1,4 @@
from typing import Any, List from typing import List
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun,
@ -16,19 +16,11 @@ class ArxivRetriever(BaseRetriever, ArxivAPIWrapper):
""" """
def _get_relevant_documents( def _get_relevant_documents(
self, self, query: str, *, run_manager: CallbackManagerForRetrieverRun
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
return self.load(query=query) return self.load(query=query)
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
raise NotImplementedError raise NotImplementedError

View File

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import json import json
from typing import Any, Dict, List, Optional from typing import Dict, List, Optional
import aiohttp import aiohttp
import requests import requests
@ -87,11 +87,7 @@ class AzureCognitiveSearchRetriever(BaseRetriever, BaseModel):
return response_json["value"] return response_json["value"]
def _get_relevant_documents( def _get_relevant_documents(
self, self, query: str, *, run_manager: CallbackManagerForRetrieverRun
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
search_results = self._search(query) search_results = self._search(query)
@ -101,11 +97,7 @@ class AzureCognitiveSearchRetriever(BaseRetriever, BaseModel):
] ]
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
search_results = await self._asearch(query) search_results = await self._asearch(query)

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, List, Optional from typing import List, Optional
import aiohttp import aiohttp
import requests import requests
@ -26,11 +26,7 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
arbitrary_types_allowed = True arbitrary_types_allowed = True
def _get_relevant_documents( def _get_relevant_documents(
self, self, query: str, *, run_manager: CallbackManagerForRetrieverRun
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
url, json, headers = self._create_request(query) url, json, headers = self._create_request(query)
response = requests.post(url, json=json, headers=headers) response = requests.post(url, json=json, headers=headers)
@ -45,11 +41,7 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
return docs return docs
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
url, json, headers = self._create_request(query) url, json, headers = self._create_request(query)

View File

@ -1,4 +1,4 @@
from typing import Any, List, Optional from typing import List, Optional
import aiohttp import aiohttp
import requests import requests
@ -28,11 +28,7 @@ class DataberryRetriever(BaseRetriever):
self.top_k = top_k self.top_k = top_k
def _get_relevant_documents( def _get_relevant_documents(
self, self, query: str, *, run_manager: CallbackManagerForRetrieverRun
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
response = requests.post( response = requests.post(
self.datastore_url, self.datastore_url,
@ -59,11 +55,7 @@ class DataberryRetriever(BaseRetriever):
] ]
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.request( async with session.request(

View File

@ -56,8 +56,8 @@ class DocArrayRetriever(BaseRetriever, BaseModel):
def _get_relevant_documents( def _get_relevant_documents(
self, self,
query: str, query: str,
*,
run_manager: CallbackManagerForRetrieverRun, run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Get documents relevant for a query. """Get documents relevant for a query.
@ -213,7 +213,7 @@ class DocArrayRetriever(BaseRetriever, BaseModel):
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self,
query: str, query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun, run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
raise NotImplementedError raise NotImplementedError

View File

@ -117,11 +117,7 @@ class ElasticSearchBM25Retriever(BaseRetriever):
return ids return ids
def _get_relevant_documents( def _get_relevant_documents(
self, self, query: str, *, run_manager: CallbackManagerForRetrieverRun
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
query_dict = {"query": {"match": {"content": query}}} query_dict = {"query": {"match": {"content": query}}}
res = self.client.search(index=self.index_name, body=query_dict) res = self.client.search(index=self.index_name, body=query_dict)
@ -132,10 +128,6 @@ class ElasticSearchBM25Retriever(BaseRetriever):
return docs return docs
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
raise NotImplementedError raise NotImplementedError

View File

@ -264,8 +264,8 @@ class AmazonKendraRetriever(BaseRetriever):
def _get_relevant_documents( def _get_relevant_documents(
self, self,
query: str, query: str,
*,
run_manager: CallbackManagerForRetrieverRun, run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Run search on Kendra index and get top k documents """Run search on Kendra index and get top k documents
@ -281,7 +281,7 @@ class AmazonKendraRetriever(BaseRetriever):
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self,
query: str, query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun, run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
raise NotImplementedError("Async version is not implemented for Kendra yet.") raise NotImplementedError("Async version is not implemented for Kendra yet.")

View File

@ -56,11 +56,7 @@ class KNNRetriever(BaseRetriever, BaseModel):
return cls(embeddings=embeddings, index=index, texts=texts, **kwargs) return cls(embeddings=embeddings, index=index, texts=texts, **kwargs)
def _get_relevant_documents( def _get_relevant_documents(
self, self, query: str, *, run_manager: CallbackManagerForRetrieverRun
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
query_embeds = np.array(self.embeddings.embed_query(query)) query_embeds = np.array(self.embeddings.embed_query(query))
# calc L2 norm # calc L2 norm
@ -84,10 +80,6 @@ class KNNRetriever(BaseRetriever, BaseModel):
return top_k_results return top_k_results
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
raise NotImplementedError raise NotImplementedError

View File

@ -16,11 +16,7 @@ class LlamaIndexRetriever(BaseRetriever, BaseModel):
query_kwargs: Dict = Field(default_factory=dict) query_kwargs: Dict = Field(default_factory=dict)
def _get_relevant_documents( def _get_relevant_documents(
self, self, query: str, *, run_manager: CallbackManagerForRetrieverRun
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Get documents relevant for a query.""" """Get documents relevant for a query."""
try: try:
@ -44,11 +40,7 @@ class LlamaIndexRetriever(BaseRetriever, BaseModel):
return docs return docs
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self, query: str, *, run_manager: Optional[AsyncCallbackManagerForRetrieverRun]
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
raise NotImplementedError("LlamaIndexRetriever does not support async") raise NotImplementedError("LlamaIndexRetriever does not support async")
@ -60,11 +52,7 @@ class LlamaIndexGraphRetriever(BaseRetriever, BaseModel):
query_configs: List[Dict] = Field(default_factory=list) query_configs: List[Dict] = Field(default_factory=list)
def _get_relevant_documents( def _get_relevant_documents(
self, self, query: str, *, run_manager: CallbackManagerForRetrieverRun
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Get documents relevant for a query.""" """Get documents relevant for a query."""
try: try:
@ -96,10 +84,6 @@ class LlamaIndexGraphRetriever(BaseRetriever, BaseModel):
return docs return docs
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
raise NotImplementedError("LlamaIndexGraphRetriever does not support async") raise NotImplementedError("LlamaIndexGraphRetriever does not support async")

View File

@ -1,4 +1,4 @@
from typing import Any, List from typing import List
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun,
@ -31,8 +31,8 @@ class MergerRetriever(BaseRetriever):
def _get_relevant_documents( def _get_relevant_documents(
self, self,
query: str, query: str,
*,
run_manager: CallbackManagerForRetrieverRun, run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
""" """
Get the relevant documents for a given query. Get the relevant documents for a given query.
@ -52,8 +52,8 @@ class MergerRetriever(BaseRetriever):
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self,
query: str, query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun, run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
""" """
Asynchronously get the relevant documents for a given query. Asynchronously get the relevant documents for a given query.

View File

@ -22,11 +22,7 @@ class MetalRetriever(BaseRetriever):
self.params = params or {} self.params = params or {}
def _get_relevant_documents( def _get_relevant_documents(
self, self, query: str, *, run_manager: CallbackManagerForRetrieverRun
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
results = self.client.search({"text": query}, **self.params) results = self.client.search({"text": query}, **self.params)
final_results = [] final_results = []
@ -36,10 +32,6 @@ class MetalRetriever(BaseRetriever):
return final_results return final_results
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
raise NotImplementedError raise NotImplementedError

View File

@ -1,5 +1,5 @@
import logging import logging
from typing import Any, List from typing import List
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -98,8 +98,8 @@ class MultiQueryRetriever(BaseRetriever):
def _get_relevant_documents( def _get_relevant_documents(
self, self,
query: str, query: str,
*,
run_manager: CallbackManagerForRetrieverRun, run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Get relevated documents given a user query. """Get relevated documents given a user query.
@ -117,8 +117,8 @@ class MultiQueryRetriever(BaseRetriever):
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self,
query: str, query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun, run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
raise NotImplementedError raise NotImplementedError

View File

@ -143,11 +143,7 @@ class PineconeHybridSearchRetriever(BaseRetriever, BaseModel):
return values return values
def _get_relevant_documents( def _get_relevant_documents(
self, self, query: str, *, run_manager: CallbackManagerForRetrieverRun
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
from pinecone_text.hybrid import hybrid_convex_scale from pinecone_text.hybrid import hybrid_convex_scale
@ -174,10 +170,6 @@ class PineconeHybridSearchRetriever(BaseRetriever, BaseModel):
return final_result return final_result
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
raise NotImplementedError raise NotImplementedError

View File

@ -1,5 +1,5 @@
"""A retriever that uses PubMed API to retrieve documents.""" """A retriever that uses PubMed API to retrieve documents."""
from typing import Any, List from typing import List
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun,
@ -17,19 +17,11 @@ class PubMedRetriever(BaseRetriever, PubMedAPIWrapper):
""" """
def _get_relevant_documents( def _get_relevant_documents(
self, self, query: str, *, run_manager: CallbackManagerForRetrieverRun
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
return self.load_docs(query=query) return self.load_docs(query=query)
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
raise NotImplementedError raise NotImplementedError

View File

@ -1,4 +1,4 @@
from typing import Any, List, Optional from typing import List, Optional
import aiohttp import aiohttp
import requests import requests
@ -20,11 +20,7 @@ class RemoteLangChainRetriever(BaseRetriever, BaseModel):
metadata_key: str = "metadata" metadata_key: str = "metadata"
def _get_relevant_documents( def _get_relevant_documents(
self, self, query: str, *, run_manager: CallbackManagerForRetrieverRun
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
response = requests.post( response = requests.post(
self.url, json={self.input_key: query}, headers=self.headers self.url, json={self.input_key: query}, headers=self.headers
@ -38,11 +34,7 @@ class RemoteLangChainRetriever(BaseRetriever, BaseModel):
] ]
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.request( async with session.request(

View File

@ -84,11 +84,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
return values return values
def _get_relevant_documents( def _get_relevant_documents(
self, self, query: str, *, run_manager: CallbackManagerForRetrieverRun
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Get documents relevant for a query. """Get documents relevant for a query.
@ -121,11 +117,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
return docs return docs
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self, query: str, *, run_manager: Optional[AsyncCallbackManagerForRetrieverRun]
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun],
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
raise NotImplementedError raise NotImplementedError

View File

@ -55,11 +55,7 @@ class SVMRetriever(BaseRetriever, BaseModel):
return cls(embeddings=embeddings, index=index, texts=texts, **kwargs) return cls(embeddings=embeddings, index=index, texts=texts, **kwargs)
def _get_relevant_documents( def _get_relevant_documents(
self, self, query: str, *, run_manager: CallbackManagerForRetrieverRun
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
from sklearn import svm from sklearn import svm
@ -98,10 +94,6 @@ class SVMRetriever(BaseRetriever, BaseModel):
return top_k_results return top_k_results
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
raise NotImplementedError raise NotImplementedError

View File

@ -64,11 +64,7 @@ class TFIDFRetriever(BaseRetriever, BaseModel):
) )
def _get_relevant_documents( def _get_relevant_documents(
self, self, query: str, *, run_manager: CallbackManagerForRetrieverRun
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
from sklearn.metrics.pairwise import cosine_similarity from sklearn.metrics.pairwise import cosine_similarity
@ -82,10 +78,6 @@ class TFIDFRetriever(BaseRetriever, BaseModel):
return return_docs return return_docs
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
raise NotImplementedError raise NotImplementedError

View File

@ -86,11 +86,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
return results return results
def _get_relevant_documents( def _get_relevant_documents(
self, self, query: str, *, run_manager: CallbackManagerForRetrieverRun
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Return documents that are relevant to the query.""" """Return documents that are relevant to the query."""
current_time = datetime.datetime.now() current_time = datetime.datetime.now()
@ -115,11 +111,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
return result return result
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Return documents that are relevant to the query.""" """Return documents that are relevant to the query."""
raise NotImplementedError raise NotImplementedError

View File

@ -65,22 +65,14 @@ class VespaRetriever(BaseRetriever):
return docs return docs
def _get_relevant_documents( def _get_relevant_documents(
self, self, query: str, *, run_manager: CallbackManagerForRetrieverRun
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
body = self._query_body.copy() body = self._query_body.copy()
body["query"] = query body["query"] = query
return self._query(body) return self._query(body)
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
raise NotImplementedError raise NotImplementedError

View File

@ -93,7 +93,6 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
*, *,
run_manager: CallbackManagerForRetrieverRun, run_manager: CallbackManagerForRetrieverRun,
where_filter: Optional[Dict[str, object]] = None, where_filter: Optional[Dict[str, object]] = None,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Look up similar documents in Weaviate.""" """Look up similar documents in Weaviate."""
query_obj = self._client.query.get(self._index_name, self._query_attrs) query_obj = self._client.query.get(self._index_name, self._query_attrs)
@ -112,11 +111,6 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
return docs return docs
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
where_filter: Optional[Dict[str, object]] = None,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
raise NotImplementedError raise NotImplementedError

View File

@ -1,4 +1,4 @@
from typing import Any, List from typing import List
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun,
@ -16,19 +16,11 @@ class WikipediaRetriever(BaseRetriever, WikipediaAPIWrapper):
""" """
def _get_relevant_documents( def _get_relevant_documents(
self, self, query: str, *, run_manager: CallbackManagerForRetrieverRun
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
return self.load(query=query) return self.load(query=query)
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
raise NotImplementedError raise NotImplementedError

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Dict, List, Optional
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun,
@ -64,7 +64,6 @@ class ZepRetriever(BaseRetriever):
*, *,
run_manager: CallbackManagerForRetrieverRun, run_manager: CallbackManagerForRetrieverRun,
metadata: Optional[Dict] = None, metadata: Optional[Dict] = None,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
from zep_python import MemorySearchPayload from zep_python import MemorySearchPayload
@ -84,7 +83,6 @@ class ZepRetriever(BaseRetriever):
*, *,
run_manager: AsyncCallbackManagerForRetrieverRun, run_manager: AsyncCallbackManagerForRetrieverRun,
metadata: Optional[Dict] = None, metadata: Optional[Dict] = None,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
from zep_python import MemorySearchPayload from zep_python import MemorySearchPayload

View File

@ -44,7 +44,6 @@ class BaseRetriever(ABC):
async def aget_relevant_documents(self, query: str) -> List[Document]: async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError raise NotImplementedError
""" # noqa: E501 """ # noqa: E501
_new_arg_supported: bool = False _new_arg_supported: bool = False
@ -86,27 +85,23 @@ class BaseRetriever(ABC):
@abstractmethod @abstractmethod
def _get_relevant_documents( def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]: ) -> List[Document]:
"""Get documents relevant to a query. """Get documents relevant to a query.
Args: Args:
query: String to find relevant documents for. query: String to find relevant documents for
run_manager: The callbacks handler to use. run_manager: The callbacks handler to use
Returns: Returns:
List of relevant documents List of relevant documents
""" """
@abstractmethod @abstractmethod
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
"""Asynchronously get documents relevant to a query. """Asynchronously get documents relevant to a query.
Args: Args:
query: string to find relevant documents for query: String to find relevant documents for
run_manager: The callbacks handler to use run_manager: The callbacks handler to use
Returns: Returns:
List of relevant documents List of relevant documents
@ -117,8 +112,8 @@ class BaseRetriever(ABC):
) -> List[Document]: ) -> List[Document]:
"""Retrieve documents relevant to a query. """Retrieve documents relevant to a query.
Args: Args:
query: String to find relevant documents for. query: string to find relevant documents for
callbacks: Callback manager or list of callbacks. callbacks: Callback manager or list of callbacks
Returns: Returns:
List of relevant documents List of relevant documents
""" """
@ -132,14 +127,13 @@ class BaseRetriever(ABC):
**kwargs, **kwargs,
) )
try: try:
_kwargs = kwargs if self._expects_other_args else {}
if self._new_arg_supported: if self._new_arg_supported:
result = self._get_relevant_documents( result = self._get_relevant_documents(
query, run_manager=run_manager, **kwargs query, run_manager=run_manager, **_kwargs
) )
elif self._expects_other_args:
result = self._get_relevant_documents(query, **kwargs)
else: else:
result = self._get_relevant_documents(query) # type: ignore[call-arg] result = self._get_relevant_documents(query, **_kwargs)
except Exception as e: except Exception as e:
run_manager.on_retriever_error(e) run_manager.on_retriever_error(e)
raise e raise e
@ -170,16 +164,13 @@ class BaseRetriever(ABC):
**kwargs, **kwargs,
) )
try: try:
_kwargs = kwargs if self._expects_other_args else {}
if self._new_arg_supported: if self._new_arg_supported:
result = await self._aget_relevant_documents( result = await self._aget_relevant_documents(
query, run_manager=run_manager, **kwargs query, run_manager=run_manager, **_kwargs
) )
elif self._expects_other_args:
result = await self._aget_relevant_documents(query, **kwargs)
else: else:
result = await self._aget_relevant_documents( result = await self._aget_relevant_documents(query, **_kwargs)
query, # type: ignore[call-arg]
)
except Exception as e: except Exception as e:
await run_manager.on_retriever_error(e) await run_manager.on_retriever_error(e)
raise e raise e

View File

@ -497,8 +497,8 @@ class AzureSearchVectorStoreRetriever(BaseRetriever, BaseModel):
def _get_relevant_documents( def _get_relevant_documents(
self, self,
query: str, query: str,
*,
run_manager: CallbackManagerForRetrieverRun, run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
if self.search_type == "similarity": if self.search_type == "similarity":
docs = self.vectorstore.vector_search(query, k=self.k) docs = self.vectorstore.vector_search(query, k=self.k)
@ -513,8 +513,8 @@ class AzureSearchVectorStoreRetriever(BaseRetriever, BaseModel):
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self,
query: str, query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun, run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
raise NotImplementedError( raise NotImplementedError(
"AzureSearchVectorStoreRetriever does not support async" "AzureSearchVectorStoreRetriever does not support async"

View File

@ -408,11 +408,7 @@ class VectorStoreRetriever(BaseRetriever, BaseModel):
return values return values
def _get_relevant_documents( def _get_relevant_documents(
self, self, query: str, *, run_manager: Optional[CallbackManagerForRetrieverRun]
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
if self.search_type == "similarity": if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, **self.search_kwargs) docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
@ -432,11 +428,7 @@ class VectorStoreRetriever(BaseRetriever, BaseModel):
return docs return docs
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self, query: str, *, run_manager: Optional[AsyncCallbackManagerForRetrieverRun]
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
if self.search_type == "similarity": if self.search_type == "similarity":
docs = await self.vectorstore.asimilarity_search( docs = await self.vectorstore.asimilarity_search(

View File

@ -620,11 +620,7 @@ class RedisVectorStoreRetriever(VectorStoreRetriever, BaseModel):
return values return values
def _get_relevant_documents( def _get_relevant_documents(
self, self, query: str, *, run_manager: Optional[CallbackManagerForRetrieverRun]
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
if self.search_type == "similarity": if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, k=self.k) docs = self.vectorstore.similarity_search(query, k=self.k)
@ -637,11 +633,7 @@ class RedisVectorStoreRetriever(VectorStoreRetriever, BaseModel):
return docs return docs
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self, query: str, *, run_manager: Optional[AsyncCallbackManagerForRetrieverRun]
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
raise NotImplementedError("RedisVectorStoreRetriever does not support async") raise NotImplementedError("RedisVectorStoreRetriever does not support async")

View File

@ -451,11 +451,7 @@ class SingleStoreDBRetriever(VectorStoreRetriever):
allowed_search_types: ClassVar[Collection[str]] = ("similarity",) allowed_search_types: ClassVar[Collection[str]] = ("similarity",)
def _get_relevant_documents( def _get_relevant_documents(
self, self, query: str, *, run_manager: Optional[CallbackManagerForRetrieverRun]
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
if self.search_type == "similarity": if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, k=self.k) docs = self.vectorstore.similarity_search(query, k=self.k)
@ -464,11 +460,7 @@ class SingleStoreDBRetriever(VectorStoreRetriever):
return docs return docs
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self, query: str, *, run_manager: Optional[AsyncCallbackManagerForRetrieverRun]
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
raise NotImplementedError( raise NotImplementedError(
"SingleStoreDBVectorStoreRetriever does not support async" "SingleStoreDBVectorStoreRetriever does not support async"

View File

@ -7,18 +7,19 @@ from langchain.document_loaders import TextLoader
from langchain.embeddings.openai import OpenAIEmbeddings from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms import OpenAI from langchain.llms import OpenAI
from langchain.text_splitter import CharacterTextSplitter from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma from langchain.vectorstores import FAISS
def test_retrieval_qa_saving_loading(tmp_path: Path) -> None: def test_retrieval_qa_saving_loading(tmp_path: Path) -> None:
"""Test saving and loading.""" """Test saving and loading."""
loader = TextLoader("docs/modules/state_of_the_union.txt") loader = TextLoader("docs/extras/modules/state_of_the_union.txt")
documents = loader.load() documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = text_splitter.split_documents(documents) texts = text_splitter.split_documents(documents)
embeddings = OpenAIEmbeddings() embeddings = OpenAIEmbeddings()
docsearch = Chroma.from_documents(texts, embeddings) docsearch = FAISS.from_documents(texts, embeddings)
qa = RetrievalQA.from_llm(llm=OpenAI(), retriever=docsearch.as_retriever()) qa = RetrievalQA.from_llm(llm=OpenAI(), retriever=docsearch.as_retriever())
qa.run("What did the president say about Ketanji Brown Jackson?")
file_path = tmp_path / "RetrievalQA_chain.yaml" file_path = tmp_path / "RetrievalQA_chain.yaml"
qa.save(file_path=file_path) qa.save(file_path=file_path)

View File

@ -1,7 +1,7 @@
from langchain.embeddings import OpenAIEmbeddings from langchain.embeddings import OpenAIEmbeddings
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import EmbeddingsFilter from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.vectorstores import Chroma from langchain.vectorstores import FAISS
def test_contextual_compression_retriever_get_relevant_docs() -> None: def test_contextual_compression_retriever_get_relevant_docs() -> None:
@ -13,7 +13,7 @@ def test_contextual_compression_retriever_get_relevant_docs() -> None:
] ]
embeddings = OpenAIEmbeddings() embeddings = OpenAIEmbeddings()
base_compressor = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.75) base_compressor = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.75)
base_retriever = Chroma.from_texts(texts, embedding=embeddings).as_retriever( base_retriever = FAISS.from_texts(texts, embedding=embeddings).as_retriever(
search_kwargs={"k": len(texts)} search_kwargs={"k": len(texts)}
) )
retriever = ContextualCompressionRetriever( retriever = ContextualCompressionRetriever(

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Optional from typing import Dict, List, Optional
import pytest import pytest
@ -141,45 +141,48 @@ async def test_fake_retriever_v1_with_kwargs_upgrade_async(
assert callbacks.retriever_errors == 0 assert callbacks.retriever_errors == 0
class FakeRetrieverV2(BaseRetriever):
def __init__(self, throw_error: bool = False) -> None:
self.throw_error = throw_error
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun | None
) -> List[Document]:
assert isinstance(self, FakeRetrieverV2)
assert run_manager is not None
assert isinstance(run_manager, CallbackManagerForRetrieverRun)
if self.throw_error:
raise ValueError("Test error")
return [
Document(page_content=query),
]
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun | None
) -> List[Document]:
assert isinstance(self, FakeRetrieverV2)
assert run_manager is not None
assert isinstance(run_manager, AsyncCallbackManagerForRetrieverRun)
if self.throw_error:
raise ValueError("Test error")
return [
Document(page_content=f"Async query {query}"),
]
@pytest.fixture @pytest.fixture
def fake_retriever_v2() -> BaseRetriever: def fake_retriever_v2() -> BaseRetriever:
class FakeRetrieverV2(BaseRetriever):
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun | None = None,
**kwargs: Any,
) -> List[Document]:
assert isinstance(self, FakeRetrieverV2)
assert run_manager is not None
assert isinstance(run_manager, CallbackManagerForRetrieverRun)
if "throw_error" in kwargs:
raise ValueError("Test error")
return [
Document(page_content=query, metadata=kwargs),
]
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun | None = None,
**kwargs: Any,
) -> List[Document]:
assert isinstance(self, FakeRetrieverV2)
assert run_manager is not None
assert isinstance(run_manager, AsyncCallbackManagerForRetrieverRun)
if "throw_error" in kwargs:
raise ValueError("Test error")
return [
Document(page_content=f"Async query {query}", metadata=kwargs),
]
return FakeRetrieverV2() # type: ignore[abstract] return FakeRetrieverV2() # type: ignore[abstract]
def test_fake_retriever_v2(fake_retriever_v2: BaseRetriever) -> None: @pytest.fixture
def fake_erroring_retriever_v2() -> BaseRetriever:
return FakeRetrieverV2(throw_error=True) # type: ignore[abstract]
def test_fake_retriever_v2(
fake_retriever_v2: BaseRetriever, fake_erroring_retriever_v2: BaseRetriever
) -> None:
callbacks = FakeCallbackHandler() callbacks = FakeCallbackHandler()
assert fake_retriever_v2._new_arg_supported is True assert fake_retriever_v2._new_arg_supported is True
results = fake_retriever_v2.get_relevant_documents("Foo", callbacks=[callbacks]) results = fake_retriever_v2.get_relevant_documents("Foo", callbacks=[callbacks])
@ -187,20 +190,17 @@ def test_fake_retriever_v2(fake_retriever_v2: BaseRetriever) -> None:
assert callbacks.retriever_starts == 1 assert callbacks.retriever_starts == 1
assert callbacks.retriever_ends == 1 assert callbacks.retriever_ends == 1
assert callbacks.retriever_errors == 0 assert callbacks.retriever_errors == 0
results2 = fake_retriever_v2.get_relevant_documents( fake_retriever_v2.get_relevant_documents("Foo", callbacks=[callbacks])
"Foo", callbacks=[callbacks], foo="bar"
)
assert results2[0].metadata == {"foo": "bar"}
with pytest.raises(ValueError, match="Test error"): with pytest.raises(ValueError, match="Test error"):
fake_retriever_v2.get_relevant_documents( fake_erroring_retriever_v2.get_relevant_documents("Foo", callbacks=[callbacks])
"Foo", callbacks=[callbacks], throw_error=True
)
assert callbacks.retriever_errors == 1 assert callbacks.retriever_errors == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fake_retriever_v2_async(fake_retriever_v2: BaseRetriever) -> None: async def test_fake_retriever_v2_async(
fake_retriever_v2: BaseRetriever, fake_erroring_retriever_v2: BaseRetriever
) -> None:
callbacks = FakeCallbackHandler() callbacks = FakeCallbackHandler()
assert fake_retriever_v2._new_arg_supported is True assert fake_retriever_v2._new_arg_supported is True
results = await fake_retriever_v2.aget_relevant_documents( results = await fake_retriever_v2.aget_relevant_documents(
@ -210,11 +210,8 @@ async def test_fake_retriever_v2_async(fake_retriever_v2: BaseRetriever) -> None
assert callbacks.retriever_starts == 1 assert callbacks.retriever_starts == 1
assert callbacks.retriever_ends == 1 assert callbacks.retriever_ends == 1
assert callbacks.retriever_errors == 0 assert callbacks.retriever_errors == 0
results2 = await fake_retriever_v2.aget_relevant_documents( await fake_retriever_v2.aget_relevant_documents("Foo", callbacks=[callbacks])
"Foo", callbacks=[callbacks], foo="bar"
)
assert results2[0].metadata == {"foo": "bar"}
with pytest.raises(ValueError, match="Test error"): with pytest.raises(ValueError, match="Test error"):
await fake_retriever_v2.aget_relevant_documents( await fake_erroring_retriever_v2.aget_relevant_documents(
"Foo", callbacks=[callbacks], throw_error=True "Foo", callbacks=[callbacks]
) )