Add serialized object to retriever start callback (#7074)

<!-- Thank you for contributing to LangChain!

Replace this comment with:
  - Description: a description of the change, 
  - Issue: the issue # it fixes (if applicable),
  - Dependencies: any dependencies required for this change,
- Tag maintainer: for a quicker response, tag the relevant maintainer
(see below),
- Twitter handle: we announce bigger features on Twitter. If your PR
gets announced and you'd like a mention, we'll gladly shout you out!

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
  2. an example notebook showing its use.

Maintainer responsibilities:
  - General / Misc / if you don't know who to tag: @dev2049
  - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev
  - Models / Prompts: @hwchase17, @dev2049
  - Memory: @hwchase17
  - Agents / Tools / Toolkits: @vowelparrot
  - Tracing / Callbacks: @agola11
  - Async: @agola11

If no one reviews your PR within a few days, feel free to @-mention the
same people again.

See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
 -->
This commit is contained in:
Nuno Campos 2023-07-05 18:04:43 +01:00 committed by GitHub
parent baf48d3583
commit 81e5b1ad36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 203 additions and 232 deletions

View File

@ -168,10 +168,12 @@ class CallbackManagerMixin:
def on_retriever_start( def on_retriever_start(
self, self,
serialized: Dict[str, Any],
query: str, query: str,
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run when Retriever starts running.""" """Run when Retriever starts running."""
@ -421,6 +423,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
async def on_retriever_start( async def on_retriever_start(
self, self,
serialized: Dict[str, Any],
query: str, query: str,
*, *,
run_id: UUID, run_id: UUID,

View File

@ -1196,6 +1196,7 @@ class CallbackManager(BaseCallbackManager):
def on_retriever_start( def on_retriever_start(
self, self,
serialized: Dict[str, Any],
query: str, query: str,
run_id: Optional[UUID] = None, run_id: Optional[UUID] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
@ -1209,6 +1210,7 @@ class CallbackManager(BaseCallbackManager):
self.handlers, self.handlers,
"on_retriever_start", "on_retriever_start",
"ignore_retriever", "ignore_retriever",
serialized,
query, query,
run_id=run_id, run_id=run_id,
parent_run_id=self.parent_run_id, parent_run_id=self.parent_run_id,
@ -1463,6 +1465,7 @@ class AsyncCallbackManager(BaseCallbackManager):
async def on_retriever_start( async def on_retriever_start(
self, self,
serialized: Dict[str, Any],
query: str, query: str,
run_id: Optional[UUID] = None, run_id: Optional[UUID] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
@ -1476,6 +1479,7 @@ class AsyncCallbackManager(BaseCallbackManager):
self.handlers, self.handlers,
"on_retriever_start", "on_retriever_start",
"ignore_retriever", "ignore_retriever",
serialized,
query, query,
run_id=run_id, run_id=run_id,
parent_run_id=self.parent_run_id, parent_run_id=self.parent_run_id,

View File

@ -312,6 +312,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
def on_retriever_start( def on_retriever_start(
self, self,
serialized: Dict[str, Any],
query: str, query: str,
*, *,
run_id: UUID, run_id: UUID,
@ -326,6 +327,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
id=run_id, id=run_id,
name="Retriever", name="Retriever",
parent_run_id=parent_run_id, parent_run_id=parent_run_id,
serialized=serialized,
inputs={"query": query}, inputs={"query": query},
extra=kwargs, extra=kwargs,
events=[{"name": "start", "time": start_time}], events=[{"name": "start", "time": start_time}],

View File

@ -7,7 +7,7 @@ from typing import Dict, List, Optional
import aiohttp import aiohttp
import requests import requests
from pydantic import BaseModel, Extra, root_validator from pydantic import Extra, root_validator
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun,
@ -17,7 +17,7 @@ from langchain.schema import BaseRetriever, Document
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
class AzureCognitiveSearchRetriever(BaseRetriever, BaseModel): class AzureCognitiveSearchRetriever(BaseRetriever):
"""Wrapper around Azure Cognitive Search.""" """Wrapper around Azure Cognitive Search."""
service_name: str = "" service_name: str = ""

View File

@ -4,7 +4,6 @@ from typing import List, Optional
import aiohttp import aiohttp
import requests import requests
from pydantic import BaseModel
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun,
@ -13,7 +12,7 @@ from langchain.callbacks.manager import (
from langchain.schema import BaseRetriever, Document from langchain.schema import BaseRetriever, Document
class ChatGPTPluginRetriever(BaseRetriever, BaseModel): class ChatGPTPluginRetriever(BaseRetriever):
url: str url: str
bearer_token: str bearer_token: str
top_k: int = 3 top_k: int = 3

View File

@ -2,8 +2,6 @@
from typing import Any, List from typing import Any, List
from pydantic import BaseModel, Extra
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun,
@ -14,7 +12,7 @@ from langchain.retrievers.document_compressors.base import (
from langchain.schema import BaseRetriever, Document from langchain.schema import BaseRetriever, Document
class ContextualCompressionRetriever(BaseRetriever, BaseModel): class ContextualCompressionRetriever(BaseRetriever):
"""Retriever that wraps a base retriever and compresses the results.""" """Retriever that wraps a base retriever and compresses the results."""
base_compressor: BaseDocumentCompressor base_compressor: BaseDocumentCompressor
@ -26,7 +24,6 @@ class ContextualCompressionRetriever(BaseRetriever, BaseModel):
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True arbitrary_types_allowed = True
def _get_relevant_documents( def _get_relevant_documents(

View File

@ -17,16 +17,6 @@ class DataberryRetriever(BaseRetriever):
top_k: Optional[int] top_k: Optional[int]
api_key: Optional[str] api_key: Optional[str]
def __init__(
self,
datastore_url: str,
top_k: Optional[int] = None,
api_key: Optional[str] = None,
):
self.datastore_url = datastore_url
self.api_key = api_key
self.top_k = top_k
def _get_relevant_documents( def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]: ) -> List[Document]:

View File

@ -2,7 +2,6 @@ from enum import Enum
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import numpy as np import numpy as np
from pydantic import BaseModel
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun,
@ -20,7 +19,7 @@ class SearchType(str, Enum):
mmr = "mmr" mmr = "mmr"
class DocArrayRetriever(BaseRetriever, BaseModel): class DocArrayRetriever(BaseRetriever):
""" """
Retriever class for DocArray Document Indices. Retriever class for DocArray Document Indices.

View File

@ -40,9 +40,8 @@ class ElasticSearchBM25Retriever(BaseRetriever):
https://username:password@cluster_id.region_id.gcp.cloud.es.io:9243. https://username:password@cluster_id.region_id.gcp.cloud.es.io:9243.
""" """
def __init__(self, client: Any, index_name: str): client: Any
self.client = client index_name: str
self.index_name = index_name
@classmethod @classmethod
def create( def create(
@ -75,7 +74,7 @@ class ElasticSearchBM25Retriever(BaseRetriever):
# Create the index with the specified settings and mappings # Create the index with the specified settings and mappings
es.indices.create(index=index_name, mappings=mappings, settings=settings) es.indices.create(index=index_name, mappings=mappings, settings=settings)
return cls(es, index_name) return cls(client=es, index_name=index_name)
def add_texts( def add_texts(
self, self,

View File

@ -1,7 +1,7 @@
import re import re
from typing import Any, Dict, List, Literal, Optional from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Extra from pydantic import BaseModel, Extra, root_validator
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun,
@ -179,37 +179,34 @@ class AmazonKendraRetriever(BaseRetriever):
""" """
def __init__( index_id: str
self, region_name: Optional[str] = None
index_id: str, credentials_profile_name: Optional[str] = None
region_name: Optional[str] = None, top_k: int = 3
credentials_profile_name: Optional[str] = None, attribute_filter: Optional[Dict] = None
top_k: int = 3, client: Any
attribute_filter: Optional[Dict] = None,
client: Optional[Any] = None,
):
self.index_id = index_id
self.top_k = top_k
self.attribute_filter = attribute_filter
if client is not None: @root_validator(pre=True)
self.client = client def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
return if values["client"] is not None:
return values
try: try:
import boto3 import boto3
if credentials_profile_name is not None: if values["credentials_profile_name"] is not None:
session = boto3.Session(profile_name=credentials_profile_name) session = boto3.Session(profile_name=values["credentials_profile_name"])
else: else:
# use default credentials # use default credentials
session = boto3.Session() session = boto3.Session()
client_params = {} client_params = {}
if region_name is not None: if values["region_name"] is not None:
client_params["region_name"] = region_name client_params["region_name"] = values["region_name"]
self.client = session.client("kendra", **client_params) values["client"] = session.client("kendra", **client_params)
return values
except ImportError: except ImportError:
raise ModuleNotFoundError( raise ModuleNotFoundError(
"Could not import boto3 python package. " "Could not import boto3 python package. "

View File

@ -8,7 +8,6 @@ import concurrent.futures
from typing import Any, List, Optional from typing import Any, List, Optional
import numpy as np import numpy as np
from pydantic import BaseModel
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun,
@ -33,7 +32,7 @@ def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
return np.array(list(executor.map(embeddings.embed_query, contexts))) return np.array(list(executor.map(embeddings.embed_query, contexts)))
class KNNRetriever(BaseRetriever, BaseModel): class KNNRetriever(BaseRetriever):
"""KNN Retriever.""" """KNN Retriever."""
embeddings: Embeddings embeddings: Embeddings

View File

@ -1,6 +1,6 @@
from typing import Any, Dict, List, cast from typing import Any, Dict, List, cast
from pydantic import BaseModel, Field from pydantic import Field
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun,
@ -9,7 +9,7 @@ from langchain.callbacks.manager import (
from langchain.schema import BaseRetriever, Document from langchain.schema import BaseRetriever, Document
class LlamaIndexRetriever(BaseRetriever, BaseModel): class LlamaIndexRetriever(BaseRetriever):
"""Question-answering with sources over an LlamaIndex data structure.""" """Question-answering with sources over an LlamaIndex data structure."""
index: Any index: Any
@ -45,7 +45,7 @@ class LlamaIndexRetriever(BaseRetriever, BaseModel):
raise NotImplementedError("LlamaIndexRetriever does not support async") raise NotImplementedError("LlamaIndexRetriever does not support async")
class LlamaIndexGraphRetriever(BaseRetriever, BaseModel): class LlamaIndexGraphRetriever(BaseRetriever):
"""Question-answering with sources over an LlamaIndex graph data structure.""" """Question-answering with sources over an LlamaIndex graph data structure."""
graph: Any graph: Any

View File

@ -15,18 +15,7 @@ class MergerRetriever(BaseRetriever):
retrievers: A list of retrievers to merge. retrievers: A list of retrievers to merge.
""" """
def __init__( retrievers: List[BaseRetriever]
self,
retrievers: List[BaseRetriever],
):
"""
Initialize the MergerRetriever class.
Args:
retrievers: A list of retrievers to merge.
"""
self.retrievers = retrievers
def _get_relevant_documents( def _get_relevant_documents(
self, self,

View File

@ -1,5 +1,7 @@
from typing import Any, List, Optional from typing import Any, List, Optional
from pydantic import root_validator
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun,
@ -10,16 +12,26 @@ from langchain.schema import BaseRetriever, Document
class MetalRetriever(BaseRetriever): class MetalRetriever(BaseRetriever):
"""Retriever that uses the Metal API.""" """Retriever that uses the Metal API."""
def __init__(self, client: Any, params: Optional[dict] = None): client: Any
params: Optional[dict] = None
@root_validator(pre=True)
def validate_client(cls, values: dict) -> dict:
"""Validate that the client is of the correct type."""
from metal_sdk.metal import Metal from metal_sdk.metal import Metal
if not isinstance(client, Metal): if "client" in values:
raise ValueError( client = values["client"]
"Got unexpected client, should be of type metal_sdk.metal.Metal. " if not isinstance(client, Metal):
f"Instead, got {type(client)}" raise ValueError(
) "Got unexpected client, should be of type metal_sdk.metal.Metal. "
self.client: Metal = client f"Instead, got {type(client)}"
self.params = params or {} )
values["params"] = values.get("params", {})
return values
def _get_relevant_documents( def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun self, query: str, *, run_manager: CallbackManagerForRetrieverRun

View File

@ -2,6 +2,8 @@
import warnings import warnings
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import root_validator
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun,
@ -16,21 +18,28 @@ from langchain.vectorstores.milvus import Milvus
class MilvusRetriever(BaseRetriever): class MilvusRetriever(BaseRetriever):
"""Retriever that uses the Milvus API.""" """Retriever that uses the Milvus API."""
def __init__( embedding_function: Embeddings
self, collection_name: str = "LangChainCollection"
embedding_function: Embeddings, connection_args: Optional[Dict[str, Any]] = None
collection_name: str = "LangChainCollection", consistency_level: str = "Session"
connection_args: Optional[Dict[str, Any]] = None, search_params: Optional[dict] = None
consistency_level: str = "Session",
search_params: Optional[dict] = None, store: Milvus
): retriever: BaseRetriever
self.store = Milvus(
embedding_function, @root_validator(pre=True)
collection_name, def create_retriever(cls, values: Dict) -> Dict:
connection_args, """Create the Milvus store and retriever."""
consistency_level, values["store"] = Milvus(
values["embedding_function"],
values["collection_name"],
values["connection_args"],
values["consistency_level"],
) )
self.retriever = self.store.as_retriever(search_kwargs={"param": search_params}) values["retriever"] = values["store"].as_retriever(
search_kwargs={"param": values["search_params"]}
)
return values
def add_texts( def add_texts(
self, texts: List[str], metadatas: Optional[List[dict]] = None self, texts: List[str], metadatas: Optional[List[dict]] = None

View File

@ -47,28 +47,10 @@ class MultiQueryRetriever(BaseRetriever):
"""Given a user query, use an LLM to write a set of queries. """Given a user query, use an LLM to write a set of queries.
Retrieve docs for each query. Rake the unique union of all retrieved docs.""" Retrieve docs for each query. Rake the unique union of all retrieved docs."""
def __init__( retriever: BaseRetriever
self, llm_chain: LLMChain
retriever: BaseRetriever, verbose: bool = True
llm_chain: LLMChain, parser_key: str = "lines"
verbose: bool = True,
parser_key: str = "lines",
) -> None:
"""Initialize MultiQueryRetriever.
Args:
retriever: retriever to query documents from
llm_chain: llm_chain for query generation
verbose: show the queries that we generated to the user
parser_key: attribute name for the parsed output
Returns:
MultiQueryRetriever
"""
self.retriever = retriever
self.llm_chain = llm_chain
self.verbose = verbose
self.parser_key = parser_key
@classmethod @classmethod
def from_llm( def from_llm(

View File

@ -3,7 +3,7 @@
import hashlib import hashlib
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator from pydantic import Extra, root_validator
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun,
@ -98,7 +98,7 @@ def create_index(
index.upsert(vectors) index.upsert(vectors)
class PineconeHybridSearchRetriever(BaseRetriever, BaseModel): class PineconeHybridSearchRetriever(BaseRetriever):
embeddings: Embeddings embeddings: Embeddings
"""description""" """description"""
sparse_encoder: Any sparse_encoder: Any

View File

@ -2,7 +2,6 @@ from typing import List, Optional
import aiohttp import aiohttp
import requests import requests
from pydantic import BaseModel
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun,
@ -11,7 +10,7 @@ from langchain.callbacks.manager import (
from langchain.schema import BaseRetriever, Document from langchain.schema import BaseRetriever, Document
class RemoteLangChainRetriever(BaseRetriever, BaseModel): class RemoteLangChainRetriever(BaseRetriever):
url: str url: str
headers: Optional[dict] = None headers: Optional[dict] = None
input_key: str = "message" input_key: str = "message"

View File

@ -8,7 +8,6 @@ import concurrent.futures
from typing import Any, List, Optional from typing import Any, List, Optional
import numpy as np import numpy as np
from pydantic import BaseModel
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun,
@ -32,7 +31,7 @@ def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
return np.array(list(executor.map(embeddings.embed_query, contexts))) return np.array(list(executor.map(embeddings.embed_query, contexts)))
class SVMRetriever(BaseRetriever, BaseModel): class SVMRetriever(BaseRetriever):
"""SVM Retriever.""" """SVM Retriever."""
embeddings: Embeddings embeddings: Embeddings

View File

@ -7,8 +7,6 @@ from __future__ import annotations
from typing import Any, Dict, Iterable, List, Optional from typing import Any, Dict, Iterable, List, Optional
from pydantic import BaseModel
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun,
@ -16,7 +14,7 @@ from langchain.callbacks.manager import (
from langchain.schema import BaseRetriever, Document from langchain.schema import BaseRetriever, Document
class TFIDFRetriever(BaseRetriever, BaseModel): class TFIDFRetriever(BaseRetriever):
vectorizer: Any vectorizer: Any
docs: List[Document] docs: List[Document]
tfidf_array: Any tfidf_array: Any

View File

@ -4,7 +4,7 @@ import datetime
from copy import deepcopy from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from pydantic import BaseModel, Field from pydantic import Field
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun,
@ -19,7 +19,7 @@ def _get_hours_passed(time: datetime.datetime, ref_time: datetime.datetime) -> f
return (time - ref_time).total_seconds() / 3600 return (time - ref_time).total_seconds() / 3600
class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel): class TimeWeightedVectorStoreRetriever(BaseRetriever):
"""Retriever combining embedding similarity with recency.""" """Retriever combining embedding similarity with recency."""
vectorstore: VectorStore vectorstore: VectorStore

View File

@ -18,29 +18,13 @@ if TYPE_CHECKING:
class VespaRetriever(BaseRetriever): class VespaRetriever(BaseRetriever):
"""Retriever that uses the Vespa.""" """Retriever that uses the Vespa."""
def __init__( app: Vespa
self, body: Dict
app: Vespa, content_field: str
body: Dict, metadata_fields: Sequence[str]
content_field: str,
metadata_fields: Optional[Sequence[str]] = None,
):
"""
Args:
app: Vespa client.
body: query body.
content_field: result field with document contents.
metadata_fields: result fields to include in document metadata.
"""
self._application = app
self._query_body = body
self._content_field = content_field
self._metadata_fields = metadata_fields or ()
def _query(self, body: Dict) -> List[Document]: def _query(self, body: Dict) -> List[Document]:
response = self._application.query(body) response = self.app.query(body)
if not str(response.status_code).startswith("2"): if not str(response.status_code).startswith("2"):
raise RuntimeError( raise RuntimeError(
@ -55,11 +39,11 @@ class VespaRetriever(BaseRetriever):
docs = [] docs = []
for child in response.hits: for child in response.hits:
page_content = child["fields"].pop(self._content_field, "") page_content = child["fields"].pop(self.content_field, "")
if self._metadata_fields == "*": if self.metadata_fields == "*":
metadata = child["fields"] metadata = child["fields"]
else: else:
metadata = {mf: child["fields"].get(mf) for mf in self._metadata_fields} metadata = {mf: child["fields"].get(mf) for mf in self.metadata_fields}
metadata["id"] = child["id"] metadata["id"] = child["id"]
docs.append(Document(page_content=page_content, metadata=metadata)) docs.append(Document(page_content=page_content, metadata=metadata))
return docs return docs
@ -67,7 +51,7 @@ class VespaRetriever(BaseRetriever):
def _get_relevant_documents( def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]: ) -> List[Document]:
body = self._query_body.copy() body = self.body.copy()
body["query"] = query body["query"] = query
return self._query(body) return self._query(body)
@ -79,7 +63,7 @@ class VespaRetriever(BaseRetriever):
def get_relevant_documents_with_filter( def get_relevant_documents_with_filter(
self, query: str, *, _filter: Optional[str] = None self, query: str, *, _filter: Optional[str] = None
) -> List[Document]: ) -> List[Document]:
body = self._query_body.copy() body = self.body.copy()
_filter = f" and {_filter}" if _filter else "" _filter = f" and {_filter}" if _filter else ""
body["yql"] = body["yql"] + _filter body["yql"] = body["yql"] + _filter
body["query"] = query body["query"] = query
@ -139,4 +123,9 @@ class VespaRetriever(BaseRetriever):
body["yql"] = yql body["yql"] = yql
if k: if k:
body["hits"] = k body["hits"] = k
return cls(app, body, content_field, metadata_fields=metadata_fields) return cls(
app=app,
body=body,
content_field=content_field,
metadata_fields=metadata_fields,
)

View File

@ -2,10 +2,10 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, cast
from uuid import uuid4 from uuid import uuid4
from pydantic import Extra from pydantic import root_validator
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun,
@ -16,16 +16,19 @@ from langchain.schema import BaseRetriever
class WeaviateHybridSearchRetriever(BaseRetriever): class WeaviateHybridSearchRetriever(BaseRetriever):
def __init__( client: Any
self, index_name: str
client: Any, text_key: str
index_name: str, alpha: float = 0.5
text_key: str, k: int = 4
alpha: float = 0.5, attributes: List[str]
k: int = 4, create_schema_if_missing: bool = True
attributes: Optional[List[str]] = None,
create_schema_if_missing: bool = True, @root_validator(pre=True)
): def validate_client(
cls,
values: Dict[str, Any],
) -> Dict[str, Any]:
try: try:
import weaviate import weaviate
except ImportError: except ImportError:
@ -33,36 +36,31 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
"Could not import weaviate python package. " "Could not import weaviate python package. "
"Please install it with `pip install weaviate-client`." "Please install it with `pip install weaviate-client`."
) )
if not isinstance(client, weaviate.Client): if not isinstance(values["client"], weaviate.Client):
client = values["client"]
raise ValueError( raise ValueError(
f"client should be an instance of weaviate.Client, got {type(client)}" f"client should be an instance of weaviate.Client, got {type(client)}"
) )
self._client = client if values["attributes"] is None:
self.k = k values["attributes"] = []
self.alpha = alpha
self._index_name = index_name
self._text_key = text_key
self._query_attrs = [self._text_key]
if attributes is not None:
self._query_attrs.extend(attributes)
if create_schema_if_missing: cast(List, values["attributes"]).append(values["text_key"])
self._create_schema_if_missing()
def _create_schema_if_missing(self) -> None: if values["create_schema_if_missing"]:
class_obj = { class_obj = {
"class": self._index_name, "class": values["index_name"],
"properties": [{"name": self._text_key, "dataType": ["text"]}], "properties": [{"name": values["text_key"], "dataType": ["text"]}],
"vectorizer": "text2vec-openai", "vectorizer": "text2vec-openai",
} }
if not self._client.schema.exists(self._index_name): if not values["client"].schema.exists(values["index_name"]):
self._client.schema.create_class(class_obj) values["client"].schema.create_class(class_obj)
return values
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True arbitrary_types_allowed = True
# added text_key # added text_key
@ -70,11 +68,11 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
"""Upload documents to Weaviate.""" """Upload documents to Weaviate."""
from weaviate.util import get_valid_uuid from weaviate.util import get_valid_uuid
with self._client.batch as batch: with self.client.batch as batch:
ids = [] ids = []
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
metadata = doc.metadata or {} metadata = doc.metadata or {}
data_properties = {self._text_key: doc.page_content, **metadata} data_properties = {self.text_key: doc.page_content, **metadata}
# If the UUID of one of the objects already exists # If the UUID of one of the objects already exists
# then the existing objectwill be replaced by the new object. # then the existing objectwill be replaced by the new object.
@ -83,7 +81,7 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
else: else:
_id = get_valid_uuid(uuid4()) _id = get_valid_uuid(uuid4())
batch.add_data_object(data_properties, self._index_name, _id) batch.add_data_object(data_properties, self.index_name, _id)
ids.append(_id) ids.append(_id)
return ids return ids
@ -95,7 +93,7 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
where_filter: Optional[Dict[str, object]] = None, where_filter: Optional[Dict[str, object]] = None,
) -> List[Document]: ) -> List[Document]:
"""Look up similar documents in Weaviate.""" """Look up similar documents in Weaviate."""
query_obj = self._client.query.get(self._index_name, self._query_attrs) query_obj = self.client.query.get(self.index_name, self.attributes)
if where_filter: if where_filter:
query_obj = query_obj.with_where(where_filter) query_obj = query_obj.with_where(where_filter)
@ -105,8 +103,8 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
docs = [] docs = []
for res in result["data"]["Get"][self._index_name]: for res in result["data"]["Get"][self.index_name]:
text = res.pop(self._text_key) text = res.pop(self.text_key)
docs.append(Document(page_content=text, metadata=res)) docs.append(Document(page_content=text, metadata=res))
return docs return docs

View File

@ -1,6 +1,8 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
from pydantic import root_validator
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun,
@ -27,13 +29,14 @@ class ZepRetriever(BaseRetriever):
https://docs.getzep.com/deployment/quickstart/ https://docs.getzep.com/deployment/quickstart/
""" """
def __init__( zep_client: Any
self,
session_id: str, session_id: str
url: str,
api_key: Optional[str] = None, top_k: Optional[int]
top_k: Optional[int] = None,
): @root_validator(pre=True)
def create_client(cls, values: dict) -> dict:
try: try:
from zep_python import ZepClient from zep_python import ZepClient
except ImportError: except ImportError:
@ -41,10 +44,11 @@ class ZepRetriever(BaseRetriever):
"Could not import zep-python package. " "Could not import zep-python package. "
"Please install it with `pip install zep-python`." "Please install it with `pip install zep-python`."
) )
values["zep_client"] = values.get(
self.zep_client = ZepClient(base_url=url, api_key=api_key) "zep_client",
self.session_id = session_id ZepClient(base_url=values["url"], api_key=values.get("api_key")),
self.top_k = top_k )
return values
def _search_result_to_doc( def _search_result_to_doc(
self, results: List[MemorySearchResult] self, results: List[MemorySearchResult]

View File

@ -2,6 +2,8 @@
import warnings import warnings
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import root_validator
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun,
@ -16,21 +18,27 @@ from langchain.vectorstores.zilliz import Zilliz
class ZillizRetriever(BaseRetriever): class ZillizRetriever(BaseRetriever):
"""Retriever that uses the Zilliz API.""" """Retriever that uses the Zilliz API."""
def __init__( embedding_function: Embeddings
self, collection_name: str = "LangChainCollection"
embedding_function: Embeddings, connection_args: Optional[Dict[str, Any]] = None
collection_name: str = "LangChainCollection", consistency_level: str = "Session"
connection_args: Optional[Dict[str, Any]] = None, search_params: Optional[dict] = None
consistency_level: str = "Session",
search_params: Optional[dict] = None, store: Zilliz
): retriever: BaseRetriever
self.store = Zilliz(
embedding_function, @root_validator(pre=True)
collection_name, def create_client(cls, values: dict) -> dict:
connection_args, values["store"] = Zilliz(
consistency_level, values["embedding_function"],
values["collection_name"],
values["connection_args"],
values["consistency_level"],
) )
self.retriever = self.store.as_retriever(search_kwargs={"param": search_params}) values["retriever"] = values["store"].as_retriever(
search_kwargs={"param": values["search_params"]}
)
return values
def add_texts( def add_texts(
self, texts: List[str], metadatas: Optional[List[dict]] = None self, texts: List[str], metadatas: Optional[List[dict]] = None

View File

@ -5,6 +5,8 @@ from abc import ABC, abstractmethod
from inspect import signature from inspect import signature
from typing import TYPE_CHECKING, Any, List from typing import TYPE_CHECKING, Any, List
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
from langchain.schema.document import Document from langchain.schema.document import Document
if TYPE_CHECKING: if TYPE_CHECKING:
@ -15,7 +17,7 @@ if TYPE_CHECKING:
) )
class BaseRetriever(ABC): class BaseRetriever(Serializable, ABC):
"""Abstract base class for a Document retrieval system. """Abstract base class for a Document retrieval system.
A retrieval system is defined as something that can take string queries and return A retrieval system is defined as something that can take string queries and return
@ -46,6 +48,11 @@ class BaseRetriever(ABC):
raise NotImplementedError raise NotImplementedError
""" # noqa: E501 """ # noqa: E501
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
_new_arg_supported: bool = False _new_arg_supported: bool = False
_expects_other_args: bool = False _expects_other_args: bool = False
@ -81,7 +88,9 @@ class BaseRetriever(ABC):
parameters = signature(cls._get_relevant_documents).parameters parameters = signature(cls._get_relevant_documents).parameters
cls._new_arg_supported = parameters.get("run_manager") is not None cls._new_arg_supported = parameters.get("run_manager") is not None
# If a V1 retriever broke the interface and expects additional arguments # If a V1 retriever broke the interface and expects additional arguments
cls._expects_other_args = (not cls._new_arg_supported) and len(parameters) > 2 cls._expects_other_args = (
len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0
)
@abstractmethod @abstractmethod
def _get_relevant_documents( def _get_relevant_documents(
@ -123,6 +132,7 @@ class BaseRetriever(ABC):
callbacks, None, verbose=kwargs.get("verbose", False) callbacks, None, verbose=kwargs.get("verbose", False)
) )
run_manager = callback_manager.on_retriever_start( run_manager = callback_manager.on_retriever_start(
dumpd(self),
query, query,
**kwargs, **kwargs,
) )
@ -160,6 +170,7 @@ class BaseRetriever(ABC):
callbacks, None, verbose=kwargs.get("verbose", False) callbacks, None, verbose=kwargs.get("verbose", False)
) )
run_manager = await callback_manager.on_retriever_start( run_manager = await callback_manager.on_retriever_start(
dumpd(self),
query, query,
**kwargs, **kwargs,
) )

View File

@ -3,7 +3,7 @@ import logging
import os import os
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, root_validator
from langchain.schema import Document from langchain.schema import Document
@ -40,11 +40,6 @@ class ArxivAPIWrapper(BaseModel):
load_all_available_meta: bool = False load_all_available_meta: bool = False
doc_content_chars_max: Optional[int] = 4000 doc_content_chars_max: Optional[int] = 4000
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment.""" """Validate that the python package exists in environment."""

View File

@ -5,7 +5,7 @@ import urllib.error
import urllib.request import urllib.request
from typing import List from typing import List
from pydantic import BaseModel, Extra from pydantic import BaseModel
from langchain.schema import Document from langchain.schema import Document
@ -42,11 +42,6 @@ class PubMedAPIWrapper(BaseModel):
load_all_available_meta: bool = False load_all_available_meta: bool = False
email: str = "your_email@example.com" email: str = "your_email@example.com"
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def run(self, query: str) -> str: def run(self, query: str) -> str:
""" """
Run PubMed search and get the article meta information. Run PubMed search and get the article meta information.

View File

@ -2,7 +2,7 @@
import logging import logging
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, root_validator
from langchain.schema import Document from langchain.schema import Document
@ -27,11 +27,6 @@ class WikipediaAPIWrapper(BaseModel):
load_all_available_meta: bool = False load_all_available_meta: bool = False
doc_content_chars_max: int = 4000 doc_content_chars_max: int = 4000
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment.""" """Validate that the python package exists in environment."""

View File

@ -24,7 +24,7 @@ def test_merger_retriever_get_relevant_docs() -> None:
) )
# The Lord of the Retrievers. # The Lord of the Retrievers.
lotr = MergerRetriever([retriever_a, retriever_b]) lotr = MergerRetriever(retrievers=[retriever_a, retriever_b])
actual = lotr.get_relevant_documents("Tell me about the Celtics") actual = lotr.get_relevant_documents("Tell me about the Celtics")
assert len(actual) == 2 assert len(actual) == 2

View File

@ -146,7 +146,7 @@ def test_ignore_retriever() -> None:
handler1 = FakeCallbackHandler(ignore_retriever_=True) handler1 = FakeCallbackHandler(ignore_retriever_=True)
handler2 = FakeCallbackHandler() handler2 = FakeCallbackHandler()
manager = CallbackManager(handlers=[handler1, handler2]) manager = CallbackManager(handlers=[handler1, handler2])
run_manager = manager.on_retriever_start("") run_manager = manager.on_retriever_start({}, "")
run_manager.on_retriever_end([]) run_manager.on_retriever_end([])
run_manager.on_retriever_error(Exception()) run_manager.on_retriever_error(Exception())
assert handler1.starts == 0 assert handler1.starts == 0

View File

@ -142,8 +142,7 @@ async def test_fake_retriever_v1_with_kwargs_upgrade_async(
class FakeRetrieverV2(BaseRetriever): class FakeRetrieverV2(BaseRetriever):
def __init__(self, throw_error: bool = False) -> None: throw_error: bool = False
self.throw_error = throw_error
def _get_relevant_documents( def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun | None self, query: str, *, run_manager: CallbackManagerForRetrieverRun | None