diff --git a/langchain/callbacks/base.py b/langchain/callbacks/base.py index a3a9bbe0f7..039b4fd1e2 100644 --- a/langchain/callbacks/base.py +++ b/langchain/callbacks/base.py @@ -168,10 +168,12 @@ class CallbackManagerMixin: def on_retriever_start( self, + serialized: Dict[str, Any], query: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, **kwargs: Any, ) -> Any: """Run when Retriever starts running.""" @@ -421,6 +423,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): async def on_retriever_start( self, + serialized: Dict[str, Any], query: str, *, run_id: UUID, diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py index 42783176e2..3ae9e61240 100644 --- a/langchain/callbacks/manager.py +++ b/langchain/callbacks/manager.py @@ -1196,6 +1196,7 @@ class CallbackManager(BaseCallbackManager): def on_retriever_start( self, + serialized: Dict[str, Any], query: str, run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None, @@ -1209,6 +1210,7 @@ class CallbackManager(BaseCallbackManager): self.handlers, "on_retriever_start", "ignore_retriever", + serialized, query, run_id=run_id, parent_run_id=self.parent_run_id, @@ -1463,6 +1465,7 @@ class AsyncCallbackManager(BaseCallbackManager): async def on_retriever_start( self, + serialized: Dict[str, Any], query: str, run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None, @@ -1476,6 +1479,7 @@ class AsyncCallbackManager(BaseCallbackManager): self.handlers, "on_retriever_start", "ignore_retriever", + serialized, query, run_id=run_id, parent_run_id=self.parent_run_id, diff --git a/langchain/callbacks/tracers/base.py b/langchain/callbacks/tracers/base.py index 040aef5c08..0c25264ae4 100644 --- a/langchain/callbacks/tracers/base.py +++ b/langchain/callbacks/tracers/base.py @@ -312,6 +312,7 @@ class BaseTracer(BaseCallbackHandler, ABC): def on_retriever_start( self, + serialized: Dict[str, Any], query: str, *, run_id: UUID, @@ -326,6 +327,7 @@ class BaseTracer(BaseCallbackHandler, ABC): id=run_id, name="Retriever", parent_run_id=parent_run_id, + serialized=serialized, inputs={"query": query}, extra=kwargs, events=[{"name": "start", "time": start_time}], diff --git a/langchain/retrievers/azure_cognitive_search.py b/langchain/retrievers/azure_cognitive_search.py index 518750d663..efaa0a8792 100644 --- a/langchain/retrievers/azure_cognitive_search.py +++ b/langchain/retrievers/azure_cognitive_search.py @@ -7,7 +7,7 @@ from typing import Dict, List, Optional import aiohttp import requests -from pydantic import BaseModel, Extra, root_validator +from pydantic import Extra, root_validator from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, @@ -17,7 +17,7 @@ from langchain.schema import BaseRetriever, Document from langchain.utils import get_from_dict_or_env -class AzureCognitiveSearchRetriever(BaseRetriever, BaseModel): +class AzureCognitiveSearchRetriever(BaseRetriever): """Wrapper around Azure Cognitive Search.""" service_name: str = "" diff --git a/langchain/retrievers/chatgpt_plugin_retriever.py b/langchain/retrievers/chatgpt_plugin_retriever.py index 06a25735c9..4eba2a4459 100644 --- a/langchain/retrievers/chatgpt_plugin_retriever.py +++ b/langchain/retrievers/chatgpt_plugin_retriever.py @@ -4,7 +4,6 @@ from typing import List, Optional import aiohttp import requests -from pydantic import BaseModel from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, @@ -13,7 +12,7 @@ from langchain.callbacks.manager import ( from langchain.schema import BaseRetriever, Document -class ChatGPTPluginRetriever(BaseRetriever, BaseModel): +class ChatGPTPluginRetriever(BaseRetriever): url: str bearer_token: str top_k: int = 3 diff --git a/langchain/retrievers/contextual_compression.py b/langchain/retrievers/contextual_compression.py index 634706d280..d3810893e5 100644 --- a/langchain/retrievers/contextual_compression.py +++ b/langchain/retrievers/contextual_compression.py @@ -2,8 +2,6 @@ from typing import Any, List -from pydantic import BaseModel, Extra - from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, @@ -14,7 +12,7 @@ from langchain.retrievers.document_compressors.base import ( from langchain.schema import BaseRetriever, Document -class ContextualCompressionRetriever(BaseRetriever, BaseModel): +class ContextualCompressionRetriever(BaseRetriever): """Retriever that wraps a base retriever and compresses the results.""" base_compressor: BaseDocumentCompressor @@ -26,7 +24,6 @@ class ContextualCompressionRetriever(BaseRetriever, BaseModel): class Config: """Configuration for this pydantic object.""" - extra = Extra.forbid arbitrary_types_allowed = True def _get_relevant_documents( diff --git a/langchain/retrievers/databerry.py b/langchain/retrievers/databerry.py index 753823bdb4..aa03d1cf46 100644 --- a/langchain/retrievers/databerry.py +++ b/langchain/retrievers/databerry.py @@ -17,16 +17,6 @@ class DataberryRetriever(BaseRetriever): top_k: Optional[int] api_key: Optional[str] - def __init__( - self, - datastore_url: str, - top_k: Optional[int] = None, - api_key: Optional[str] = None, - ): - self.datastore_url = datastore_url - self.api_key = api_key - self.top_k = top_k - def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: diff --git a/langchain/retrievers/docarray.py b/langchain/retrievers/docarray.py index 93b35d1c40..e50a654ad3 100644 --- a/langchain/retrievers/docarray.py +++ b/langchain/retrievers/docarray.py @@ -2,7 +2,6 @@ from enum import Enum from typing import Any, Dict, List, Optional, Union import numpy as np -from pydantic import BaseModel from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, @@ -20,7 +19,7 @@ class SearchType(str, Enum): mmr = "mmr" -class DocArrayRetriever(BaseRetriever, BaseModel): +class DocArrayRetriever(BaseRetriever): """ Retriever class for DocArray Document Indices. diff --git a/langchain/retrievers/elastic_search_bm25.py b/langchain/retrievers/elastic_search_bm25.py index 81d0192554..e1e09f55c1 100644 --- a/langchain/retrievers/elastic_search_bm25.py +++ b/langchain/retrievers/elastic_search_bm25.py @@ -40,9 +40,8 @@ class ElasticSearchBM25Retriever(BaseRetriever): https://username:password@cluster_id.region_id.gcp.cloud.es.io:9243. """ - def __init__(self, client: Any, index_name: str): - self.client = client - self.index_name = index_name + client: Any + index_name: str @classmethod def create( @@ -75,7 +74,7 @@ class ElasticSearchBM25Retriever(BaseRetriever): # Create the index with the specified settings and mappings es.indices.create(index=index_name, mappings=mappings, settings=settings) - return cls(es, index_name) + return cls(client=es, index_name=index_name) def add_texts( self, diff --git a/langchain/retrievers/kendra.py b/langchain/retrievers/kendra.py index 47d2e321ba..8008694001 100644 --- a/langchain/retrievers/kendra.py +++ b/langchain/retrievers/kendra.py @@ -1,7 +1,7 @@ import re from typing import Any, Dict, List, Literal, Optional -from pydantic import BaseModel, Extra +from pydantic import BaseModel, Extra, root_validator from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, @@ -179,37 +179,34 @@ class AmazonKendraRetriever(BaseRetriever): """ - def __init__( - self, - index_id: str, - region_name: Optional[str] = None, - credentials_profile_name: Optional[str] = None, - top_k: int = 3, - attribute_filter: Optional[Dict] = None, - client: Optional[Any] = None, - ): - self.index_id = index_id - self.top_k = top_k - self.attribute_filter = attribute_filter + index_id: str + region_name: Optional[str] = None + credentials_profile_name: Optional[str] = None + top_k: int = 3 + attribute_filter: Optional[Dict] = None + client: Any - if client is not None: - self.client = client - return + @root_validator(pre=True) + def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if values["client"] is not None: + return values try: import boto3 - if credentials_profile_name is not None: - session = boto3.Session(profile_name=credentials_profile_name) + if values["credentials_profile_name"] is not None: + session = boto3.Session(profile_name=values["credentials_profile_name"]) else: # use default credentials session = boto3.Session() client_params = {} - if region_name is not None: - client_params["region_name"] = region_name + if values["region_name"] is not None: + client_params["region_name"] = values["region_name"] + + values["client"] = session.client("kendra", **client_params) - self.client = session.client("kendra", **client_params) + return values except ImportError: raise ModuleNotFoundError( "Could not import boto3 python package. " diff --git a/langchain/retrievers/knn.py b/langchain/retrievers/knn.py index d41f5ab32f..945909d10b 100644 --- a/langchain/retrievers/knn.py +++ b/langchain/retrievers/knn.py @@ -8,7 +8,6 @@ import concurrent.futures from typing import Any, List, Optional import numpy as np -from pydantic import BaseModel from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, @@ -33,7 +32,7 @@ def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray: return np.array(list(executor.map(embeddings.embed_query, contexts))) -class KNNRetriever(BaseRetriever, BaseModel): +class KNNRetriever(BaseRetriever): """KNN Retriever.""" embeddings: Embeddings diff --git a/langchain/retrievers/llama_index.py b/langchain/retrievers/llama_index.py index 90be4d6445..8cce86418d 100644 --- a/langchain/retrievers/llama_index.py +++ b/langchain/retrievers/llama_index.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, cast -from pydantic import BaseModel, Field +from pydantic import Field from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, @@ -9,7 +9,7 @@ from langchain.callbacks.manager import ( from langchain.schema import BaseRetriever, Document -class LlamaIndexRetriever(BaseRetriever, BaseModel): +class LlamaIndexRetriever(BaseRetriever): """Question-answering with sources over an LlamaIndex data structure.""" index: Any @@ -45,7 +45,7 @@ class LlamaIndexRetriever(BaseRetriever, BaseModel): raise NotImplementedError("LlamaIndexRetriever does not support async") -class LlamaIndexGraphRetriever(BaseRetriever, BaseModel): +class LlamaIndexGraphRetriever(BaseRetriever): """Question-answering with sources over an LlamaIndex graph data structure.""" graph: Any diff --git a/langchain/retrievers/merger_retriever.py b/langchain/retrievers/merger_retriever.py index 59d217c31c..b1f8b32992 100644 --- a/langchain/retrievers/merger_retriever.py +++ b/langchain/retrievers/merger_retriever.py @@ -15,18 +15,7 @@ class MergerRetriever(BaseRetriever): retrievers: A list of retrievers to merge. """ - def __init__( - self, - retrievers: List[BaseRetriever], - ): - """ - Initialize the MergerRetriever class. - - Args: - retrievers: A list of retrievers to merge. - """ - - self.retrievers = retrievers + retrievers: List[BaseRetriever] def _get_relevant_documents( self, diff --git a/langchain/retrievers/metal.py b/langchain/retrievers/metal.py index e5e3b6fae8..80dd12cc31 100644 --- a/langchain/retrievers/metal.py +++ b/langchain/retrievers/metal.py @@ -1,5 +1,7 @@ from typing import Any, List, Optional +from pydantic import root_validator + from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, @@ -10,16 +12,26 @@ from langchain.schema import BaseRetriever, Document class MetalRetriever(BaseRetriever): """Retriever that uses the Metal API.""" - def __init__(self, client: Any, params: Optional[dict] = None): + client: Any + + params: Optional[dict] = None + + @root_validator(pre=True) + def validate_client(cls, values: dict) -> dict: + """Validate that the client is of the correct type.""" from metal_sdk.metal import Metal - if not isinstance(client, Metal): - raise ValueError( - "Got unexpected client, should be of type metal_sdk.metal.Metal. " - f"Instead, got {type(client)}" - ) - self.client: Metal = client - self.params = params or {} + if "client" in values: + client = values["client"] + if not isinstance(client, Metal): + raise ValueError( + "Got unexpected client, should be of type metal_sdk.metal.Metal. " + f"Instead, got {type(client)}" + ) + + values["params"] = values.get("params", {}) + + return values def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun diff --git a/langchain/retrievers/milvus.py b/langchain/retrievers/milvus.py index 3255cde4ba..6541ce441f 100644 --- a/langchain/retrievers/milvus.py +++ b/langchain/retrievers/milvus.py @@ -2,6 +2,8 @@ import warnings from typing import Any, Dict, List, Optional +from pydantic import root_validator + from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, @@ -16,21 +18,28 @@ from langchain.vectorstores.milvus import Milvus class MilvusRetriever(BaseRetriever): """Retriever that uses the Milvus API.""" - def __init__( - self, - embedding_function: Embeddings, - collection_name: str = "LangChainCollection", - connection_args: Optional[Dict[str, Any]] = None, - consistency_level: str = "Session", - search_params: Optional[dict] = None, - ): - self.store = Milvus( - embedding_function, - collection_name, - connection_args, - consistency_level, + embedding_function: Embeddings + collection_name: str = "LangChainCollection" + connection_args: Optional[Dict[str, Any]] = None + consistency_level: str = "Session" + search_params: Optional[dict] = None + + store: Milvus + retriever: BaseRetriever + + @root_validator(pre=True) + def create_retriever(cls, values: Dict) -> Dict: + """Create the Milvus store and retriever.""" + values["store"] = Milvus( + values["embedding_function"], + values["collection_name"], + values["connection_args"], + values["consistency_level"], + ) + values["retriever"] = values["store"].as_retriever( + search_kwargs={"param": values["search_params"]} ) - self.retriever = self.store.as_retriever(search_kwargs={"param": search_params}) + return values def add_texts( self, texts: List[str], metadatas: Optional[List[dict]] = None diff --git a/langchain/retrievers/multi_query.py b/langchain/retrievers/multi_query.py index 10c540d94c..4da52888e3 100644 --- a/langchain/retrievers/multi_query.py +++ b/langchain/retrievers/multi_query.py @@ -47,28 +47,10 @@ class MultiQueryRetriever(BaseRetriever): """Given a user query, use an LLM to write a set of queries. Retrieve docs for each query. Rake the unique union of all retrieved docs.""" - def __init__( - self, - retriever: BaseRetriever, - llm_chain: LLMChain, - verbose: bool = True, - parser_key: str = "lines", - ) -> None: - """Initialize MultiQueryRetriever. - - Args: - retriever: retriever to query documents from - llm_chain: llm_chain for query generation - verbose: show the queries that we generated to the user - parser_key: attribute name for the parsed output - - Returns: - MultiQueryRetriever - """ - self.retriever = retriever - self.llm_chain = llm_chain - self.verbose = verbose - self.parser_key = parser_key + retriever: BaseRetriever + llm_chain: LLMChain + verbose: bool = True + parser_key: str = "lines" @classmethod def from_llm( diff --git a/langchain/retrievers/pinecone_hybrid_search.py b/langchain/retrievers/pinecone_hybrid_search.py index eca21de810..6e8541937d 100644 --- a/langchain/retrievers/pinecone_hybrid_search.py +++ b/langchain/retrievers/pinecone_hybrid_search.py @@ -3,7 +3,7 @@ import hashlib from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Extra, root_validator +from pydantic import Extra, root_validator from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, @@ -98,7 +98,7 @@ def create_index( index.upsert(vectors) -class PineconeHybridSearchRetriever(BaseRetriever, BaseModel): +class PineconeHybridSearchRetriever(BaseRetriever): embeddings: Embeddings """description""" sparse_encoder: Any diff --git a/langchain/retrievers/remote_retriever.py b/langchain/retrievers/remote_retriever.py index a75d87f9e0..f0f7ba4dc8 100644 --- a/langchain/retrievers/remote_retriever.py +++ b/langchain/retrievers/remote_retriever.py @@ -2,7 +2,6 @@ from typing import List, Optional import aiohttp import requests -from pydantic import BaseModel from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, @@ -11,7 +10,7 @@ from langchain.callbacks.manager import ( from langchain.schema import BaseRetriever, Document -class RemoteLangChainRetriever(BaseRetriever, BaseModel): +class RemoteLangChainRetriever(BaseRetriever): url: str headers: Optional[dict] = None input_key: str = "message" diff --git a/langchain/retrievers/svm.py b/langchain/retrievers/svm.py index 5fcc3cd665..95792a8531 100644 --- a/langchain/retrievers/svm.py +++ b/langchain/retrievers/svm.py @@ -8,7 +8,6 @@ import concurrent.futures from typing import Any, List, Optional import numpy as np -from pydantic import BaseModel from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, @@ -32,7 +31,7 @@ def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray: return np.array(list(executor.map(embeddings.embed_query, contexts))) -class SVMRetriever(BaseRetriever, BaseModel): +class SVMRetriever(BaseRetriever): """SVM Retriever.""" embeddings: Embeddings diff --git a/langchain/retrievers/tfidf.py b/langchain/retrievers/tfidf.py index 818541f09e..5517de547f 100644 --- a/langchain/retrievers/tfidf.py +++ b/langchain/retrievers/tfidf.py @@ -7,8 +7,6 @@ from __future__ import annotations from typing import Any, Dict, Iterable, List, Optional -from pydantic import BaseModel - from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, @@ -16,7 +14,7 @@ from langchain.callbacks.manager import ( from langchain.schema import BaseRetriever, Document -class TFIDFRetriever(BaseRetriever, BaseModel): +class TFIDFRetriever(BaseRetriever): vectorizer: Any docs: List[Document] tfidf_array: Any diff --git a/langchain/retrievers/time_weighted_retriever.py b/langchain/retrievers/time_weighted_retriever.py index 0340785844..64e641a6a5 100644 --- a/langchain/retrievers/time_weighted_retriever.py +++ b/langchain/retrievers/time_weighted_retriever.py @@ -4,7 +4,7 @@ import datetime from copy import deepcopy from typing import Any, Dict, List, Optional, Tuple -from pydantic import BaseModel, Field +from pydantic import Field from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, @@ -19,7 +19,7 @@ def _get_hours_passed(time: datetime.datetime, ref_time: datetime.datetime) -> f return (time - ref_time).total_seconds() / 3600 -class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel): +class TimeWeightedVectorStoreRetriever(BaseRetriever): """Retriever combining embedding similarity with recency.""" vectorstore: VectorStore diff --git a/langchain/retrievers/vespa_retriever.py b/langchain/retrievers/vespa_retriever.py index c8eb04b8b7..29bd420ced 100644 --- a/langchain/retrievers/vespa_retriever.py +++ b/langchain/retrievers/vespa_retriever.py @@ -18,29 +18,13 @@ if TYPE_CHECKING: class VespaRetriever(BaseRetriever): """Retriever that uses the Vespa.""" - def __init__( - self, - app: Vespa, - body: Dict, - content_field: str, - metadata_fields: Optional[Sequence[str]] = None, - ): - """ - - Args: - app: Vespa client. - body: query body. - content_field: result field with document contents. - metadata_fields: result fields to include in document metadata. - - """ - self._application = app - self._query_body = body - self._content_field = content_field - self._metadata_fields = metadata_fields or () + app: Vespa + body: Dict + content_field: str + metadata_fields: Sequence[str] def _query(self, body: Dict) -> List[Document]: - response = self._application.query(body) + response = self.app.query(body) if not str(response.status_code).startswith("2"): raise RuntimeError( @@ -55,11 +39,11 @@ class VespaRetriever(BaseRetriever): docs = [] for child in response.hits: - page_content = child["fields"].pop(self._content_field, "") - if self._metadata_fields == "*": + page_content = child["fields"].pop(self.content_field, "") + if self.metadata_fields == "*": metadata = child["fields"] else: - metadata = {mf: child["fields"].get(mf) for mf in self._metadata_fields} + metadata = {mf: child["fields"].get(mf) for mf in self.metadata_fields} metadata["id"] = child["id"] docs.append(Document(page_content=page_content, metadata=metadata)) return docs @@ -67,7 +51,7 @@ class VespaRetriever(BaseRetriever): def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: - body = self._query_body.copy() + body = self.body.copy() body["query"] = query return self._query(body) @@ -79,7 +63,7 @@ class VespaRetriever(BaseRetriever): def get_relevant_documents_with_filter( self, query: str, *, _filter: Optional[str] = None ) -> List[Document]: - body = self._query_body.copy() + body = self.body.copy() _filter = f" and {_filter}" if _filter else "" body["yql"] = body["yql"] + _filter body["query"] = query @@ -139,4 +123,9 @@ class VespaRetriever(BaseRetriever): body["yql"] = yql if k: body["hits"] = k - return cls(app, body, content_field, metadata_fields=metadata_fields) + return cls( + app=app, + body=body, + content_field=content_field, + metadata_fields=metadata_fields, + ) diff --git a/langchain/retrievers/weaviate_hybrid_search.py b/langchain/retrievers/weaviate_hybrid_search.py index f8ad46fa44..84a5c7e963 100644 --- a/langchain/retrievers/weaviate_hybrid_search.py +++ b/langchain/retrievers/weaviate_hybrid_search.py @@ -2,10 +2,10 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, cast from uuid import uuid4 -from pydantic import Extra +from pydantic import root_validator from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, @@ -16,16 +16,19 @@ from langchain.schema import BaseRetriever class WeaviateHybridSearchRetriever(BaseRetriever): - def __init__( - self, - client: Any, - index_name: str, - text_key: str, - alpha: float = 0.5, - k: int = 4, - attributes: Optional[List[str]] = None, - create_schema_if_missing: bool = True, - ): + client: Any + index_name: str + text_key: str + alpha: float = 0.5 + k: int = 4 + attributes: List[str] + create_schema_if_missing: bool = True + + @root_validator(pre=True) + def validate_client( + cls, + values: Dict[str, Any], + ) -> Dict[str, Any]: try: import weaviate except ImportError: @@ -33,36 +36,31 @@ class WeaviateHybridSearchRetriever(BaseRetriever): "Could not import weaviate python package. " "Please install it with `pip install weaviate-client`." ) - if not isinstance(client, weaviate.Client): + if not isinstance(values["client"], weaviate.Client): + client = values["client"] raise ValueError( f"client should be an instance of weaviate.Client, got {type(client)}" ) - self._client = client - self.k = k - self.alpha = alpha - self._index_name = index_name - self._text_key = text_key - self._query_attrs = [self._text_key] - if attributes is not None: - self._query_attrs.extend(attributes) - - if create_schema_if_missing: - self._create_schema_if_missing() - - def _create_schema_if_missing(self) -> None: - class_obj = { - "class": self._index_name, - "properties": [{"name": self._text_key, "dataType": ["text"]}], - "vectorizer": "text2vec-openai", - } - - if not self._client.schema.exists(self._index_name): - self._client.schema.create_class(class_obj) + if values["attributes"] is None: + values["attributes"] = [] + + cast(List, values["attributes"]).append(values["text_key"]) + + if values["create_schema_if_missing"]: + class_obj = { + "class": values["index_name"], + "properties": [{"name": values["text_key"], "dataType": ["text"]}], + "vectorizer": "text2vec-openai", + } + + if not values["client"].schema.exists(values["index_name"]): + values["client"].schema.create_class(class_obj) + + return values class Config: """Configuration for this pydantic object.""" - extra = Extra.forbid arbitrary_types_allowed = True # added text_key @@ -70,11 +68,11 @@ class WeaviateHybridSearchRetriever(BaseRetriever): """Upload documents to Weaviate.""" from weaviate.util import get_valid_uuid - with self._client.batch as batch: + with self.client.batch as batch: ids = [] for i, doc in enumerate(docs): metadata = doc.metadata or {} - data_properties = {self._text_key: doc.page_content, **metadata} + data_properties = {self.text_key: doc.page_content, **metadata} # If the UUID of one of the objects already exists # then the existing objectwill be replaced by the new object. @@ -83,7 +81,7 @@ class WeaviateHybridSearchRetriever(BaseRetriever): else: _id = get_valid_uuid(uuid4()) - batch.add_data_object(data_properties, self._index_name, _id) + batch.add_data_object(data_properties, self.index_name, _id) ids.append(_id) return ids @@ -95,7 +93,7 @@ class WeaviateHybridSearchRetriever(BaseRetriever): where_filter: Optional[Dict[str, object]] = None, ) -> List[Document]: """Look up similar documents in Weaviate.""" - query_obj = self._client.query.get(self._index_name, self._query_attrs) + query_obj = self.client.query.get(self.index_name, self.attributes) if where_filter: query_obj = query_obj.with_where(where_filter) @@ -105,8 +103,8 @@ class WeaviateHybridSearchRetriever(BaseRetriever): docs = [] - for res in result["data"]["Get"][self._index_name]: - text = res.pop(self._text_key) + for res in result["data"]["Get"][self.index_name]: + text = res.pop(self.text_key) docs.append(Document(page_content=text, metadata=res)) return docs diff --git a/langchain/retrievers/zep.py b/langchain/retrievers/zep.py index 9062a5a982..64333ff8dc 100644 --- a/langchain/retrievers/zep.py +++ b/langchain/retrievers/zep.py @@ -1,6 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from pydantic import root_validator from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, @@ -27,13 +29,14 @@ class ZepRetriever(BaseRetriever): https://docs.getzep.com/deployment/quickstart/ """ - def __init__( - self, - session_id: str, - url: str, - api_key: Optional[str] = None, - top_k: Optional[int] = None, - ): + zep_client: Any + + session_id: str + + top_k: Optional[int] + + @root_validator(pre=True) + def create_client(cls, values: dict) -> dict: try: from zep_python import ZepClient except ImportError: @@ -41,10 +44,11 @@ class ZepRetriever(BaseRetriever): "Could not import zep-python package. " "Please install it with `pip install zep-python`." ) - - self.zep_client = ZepClient(base_url=url, api_key=api_key) - self.session_id = session_id - self.top_k = top_k + values["zep_client"] = values.get( + "zep_client", + ZepClient(base_url=values["url"], api_key=values.get("api_key")), + ) + return values def _search_result_to_doc( self, results: List[MemorySearchResult] diff --git a/langchain/retrievers/zilliz.py b/langchain/retrievers/zilliz.py index 810f6ae4a0..8ff463d283 100644 --- a/langchain/retrievers/zilliz.py +++ b/langchain/retrievers/zilliz.py @@ -2,6 +2,8 @@ import warnings from typing import Any, Dict, List, Optional +from pydantic import root_validator + from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, @@ -16,21 +18,27 @@ from langchain.vectorstores.zilliz import Zilliz class ZillizRetriever(BaseRetriever): """Retriever that uses the Zilliz API.""" - def __init__( - self, - embedding_function: Embeddings, - collection_name: str = "LangChainCollection", - connection_args: Optional[Dict[str, Any]] = None, - consistency_level: str = "Session", - search_params: Optional[dict] = None, - ): - self.store = Zilliz( - embedding_function, - collection_name, - connection_args, - consistency_level, + embedding_function: Embeddings + collection_name: str = "LangChainCollection" + connection_args: Optional[Dict[str, Any]] = None + consistency_level: str = "Session" + search_params: Optional[dict] = None + + store: Zilliz + retriever: BaseRetriever + + @root_validator(pre=True) + def create_client(cls, values: dict) -> dict: + values["store"] = Zilliz( + values["embedding_function"], + values["collection_name"], + values["connection_args"], + values["consistency_level"], + ) + values["retriever"] = values["store"].as_retriever( + search_kwargs={"param": values["search_params"]} ) - self.retriever = self.store.as_retriever(search_kwargs={"param": search_params}) + return values def add_texts( self, texts: List[str], metadatas: Optional[List[dict]] = None diff --git a/langchain/schema/retriever.py b/langchain/schema/retriever.py index cd4ab75813..ce9e6ce664 100644 --- a/langchain/schema/retriever.py +++ b/langchain/schema/retriever.py @@ -5,6 +5,8 @@ from abc import ABC, abstractmethod from inspect import signature from typing import TYPE_CHECKING, Any, List +from langchain.load.dump import dumpd +from langchain.load.serializable import Serializable from langchain.schema.document import Document if TYPE_CHECKING: @@ -15,7 +17,7 @@ if TYPE_CHECKING: ) -class BaseRetriever(ABC): +class BaseRetriever(Serializable, ABC): """Abstract base class for a Document retrieval system. A retrieval system is defined as something that can take string queries and return @@ -46,6 +48,11 @@ class BaseRetriever(ABC): raise NotImplementedError """ # noqa: E501 + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + _new_arg_supported: bool = False _expects_other_args: bool = False @@ -81,7 +88,9 @@ class BaseRetriever(ABC): parameters = signature(cls._get_relevant_documents).parameters cls._new_arg_supported = parameters.get("run_manager") is not None # If a V1 retriever broke the interface and expects additional arguments - cls._expects_other_args = (not cls._new_arg_supported) and len(parameters) > 2 + cls._expects_other_args = ( + len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0 + ) @abstractmethod def _get_relevant_documents( @@ -123,6 +132,7 @@ class BaseRetriever(ABC): callbacks, None, verbose=kwargs.get("verbose", False) ) run_manager = callback_manager.on_retriever_start( + dumpd(self), query, **kwargs, ) @@ -160,6 +170,7 @@ class BaseRetriever(ABC): callbacks, None, verbose=kwargs.get("verbose", False) ) run_manager = await callback_manager.on_retriever_start( + dumpd(self), query, **kwargs, ) diff --git a/langchain/utilities/arxiv.py b/langchain/utilities/arxiv.py index 92af5df8ec..d958255abf 100644 --- a/langchain/utilities/arxiv.py +++ b/langchain/utilities/arxiv.py @@ -3,7 +3,7 @@ import logging import os from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Extra, root_validator +from pydantic import BaseModel, root_validator from langchain.schema import Document @@ -40,11 +40,6 @@ class ArxivAPIWrapper(BaseModel): load_all_available_meta: bool = False doc_content_chars_max: Optional[int] = 4000 - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that the python package exists in environment.""" diff --git a/langchain/utilities/pupmed.py b/langchain/utilities/pupmed.py index 8870456b21..f242be3abd 100644 --- a/langchain/utilities/pupmed.py +++ b/langchain/utilities/pupmed.py @@ -5,7 +5,7 @@ import urllib.error import urllib.request from typing import List -from pydantic import BaseModel, Extra +from pydantic import BaseModel from langchain.schema import Document @@ -42,11 +42,6 @@ class PubMedAPIWrapper(BaseModel): load_all_available_meta: bool = False email: str = "your_email@example.com" - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - def run(self, query: str) -> str: """ Run PubMed search and get the article meta information. diff --git a/langchain/utilities/wikipedia.py b/langchain/utilities/wikipedia.py index 6a33ccf8bf..1202f8b24e 100644 --- a/langchain/utilities/wikipedia.py +++ b/langchain/utilities/wikipedia.py @@ -2,7 +2,7 @@ import logging from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Extra, root_validator +from pydantic import BaseModel, root_validator from langchain.schema import Document @@ -27,11 +27,6 @@ class WikipediaAPIWrapper(BaseModel): load_all_available_meta: bool = False doc_content_chars_max: int = 4000 - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that the python package exists in environment.""" diff --git a/tests/integration_tests/retrievers/test_merger_retriever.py b/tests/integration_tests/retrievers/test_merger_retriever.py index f42f664478..1dab7507cd 100644 --- a/tests/integration_tests/retrievers/test_merger_retriever.py +++ b/tests/integration_tests/retrievers/test_merger_retriever.py @@ -24,7 +24,7 @@ def test_merger_retriever_get_relevant_docs() -> None: ) # The Lord of the Retrievers. - lotr = MergerRetriever([retriever_a, retriever_b]) + lotr = MergerRetriever(retrievers=[retriever_a, retriever_b]) actual = lotr.get_relevant_documents("Tell me about the Celtics") assert len(actual) == 2 diff --git a/tests/unit_tests/callbacks/test_callback_manager.py b/tests/unit_tests/callbacks/test_callback_manager.py index 5793e71d14..8d82fa90a3 100644 --- a/tests/unit_tests/callbacks/test_callback_manager.py +++ b/tests/unit_tests/callbacks/test_callback_manager.py @@ -146,7 +146,7 @@ def test_ignore_retriever() -> None: handler1 = FakeCallbackHandler(ignore_retriever_=True) handler2 = FakeCallbackHandler() manager = CallbackManager(handlers=[handler1, handler2]) - run_manager = manager.on_retriever_start("") + run_manager = manager.on_retriever_start({}, "") run_manager.on_retriever_end([]) run_manager.on_retriever_error(Exception()) assert handler1.starts == 0 diff --git a/tests/unit_tests/retrievers/test_base.py b/tests/unit_tests/retrievers/test_base.py index 28295d1c56..b511dd7496 100644 --- a/tests/unit_tests/retrievers/test_base.py +++ b/tests/unit_tests/retrievers/test_base.py @@ -142,8 +142,7 @@ async def test_fake_retriever_v1_with_kwargs_upgrade_async( class FakeRetrieverV2(BaseRetriever): - def __init__(self, throw_error: bool = False) -> None: - self.throw_error = throw_error + throw_error: bool = False def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun | None