mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Rm retriever kwargs (#7013)
Doesn't actually limit the Retriever interface but hopefully in practice it does
This commit is contained in:
parent
9dc77614e3
commit
7acd524210
@ -14,9 +14,10 @@ from langchain.callbacks.manager import (
|
||||
)
|
||||
|
||||
class BaseRetriever(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant to a query.
|
||||
Args:
|
||||
@ -32,7 +33,6 @@ class BaseRetriever(ABC):
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Asynchronously get documents relevant to a query.
|
||||
Args:
|
||||
@ -110,7 +110,7 @@ class NumpyRetriever(BaseRetriever):
|
||||
]
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant to a query.
|
||||
Args:
|
||||
@ -127,7 +127,6 @@ class NumpyRetriever(BaseRetriever):
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Asynchronously get documents relevant to a query.
|
||||
Args:
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, List
|
||||
from typing import List
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
@ -16,19 +16,11 @@ class ArxivRetriever(BaseRetriever, ArxivAPIWrapper):
|
||||
"""
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
return self.load(query=query)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
@ -87,11 +87,7 @@ class AzureCognitiveSearchRetriever(BaseRetriever, BaseModel):
|
||||
return response_json["value"]
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
search_results = self._search(query)
|
||||
|
||||
@ -101,11 +97,7 @@ class AzureCognitiveSearchRetriever(BaseRetriever, BaseModel):
|
||||
]
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
search_results = await self._asearch(query)
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
@ -26,11 +26,7 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
url, json, headers = self._create_request(query)
|
||||
response = requests.post(url, json=json, headers=headers)
|
||||
@ -45,11 +41,7 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
|
||||
return docs
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
url, json, headers = self._create_request(query)
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
@ -28,11 +28,7 @@ class DataberryRetriever(BaseRetriever):
|
||||
self.top_k = top_k
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
response = requests.post(
|
||||
self.datastore_url,
|
||||
@ -59,11 +55,7 @@ class DataberryRetriever(BaseRetriever):
|
||||
]
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.request(
|
||||
|
@ -56,8 +56,8 @@ class DocArrayRetriever(BaseRetriever, BaseModel):
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant for a query.
|
||||
|
||||
@ -213,7 +213,7 @@ class DocArrayRetriever(BaseRetriever, BaseModel):
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
@ -117,11 +117,7 @@ class ElasticSearchBM25Retriever(BaseRetriever):
|
||||
return ids
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
query_dict = {"query": {"match": {"content": query}}}
|
||||
res = self.client.search(index=self.index_name, body=query_dict)
|
||||
@ -132,10 +128,6 @@ class ElasticSearchBM25Retriever(BaseRetriever):
|
||||
return docs
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
@ -264,8 +264,8 @@ class AmazonKendraRetriever(BaseRetriever):
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Run search on Kendra index and get top k documents
|
||||
|
||||
@ -281,7 +281,7 @@ class AmazonKendraRetriever(BaseRetriever):
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError("Async version is not implemented for Kendra yet.")
|
||||
|
@ -56,11 +56,7 @@ class KNNRetriever(BaseRetriever, BaseModel):
|
||||
return cls(embeddings=embeddings, index=index, texts=texts, **kwargs)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
query_embeds = np.array(self.embeddings.embed_query(query))
|
||||
# calc L2 norm
|
||||
@ -84,10 +80,6 @@ class KNNRetriever(BaseRetriever, BaseModel):
|
||||
return top_k_results
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
@ -16,11 +16,7 @@ class LlamaIndexRetriever(BaseRetriever, BaseModel):
|
||||
query_kwargs: Dict = Field(default_factory=dict)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant for a query."""
|
||||
try:
|
||||
@ -44,11 +40,7 @@ class LlamaIndexRetriever(BaseRetriever, BaseModel):
|
||||
return docs
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: Optional[AsyncCallbackManagerForRetrieverRun]
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError("LlamaIndexRetriever does not support async")
|
||||
|
||||
@ -60,11 +52,7 @@ class LlamaIndexGraphRetriever(BaseRetriever, BaseModel):
|
||||
query_configs: List[Dict] = Field(default_factory=list)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant for a query."""
|
||||
try:
|
||||
@ -96,10 +84,6 @@ class LlamaIndexGraphRetriever(BaseRetriever, BaseModel):
|
||||
return docs
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError("LlamaIndexGraphRetriever does not support async")
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, List
|
||||
from typing import List
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
@ -31,8 +31,8 @@ class MergerRetriever(BaseRetriever):
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Get the relevant documents for a given query.
|
||||
@ -52,8 +52,8 @@ class MergerRetriever(BaseRetriever):
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Asynchronously get the relevant documents for a given query.
|
||||
|
@ -22,11 +22,7 @@ class MetalRetriever(BaseRetriever):
|
||||
self.params = params or {}
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
results = self.client.search({"text": query}, **self.params)
|
||||
final_results = []
|
||||
@ -36,10 +32,6 @@ class MetalRetriever(BaseRetriever):
|
||||
return final_results
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Any, List
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -98,8 +98,8 @@ class MultiQueryRetriever(BaseRetriever):
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Get relevated documents given a user query.
|
||||
|
||||
@ -117,8 +117,8 @@ class MultiQueryRetriever(BaseRetriever):
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -143,11 +143,7 @@ class PineconeHybridSearchRetriever(BaseRetriever, BaseModel):
|
||||
return values
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
from pinecone_text.hybrid import hybrid_convex_scale
|
||||
|
||||
@ -174,10 +170,6 @@ class PineconeHybridSearchRetriever(BaseRetriever, BaseModel):
|
||||
return final_result
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""A retriever that uses PubMed API to retrieve documents."""
|
||||
from typing import Any, List
|
||||
from typing import List
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
@ -17,19 +17,11 @@ class PubMedRetriever(BaseRetriever, PubMedAPIWrapper):
|
||||
"""
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
return self.load_docs(query=query)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
@ -20,11 +20,7 @@ class RemoteLangChainRetriever(BaseRetriever, BaseModel):
|
||||
metadata_key: str = "metadata"
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
response = requests.post(
|
||||
self.url, json={self.input_key: query}, headers=self.headers
|
||||
@ -38,11 +34,7 @@ class RemoteLangChainRetriever(BaseRetriever, BaseModel):
|
||||
]
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.request(
|
||||
|
@ -84,11 +84,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
||||
return values
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant for a query.
|
||||
|
||||
@ -121,11 +117,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
||||
return docs
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun],
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: Optional[AsyncCallbackManagerForRetrieverRun]
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -55,11 +55,7 @@ class SVMRetriever(BaseRetriever, BaseModel):
|
||||
return cls(embeddings=embeddings, index=index, texts=texts, **kwargs)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
from sklearn import svm
|
||||
|
||||
@ -98,10 +94,6 @@ class SVMRetriever(BaseRetriever, BaseModel):
|
||||
return top_k_results
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
@ -64,11 +64,7 @@ class TFIDFRetriever(BaseRetriever, BaseModel):
|
||||
)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
@ -82,10 +78,6 @@ class TFIDFRetriever(BaseRetriever, BaseModel):
|
||||
return return_docs
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
@ -86,11 +86,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
|
||||
return results
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""Return documents that are relevant to the query."""
|
||||
current_time = datetime.datetime.now()
|
||||
@ -115,11 +111,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
|
||||
return result
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""Return documents that are relevant to the query."""
|
||||
raise NotImplementedError
|
||||
|
@ -65,22 +65,14 @@ class VespaRetriever(BaseRetriever):
|
||||
return docs
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
body = self._query_body.copy()
|
||||
body["query"] = query
|
||||
return self._query(body)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -93,7 +93,6 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
where_filter: Optional[Dict[str, object]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Look up similar documents in Weaviate."""
|
||||
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
||||
@ -112,11 +111,6 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
|
||||
return docs
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
where_filter: Optional[Dict[str, object]] = None,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, List
|
||||
from typing import List
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
@ -16,19 +16,11 @@ class WikipediaRetriever(BaseRetriever, WikipediaAPIWrapper):
|
||||
"""
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
return self.load(query=query)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
@ -64,7 +64,6 @@ class ZepRetriever(BaseRetriever):
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
metadata: Optional[Dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
from zep_python import MemorySearchPayload
|
||||
|
||||
@ -84,7 +83,6 @@ class ZepRetriever(BaseRetriever):
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
metadata: Optional[Dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
from zep_python import MemorySearchPayload
|
||||
|
||||
|
@ -44,7 +44,6 @@ class BaseRetriever(ABC):
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
_new_arg_supported: bool = False
|
||||
@ -86,27 +85,23 @@ class BaseRetriever(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant to a query.
|
||||
Args:
|
||||
query: String to find relevant documents for.
|
||||
run_manager: The callbacks handler to use.
|
||||
query: String to find relevant documents for
|
||||
run_manager: The callbacks handler to use
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""Asynchronously get documents relevant to a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
query: String to find relevant documents for
|
||||
run_manager: The callbacks handler to use
|
||||
Returns:
|
||||
List of relevant documents
|
||||
@ -117,8 +112,8 @@ class BaseRetriever(ABC):
|
||||
) -> List[Document]:
|
||||
"""Retrieve documents relevant to a query.
|
||||
Args:
|
||||
query: String to find relevant documents for.
|
||||
callbacks: Callback manager or list of callbacks.
|
||||
query: string to find relevant documents for
|
||||
callbacks: Callback manager or list of callbacks
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
@ -132,14 +127,13 @@ class BaseRetriever(ABC):
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
_kwargs = kwargs if self._expects_other_args else {}
|
||||
if self._new_arg_supported:
|
||||
result = self._get_relevant_documents(
|
||||
query, run_manager=run_manager, **kwargs
|
||||
query, run_manager=run_manager, **_kwargs
|
||||
)
|
||||
elif self._expects_other_args:
|
||||
result = self._get_relevant_documents(query, **kwargs)
|
||||
else:
|
||||
result = self._get_relevant_documents(query) # type: ignore[call-arg]
|
||||
result = self._get_relevant_documents(query, **_kwargs)
|
||||
except Exception as e:
|
||||
run_manager.on_retriever_error(e)
|
||||
raise e
|
||||
@ -170,16 +164,13 @@ class BaseRetriever(ABC):
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
_kwargs = kwargs if self._expects_other_args else {}
|
||||
if self._new_arg_supported:
|
||||
result = await self._aget_relevant_documents(
|
||||
query, run_manager=run_manager, **kwargs
|
||||
query, run_manager=run_manager, **_kwargs
|
||||
)
|
||||
elif self._expects_other_args:
|
||||
result = await self._aget_relevant_documents(query, **kwargs)
|
||||
else:
|
||||
result = await self._aget_relevant_documents(
|
||||
query, # type: ignore[call-arg]
|
||||
)
|
||||
result = await self._aget_relevant_documents(query, **_kwargs)
|
||||
except Exception as e:
|
||||
await run_manager.on_retriever_error(e)
|
||||
raise e
|
||||
|
@ -497,8 +497,8 @@ class AzureSearchVectorStoreRetriever(BaseRetriever, BaseModel):
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
if self.search_type == "similarity":
|
||||
docs = self.vectorstore.vector_search(query, k=self.k)
|
||||
@ -513,8 +513,8 @@ class AzureSearchVectorStoreRetriever(BaseRetriever, BaseModel):
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError(
|
||||
"AzureSearchVectorStoreRetriever does not support async"
|
||||
|
@ -408,11 +408,7 @@ class VectorStoreRetriever(BaseRetriever, BaseModel):
|
||||
return values
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: Optional[CallbackManagerForRetrieverRun]
|
||||
) -> List[Document]:
|
||||
if self.search_type == "similarity":
|
||||
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
|
||||
@ -432,11 +428,7 @@ class VectorStoreRetriever(BaseRetriever, BaseModel):
|
||||
return docs
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: Optional[AsyncCallbackManagerForRetrieverRun]
|
||||
) -> List[Document]:
|
||||
if self.search_type == "similarity":
|
||||
docs = await self.vectorstore.asimilarity_search(
|
||||
|
@ -620,11 +620,7 @@ class RedisVectorStoreRetriever(VectorStoreRetriever, BaseModel):
|
||||
return values
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: Optional[CallbackManagerForRetrieverRun]
|
||||
) -> List[Document]:
|
||||
if self.search_type == "similarity":
|
||||
docs = self.vectorstore.similarity_search(query, k=self.k)
|
||||
@ -637,11 +633,7 @@ class RedisVectorStoreRetriever(VectorStoreRetriever, BaseModel):
|
||||
return docs
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: Optional[AsyncCallbackManagerForRetrieverRun]
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError("RedisVectorStoreRetriever does not support async")
|
||||
|
||||
|
@ -451,11 +451,7 @@ class SingleStoreDBRetriever(VectorStoreRetriever):
|
||||
allowed_search_types: ClassVar[Collection[str]] = ("similarity",)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: Optional[CallbackManagerForRetrieverRun]
|
||||
) -> List[Document]:
|
||||
if self.search_type == "similarity":
|
||||
docs = self.vectorstore.similarity_search(query, k=self.k)
|
||||
@ -464,11 +460,7 @@ class SingleStoreDBRetriever(VectorStoreRetriever):
|
||||
return docs
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: Optional[AsyncCallbackManagerForRetrieverRun]
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError(
|
||||
"SingleStoreDBVectorStoreRetriever does not support async"
|
||||
|
@ -7,18 +7,19 @@ from langchain.document_loaders import TextLoader
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.llms import OpenAI
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from langchain.vectorstores import Chroma
|
||||
from langchain.vectorstores import FAISS
|
||||
|
||||
|
||||
def test_retrieval_qa_saving_loading(tmp_path: Path) -> None:
|
||||
"""Test saving and loading."""
|
||||
loader = TextLoader("docs/modules/state_of_the_union.txt")
|
||||
loader = TextLoader("docs/extras/modules/state_of_the_union.txt")
|
||||
documents = loader.load()
|
||||
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
||||
texts = text_splitter.split_documents(documents)
|
||||
embeddings = OpenAIEmbeddings()
|
||||
docsearch = Chroma.from_documents(texts, embeddings)
|
||||
docsearch = FAISS.from_documents(texts, embeddings)
|
||||
qa = RetrievalQA.from_llm(llm=OpenAI(), retriever=docsearch.as_retriever())
|
||||
qa.run("What did the president say about Ketanji Brown Jackson?")
|
||||
|
||||
file_path = tmp_path / "RetrievalQA_chain.yaml"
|
||||
qa.save(file_path=file_path)
|
||||
|
@ -1,7 +1,7 @@
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
|
||||
from langchain.retrievers.document_compressors import EmbeddingsFilter
|
||||
from langchain.vectorstores import Chroma
|
||||
from langchain.vectorstores import FAISS
|
||||
|
||||
|
||||
def test_contextual_compression_retriever_get_relevant_docs() -> None:
|
||||
@ -13,7 +13,7 @@ def test_contextual_compression_retriever_get_relevant_docs() -> None:
|
||||
]
|
||||
embeddings = OpenAIEmbeddings()
|
||||
base_compressor = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.75)
|
||||
base_retriever = Chroma.from_texts(texts, embedding=embeddings).as_retriever(
|
||||
base_retriever = FAISS.from_texts(texts, embedding=embeddings).as_retriever(
|
||||
search_kwargs={"k": len(texts)}
|
||||
)
|
||||
retriever = ContextualCompressionRetriever(
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
@ -141,45 +141,48 @@ async def test_fake_retriever_v1_with_kwargs_upgrade_async(
|
||||
assert callbacks.retriever_errors == 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_retriever_v2() -> BaseRetriever:
|
||||
class FakeRetrieverV2(BaseRetriever):
|
||||
class FakeRetrieverV2(BaseRetriever):
|
||||
def __init__(self, throw_error: bool = False) -> None:
|
||||
self.throw_error = throw_error
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun | None = None,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun | None
|
||||
) -> List[Document]:
|
||||
assert isinstance(self, FakeRetrieverV2)
|
||||
assert run_manager is not None
|
||||
assert isinstance(run_manager, CallbackManagerForRetrieverRun)
|
||||
if "throw_error" in kwargs:
|
||||
if self.throw_error:
|
||||
raise ValueError("Test error")
|
||||
return [
|
||||
Document(page_content=query, metadata=kwargs),
|
||||
Document(page_content=query),
|
||||
]
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun | None = None,
|
||||
**kwargs: Any,
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun | None
|
||||
) -> List[Document]:
|
||||
assert isinstance(self, FakeRetrieverV2)
|
||||
assert run_manager is not None
|
||||
assert isinstance(run_manager, AsyncCallbackManagerForRetrieverRun)
|
||||
if "throw_error" in kwargs:
|
||||
if self.throw_error:
|
||||
raise ValueError("Test error")
|
||||
return [
|
||||
Document(page_content=f"Async query {query}", metadata=kwargs),
|
||||
Document(page_content=f"Async query {query}"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_retriever_v2() -> BaseRetriever:
|
||||
return FakeRetrieverV2() # type: ignore[abstract]
|
||||
|
||||
|
||||
def test_fake_retriever_v2(fake_retriever_v2: BaseRetriever) -> None:
|
||||
@pytest.fixture
|
||||
def fake_erroring_retriever_v2() -> BaseRetriever:
|
||||
return FakeRetrieverV2(throw_error=True) # type: ignore[abstract]
|
||||
|
||||
|
||||
def test_fake_retriever_v2(
|
||||
fake_retriever_v2: BaseRetriever, fake_erroring_retriever_v2: BaseRetriever
|
||||
) -> None:
|
||||
callbacks = FakeCallbackHandler()
|
||||
assert fake_retriever_v2._new_arg_supported is True
|
||||
results = fake_retriever_v2.get_relevant_documents("Foo", callbacks=[callbacks])
|
||||
@ -187,20 +190,17 @@ def test_fake_retriever_v2(fake_retriever_v2: BaseRetriever) -> None:
|
||||
assert callbacks.retriever_starts == 1
|
||||
assert callbacks.retriever_ends == 1
|
||||
assert callbacks.retriever_errors == 0
|
||||
results2 = fake_retriever_v2.get_relevant_documents(
|
||||
"Foo", callbacks=[callbacks], foo="bar"
|
||||
)
|
||||
assert results2[0].metadata == {"foo": "bar"}
|
||||
fake_retriever_v2.get_relevant_documents("Foo", callbacks=[callbacks])
|
||||
|
||||
with pytest.raises(ValueError, match="Test error"):
|
||||
fake_retriever_v2.get_relevant_documents(
|
||||
"Foo", callbacks=[callbacks], throw_error=True
|
||||
)
|
||||
fake_erroring_retriever_v2.get_relevant_documents("Foo", callbacks=[callbacks])
|
||||
assert callbacks.retriever_errors == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_retriever_v2_async(fake_retriever_v2: BaseRetriever) -> None:
|
||||
async def test_fake_retriever_v2_async(
|
||||
fake_retriever_v2: BaseRetriever, fake_erroring_retriever_v2: BaseRetriever
|
||||
) -> None:
|
||||
callbacks = FakeCallbackHandler()
|
||||
assert fake_retriever_v2._new_arg_supported is True
|
||||
results = await fake_retriever_v2.aget_relevant_documents(
|
||||
@ -210,11 +210,8 @@ async def test_fake_retriever_v2_async(fake_retriever_v2: BaseRetriever) -> None
|
||||
assert callbacks.retriever_starts == 1
|
||||
assert callbacks.retriever_ends == 1
|
||||
assert callbacks.retriever_errors == 0
|
||||
results2 = await fake_retriever_v2.aget_relevant_documents(
|
||||
"Foo", callbacks=[callbacks], foo="bar"
|
||||
)
|
||||
assert results2[0].metadata == {"foo": "bar"}
|
||||
await fake_retriever_v2.aget_relevant_documents("Foo", callbacks=[callbacks])
|
||||
with pytest.raises(ValueError, match="Test error"):
|
||||
await fake_retriever_v2.aget_relevant_documents(
|
||||
"Foo", callbacks=[callbacks], throw_error=True
|
||||
await fake_erroring_retriever_v2.aget_relevant_documents(
|
||||
"Foo", callbacks=[callbacks]
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user