mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Add serialized object to retriever start callback (#7074)
<!-- Thank you for contributing to LangChain! Replace this comment with: - Description: a description of the change, - Issue: the issue # it fixes (if applicable), - Dependencies: any dependencies required for this change, - Tag maintainer: for a quicker response, tag the relevant maintainer (see below), - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out! If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. Maintainer responsibilities: - General / Misc / if you don't know who to tag: @dev2049 - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev - Models / Prompts: @hwchase17, @dev2049 - Memory: @hwchase17 - Agents / Tools / Toolkits: @vowelparrot - Tracing / Callbacks: @agola11 - Async: @agola11 If no one reviews your PR within a few days, feel free to @-mention the same people again. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md -->
This commit is contained in:
parent
baf48d3583
commit
81e5b1ad36
@ -168,10 +168,12 @@ class CallbackManagerMixin:
|
|||||||
|
|
||||||
def on_retriever_start(
|
def on_retriever_start(
|
||||||
self,
|
self,
|
||||||
|
serialized: Dict[str, Any],
|
||||||
query: str,
|
query: str,
|
||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Run when Retriever starts running."""
|
"""Run when Retriever starts running."""
|
||||||
@ -421,6 +423,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
|||||||
|
|
||||||
async def on_retriever_start(
|
async def on_retriever_start(
|
||||||
self,
|
self,
|
||||||
|
serialized: Dict[str, Any],
|
||||||
query: str,
|
query: str,
|
||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
|
@ -1196,6 +1196,7 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
|
|
||||||
def on_retriever_start(
|
def on_retriever_start(
|
||||||
self,
|
self,
|
||||||
|
serialized: Dict[str, Any],
|
||||||
query: str,
|
query: str,
|
||||||
run_id: Optional[UUID] = None,
|
run_id: Optional[UUID] = None,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
@ -1209,6 +1210,7 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
self.handlers,
|
self.handlers,
|
||||||
"on_retriever_start",
|
"on_retriever_start",
|
||||||
"ignore_retriever",
|
"ignore_retriever",
|
||||||
|
serialized,
|
||||||
query,
|
query,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
parent_run_id=self.parent_run_id,
|
parent_run_id=self.parent_run_id,
|
||||||
@ -1463,6 +1465,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
|
|
||||||
async def on_retriever_start(
|
async def on_retriever_start(
|
||||||
self,
|
self,
|
||||||
|
serialized: Dict[str, Any],
|
||||||
query: str,
|
query: str,
|
||||||
run_id: Optional[UUID] = None,
|
run_id: Optional[UUID] = None,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
@ -1476,6 +1479,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
self.handlers,
|
self.handlers,
|
||||||
"on_retriever_start",
|
"on_retriever_start",
|
||||||
"ignore_retriever",
|
"ignore_retriever",
|
||||||
|
serialized,
|
||||||
query,
|
query,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
parent_run_id=self.parent_run_id,
|
parent_run_id=self.parent_run_id,
|
||||||
|
@ -312,6 +312,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
|
|
||||||
def on_retriever_start(
|
def on_retriever_start(
|
||||||
self,
|
self,
|
||||||
|
serialized: Dict[str, Any],
|
||||||
query: str,
|
query: str,
|
||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
@ -326,6 +327,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
id=run_id,
|
id=run_id,
|
||||||
name="Retriever",
|
name="Retriever",
|
||||||
parent_run_id=parent_run_id,
|
parent_run_id=parent_run_id,
|
||||||
|
serialized=serialized,
|
||||||
inputs={"query": query},
|
inputs={"query": query},
|
||||||
extra=kwargs,
|
extra=kwargs,
|
||||||
events=[{"name": "start", "time": start_time}],
|
events=[{"name": "start", "time": start_time}],
|
||||||
|
@ -7,7 +7,7 @@ from typing import Dict, List, Optional
|
|||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import requests
|
import requests
|
||||||
from pydantic import BaseModel, Extra, root_validator
|
from pydantic import Extra, root_validator
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
@ -17,7 +17,7 @@ from langchain.schema import BaseRetriever, Document
|
|||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
|
||||||
class AzureCognitiveSearchRetriever(BaseRetriever, BaseModel):
|
class AzureCognitiveSearchRetriever(BaseRetriever):
|
||||||
"""Wrapper around Azure Cognitive Search."""
|
"""Wrapper around Azure Cognitive Search."""
|
||||||
|
|
||||||
service_name: str = ""
|
service_name: str = ""
|
||||||
|
@ -4,7 +4,6 @@ from typing import List, Optional
|
|||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import requests
|
import requests
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
@ -13,7 +12,7 @@ from langchain.callbacks.manager import (
|
|||||||
from langchain.schema import BaseRetriever, Document
|
from langchain.schema import BaseRetriever, Document
|
||||||
|
|
||||||
|
|
||||||
class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
|
class ChatGPTPluginRetriever(BaseRetriever):
|
||||||
url: str
|
url: str
|
||||||
bearer_token: str
|
bearer_token: str
|
||||||
top_k: int = 3
|
top_k: int = 3
|
||||||
|
@ -2,8 +2,6 @@
|
|||||||
|
|
||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra
|
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
CallbackManagerForRetrieverRun,
|
CallbackManagerForRetrieverRun,
|
||||||
@ -14,7 +12,7 @@ from langchain.retrievers.document_compressors.base import (
|
|||||||
from langchain.schema import BaseRetriever, Document
|
from langchain.schema import BaseRetriever, Document
|
||||||
|
|
||||||
|
|
||||||
class ContextualCompressionRetriever(BaseRetriever, BaseModel):
|
class ContextualCompressionRetriever(BaseRetriever):
|
||||||
"""Retriever that wraps a base retriever and compresses the results."""
|
"""Retriever that wraps a base retriever and compresses the results."""
|
||||||
|
|
||||||
base_compressor: BaseDocumentCompressor
|
base_compressor: BaseDocumentCompressor
|
||||||
@ -26,7 +24,6 @@ class ContextualCompressionRetriever(BaseRetriever, BaseModel):
|
|||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
extra = Extra.forbid
|
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
|
@ -17,16 +17,6 @@ class DataberryRetriever(BaseRetriever):
|
|||||||
top_k: Optional[int]
|
top_k: Optional[int]
|
||||||
api_key: Optional[str]
|
api_key: Optional[str]
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
datastore_url: str,
|
|
||||||
top_k: Optional[int] = None,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
):
|
|
||||||
self.datastore_url = datastore_url
|
|
||||||
self.api_key = api_key
|
|
||||||
self.top_k = top_k
|
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
|
@ -2,7 +2,6 @@ from enum import Enum
|
|||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
@ -20,7 +19,7 @@ class SearchType(str, Enum):
|
|||||||
mmr = "mmr"
|
mmr = "mmr"
|
||||||
|
|
||||||
|
|
||||||
class DocArrayRetriever(BaseRetriever, BaseModel):
|
class DocArrayRetriever(BaseRetriever):
|
||||||
"""
|
"""
|
||||||
Retriever class for DocArray Document Indices.
|
Retriever class for DocArray Document Indices.
|
||||||
|
|
||||||
|
@ -40,9 +40,8 @@ class ElasticSearchBM25Retriever(BaseRetriever):
|
|||||||
https://username:password@cluster_id.region_id.gcp.cloud.es.io:9243.
|
https://username:password@cluster_id.region_id.gcp.cloud.es.io:9243.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, client: Any, index_name: str):
|
client: Any
|
||||||
self.client = client
|
index_name: str
|
||||||
self.index_name = index_name
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
@ -75,7 +74,7 @@ class ElasticSearchBM25Retriever(BaseRetriever):
|
|||||||
|
|
||||||
# Create the index with the specified settings and mappings
|
# Create the index with the specified settings and mappings
|
||||||
es.indices.create(index=index_name, mappings=mappings, settings=settings)
|
es.indices.create(index=index_name, mappings=mappings, settings=settings)
|
||||||
return cls(es, index_name)
|
return cls(client=es, index_name=index_name)
|
||||||
|
|
||||||
def add_texts(
|
def add_texts(
|
||||||
self,
|
self,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import re
|
import re
|
||||||
from typing import Any, Dict, List, Literal, Optional
|
from typing import Any, Dict, List, Literal, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
@ -179,37 +179,34 @@ class AmazonKendraRetriever(BaseRetriever):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
index_id: str
|
||||||
self,
|
region_name: Optional[str] = None
|
||||||
index_id: str,
|
credentials_profile_name: Optional[str] = None
|
||||||
region_name: Optional[str] = None,
|
top_k: int = 3
|
||||||
credentials_profile_name: Optional[str] = None,
|
attribute_filter: Optional[Dict] = None
|
||||||
top_k: int = 3,
|
client: Any
|
||||||
attribute_filter: Optional[Dict] = None,
|
|
||||||
client: Optional[Any] = None,
|
|
||||||
):
|
|
||||||
self.index_id = index_id
|
|
||||||
self.top_k = top_k
|
|
||||||
self.attribute_filter = attribute_filter
|
|
||||||
|
|
||||||
if client is not None:
|
@root_validator(pre=True)
|
||||||
self.client = client
|
def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
return
|
if values["client"] is not None:
|
||||||
|
return values
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import boto3
|
import boto3
|
||||||
|
|
||||||
if credentials_profile_name is not None:
|
if values["credentials_profile_name"] is not None:
|
||||||
session = boto3.Session(profile_name=credentials_profile_name)
|
session = boto3.Session(profile_name=values["credentials_profile_name"])
|
||||||
else:
|
else:
|
||||||
# use default credentials
|
# use default credentials
|
||||||
session = boto3.Session()
|
session = boto3.Session()
|
||||||
|
|
||||||
client_params = {}
|
client_params = {}
|
||||||
if region_name is not None:
|
if values["region_name"] is not None:
|
||||||
client_params["region_name"] = region_name
|
client_params["region_name"] = values["region_name"]
|
||||||
|
|
||||||
self.client = session.client("kendra", **client_params)
|
values["client"] = session.client("kendra", **client_params)
|
||||||
|
|
||||||
|
return values
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ModuleNotFoundError(
|
raise ModuleNotFoundError(
|
||||||
"Could not import boto3 python package. "
|
"Could not import boto3 python package. "
|
||||||
|
@ -8,7 +8,6 @@ import concurrent.futures
|
|||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
@ -33,7 +32,7 @@ def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
|
|||||||
return np.array(list(executor.map(embeddings.embed_query, contexts)))
|
return np.array(list(executor.map(embeddings.embed_query, contexts)))
|
||||||
|
|
||||||
|
|
||||||
class KNNRetriever(BaseRetriever, BaseModel):
|
class KNNRetriever(BaseRetriever):
|
||||||
"""KNN Retriever."""
|
"""KNN Retriever."""
|
||||||
|
|
||||||
embeddings: Embeddings
|
embeddings: Embeddings
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from typing import Any, Dict, List, cast
|
from typing import Any, Dict, List, cast
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import Field
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
@ -9,7 +9,7 @@ from langchain.callbacks.manager import (
|
|||||||
from langchain.schema import BaseRetriever, Document
|
from langchain.schema import BaseRetriever, Document
|
||||||
|
|
||||||
|
|
||||||
class LlamaIndexRetriever(BaseRetriever, BaseModel):
|
class LlamaIndexRetriever(BaseRetriever):
|
||||||
"""Question-answering with sources over an LlamaIndex data structure."""
|
"""Question-answering with sources over an LlamaIndex data structure."""
|
||||||
|
|
||||||
index: Any
|
index: Any
|
||||||
@ -45,7 +45,7 @@ class LlamaIndexRetriever(BaseRetriever, BaseModel):
|
|||||||
raise NotImplementedError("LlamaIndexRetriever does not support async")
|
raise NotImplementedError("LlamaIndexRetriever does not support async")
|
||||||
|
|
||||||
|
|
||||||
class LlamaIndexGraphRetriever(BaseRetriever, BaseModel):
|
class LlamaIndexGraphRetriever(BaseRetriever):
|
||||||
"""Question-answering with sources over an LlamaIndex graph data structure."""
|
"""Question-answering with sources over an LlamaIndex graph data structure."""
|
||||||
|
|
||||||
graph: Any
|
graph: Any
|
||||||
|
@ -15,18 +15,7 @@ class MergerRetriever(BaseRetriever):
|
|||||||
retrievers: A list of retrievers to merge.
|
retrievers: A list of retrievers to merge.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
retrievers: List[BaseRetriever]
|
||||||
self,
|
|
||||||
retrievers: List[BaseRetriever],
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize the MergerRetriever class.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
retrievers: A list of retrievers to merge.
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.retrievers = retrievers
|
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self,
|
self,
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
from pydantic import root_validator
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
CallbackManagerForRetrieverRun,
|
CallbackManagerForRetrieverRun,
|
||||||
@ -10,16 +12,26 @@ from langchain.schema import BaseRetriever, Document
|
|||||||
class MetalRetriever(BaseRetriever):
|
class MetalRetriever(BaseRetriever):
|
||||||
"""Retriever that uses the Metal API."""
|
"""Retriever that uses the Metal API."""
|
||||||
|
|
||||||
def __init__(self, client: Any, params: Optional[dict] = None):
|
client: Any
|
||||||
|
|
||||||
|
params: Optional[dict] = None
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def validate_client(cls, values: dict) -> dict:
|
||||||
|
"""Validate that the client is of the correct type."""
|
||||||
from metal_sdk.metal import Metal
|
from metal_sdk.metal import Metal
|
||||||
|
|
||||||
if not isinstance(client, Metal):
|
if "client" in values:
|
||||||
raise ValueError(
|
client = values["client"]
|
||||||
"Got unexpected client, should be of type metal_sdk.metal.Metal. "
|
if not isinstance(client, Metal):
|
||||||
f"Instead, got {type(client)}"
|
raise ValueError(
|
||||||
)
|
"Got unexpected client, should be of type metal_sdk.metal.Metal. "
|
||||||
self.client: Metal = client
|
f"Instead, got {type(client)}"
|
||||||
self.params = params or {}
|
)
|
||||||
|
|
||||||
|
values["params"] = values.get("params", {})
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from pydantic import root_validator
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
CallbackManagerForRetrieverRun,
|
CallbackManagerForRetrieverRun,
|
||||||
@ -16,21 +18,28 @@ from langchain.vectorstores.milvus import Milvus
|
|||||||
class MilvusRetriever(BaseRetriever):
|
class MilvusRetriever(BaseRetriever):
|
||||||
"""Retriever that uses the Milvus API."""
|
"""Retriever that uses the Milvus API."""
|
||||||
|
|
||||||
def __init__(
|
embedding_function: Embeddings
|
||||||
self,
|
collection_name: str = "LangChainCollection"
|
||||||
embedding_function: Embeddings,
|
connection_args: Optional[Dict[str, Any]] = None
|
||||||
collection_name: str = "LangChainCollection",
|
consistency_level: str = "Session"
|
||||||
connection_args: Optional[Dict[str, Any]] = None,
|
search_params: Optional[dict] = None
|
||||||
consistency_level: str = "Session",
|
|
||||||
search_params: Optional[dict] = None,
|
store: Milvus
|
||||||
):
|
retriever: BaseRetriever
|
||||||
self.store = Milvus(
|
|
||||||
embedding_function,
|
@root_validator(pre=True)
|
||||||
collection_name,
|
def create_retriever(cls, values: Dict) -> Dict:
|
||||||
connection_args,
|
"""Create the Milvus store and retriever."""
|
||||||
consistency_level,
|
values["store"] = Milvus(
|
||||||
|
values["embedding_function"],
|
||||||
|
values["collection_name"],
|
||||||
|
values["connection_args"],
|
||||||
|
values["consistency_level"],
|
||||||
)
|
)
|
||||||
self.retriever = self.store.as_retriever(search_kwargs={"param": search_params})
|
values["retriever"] = values["store"].as_retriever(
|
||||||
|
search_kwargs={"param": values["search_params"]}
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
def add_texts(
|
def add_texts(
|
||||||
self, texts: List[str], metadatas: Optional[List[dict]] = None
|
self, texts: List[str], metadatas: Optional[List[dict]] = None
|
||||||
|
@ -47,28 +47,10 @@ class MultiQueryRetriever(BaseRetriever):
|
|||||||
"""Given a user query, use an LLM to write a set of queries.
|
"""Given a user query, use an LLM to write a set of queries.
|
||||||
Retrieve docs for each query. Rake the unique union of all retrieved docs."""
|
Retrieve docs for each query. Rake the unique union of all retrieved docs."""
|
||||||
|
|
||||||
def __init__(
|
retriever: BaseRetriever
|
||||||
self,
|
llm_chain: LLMChain
|
||||||
retriever: BaseRetriever,
|
verbose: bool = True
|
||||||
llm_chain: LLMChain,
|
parser_key: str = "lines"
|
||||||
verbose: bool = True,
|
|
||||||
parser_key: str = "lines",
|
|
||||||
) -> None:
|
|
||||||
"""Initialize MultiQueryRetriever.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
retriever: retriever to query documents from
|
|
||||||
llm_chain: llm_chain for query generation
|
|
||||||
verbose: show the queries that we generated to the user
|
|
||||||
parser_key: attribute name for the parsed output
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
MultiQueryRetriever
|
|
||||||
"""
|
|
||||||
self.retriever = retriever
|
|
||||||
self.llm_chain = llm_chain
|
|
||||||
self.verbose = verbose
|
|
||||||
self.parser_key = parser_key
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm(
|
def from_llm(
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, root_validator
|
from pydantic import Extra, root_validator
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
@ -98,7 +98,7 @@ def create_index(
|
|||||||
index.upsert(vectors)
|
index.upsert(vectors)
|
||||||
|
|
||||||
|
|
||||||
class PineconeHybridSearchRetriever(BaseRetriever, BaseModel):
|
class PineconeHybridSearchRetriever(BaseRetriever):
|
||||||
embeddings: Embeddings
|
embeddings: Embeddings
|
||||||
"""description"""
|
"""description"""
|
||||||
sparse_encoder: Any
|
sparse_encoder: Any
|
||||||
|
@ -2,7 +2,6 @@ from typing import List, Optional
|
|||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import requests
|
import requests
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
@ -11,7 +10,7 @@ from langchain.callbacks.manager import (
|
|||||||
from langchain.schema import BaseRetriever, Document
|
from langchain.schema import BaseRetriever, Document
|
||||||
|
|
||||||
|
|
||||||
class RemoteLangChainRetriever(BaseRetriever, BaseModel):
|
class RemoteLangChainRetriever(BaseRetriever):
|
||||||
url: str
|
url: str
|
||||||
headers: Optional[dict] = None
|
headers: Optional[dict] = None
|
||||||
input_key: str = "message"
|
input_key: str = "message"
|
||||||
|
@ -8,7 +8,6 @@ import concurrent.futures
|
|||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
@ -32,7 +31,7 @@ def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
|
|||||||
return np.array(list(executor.map(embeddings.embed_query, contexts)))
|
return np.array(list(executor.map(embeddings.embed_query, contexts)))
|
||||||
|
|
||||||
|
|
||||||
class SVMRetriever(BaseRetriever, BaseModel):
|
class SVMRetriever(BaseRetriever):
|
||||||
"""SVM Retriever."""
|
"""SVM Retriever."""
|
||||||
|
|
||||||
embeddings: Embeddings
|
embeddings: Embeddings
|
||||||
|
@ -7,8 +7,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Any, Dict, Iterable, List, Optional
|
from typing import Any, Dict, Iterable, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
CallbackManagerForRetrieverRun,
|
CallbackManagerForRetrieverRun,
|
||||||
@ -16,7 +14,7 @@ from langchain.callbacks.manager import (
|
|||||||
from langchain.schema import BaseRetriever, Document
|
from langchain.schema import BaseRetriever, Document
|
||||||
|
|
||||||
|
|
||||||
class TFIDFRetriever(BaseRetriever, BaseModel):
|
class TFIDFRetriever(BaseRetriever):
|
||||||
vectorizer: Any
|
vectorizer: Any
|
||||||
docs: List[Document]
|
docs: List[Document]
|
||||||
tfidf_array: Any
|
tfidf_array: Any
|
||||||
|
@ -4,7 +4,7 @@ import datetime
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import Field
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
@ -19,7 +19,7 @@ def _get_hours_passed(time: datetime.datetime, ref_time: datetime.datetime) -> f
|
|||||||
return (time - ref_time).total_seconds() / 3600
|
return (time - ref_time).total_seconds() / 3600
|
||||||
|
|
||||||
|
|
||||||
class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
|
class TimeWeightedVectorStoreRetriever(BaseRetriever):
|
||||||
"""Retriever combining embedding similarity with recency."""
|
"""Retriever combining embedding similarity with recency."""
|
||||||
|
|
||||||
vectorstore: VectorStore
|
vectorstore: VectorStore
|
||||||
|
@ -18,29 +18,13 @@ if TYPE_CHECKING:
|
|||||||
class VespaRetriever(BaseRetriever):
|
class VespaRetriever(BaseRetriever):
|
||||||
"""Retriever that uses the Vespa."""
|
"""Retriever that uses the Vespa."""
|
||||||
|
|
||||||
def __init__(
|
app: Vespa
|
||||||
self,
|
body: Dict
|
||||||
app: Vespa,
|
content_field: str
|
||||||
body: Dict,
|
metadata_fields: Sequence[str]
|
||||||
content_field: str,
|
|
||||||
metadata_fields: Optional[Sequence[str]] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
|
|
||||||
Args:
|
|
||||||
app: Vespa client.
|
|
||||||
body: query body.
|
|
||||||
content_field: result field with document contents.
|
|
||||||
metadata_fields: result fields to include in document metadata.
|
|
||||||
|
|
||||||
"""
|
|
||||||
self._application = app
|
|
||||||
self._query_body = body
|
|
||||||
self._content_field = content_field
|
|
||||||
self._metadata_fields = metadata_fields or ()
|
|
||||||
|
|
||||||
def _query(self, body: Dict) -> List[Document]:
|
def _query(self, body: Dict) -> List[Document]:
|
||||||
response = self._application.query(body)
|
response = self.app.query(body)
|
||||||
|
|
||||||
if not str(response.status_code).startswith("2"):
|
if not str(response.status_code).startswith("2"):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -55,11 +39,11 @@ class VespaRetriever(BaseRetriever):
|
|||||||
|
|
||||||
docs = []
|
docs = []
|
||||||
for child in response.hits:
|
for child in response.hits:
|
||||||
page_content = child["fields"].pop(self._content_field, "")
|
page_content = child["fields"].pop(self.content_field, "")
|
||||||
if self._metadata_fields == "*":
|
if self.metadata_fields == "*":
|
||||||
metadata = child["fields"]
|
metadata = child["fields"]
|
||||||
else:
|
else:
|
||||||
metadata = {mf: child["fields"].get(mf) for mf in self._metadata_fields}
|
metadata = {mf: child["fields"].get(mf) for mf in self.metadata_fields}
|
||||||
metadata["id"] = child["id"]
|
metadata["id"] = child["id"]
|
||||||
docs.append(Document(page_content=page_content, metadata=metadata))
|
docs.append(Document(page_content=page_content, metadata=metadata))
|
||||||
return docs
|
return docs
|
||||||
@ -67,7 +51,7 @@ class VespaRetriever(BaseRetriever):
|
|||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
body = self._query_body.copy()
|
body = self.body.copy()
|
||||||
body["query"] = query
|
body["query"] = query
|
||||||
return self._query(body)
|
return self._query(body)
|
||||||
|
|
||||||
@ -79,7 +63,7 @@ class VespaRetriever(BaseRetriever):
|
|||||||
def get_relevant_documents_with_filter(
|
def get_relevant_documents_with_filter(
|
||||||
self, query: str, *, _filter: Optional[str] = None
|
self, query: str, *, _filter: Optional[str] = None
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
body = self._query_body.copy()
|
body = self.body.copy()
|
||||||
_filter = f" and {_filter}" if _filter else ""
|
_filter = f" and {_filter}" if _filter else ""
|
||||||
body["yql"] = body["yql"] + _filter
|
body["yql"] = body["yql"] + _filter
|
||||||
body["query"] = query
|
body["query"] = query
|
||||||
@ -139,4 +123,9 @@ class VespaRetriever(BaseRetriever):
|
|||||||
body["yql"] = yql
|
body["yql"] = yql
|
||||||
if k:
|
if k:
|
||||||
body["hits"] = k
|
body["hits"] = k
|
||||||
return cls(app, body, content_field, metadata_fields=metadata_fields)
|
return cls(
|
||||||
|
app=app,
|
||||||
|
body=body,
|
||||||
|
content_field=content_field,
|
||||||
|
metadata_fields=metadata_fields,
|
||||||
|
)
|
||||||
|
@ -2,10 +2,10 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, cast
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from pydantic import Extra
|
from pydantic import root_validator
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
@ -16,16 +16,19 @@ from langchain.schema import BaseRetriever
|
|||||||
|
|
||||||
|
|
||||||
class WeaviateHybridSearchRetriever(BaseRetriever):
|
class WeaviateHybridSearchRetriever(BaseRetriever):
|
||||||
def __init__(
|
client: Any
|
||||||
self,
|
index_name: str
|
||||||
client: Any,
|
text_key: str
|
||||||
index_name: str,
|
alpha: float = 0.5
|
||||||
text_key: str,
|
k: int = 4
|
||||||
alpha: float = 0.5,
|
attributes: List[str]
|
||||||
k: int = 4,
|
create_schema_if_missing: bool = True
|
||||||
attributes: Optional[List[str]] = None,
|
|
||||||
create_schema_if_missing: bool = True,
|
@root_validator(pre=True)
|
||||||
):
|
def validate_client(
|
||||||
|
cls,
|
||||||
|
values: Dict[str, Any],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
try:
|
try:
|
||||||
import weaviate
|
import weaviate
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -33,36 +36,31 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
|
|||||||
"Could not import weaviate python package. "
|
"Could not import weaviate python package. "
|
||||||
"Please install it with `pip install weaviate-client`."
|
"Please install it with `pip install weaviate-client`."
|
||||||
)
|
)
|
||||||
if not isinstance(client, weaviate.Client):
|
if not isinstance(values["client"], weaviate.Client):
|
||||||
|
client = values["client"]
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"client should be an instance of weaviate.Client, got {type(client)}"
|
f"client should be an instance of weaviate.Client, got {type(client)}"
|
||||||
)
|
)
|
||||||
self._client = client
|
if values["attributes"] is None:
|
||||||
self.k = k
|
values["attributes"] = []
|
||||||
self.alpha = alpha
|
|
||||||
self._index_name = index_name
|
|
||||||
self._text_key = text_key
|
|
||||||
self._query_attrs = [self._text_key]
|
|
||||||
if attributes is not None:
|
|
||||||
self._query_attrs.extend(attributes)
|
|
||||||
|
|
||||||
if create_schema_if_missing:
|
cast(List, values["attributes"]).append(values["text_key"])
|
||||||
self._create_schema_if_missing()
|
|
||||||
|
|
||||||
def _create_schema_if_missing(self) -> None:
|
if values["create_schema_if_missing"]:
|
||||||
class_obj = {
|
class_obj = {
|
||||||
"class": self._index_name,
|
"class": values["index_name"],
|
||||||
"properties": [{"name": self._text_key, "dataType": ["text"]}],
|
"properties": [{"name": values["text_key"], "dataType": ["text"]}],
|
||||||
"vectorizer": "text2vec-openai",
|
"vectorizer": "text2vec-openai",
|
||||||
}
|
}
|
||||||
|
|
||||||
if not self._client.schema.exists(self._index_name):
|
if not values["client"].schema.exists(values["index_name"]):
|
||||||
self._client.schema.create_class(class_obj)
|
values["client"].schema.create_class(class_obj)
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
extra = Extra.forbid
|
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
# added text_key
|
# added text_key
|
||||||
@ -70,11 +68,11 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
|
|||||||
"""Upload documents to Weaviate."""
|
"""Upload documents to Weaviate."""
|
||||||
from weaviate.util import get_valid_uuid
|
from weaviate.util import get_valid_uuid
|
||||||
|
|
||||||
with self._client.batch as batch:
|
with self.client.batch as batch:
|
||||||
ids = []
|
ids = []
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
metadata = doc.metadata or {}
|
metadata = doc.metadata or {}
|
||||||
data_properties = {self._text_key: doc.page_content, **metadata}
|
data_properties = {self.text_key: doc.page_content, **metadata}
|
||||||
|
|
||||||
# If the UUID of one of the objects already exists
|
# If the UUID of one of the objects already exists
|
||||||
# then the existing objectwill be replaced by the new object.
|
# then the existing objectwill be replaced by the new object.
|
||||||
@ -83,7 +81,7 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
|
|||||||
else:
|
else:
|
||||||
_id = get_valid_uuid(uuid4())
|
_id = get_valid_uuid(uuid4())
|
||||||
|
|
||||||
batch.add_data_object(data_properties, self._index_name, _id)
|
batch.add_data_object(data_properties, self.index_name, _id)
|
||||||
ids.append(_id)
|
ids.append(_id)
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
@ -95,7 +93,7 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
|
|||||||
where_filter: Optional[Dict[str, object]] = None,
|
where_filter: Optional[Dict[str, object]] = None,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Look up similar documents in Weaviate."""
|
"""Look up similar documents in Weaviate."""
|
||||||
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
query_obj = self.client.query.get(self.index_name, self.attributes)
|
||||||
if where_filter:
|
if where_filter:
|
||||||
query_obj = query_obj.with_where(where_filter)
|
query_obj = query_obj.with_where(where_filter)
|
||||||
|
|
||||||
@ -105,8 +103,8 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
|
|||||||
|
|
||||||
docs = []
|
docs = []
|
||||||
|
|
||||||
for res in result["data"]["Get"][self._index_name]:
|
for res in result["data"]["Get"][self.index_name]:
|
||||||
text = res.pop(self._text_key)
|
text = res.pop(self.text_key)
|
||||||
docs.append(Document(page_content=text, metadata=res))
|
docs.append(Document(page_content=text, metadata=res))
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from pydantic import root_validator
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
@ -27,13 +29,14 @@ class ZepRetriever(BaseRetriever):
|
|||||||
https://docs.getzep.com/deployment/quickstart/
|
https://docs.getzep.com/deployment/quickstart/
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
zep_client: Any
|
||||||
self,
|
|
||||||
session_id: str,
|
session_id: str
|
||||||
url: str,
|
|
||||||
api_key: Optional[str] = None,
|
top_k: Optional[int]
|
||||||
top_k: Optional[int] = None,
|
|
||||||
):
|
@root_validator(pre=True)
|
||||||
|
def create_client(cls, values: dict) -> dict:
|
||||||
try:
|
try:
|
||||||
from zep_python import ZepClient
|
from zep_python import ZepClient
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -41,10 +44,11 @@ class ZepRetriever(BaseRetriever):
|
|||||||
"Could not import zep-python package. "
|
"Could not import zep-python package. "
|
||||||
"Please install it with `pip install zep-python`."
|
"Please install it with `pip install zep-python`."
|
||||||
)
|
)
|
||||||
|
values["zep_client"] = values.get(
|
||||||
self.zep_client = ZepClient(base_url=url, api_key=api_key)
|
"zep_client",
|
||||||
self.session_id = session_id
|
ZepClient(base_url=values["url"], api_key=values.get("api_key")),
|
||||||
self.top_k = top_k
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
def _search_result_to_doc(
|
def _search_result_to_doc(
|
||||||
self, results: List[MemorySearchResult]
|
self, results: List[MemorySearchResult]
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from pydantic import root_validator
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
CallbackManagerForRetrieverRun,
|
CallbackManagerForRetrieverRun,
|
||||||
@ -16,21 +18,27 @@ from langchain.vectorstores.zilliz import Zilliz
|
|||||||
class ZillizRetriever(BaseRetriever):
|
class ZillizRetriever(BaseRetriever):
|
||||||
"""Retriever that uses the Zilliz API."""
|
"""Retriever that uses the Zilliz API."""
|
||||||
|
|
||||||
def __init__(
|
embedding_function: Embeddings
|
||||||
self,
|
collection_name: str = "LangChainCollection"
|
||||||
embedding_function: Embeddings,
|
connection_args: Optional[Dict[str, Any]] = None
|
||||||
collection_name: str = "LangChainCollection",
|
consistency_level: str = "Session"
|
||||||
connection_args: Optional[Dict[str, Any]] = None,
|
search_params: Optional[dict] = None
|
||||||
consistency_level: str = "Session",
|
|
||||||
search_params: Optional[dict] = None,
|
store: Zilliz
|
||||||
):
|
retriever: BaseRetriever
|
||||||
self.store = Zilliz(
|
|
||||||
embedding_function,
|
@root_validator(pre=True)
|
||||||
collection_name,
|
def create_client(cls, values: dict) -> dict:
|
||||||
connection_args,
|
values["store"] = Zilliz(
|
||||||
consistency_level,
|
values["embedding_function"],
|
||||||
|
values["collection_name"],
|
||||||
|
values["connection_args"],
|
||||||
|
values["consistency_level"],
|
||||||
)
|
)
|
||||||
self.retriever = self.store.as_retriever(search_kwargs={"param": search_params})
|
values["retriever"] = values["store"].as_retriever(
|
||||||
|
search_kwargs={"param": values["search_params"]}
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
def add_texts(
|
def add_texts(
|
||||||
self, texts: List[str], metadatas: Optional[List[dict]] = None
|
self, texts: List[str], metadatas: Optional[List[dict]] = None
|
||||||
|
@ -5,6 +5,8 @@ from abc import ABC, abstractmethod
|
|||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import TYPE_CHECKING, Any, List
|
from typing import TYPE_CHECKING, Any, List
|
||||||
|
|
||||||
|
from langchain.load.dump import dumpd
|
||||||
|
from langchain.load.serializable import Serializable
|
||||||
from langchain.schema.document import Document
|
from langchain.schema.document import Document
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -15,7 +17,7 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseRetriever(ABC):
|
class BaseRetriever(Serializable, ABC):
|
||||||
"""Abstract base class for a Document retrieval system.
|
"""Abstract base class for a Document retrieval system.
|
||||||
|
|
||||||
A retrieval system is defined as something that can take string queries and return
|
A retrieval system is defined as something that can take string queries and return
|
||||||
@ -46,6 +48,11 @@ class BaseRetriever(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
_new_arg_supported: bool = False
|
_new_arg_supported: bool = False
|
||||||
_expects_other_args: bool = False
|
_expects_other_args: bool = False
|
||||||
|
|
||||||
@ -81,7 +88,9 @@ class BaseRetriever(ABC):
|
|||||||
parameters = signature(cls._get_relevant_documents).parameters
|
parameters = signature(cls._get_relevant_documents).parameters
|
||||||
cls._new_arg_supported = parameters.get("run_manager") is not None
|
cls._new_arg_supported = parameters.get("run_manager") is not None
|
||||||
# If a V1 retriever broke the interface and expects additional arguments
|
# If a V1 retriever broke the interface and expects additional arguments
|
||||||
cls._expects_other_args = (not cls._new_arg_supported) and len(parameters) > 2
|
cls._expects_other_args = (
|
||||||
|
len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0
|
||||||
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
@ -123,6 +132,7 @@ class BaseRetriever(ABC):
|
|||||||
callbacks, None, verbose=kwargs.get("verbose", False)
|
callbacks, None, verbose=kwargs.get("verbose", False)
|
||||||
)
|
)
|
||||||
run_manager = callback_manager.on_retriever_start(
|
run_manager = callback_manager.on_retriever_start(
|
||||||
|
dumpd(self),
|
||||||
query,
|
query,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
@ -160,6 +170,7 @@ class BaseRetriever(ABC):
|
|||||||
callbacks, None, verbose=kwargs.get("verbose", False)
|
callbacks, None, verbose=kwargs.get("verbose", False)
|
||||||
)
|
)
|
||||||
run_manager = await callback_manager.on_retriever_start(
|
run_manager = await callback_manager.on_retriever_start(
|
||||||
|
dumpd(self),
|
||||||
query,
|
query,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
@ -3,7 +3,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
|
|
||||||
@ -40,11 +40,6 @@ class ArxivAPIWrapper(BaseModel):
|
|||||||
load_all_available_meta: bool = False
|
load_all_available_meta: bool = False
|
||||||
doc_content_chars_max: Optional[int] = 4000
|
doc_content_chars_max: Optional[int] = 4000
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
extra = Extra.forbid
|
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that the python package exists in environment."""
|
"""Validate that the python package exists in environment."""
|
||||||
|
@ -5,7 +5,7 @@ import urllib.error
|
|||||||
import urllib.request
|
import urllib.request
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
|
|
||||||
@ -42,11 +42,6 @@ class PubMedAPIWrapper(BaseModel):
|
|||||||
load_all_available_meta: bool = False
|
load_all_available_meta: bool = False
|
||||||
email: str = "your_email@example.com"
|
email: str = "your_email@example.com"
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
extra = Extra.forbid
|
|
||||||
|
|
||||||
def run(self, query: str) -> str:
|
def run(self, query: str) -> str:
|
||||||
"""
|
"""
|
||||||
Run PubMed search and get the article meta information.
|
Run PubMed search and get the article meta information.
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
|
|
||||||
@ -27,11 +27,6 @@ class WikipediaAPIWrapper(BaseModel):
|
|||||||
load_all_available_meta: bool = False
|
load_all_available_meta: bool = False
|
||||||
doc_content_chars_max: int = 4000
|
doc_content_chars_max: int = 4000
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
extra = Extra.forbid
|
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that the python package exists in environment."""
|
"""Validate that the python package exists in environment."""
|
||||||
|
@ -24,7 +24,7 @@ def test_merger_retriever_get_relevant_docs() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# The Lord of the Retrievers.
|
# The Lord of the Retrievers.
|
||||||
lotr = MergerRetriever([retriever_a, retriever_b])
|
lotr = MergerRetriever(retrievers=[retriever_a, retriever_b])
|
||||||
|
|
||||||
actual = lotr.get_relevant_documents("Tell me about the Celtics")
|
actual = lotr.get_relevant_documents("Tell me about the Celtics")
|
||||||
assert len(actual) == 2
|
assert len(actual) == 2
|
||||||
|
@ -146,7 +146,7 @@ def test_ignore_retriever() -> None:
|
|||||||
handler1 = FakeCallbackHandler(ignore_retriever_=True)
|
handler1 = FakeCallbackHandler(ignore_retriever_=True)
|
||||||
handler2 = FakeCallbackHandler()
|
handler2 = FakeCallbackHandler()
|
||||||
manager = CallbackManager(handlers=[handler1, handler2])
|
manager = CallbackManager(handlers=[handler1, handler2])
|
||||||
run_manager = manager.on_retriever_start("")
|
run_manager = manager.on_retriever_start({}, "")
|
||||||
run_manager.on_retriever_end([])
|
run_manager.on_retriever_end([])
|
||||||
run_manager.on_retriever_error(Exception())
|
run_manager.on_retriever_error(Exception())
|
||||||
assert handler1.starts == 0
|
assert handler1.starts == 0
|
||||||
|
@ -142,8 +142,7 @@ async def test_fake_retriever_v1_with_kwargs_upgrade_async(
|
|||||||
|
|
||||||
|
|
||||||
class FakeRetrieverV2(BaseRetriever):
|
class FakeRetrieverV2(BaseRetriever):
|
||||||
def __init__(self, throw_error: bool = False) -> None:
|
throw_error: bool = False
|
||||||
self.throw_error = throw_error
|
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun | None
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun | None
|
||||||
|
Loading…
Reference in New Issue
Block a user