mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Rm retriever kwargs (#7013)
Doesn't actually limit the Retriever interface but hopefully in practice it does
This commit is contained in:
parent
9dc77614e3
commit
7acd524210
@ -14,9 +14,10 @@ from langchain.callbacks.manager import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
class BaseRetriever(ABC):
|
class BaseRetriever(ABC):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Get documents relevant to a query.
|
"""Get documents relevant to a query.
|
||||||
Args:
|
Args:
|
||||||
@ -32,7 +33,6 @@ class BaseRetriever(ABC):
|
|||||||
query: str,
|
query: str,
|
||||||
*,
|
*,
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Asynchronously get documents relevant to a query.
|
"""Asynchronously get documents relevant to a query.
|
||||||
Args:
|
Args:
|
||||||
@ -110,7 +110,7 @@ class NumpyRetriever(BaseRetriever):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Get documents relevant to a query.
|
"""Get documents relevant to a query.
|
||||||
Args:
|
Args:
|
||||||
@ -127,7 +127,6 @@ class NumpyRetriever(BaseRetriever):
|
|||||||
query: str,
|
query: str,
|
||||||
*,
|
*,
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Asynchronously get documents relevant to a query.
|
"""Asynchronously get documents relevant to a query.
|
||||||
Args:
|
Args:
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, List
|
from typing import List
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
@ -16,19 +16,11 @@ class ArxivRetriever(BaseRetriever, ArxivAPIWrapper):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: CallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
return self.load(query=query)
|
return self.load(query=query)
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import requests
|
import requests
|
||||||
@ -87,11 +87,7 @@ class AzureCognitiveSearchRetriever(BaseRetriever, BaseModel):
|
|||||||
return response_json["value"]
|
return response_json["value"]
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: CallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
search_results = self._search(query)
|
search_results = self._search(query)
|
||||||
|
|
||||||
@ -101,11 +97,7 @@ class AzureCognitiveSearchRetriever(BaseRetriever, BaseModel):
|
|||||||
]
|
]
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
search_results = await self._asearch(query)
|
search_results = await self._asearch(query)
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import requests
|
import requests
|
||||||
@ -26,11 +26,7 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
|
|||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: CallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
url, json, headers = self._create_request(query)
|
url, json, headers = self._create_request(query)
|
||||||
response = requests.post(url, json=json, headers=headers)
|
response = requests.post(url, json=json, headers=headers)
|
||||||
@ -45,11 +41,7 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
|
|||||||
return docs
|
return docs
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
url, json, headers = self._create_request(query)
|
url, json, headers = self._create_request(query)
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import requests
|
import requests
|
||||||
@ -28,11 +28,7 @@ class DataberryRetriever(BaseRetriever):
|
|||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: CallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
self.datastore_url,
|
self.datastore_url,
|
||||||
@ -59,11 +55,7 @@ class DataberryRetriever(BaseRetriever):
|
|||||||
]
|
]
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.request(
|
async with session.request(
|
||||||
|
@ -56,8 +56,8 @@ class DocArrayRetriever(BaseRetriever, BaseModel):
|
|||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
*,
|
||||||
run_manager: CallbackManagerForRetrieverRun,
|
run_manager: CallbackManagerForRetrieverRun,
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Get documents relevant for a query.
|
"""Get documents relevant for a query.
|
||||||
|
|
||||||
@ -213,7 +213,7 @@ class DocArrayRetriever(BaseRetriever, BaseModel):
|
|||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
*,
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -117,11 +117,7 @@ class ElasticSearchBM25Retriever(BaseRetriever):
|
|||||||
return ids
|
return ids
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: CallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
query_dict = {"query": {"match": {"content": query}}}
|
query_dict = {"query": {"match": {"content": query}}}
|
||||||
res = self.client.search(index=self.index_name, body=query_dict)
|
res = self.client.search(index=self.index_name, body=query_dict)
|
||||||
@ -132,10 +128,6 @@ class ElasticSearchBM25Retriever(BaseRetriever):
|
|||||||
return docs
|
return docs
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -264,8 +264,8 @@ class AmazonKendraRetriever(BaseRetriever):
|
|||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
*,
|
||||||
run_manager: CallbackManagerForRetrieverRun,
|
run_manager: CallbackManagerForRetrieverRun,
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Run search on Kendra index and get top k documents
|
"""Run search on Kendra index and get top k documents
|
||||||
|
|
||||||
@ -281,7 +281,7 @@ class AmazonKendraRetriever(BaseRetriever):
|
|||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
*,
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
raise NotImplementedError("Async version is not implemented for Kendra yet.")
|
raise NotImplementedError("Async version is not implemented for Kendra yet.")
|
||||||
|
@ -56,11 +56,7 @@ class KNNRetriever(BaseRetriever, BaseModel):
|
|||||||
return cls(embeddings=embeddings, index=index, texts=texts, **kwargs)
|
return cls(embeddings=embeddings, index=index, texts=texts, **kwargs)
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: CallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
query_embeds = np.array(self.embeddings.embed_query(query))
|
query_embeds = np.array(self.embeddings.embed_query(query))
|
||||||
# calc L2 norm
|
# calc L2 norm
|
||||||
@ -84,10 +80,6 @@ class KNNRetriever(BaseRetriever, BaseModel):
|
|||||||
return top_k_results
|
return top_k_results
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -16,11 +16,7 @@ class LlamaIndexRetriever(BaseRetriever, BaseModel):
|
|||||||
query_kwargs: Dict = Field(default_factory=dict)
|
query_kwargs: Dict = Field(default_factory=dict)
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: CallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Get documents relevant for a query."""
|
"""Get documents relevant for a query."""
|
||||||
try:
|
try:
|
||||||
@ -44,11 +40,7 @@ class LlamaIndexRetriever(BaseRetriever, BaseModel):
|
|||||||
return docs
|
return docs
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: Optional[AsyncCallbackManagerForRetrieverRun]
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
raise NotImplementedError("LlamaIndexRetriever does not support async")
|
raise NotImplementedError("LlamaIndexRetriever does not support async")
|
||||||
|
|
||||||
@ -60,11 +52,7 @@ class LlamaIndexGraphRetriever(BaseRetriever, BaseModel):
|
|||||||
query_configs: List[Dict] = Field(default_factory=list)
|
query_configs: List[Dict] = Field(default_factory=list)
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: CallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Get documents relevant for a query."""
|
"""Get documents relevant for a query."""
|
||||||
try:
|
try:
|
||||||
@ -96,10 +84,6 @@ class LlamaIndexGraphRetriever(BaseRetriever, BaseModel):
|
|||||||
return docs
|
return docs
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
raise NotImplementedError("LlamaIndexGraphRetriever does not support async")
|
raise NotImplementedError("LlamaIndexGraphRetriever does not support async")
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, List
|
from typing import List
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
@ -31,8 +31,8 @@ class MergerRetriever(BaseRetriever):
|
|||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
*,
|
||||||
run_manager: CallbackManagerForRetrieverRun,
|
run_manager: CallbackManagerForRetrieverRun,
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""
|
"""
|
||||||
Get the relevant documents for a given query.
|
Get the relevant documents for a given query.
|
||||||
@ -52,8 +52,8 @@ class MergerRetriever(BaseRetriever):
|
|||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
*,
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""
|
"""
|
||||||
Asynchronously get the relevant documents for a given query.
|
Asynchronously get the relevant documents for a given query.
|
||||||
|
@ -22,11 +22,7 @@ class MetalRetriever(BaseRetriever):
|
|||||||
self.params = params or {}
|
self.params = params or {}
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: CallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
results = self.client.search({"text": query}, **self.params)
|
results = self.client.search({"text": query}, **self.params)
|
||||||
final_results = []
|
final_results = []
|
||||||
@ -36,10 +32,6 @@ class MetalRetriever(BaseRetriever):
|
|||||||
return final_results
|
return final_results
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, List
|
from typing import List
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@ -98,8 +98,8 @@ class MultiQueryRetriever(BaseRetriever):
|
|||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
*,
|
||||||
run_manager: CallbackManagerForRetrieverRun,
|
run_manager: CallbackManagerForRetrieverRun,
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Get relevated documents given a user query.
|
"""Get relevated documents given a user query.
|
||||||
|
|
||||||
@ -117,8 +117,8 @@ class MultiQueryRetriever(BaseRetriever):
|
|||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
*,
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -143,11 +143,7 @@ class PineconeHybridSearchRetriever(BaseRetriever, BaseModel):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: CallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
from pinecone_text.hybrid import hybrid_convex_scale
|
from pinecone_text.hybrid import hybrid_convex_scale
|
||||||
|
|
||||||
@ -174,10 +170,6 @@ class PineconeHybridSearchRetriever(BaseRetriever, BaseModel):
|
|||||||
return final_result
|
return final_result
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""A retriever that uses PubMed API to retrieve documents."""
|
"""A retriever that uses PubMed API to retrieve documents."""
|
||||||
from typing import Any, List
|
from typing import List
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
@ -17,19 +17,11 @@ class PubMedRetriever(BaseRetriever, PubMedAPIWrapper):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: CallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
return self.load_docs(query=query)
|
return self.load_docs(query=query)
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import requests
|
import requests
|
||||||
@ -20,11 +20,7 @@ class RemoteLangChainRetriever(BaseRetriever, BaseModel):
|
|||||||
metadata_key: str = "metadata"
|
metadata_key: str = "metadata"
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: CallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
self.url, json={self.input_key: query}, headers=self.headers
|
self.url, json={self.input_key: query}, headers=self.headers
|
||||||
@ -38,11 +34,7 @@ class RemoteLangChainRetriever(BaseRetriever, BaseModel):
|
|||||||
]
|
]
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.request(
|
async with session.request(
|
||||||
|
@ -84,11 +84,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: CallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Get documents relevant for a query.
|
"""Get documents relevant for a query.
|
||||||
|
|
||||||
@ -121,11 +117,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
|||||||
return docs
|
return docs
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: Optional[AsyncCallbackManagerForRetrieverRun]
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun],
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -55,11 +55,7 @@ class SVMRetriever(BaseRetriever, BaseModel):
|
|||||||
return cls(embeddings=embeddings, index=index, texts=texts, **kwargs)
|
return cls(embeddings=embeddings, index=index, texts=texts, **kwargs)
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: CallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
from sklearn import svm
|
from sklearn import svm
|
||||||
|
|
||||||
@ -98,10 +94,6 @@ class SVMRetriever(BaseRetriever, BaseModel):
|
|||||||
return top_k_results
|
return top_k_results
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -64,11 +64,7 @@ class TFIDFRetriever(BaseRetriever, BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: CallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
from sklearn.metrics.pairwise import cosine_similarity
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
|
|
||||||
@ -82,10 +78,6 @@ class TFIDFRetriever(BaseRetriever, BaseModel):
|
|||||||
return return_docs
|
return return_docs
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -86,11 +86,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: CallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Return documents that are relevant to the query."""
|
"""Return documents that are relevant to the query."""
|
||||||
current_time = datetime.datetime.now()
|
current_time = datetime.datetime.now()
|
||||||
@ -115,11 +111,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Return documents that are relevant to the query."""
|
"""Return documents that are relevant to the query."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -65,22 +65,14 @@ class VespaRetriever(BaseRetriever):
|
|||||||
return docs
|
return docs
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: CallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
body = self._query_body.copy()
|
body = self._query_body.copy()
|
||||||
body["query"] = query
|
body["query"] = query
|
||||||
return self._query(body)
|
return self._query(body)
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -93,7 +93,6 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
|
|||||||
*,
|
*,
|
||||||
run_manager: CallbackManagerForRetrieverRun,
|
run_manager: CallbackManagerForRetrieverRun,
|
||||||
where_filter: Optional[Dict[str, object]] = None,
|
where_filter: Optional[Dict[str, object]] = None,
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Look up similar documents in Weaviate."""
|
"""Look up similar documents in Weaviate."""
|
||||||
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
||||||
@ -112,11 +111,6 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
|
|||||||
return docs
|
return docs
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
|
||||||
where_filter: Optional[Dict[str, object]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, List
|
from typing import List
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
@ -16,19 +16,11 @@ class WikipediaRetriever(BaseRetriever, WikipediaAPIWrapper):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: CallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
return self.load(query=query)
|
return self.load(query=query)
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
@ -64,7 +64,6 @@ class ZepRetriever(BaseRetriever):
|
|||||||
*,
|
*,
|
||||||
run_manager: CallbackManagerForRetrieverRun,
|
run_manager: CallbackManagerForRetrieverRun,
|
||||||
metadata: Optional[Dict] = None,
|
metadata: Optional[Dict] = None,
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
from zep_python import MemorySearchPayload
|
from zep_python import MemorySearchPayload
|
||||||
|
|
||||||
@ -84,7 +83,6 @@ class ZepRetriever(BaseRetriever):
|
|||||||
*,
|
*,
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||||
metadata: Optional[Dict] = None,
|
metadata: Optional[Dict] = None,
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
from zep_python import MemorySearchPayload
|
from zep_python import MemorySearchPayload
|
||||||
|
|
||||||
|
@ -44,7 +44,6 @@ class BaseRetriever(ABC):
|
|||||||
|
|
||||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
_new_arg_supported: bool = False
|
_new_arg_supported: bool = False
|
||||||
@ -86,27 +85,23 @@ class BaseRetriever(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Get documents relevant to a query.
|
"""Get documents relevant to a query.
|
||||||
Args:
|
Args:
|
||||||
query: String to find relevant documents for.
|
query: String to find relevant documents for
|
||||||
run_manager: The callbacks handler to use.
|
run_manager: The callbacks handler to use
|
||||||
Returns:
|
Returns:
|
||||||
List of relevant documents
|
List of relevant documents
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Asynchronously get documents relevant to a query.
|
"""Asynchronously get documents relevant to a query.
|
||||||
Args:
|
Args:
|
||||||
query: string to find relevant documents for
|
query: String to find relevant documents for
|
||||||
run_manager: The callbacks handler to use
|
run_manager: The callbacks handler to use
|
||||||
Returns:
|
Returns:
|
||||||
List of relevant documents
|
List of relevant documents
|
||||||
@ -117,8 +112,8 @@ class BaseRetriever(ABC):
|
|||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Retrieve documents relevant to a query.
|
"""Retrieve documents relevant to a query.
|
||||||
Args:
|
Args:
|
||||||
query: String to find relevant documents for.
|
query: string to find relevant documents for
|
||||||
callbacks: Callback manager or list of callbacks.
|
callbacks: Callback manager or list of callbacks
|
||||||
Returns:
|
Returns:
|
||||||
List of relevant documents
|
List of relevant documents
|
||||||
"""
|
"""
|
||||||
@ -132,14 +127,13 @@ class BaseRetriever(ABC):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
_kwargs = kwargs if self._expects_other_args else {}
|
||||||
if self._new_arg_supported:
|
if self._new_arg_supported:
|
||||||
result = self._get_relevant_documents(
|
result = self._get_relevant_documents(
|
||||||
query, run_manager=run_manager, **kwargs
|
query, run_manager=run_manager, **_kwargs
|
||||||
)
|
)
|
||||||
elif self._expects_other_args:
|
|
||||||
result = self._get_relevant_documents(query, **kwargs)
|
|
||||||
else:
|
else:
|
||||||
result = self._get_relevant_documents(query) # type: ignore[call-arg]
|
result = self._get_relevant_documents(query, **_kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
run_manager.on_retriever_error(e)
|
run_manager.on_retriever_error(e)
|
||||||
raise e
|
raise e
|
||||||
@ -170,16 +164,13 @@ class BaseRetriever(ABC):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
_kwargs = kwargs if self._expects_other_args else {}
|
||||||
if self._new_arg_supported:
|
if self._new_arg_supported:
|
||||||
result = await self._aget_relevant_documents(
|
result = await self._aget_relevant_documents(
|
||||||
query, run_manager=run_manager, **kwargs
|
query, run_manager=run_manager, **_kwargs
|
||||||
)
|
)
|
||||||
elif self._expects_other_args:
|
|
||||||
result = await self._aget_relevant_documents(query, **kwargs)
|
|
||||||
else:
|
else:
|
||||||
result = await self._aget_relevant_documents(
|
result = await self._aget_relevant_documents(query, **_kwargs)
|
||||||
query, # type: ignore[call-arg]
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await run_manager.on_retriever_error(e)
|
await run_manager.on_retriever_error(e)
|
||||||
raise e
|
raise e
|
||||||
|
@ -497,8 +497,8 @@ class AzureSearchVectorStoreRetriever(BaseRetriever, BaseModel):
|
|||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
*,
|
||||||
run_manager: CallbackManagerForRetrieverRun,
|
run_manager: CallbackManagerForRetrieverRun,
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
if self.search_type == "similarity":
|
if self.search_type == "similarity":
|
||||||
docs = self.vectorstore.vector_search(query, k=self.k)
|
docs = self.vectorstore.vector_search(query, k=self.k)
|
||||||
@ -513,8 +513,8 @@ class AzureSearchVectorStoreRetriever(BaseRetriever, BaseModel):
|
|||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
*,
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"AzureSearchVectorStoreRetriever does not support async"
|
"AzureSearchVectorStoreRetriever does not support async"
|
||||||
|
@ -408,11 +408,7 @@ class VectorStoreRetriever(BaseRetriever, BaseModel):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: Optional[CallbackManagerForRetrieverRun]
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
if self.search_type == "similarity":
|
if self.search_type == "similarity":
|
||||||
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
|
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
|
||||||
@ -432,11 +428,7 @@ class VectorStoreRetriever(BaseRetriever, BaseModel):
|
|||||||
return docs
|
return docs
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: Optional[AsyncCallbackManagerForRetrieverRun]
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
if self.search_type == "similarity":
|
if self.search_type == "similarity":
|
||||||
docs = await self.vectorstore.asimilarity_search(
|
docs = await self.vectorstore.asimilarity_search(
|
||||||
|
@ -620,11 +620,7 @@ class RedisVectorStoreRetriever(VectorStoreRetriever, BaseModel):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: Optional[CallbackManagerForRetrieverRun]
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
if self.search_type == "similarity":
|
if self.search_type == "similarity":
|
||||||
docs = self.vectorstore.similarity_search(query, k=self.k)
|
docs = self.vectorstore.similarity_search(query, k=self.k)
|
||||||
@ -637,11 +633,7 @@ class RedisVectorStoreRetriever(VectorStoreRetriever, BaseModel):
|
|||||||
return docs
|
return docs
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: Optional[AsyncCallbackManagerForRetrieverRun]
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
raise NotImplementedError("RedisVectorStoreRetriever does not support async")
|
raise NotImplementedError("RedisVectorStoreRetriever does not support async")
|
||||||
|
|
||||||
|
@ -451,11 +451,7 @@ class SingleStoreDBRetriever(VectorStoreRetriever):
|
|||||||
allowed_search_types: ClassVar[Collection[str]] = ("similarity",)
|
allowed_search_types: ClassVar[Collection[str]] = ("similarity",)
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: Optional[CallbackManagerForRetrieverRun]
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
if self.search_type == "similarity":
|
if self.search_type == "similarity":
|
||||||
docs = self.vectorstore.similarity_search(query, k=self.k)
|
docs = self.vectorstore.similarity_search(query, k=self.k)
|
||||||
@ -464,11 +460,7 @@ class SingleStoreDBRetriever(VectorStoreRetriever):
|
|||||||
return docs
|
return docs
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self, query: str, *, run_manager: Optional[AsyncCallbackManagerForRetrieverRun]
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"SingleStoreDBVectorStoreRetriever does not support async"
|
"SingleStoreDBVectorStoreRetriever does not support async"
|
||||||
|
@ -7,18 +7,19 @@ from langchain.document_loaders import TextLoader
|
|||||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||||
from langchain.llms import OpenAI
|
from langchain.llms import OpenAI
|
||||||
from langchain.text_splitter import CharacterTextSplitter
|
from langchain.text_splitter import CharacterTextSplitter
|
||||||
from langchain.vectorstores import Chroma
|
from langchain.vectorstores import FAISS
|
||||||
|
|
||||||
|
|
||||||
def test_retrieval_qa_saving_loading(tmp_path: Path) -> None:
|
def test_retrieval_qa_saving_loading(tmp_path: Path) -> None:
|
||||||
"""Test saving and loading."""
|
"""Test saving and loading."""
|
||||||
loader = TextLoader("docs/modules/state_of_the_union.txt")
|
loader = TextLoader("docs/extras/modules/state_of_the_union.txt")
|
||||||
documents = loader.load()
|
documents = loader.load()
|
||||||
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
||||||
texts = text_splitter.split_documents(documents)
|
texts = text_splitter.split_documents(documents)
|
||||||
embeddings = OpenAIEmbeddings()
|
embeddings = OpenAIEmbeddings()
|
||||||
docsearch = Chroma.from_documents(texts, embeddings)
|
docsearch = FAISS.from_documents(texts, embeddings)
|
||||||
qa = RetrievalQA.from_llm(llm=OpenAI(), retriever=docsearch.as_retriever())
|
qa = RetrievalQA.from_llm(llm=OpenAI(), retriever=docsearch.as_retriever())
|
||||||
|
qa.run("What did the president say about Ketanji Brown Jackson?")
|
||||||
|
|
||||||
file_path = tmp_path / "RetrievalQA_chain.yaml"
|
file_path = tmp_path / "RetrievalQA_chain.yaml"
|
||||||
qa.save(file_path=file_path)
|
qa.save(file_path=file_path)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from langchain.embeddings import OpenAIEmbeddings
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
|
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
|
||||||
from langchain.retrievers.document_compressors import EmbeddingsFilter
|
from langchain.retrievers.document_compressors import EmbeddingsFilter
|
||||||
from langchain.vectorstores import Chroma
|
from langchain.vectorstores import FAISS
|
||||||
|
|
||||||
|
|
||||||
def test_contextual_compression_retriever_get_relevant_docs() -> None:
|
def test_contextual_compression_retriever_get_relevant_docs() -> None:
|
||||||
@ -13,7 +13,7 @@ def test_contextual_compression_retriever_get_relevant_docs() -> None:
|
|||||||
]
|
]
|
||||||
embeddings = OpenAIEmbeddings()
|
embeddings = OpenAIEmbeddings()
|
||||||
base_compressor = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.75)
|
base_compressor = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.75)
|
||||||
base_retriever = Chroma.from_texts(texts, embedding=embeddings).as_retriever(
|
base_retriever = FAISS.from_texts(texts, embedding=embeddings).as_retriever(
|
||||||
search_kwargs={"k": len(texts)}
|
search_kwargs={"k": len(texts)}
|
||||||
)
|
)
|
||||||
retriever = ContextualCompressionRetriever(
|
retriever = ContextualCompressionRetriever(
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -141,45 +141,48 @@ async def test_fake_retriever_v1_with_kwargs_upgrade_async(
|
|||||||
assert callbacks.retriever_errors == 0
|
assert callbacks.retriever_errors == 0
|
||||||
|
|
||||||
|
|
||||||
|
class FakeRetrieverV2(BaseRetriever):
|
||||||
|
def __init__(self, throw_error: bool = False) -> None:
|
||||||
|
self.throw_error = throw_error
|
||||||
|
|
||||||
|
def _get_relevant_documents(
|
||||||
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun | None
|
||||||
|
) -> List[Document]:
|
||||||
|
assert isinstance(self, FakeRetrieverV2)
|
||||||
|
assert run_manager is not None
|
||||||
|
assert isinstance(run_manager, CallbackManagerForRetrieverRun)
|
||||||
|
if self.throw_error:
|
||||||
|
raise ValueError("Test error")
|
||||||
|
return [
|
||||||
|
Document(page_content=query),
|
||||||
|
]
|
||||||
|
|
||||||
|
async def _aget_relevant_documents(
|
||||||
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun | None
|
||||||
|
) -> List[Document]:
|
||||||
|
assert isinstance(self, FakeRetrieverV2)
|
||||||
|
assert run_manager is not None
|
||||||
|
assert isinstance(run_manager, AsyncCallbackManagerForRetrieverRun)
|
||||||
|
if self.throw_error:
|
||||||
|
raise ValueError("Test error")
|
||||||
|
return [
|
||||||
|
Document(page_content=f"Async query {query}"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def fake_retriever_v2() -> BaseRetriever:
|
def fake_retriever_v2() -> BaseRetriever:
|
||||||
class FakeRetrieverV2(BaseRetriever):
|
|
||||||
def _get_relevant_documents(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: CallbackManagerForRetrieverRun | None = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
|
||||||
assert isinstance(self, FakeRetrieverV2)
|
|
||||||
assert run_manager is not None
|
|
||||||
assert isinstance(run_manager, CallbackManagerForRetrieverRun)
|
|
||||||
if "throw_error" in kwargs:
|
|
||||||
raise ValueError("Test error")
|
|
||||||
return [
|
|
||||||
Document(page_content=query, metadata=kwargs),
|
|
||||||
]
|
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun | None = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
|
||||||
assert isinstance(self, FakeRetrieverV2)
|
|
||||||
assert run_manager is not None
|
|
||||||
assert isinstance(run_manager, AsyncCallbackManagerForRetrieverRun)
|
|
||||||
if "throw_error" in kwargs:
|
|
||||||
raise ValueError("Test error")
|
|
||||||
return [
|
|
||||||
Document(page_content=f"Async query {query}", metadata=kwargs),
|
|
||||||
]
|
|
||||||
|
|
||||||
return FakeRetrieverV2() # type: ignore[abstract]
|
return FakeRetrieverV2() # type: ignore[abstract]
|
||||||
|
|
||||||
|
|
||||||
def test_fake_retriever_v2(fake_retriever_v2: BaseRetriever) -> None:
|
@pytest.fixture
|
||||||
|
def fake_erroring_retriever_v2() -> BaseRetriever:
|
||||||
|
return FakeRetrieverV2(throw_error=True) # type: ignore[abstract]
|
||||||
|
|
||||||
|
|
||||||
|
def test_fake_retriever_v2(
|
||||||
|
fake_retriever_v2: BaseRetriever, fake_erroring_retriever_v2: BaseRetriever
|
||||||
|
) -> None:
|
||||||
callbacks = FakeCallbackHandler()
|
callbacks = FakeCallbackHandler()
|
||||||
assert fake_retriever_v2._new_arg_supported is True
|
assert fake_retriever_v2._new_arg_supported is True
|
||||||
results = fake_retriever_v2.get_relevant_documents("Foo", callbacks=[callbacks])
|
results = fake_retriever_v2.get_relevant_documents("Foo", callbacks=[callbacks])
|
||||||
@ -187,20 +190,17 @@ def test_fake_retriever_v2(fake_retriever_v2: BaseRetriever) -> None:
|
|||||||
assert callbacks.retriever_starts == 1
|
assert callbacks.retriever_starts == 1
|
||||||
assert callbacks.retriever_ends == 1
|
assert callbacks.retriever_ends == 1
|
||||||
assert callbacks.retriever_errors == 0
|
assert callbacks.retriever_errors == 0
|
||||||
results2 = fake_retriever_v2.get_relevant_documents(
|
fake_retriever_v2.get_relevant_documents("Foo", callbacks=[callbacks])
|
||||||
"Foo", callbacks=[callbacks], foo="bar"
|
|
||||||
)
|
|
||||||
assert results2[0].metadata == {"foo": "bar"}
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Test error"):
|
with pytest.raises(ValueError, match="Test error"):
|
||||||
fake_retriever_v2.get_relevant_documents(
|
fake_erroring_retriever_v2.get_relevant_documents("Foo", callbacks=[callbacks])
|
||||||
"Foo", callbacks=[callbacks], throw_error=True
|
|
||||||
)
|
|
||||||
assert callbacks.retriever_errors == 1
|
assert callbacks.retriever_errors == 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_fake_retriever_v2_async(fake_retriever_v2: BaseRetriever) -> None:
|
async def test_fake_retriever_v2_async(
|
||||||
|
fake_retriever_v2: BaseRetriever, fake_erroring_retriever_v2: BaseRetriever
|
||||||
|
) -> None:
|
||||||
callbacks = FakeCallbackHandler()
|
callbacks = FakeCallbackHandler()
|
||||||
assert fake_retriever_v2._new_arg_supported is True
|
assert fake_retriever_v2._new_arg_supported is True
|
||||||
results = await fake_retriever_v2.aget_relevant_documents(
|
results = await fake_retriever_v2.aget_relevant_documents(
|
||||||
@ -210,11 +210,8 @@ async def test_fake_retriever_v2_async(fake_retriever_v2: BaseRetriever) -> None
|
|||||||
assert callbacks.retriever_starts == 1
|
assert callbacks.retriever_starts == 1
|
||||||
assert callbacks.retriever_ends == 1
|
assert callbacks.retriever_ends == 1
|
||||||
assert callbacks.retriever_errors == 0
|
assert callbacks.retriever_errors == 0
|
||||||
results2 = await fake_retriever_v2.aget_relevant_documents(
|
await fake_retriever_v2.aget_relevant_documents("Foo", callbacks=[callbacks])
|
||||||
"Foo", callbacks=[callbacks], foo="bar"
|
|
||||||
)
|
|
||||||
assert results2[0].metadata == {"foo": "bar"}
|
|
||||||
with pytest.raises(ValueError, match="Test error"):
|
with pytest.raises(ValueError, match="Test error"):
|
||||||
await fake_retriever_v2.aget_relevant_documents(
|
await fake_erroring_retriever_v2.aget_relevant_documents(
|
||||||
"Foo", callbacks=[callbacks], throw_error=True
|
"Foo", callbacks=[callbacks]
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user