Add serialized object to retriever start callback (#7074)

<!-- Thank you for contributing to LangChain!

Replace this comment with:
  - Description: a description of the change, 
  - Issue: the issue # it fixes (if applicable),
  - Dependencies: any dependencies required for this change,
- Tag maintainer: for a quicker response, tag the relevant maintainer
(see below),
- Twitter handle: we announce bigger features on Twitter. If your PR
gets announced and you'd like a mention, we'll gladly shout you out!

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
  2. an example notebook showing its use.

Maintainer responsibilities:
  - General / Misc / if you don't know who to tag: @dev2049
  - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev
  - Models / Prompts: @hwchase17, @dev2049
  - Memory: @hwchase17
  - Agents / Tools / Toolkits: @vowelparrot
  - Tracing / Callbacks: @agola11
  - Async: @agola11

If no one reviews your PR within a few days, feel free to @-mention the
same people again.

See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
 -->
pull/7213/head
Nuno Campos 1 year ago committed by GitHub
parent baf48d3583
commit 81e5b1ad36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -168,10 +168,12 @@ class CallbackManagerMixin:
def on_retriever_start(
self,
serialized: Dict[str, Any],
query: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> Any:
"""Run when Retriever starts running."""
@ -421,6 +423,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
async def on_retriever_start(
self,
serialized: Dict[str, Any],
query: str,
*,
run_id: UUID,

@ -1196,6 +1196,7 @@ class CallbackManager(BaseCallbackManager):
def on_retriever_start(
self,
serialized: Dict[str, Any],
query: str,
run_id: Optional[UUID] = None,
parent_run_id: Optional[UUID] = None,
@ -1209,6 +1210,7 @@ class CallbackManager(BaseCallbackManager):
self.handlers,
"on_retriever_start",
"ignore_retriever",
serialized,
query,
run_id=run_id,
parent_run_id=self.parent_run_id,
@ -1463,6 +1465,7 @@ class AsyncCallbackManager(BaseCallbackManager):
async def on_retriever_start(
self,
serialized: Dict[str, Any],
query: str,
run_id: Optional[UUID] = None,
parent_run_id: Optional[UUID] = None,
@ -1476,6 +1479,7 @@ class AsyncCallbackManager(BaseCallbackManager):
self.handlers,
"on_retriever_start",
"ignore_retriever",
serialized,
query,
run_id=run_id,
parent_run_id=self.parent_run_id,

@ -312,6 +312,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
def on_retriever_start(
self,
serialized: Dict[str, Any],
query: str,
*,
run_id: UUID,
@ -326,6 +327,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
id=run_id,
name="Retriever",
parent_run_id=parent_run_id,
serialized=serialized,
inputs={"query": query},
extra=kwargs,
events=[{"name": "start", "time": start_time}],

@ -7,7 +7,7 @@ from typing import Dict, List, Optional
import aiohttp
import requests
from pydantic import BaseModel, Extra, root_validator
from pydantic import Extra, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
@ -17,7 +17,7 @@ from langchain.schema import BaseRetriever, Document
from langchain.utils import get_from_dict_or_env
class AzureCognitiveSearchRetriever(BaseRetriever, BaseModel):
class AzureCognitiveSearchRetriever(BaseRetriever):
"""Wrapper around Azure Cognitive Search."""
service_name: str = ""

@ -4,7 +4,6 @@ from typing import List, Optional
import aiohttp
import requests
from pydantic import BaseModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
@ -13,7 +12,7 @@ from langchain.callbacks.manager import (
from langchain.schema import BaseRetriever, Document
class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
class ChatGPTPluginRetriever(BaseRetriever):
url: str
bearer_token: str
top_k: int = 3

@ -2,8 +2,6 @@
from typing import Any, List
from pydantic import BaseModel, Extra
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
@ -14,7 +12,7 @@ from langchain.retrievers.document_compressors.base import (
from langchain.schema import BaseRetriever, Document
class ContextualCompressionRetriever(BaseRetriever, BaseModel):
class ContextualCompressionRetriever(BaseRetriever):
"""Retriever that wraps a base retriever and compresses the results."""
base_compressor: BaseDocumentCompressor
@ -26,7 +24,6 @@ class ContextualCompressionRetriever(BaseRetriever, BaseModel):
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
def _get_relevant_documents(

@ -17,16 +17,6 @@ class DataberryRetriever(BaseRetriever):
top_k: Optional[int]
api_key: Optional[str]
def __init__(
self,
datastore_url: str,
top_k: Optional[int] = None,
api_key: Optional[str] = None,
):
self.datastore_url = datastore_url
self.api_key = api_key
self.top_k = top_k
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:

@ -2,7 +2,6 @@ from enum import Enum
from typing import Any, Dict, List, Optional, Union
import numpy as np
from pydantic import BaseModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
@ -20,7 +19,7 @@ class SearchType(str, Enum):
mmr = "mmr"
class DocArrayRetriever(BaseRetriever, BaseModel):
class DocArrayRetriever(BaseRetriever):
"""
Retriever class for DocArray Document Indices.

@ -40,9 +40,8 @@ class ElasticSearchBM25Retriever(BaseRetriever):
https://username:password@cluster_id.region_id.gcp.cloud.es.io:9243.
"""
def __init__(self, client: Any, index_name: str):
self.client = client
self.index_name = index_name
client: Any
index_name: str
@classmethod
def create(
@ -75,7 +74,7 @@ class ElasticSearchBM25Retriever(BaseRetriever):
# Create the index with the specified settings and mappings
es.indices.create(index=index_name, mappings=mappings, settings=settings)
return cls(es, index_name)
return cls(client=es, index_name=index_name)
def add_texts(
self,

@ -1,7 +1,7 @@
import re
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Extra
from pydantic import BaseModel, Extra, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
@ -179,37 +179,34 @@ class AmazonKendraRetriever(BaseRetriever):
"""
def __init__(
self,
index_id: str,
region_name: Optional[str] = None,
credentials_profile_name: Optional[str] = None,
top_k: int = 3,
attribute_filter: Optional[Dict] = None,
client: Optional[Any] = None,
):
self.index_id = index_id
self.top_k = top_k
self.attribute_filter = attribute_filter
index_id: str
region_name: Optional[str] = None
credentials_profile_name: Optional[str] = None
top_k: int = 3
attribute_filter: Optional[Dict] = None
client: Any
if client is not None:
self.client = client
return
@root_validator(pre=True)
def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if values["client"] is not None:
return values
try:
import boto3
if credentials_profile_name is not None:
session = boto3.Session(profile_name=credentials_profile_name)
if values["credentials_profile_name"] is not None:
session = boto3.Session(profile_name=values["credentials_profile_name"])
else:
# use default credentials
session = boto3.Session()
client_params = {}
if region_name is not None:
client_params["region_name"] = region_name
if values["region_name"] is not None:
client_params["region_name"] = values["region_name"]
values["client"] = session.client("kendra", **client_params)
self.client = session.client("kendra", **client_params)
return values
except ImportError:
raise ModuleNotFoundError(
"Could not import boto3 python package. "

@ -8,7 +8,6 @@ import concurrent.futures
from typing import Any, List, Optional
import numpy as np
from pydantic import BaseModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
@ -33,7 +32,7 @@ def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
return np.array(list(executor.map(embeddings.embed_query, contexts)))
class KNNRetriever(BaseRetriever, BaseModel):
class KNNRetriever(BaseRetriever):
"""KNN Retriever."""
embeddings: Embeddings

@ -1,6 +1,6 @@
from typing import Any, Dict, List, cast
from pydantic import BaseModel, Field
from pydantic import Field
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
@ -9,7 +9,7 @@ from langchain.callbacks.manager import (
from langchain.schema import BaseRetriever, Document
class LlamaIndexRetriever(BaseRetriever, BaseModel):
class LlamaIndexRetriever(BaseRetriever):
"""Question-answering with sources over an LlamaIndex data structure."""
index: Any
@ -45,7 +45,7 @@ class LlamaIndexRetriever(BaseRetriever, BaseModel):
raise NotImplementedError("LlamaIndexRetriever does not support async")
class LlamaIndexGraphRetriever(BaseRetriever, BaseModel):
class LlamaIndexGraphRetriever(BaseRetriever):
"""Question-answering with sources over an LlamaIndex graph data structure."""
graph: Any

@ -15,18 +15,7 @@ class MergerRetriever(BaseRetriever):
retrievers: A list of retrievers to merge.
"""
def __init__(
self,
retrievers: List[BaseRetriever],
):
"""
Initialize the MergerRetriever class.
Args:
retrievers: A list of retrievers to merge.
"""
self.retrievers = retrievers
retrievers: List[BaseRetriever]
def _get_relevant_documents(
self,

@ -1,5 +1,7 @@
from typing import Any, List, Optional
from pydantic import root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
@ -10,16 +12,26 @@ from langchain.schema import BaseRetriever, Document
class MetalRetriever(BaseRetriever):
"""Retriever that uses the Metal API."""
def __init__(self, client: Any, params: Optional[dict] = None):
client: Any
params: Optional[dict] = None
@root_validator(pre=True)
def validate_client(cls, values: dict) -> dict:
"""Validate that the client is of the correct type."""
from metal_sdk.metal import Metal
if not isinstance(client, Metal):
raise ValueError(
"Got unexpected client, should be of type metal_sdk.metal.Metal. "
f"Instead, got {type(client)}"
)
self.client: Metal = client
self.params = params or {}
if "client" in values:
client = values["client"]
if not isinstance(client, Metal):
raise ValueError(
"Got unexpected client, should be of type metal_sdk.metal.Metal. "
f"Instead, got {type(client)}"
)
values["params"] = values.get("params", {})
return values
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun

@ -2,6 +2,8 @@
import warnings
from typing import Any, Dict, List, Optional
from pydantic import root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
@ -16,21 +18,28 @@ from langchain.vectorstores.milvus import Milvus
class MilvusRetriever(BaseRetriever):
"""Retriever that uses the Milvus API."""
def __init__(
self,
embedding_function: Embeddings,
collection_name: str = "LangChainCollection",
connection_args: Optional[Dict[str, Any]] = None,
consistency_level: str = "Session",
search_params: Optional[dict] = None,
):
self.store = Milvus(
embedding_function,
collection_name,
connection_args,
consistency_level,
embedding_function: Embeddings
collection_name: str = "LangChainCollection"
connection_args: Optional[Dict[str, Any]] = None
consistency_level: str = "Session"
search_params: Optional[dict] = None
store: Milvus
retriever: BaseRetriever
@root_validator(pre=True)
def create_retriever(cls, values: Dict) -> Dict:
"""Create the Milvus store and retriever."""
values["store"] = Milvus(
values["embedding_function"],
values["collection_name"],
values["connection_args"],
values["consistency_level"],
)
values["retriever"] = values["store"].as_retriever(
search_kwargs={"param": values["search_params"]}
)
self.retriever = self.store.as_retriever(search_kwargs={"param": search_params})
return values
def add_texts(
self, texts: List[str], metadatas: Optional[List[dict]] = None

@ -47,28 +47,10 @@ class MultiQueryRetriever(BaseRetriever):
"""Given a user query, use an LLM to write a set of queries.
Retrieve docs for each query. Rake the unique union of all retrieved docs."""
def __init__(
self,
retriever: BaseRetriever,
llm_chain: LLMChain,
verbose: bool = True,
parser_key: str = "lines",
) -> None:
"""Initialize MultiQueryRetriever.
Args:
retriever: retriever to query documents from
llm_chain: llm_chain for query generation
verbose: show the queries that we generated to the user
parser_key: attribute name for the parsed output
Returns:
MultiQueryRetriever
"""
self.retriever = retriever
self.llm_chain = llm_chain
self.verbose = verbose
self.parser_key = parser_key
retriever: BaseRetriever
llm_chain: LLMChain
verbose: bool = True
parser_key: str = "lines"
@classmethod
def from_llm(

@ -3,7 +3,7 @@
import hashlib
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator
from pydantic import Extra, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
@ -98,7 +98,7 @@ def create_index(
index.upsert(vectors)
class PineconeHybridSearchRetriever(BaseRetriever, BaseModel):
class PineconeHybridSearchRetriever(BaseRetriever):
embeddings: Embeddings
"""description"""
sparse_encoder: Any

@ -2,7 +2,6 @@ from typing import List, Optional
import aiohttp
import requests
from pydantic import BaseModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
@ -11,7 +10,7 @@ from langchain.callbacks.manager import (
from langchain.schema import BaseRetriever, Document
class RemoteLangChainRetriever(BaseRetriever, BaseModel):
class RemoteLangChainRetriever(BaseRetriever):
url: str
headers: Optional[dict] = None
input_key: str = "message"

@ -8,7 +8,6 @@ import concurrent.futures
from typing import Any, List, Optional
import numpy as np
from pydantic import BaseModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
@ -32,7 +31,7 @@ def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
return np.array(list(executor.map(embeddings.embed_query, contexts)))
class SVMRetriever(BaseRetriever, BaseModel):
class SVMRetriever(BaseRetriever):
"""SVM Retriever."""
embeddings: Embeddings

@ -7,8 +7,6 @@ from __future__ import annotations
from typing import Any, Dict, Iterable, List, Optional
from pydantic import BaseModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
@ -16,7 +14,7 @@ from langchain.callbacks.manager import (
from langchain.schema import BaseRetriever, Document
class TFIDFRetriever(BaseRetriever, BaseModel):
class TFIDFRetriever(BaseRetriever):
vectorizer: Any
docs: List[Document]
tfidf_array: Any

@ -4,7 +4,7 @@ import datetime
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple
from pydantic import BaseModel, Field
from pydantic import Field
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
@ -19,7 +19,7 @@ def _get_hours_passed(time: datetime.datetime, ref_time: datetime.datetime) -> f
return (time - ref_time).total_seconds() / 3600
class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
class TimeWeightedVectorStoreRetriever(BaseRetriever):
"""Retriever combining embedding similarity with recency."""
vectorstore: VectorStore

@ -18,29 +18,13 @@ if TYPE_CHECKING:
class VespaRetriever(BaseRetriever):
"""Retriever that uses the Vespa."""
def __init__(
self,
app: Vespa,
body: Dict,
content_field: str,
metadata_fields: Optional[Sequence[str]] = None,
):
"""
Args:
app: Vespa client.
body: query body.
content_field: result field with document contents.
metadata_fields: result fields to include in document metadata.
"""
self._application = app
self._query_body = body
self._content_field = content_field
self._metadata_fields = metadata_fields or ()
app: Vespa
body: Dict
content_field: str
metadata_fields: Sequence[str]
def _query(self, body: Dict) -> List[Document]:
response = self._application.query(body)
response = self.app.query(body)
if not str(response.status_code).startswith("2"):
raise RuntimeError(
@ -55,11 +39,11 @@ class VespaRetriever(BaseRetriever):
docs = []
for child in response.hits:
page_content = child["fields"].pop(self._content_field, "")
if self._metadata_fields == "*":
page_content = child["fields"].pop(self.content_field, "")
if self.metadata_fields == "*":
metadata = child["fields"]
else:
metadata = {mf: child["fields"].get(mf) for mf in self._metadata_fields}
metadata = {mf: child["fields"].get(mf) for mf in self.metadata_fields}
metadata["id"] = child["id"]
docs.append(Document(page_content=page_content, metadata=metadata))
return docs
@ -67,7 +51,7 @@ class VespaRetriever(BaseRetriever):
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
body = self._query_body.copy()
body = self.body.copy()
body["query"] = query
return self._query(body)
@ -79,7 +63,7 @@ class VespaRetriever(BaseRetriever):
def get_relevant_documents_with_filter(
self, query: str, *, _filter: Optional[str] = None
) -> List[Document]:
body = self._query_body.copy()
body = self.body.copy()
_filter = f" and {_filter}" if _filter else ""
body["yql"] = body["yql"] + _filter
body["query"] = query
@ -139,4 +123,9 @@ class VespaRetriever(BaseRetriever):
body["yql"] = yql
if k:
body["hits"] = k
return cls(app, body, content_field, metadata_fields=metadata_fields)
return cls(
app=app,
body=body,
content_field=content_field,
metadata_fields=metadata_fields,
)

@ -2,10 +2,10 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, cast
from uuid import uuid4
from pydantic import Extra
from pydantic import root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
@ -16,16 +16,19 @@ from langchain.schema import BaseRetriever
class WeaviateHybridSearchRetriever(BaseRetriever):
def __init__(
self,
client: Any,
index_name: str,
text_key: str,
alpha: float = 0.5,
k: int = 4,
attributes: Optional[List[str]] = None,
create_schema_if_missing: bool = True,
):
client: Any
index_name: str
text_key: str
alpha: float = 0.5
k: int = 4
attributes: List[str]
create_schema_if_missing: bool = True
@root_validator(pre=True)
def validate_client(
cls,
values: Dict[str, Any],
) -> Dict[str, Any]:
try:
import weaviate
except ImportError:
@ -33,36 +36,31 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
"Could not import weaviate python package. "
"Please install it with `pip install weaviate-client`."
)
if not isinstance(client, weaviate.Client):
if not isinstance(values["client"], weaviate.Client):
client = values["client"]
raise ValueError(
f"client should be an instance of weaviate.Client, got {type(client)}"
)
self._client = client
self.k = k
self.alpha = alpha
self._index_name = index_name
self._text_key = text_key
self._query_attrs = [self._text_key]
if attributes is not None:
self._query_attrs.extend(attributes)
if create_schema_if_missing:
self._create_schema_if_missing()
def _create_schema_if_missing(self) -> None:
class_obj = {
"class": self._index_name,
"properties": [{"name": self._text_key, "dataType": ["text"]}],
"vectorizer": "text2vec-openai",
}
if not self._client.schema.exists(self._index_name):
self._client.schema.create_class(class_obj)
if values["attributes"] is None:
values["attributes"] = []
cast(List, values["attributes"]).append(values["text_key"])
if values["create_schema_if_missing"]:
class_obj = {
"class": values["index_name"],
"properties": [{"name": values["text_key"], "dataType": ["text"]}],
"vectorizer": "text2vec-openai",
}
if not values["client"].schema.exists(values["index_name"]):
values["client"].schema.create_class(class_obj)
return values
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
# added text_key
@ -70,11 +68,11 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
"""Upload documents to Weaviate."""
from weaviate.util import get_valid_uuid
with self._client.batch as batch:
with self.client.batch as batch:
ids = []
for i, doc in enumerate(docs):
metadata = doc.metadata or {}
data_properties = {self._text_key: doc.page_content, **metadata}
data_properties = {self.text_key: doc.page_content, **metadata}
# If the UUID of one of the objects already exists
# then the existing objectwill be replaced by the new object.
@ -83,7 +81,7 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
else:
_id = get_valid_uuid(uuid4())
batch.add_data_object(data_properties, self._index_name, _id)
batch.add_data_object(data_properties, self.index_name, _id)
ids.append(_id)
return ids
@ -95,7 +93,7 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
where_filter: Optional[Dict[str, object]] = None,
) -> List[Document]:
"""Look up similar documents in Weaviate."""
query_obj = self._client.query.get(self._index_name, self._query_attrs)
query_obj = self.client.query.get(self.index_name, self.attributes)
if where_filter:
query_obj = query_obj.with_where(where_filter)
@ -105,8 +103,8 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
docs = []
for res in result["data"]["Get"][self._index_name]:
text = res.pop(self._text_key)
for res in result["data"]["Get"][self.index_name]:
text = res.pop(self.text_key)
docs.append(Document(page_content=text, metadata=res))
return docs

@ -1,6 +1,8 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from pydantic import root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
@ -27,13 +29,14 @@ class ZepRetriever(BaseRetriever):
https://docs.getzep.com/deployment/quickstart/
"""
def __init__(
self,
session_id: str,
url: str,
api_key: Optional[str] = None,
top_k: Optional[int] = None,
):
zep_client: Any
session_id: str
top_k: Optional[int]
@root_validator(pre=True)
def create_client(cls, values: dict) -> dict:
try:
from zep_python import ZepClient
except ImportError:
@ -41,10 +44,11 @@ class ZepRetriever(BaseRetriever):
"Could not import zep-python package. "
"Please install it with `pip install zep-python`."
)
self.zep_client = ZepClient(base_url=url, api_key=api_key)
self.session_id = session_id
self.top_k = top_k
values["zep_client"] = values.get(
"zep_client",
ZepClient(base_url=values["url"], api_key=values.get("api_key")),
)
return values
def _search_result_to_doc(
self, results: List[MemorySearchResult]

@ -2,6 +2,8 @@
import warnings
from typing import Any, Dict, List, Optional
from pydantic import root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
@ -16,21 +18,27 @@ from langchain.vectorstores.zilliz import Zilliz
class ZillizRetriever(BaseRetriever):
"""Retriever that uses the Zilliz API."""
def __init__(
self,
embedding_function: Embeddings,
collection_name: str = "LangChainCollection",
connection_args: Optional[Dict[str, Any]] = None,
consistency_level: str = "Session",
search_params: Optional[dict] = None,
):
self.store = Zilliz(
embedding_function,
collection_name,
connection_args,
consistency_level,
embedding_function: Embeddings
collection_name: str = "LangChainCollection"
connection_args: Optional[Dict[str, Any]] = None
consistency_level: str = "Session"
search_params: Optional[dict] = None
store: Zilliz
retriever: BaseRetriever
@root_validator(pre=True)
def create_client(cls, values: dict) -> dict:
values["store"] = Zilliz(
values["embedding_function"],
values["collection_name"],
values["connection_args"],
values["consistency_level"],
)
values["retriever"] = values["store"].as_retriever(
search_kwargs={"param": values["search_params"]}
)
self.retriever = self.store.as_retriever(search_kwargs={"param": search_params})
return values
def add_texts(
self, texts: List[str], metadatas: Optional[List[dict]] = None

@ -5,6 +5,8 @@ from abc import ABC, abstractmethod
from inspect import signature
from typing import TYPE_CHECKING, Any, List
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
from langchain.schema.document import Document
if TYPE_CHECKING:
@ -15,7 +17,7 @@ if TYPE_CHECKING:
)
class BaseRetriever(ABC):
class BaseRetriever(Serializable, ABC):
"""Abstract base class for a Document retrieval system.
A retrieval system is defined as something that can take string queries and return
@ -46,6 +48,11 @@ class BaseRetriever(ABC):
raise NotImplementedError
""" # noqa: E501
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
_new_arg_supported: bool = False
_expects_other_args: bool = False
@ -81,7 +88,9 @@ class BaseRetriever(ABC):
parameters = signature(cls._get_relevant_documents).parameters
cls._new_arg_supported = parameters.get("run_manager") is not None
# If a V1 retriever broke the interface and expects additional arguments
cls._expects_other_args = (not cls._new_arg_supported) and len(parameters) > 2
cls._expects_other_args = (
len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0
)
@abstractmethod
def _get_relevant_documents(
@ -123,6 +132,7 @@ class BaseRetriever(ABC):
callbacks, None, verbose=kwargs.get("verbose", False)
)
run_manager = callback_manager.on_retriever_start(
dumpd(self),
query,
**kwargs,
)
@ -160,6 +170,7 @@ class BaseRetriever(ABC):
callbacks, None, verbose=kwargs.get("verbose", False)
)
run_manager = await callback_manager.on_retriever_start(
dumpd(self),
query,
**kwargs,
)

@ -3,7 +3,7 @@ import logging
import os
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator
from pydantic import BaseModel, root_validator
from langchain.schema import Document
@ -40,11 +40,6 @@ class ArxivAPIWrapper(BaseModel):
load_all_available_meta: bool = False
doc_content_chars_max: Optional[int] = 4000
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment."""

@ -5,7 +5,7 @@ import urllib.error
import urllib.request
from typing import List
from pydantic import BaseModel, Extra
from pydantic import BaseModel
from langchain.schema import Document
@ -42,11 +42,6 @@ class PubMedAPIWrapper(BaseModel):
load_all_available_meta: bool = False
email: str = "your_email@example.com"
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def run(self, query: str) -> str:
"""
Run PubMed search and get the article meta information.

@ -2,7 +2,7 @@
import logging
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator
from pydantic import BaseModel, root_validator
from langchain.schema import Document
@ -27,11 +27,6 @@ class WikipediaAPIWrapper(BaseModel):
load_all_available_meta: bool = False
doc_content_chars_max: int = 4000
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment."""

@ -24,7 +24,7 @@ def test_merger_retriever_get_relevant_docs() -> None:
)
# The Lord of the Retrievers.
lotr = MergerRetriever([retriever_a, retriever_b])
lotr = MergerRetriever(retrievers=[retriever_a, retriever_b])
actual = lotr.get_relevant_documents("Tell me about the Celtics")
assert len(actual) == 2

@ -146,7 +146,7 @@ def test_ignore_retriever() -> None:
handler1 = FakeCallbackHandler(ignore_retriever_=True)
handler2 = FakeCallbackHandler()
manager = CallbackManager(handlers=[handler1, handler2])
run_manager = manager.on_retriever_start("")
run_manager = manager.on_retriever_start({}, "")
run_manager.on_retriever_end([])
run_manager.on_retriever_error(Exception())
assert handler1.starts == 0

@ -142,8 +142,7 @@ async def test_fake_retriever_v1_with_kwargs_upgrade_async(
class FakeRetrieverV2(BaseRetriever):
def __init__(self, throw_error: bool = False) -> None:
self.throw_error = throw_error
throw_error: bool = False
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun | None

Loading…
Cancel
Save