mirror of https://github.com/hwchase17/langchain
langchain[minor]: Add PebbloRetrievalQA chain with Identity & Semantic Enforcement support (#20641)
- **Description:** PebbloRetrievalQA chain introduces identity enforcement using vector-db metadata filtering - **Dependencies:** None - **Issue:** None - **Documentation:** Adding documentation for PebbloRetrievalQA chain in a separate PR(https://github.com/langchain-ai/langchain/pull/20746) - **Unit tests:** New unit-tests added --------- Co-authored-by: Eugene Yurtsev <eugene@langchain.dev>pull/21709/head
parent
f2f970f93d
commit
54e003268e
@ -0,0 +1,24 @@
|
||||
"""
|
||||
Chains module for langchain_community
|
||||
|
||||
This module contains the community chains.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.chains.pebblo_retrieval.base import PebbloRetrievalQA
|
||||
|
||||
__all__ = ["PebbloRetrievalQA"]
|
||||
|
||||
_module_lookup = {
|
||||
"PebbloRetrievalQA": "langchain_community.chains.pebblo_retrieval.base"
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in _module_lookup:
|
||||
module = importlib.import_module(_module_lookup[name])
|
||||
return getattr(module, name)
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
@ -0,0 +1,218 @@
|
||||
"""
|
||||
Pebblo Retrieval Chain with Identity & Semantic Enforcement for question-answering
|
||||
against a vector database.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.pydantic_v1 import Extra, Field, validator
|
||||
from langchain_core.vectorstores import VectorStoreRetriever
|
||||
|
||||
from langchain_community.chains.pebblo_retrieval.enforcement_filters import (
|
||||
SUPPORTED_VECTORSTORES,
|
||||
set_enforcement_filters,
|
||||
)
|
||||
from langchain_community.chains.pebblo_retrieval.models import (
|
||||
AuthContext,
|
||||
SemanticContext,
|
||||
)
|
||||
|
||||
|
||||
class PebbloRetrievalQA(Chain):
|
||||
"""
|
||||
Retrieval Chain with Identity & Semantic Enforcement for question-answering
|
||||
against a vector database.
|
||||
"""
|
||||
|
||||
combine_documents_chain: BaseCombineDocumentsChain
|
||||
"""Chain to use to combine the documents."""
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
return_source_documents: bool = False
|
||||
"""Return the source documents or not."""
|
||||
|
||||
retriever: VectorStoreRetriever = Field(exclude=True)
|
||||
"""VectorStore to use for retrieval."""
|
||||
auth_context_key: str = "auth_context" #: :meta private:
|
||||
"""Authentication context for identity enforcement."""
|
||||
semantic_context_key: str = "semantic_context" #: :meta private:
|
||||
"""Semantic context for semantic enforcement."""
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run get_relevant_text and llm on input query.
|
||||
|
||||
If chain has 'return_source_documents' as 'True', returns
|
||||
the retrieved documents as well under the key 'source_documents'.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
res = indexqa({'query': 'This is my query'})
|
||||
answer, docs = res['result'], res['source_documents']
|
||||
"""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs[self.input_key]
|
||||
auth_context = inputs.get(self.auth_context_key)
|
||||
semantic_context = inputs.get(self.semantic_context_key)
|
||||
accepts_run_manager = (
|
||||
"run_manager" in inspect.signature(self._get_docs).parameters
|
||||
)
|
||||
if accepts_run_manager:
|
||||
docs = self._get_docs(
|
||||
question, auth_context, semantic_context, run_manager=_run_manager
|
||||
)
|
||||
else:
|
||||
docs = self._get_docs(question, auth_context, semantic_context) # type: ignore[call-arg]
|
||||
answer = self.combine_documents_chain.run(
|
||||
input_documents=docs, question=question, callbacks=_run_manager.get_child()
|
||||
)
|
||||
|
||||
if self.return_source_documents:
|
||||
return {self.output_key: answer, "source_documents": docs}
|
||||
else:
|
||||
return {self.output_key: answer}
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run get_relevant_text and llm on input query.
|
||||
|
||||
If chain has 'return_source_documents' as 'True', returns
|
||||
the retrieved documents as well under the key 'source_documents'.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
res = indexqa({'query': 'This is my query'})
|
||||
answer, docs = res['result'], res['source_documents']
|
||||
"""
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs[self.input_key]
|
||||
auth_context = inputs.get(self.auth_context_key)
|
||||
semantic_context = inputs.get(self.semantic_context_key)
|
||||
accepts_run_manager = (
|
||||
"run_manager" in inspect.signature(self._aget_docs).parameters
|
||||
)
|
||||
if accepts_run_manager:
|
||||
docs = await self._aget_docs(
|
||||
question, auth_context, semantic_context, run_manager=_run_manager
|
||||
)
|
||||
else:
|
||||
docs = await self._aget_docs(question, auth_context, semantic_context) # type: ignore[call-arg]
|
||||
answer = await self.combine_documents_chain.arun(
|
||||
input_documents=docs, question=question, callbacks=_run_manager.get_child()
|
||||
)
|
||||
|
||||
if self.return_source_documents:
|
||||
return {self.output_key: answer, "source_documents": docs}
|
||||
else:
|
||||
return {self.output_key: answer}
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key, self.auth_context_key, self.semantic_context_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
_output_keys = [self.output_key]
|
||||
if self.return_source_documents:
|
||||
_output_keys += ["source_documents"]
|
||||
return _output_keys
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
"""Return the chain type."""
|
||||
return "pebblo_retrieval_qa"
|
||||
|
||||
@classmethod
|
||||
def from_chain_type(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
chain_type: str = "stuff",
|
||||
chain_type_kwargs: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> "PebbloRetrievalQA":
|
||||
"""Load chain from chain type."""
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
|
||||
_chain_type_kwargs = chain_type_kwargs or {}
|
||||
combine_documents_chain = load_qa_chain(
|
||||
llm, chain_type=chain_type, **_chain_type_kwargs
|
||||
)
|
||||
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
|
||||
|
||||
@validator("retriever", pre=True, always=True)
|
||||
def validate_vectorstore(
|
||||
cls, retriever: VectorStoreRetriever
|
||||
) -> VectorStoreRetriever:
|
||||
"""
|
||||
Validate that the vectorstore of the retriever is supported vectorstores.
|
||||
"""
|
||||
if not any(
|
||||
isinstance(retriever.vectorstore, supported_class)
|
||||
for supported_class in SUPPORTED_VECTORSTORES
|
||||
):
|
||||
raise ValueError(
|
||||
f"Vectorstore must be an instance of one of the supported "
|
||||
f"vectorstores: {SUPPORTED_VECTORSTORES}. "
|
||||
f"Got {type(retriever.vectorstore).__name__} instead."
|
||||
)
|
||||
return retriever
|
||||
|
||||
def _get_docs(
|
||||
self,
|
||||
question: str,
|
||||
auth_context: Optional[AuthContext],
|
||||
semantic_context: Optional[SemanticContext],
|
||||
*,
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
"""Get docs."""
|
||||
set_enforcement_filters(self.retriever, auth_context, semantic_context)
|
||||
return self.retriever.get_relevant_documents(
|
||||
question, callbacks=run_manager.get_child()
|
||||
)
|
||||
|
||||
async def _aget_docs(
|
||||
self,
|
||||
question: str,
|
||||
auth_context: Optional[AuthContext],
|
||||
semantic_context: Optional[SemanticContext],
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
"""Get docs."""
|
||||
set_enforcement_filters(self.retriever, auth_context, semantic_context)
|
||||
return await self.retriever.aget_relevant_documents(
|
||||
question, callbacks=run_manager.get_child()
|
||||
)
|
@ -0,0 +1,265 @@
|
||||
"""
|
||||
Identity & Semantic Enforcement filters for PebbloRetrievalQA chain:
|
||||
|
||||
This module contains methods for applying Identity and Semantic Enforcement filters
|
||||
in the PebbloRetrievalQA chain.
|
||||
These filters are used to control the retrieval of documents based on authorization and
|
||||
semantic context.
|
||||
The Identity Enforcement filter ensures that only authorized identities can access
|
||||
certain documents, while the Semantic Enforcement filter controls document retrieval
|
||||
based on semantic context.
|
||||
|
||||
The methods in this module are designed to work with different types of vector stores.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from langchain_core.vectorstores import VectorStoreRetriever
|
||||
|
||||
from langchain_community.chains.pebblo_retrieval.models import (
|
||||
AuthContext,
|
||||
SemanticContext,
|
||||
)
|
||||
from langchain_community.vectorstores import Pinecone, Qdrant
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SUPPORTED_VECTORSTORES = [Pinecone, Qdrant]
|
||||
|
||||
|
||||
def set_enforcement_filters(
|
||||
retriever: VectorStoreRetriever,
|
||||
auth_context: Optional[AuthContext],
|
||||
semantic_context: Optional[SemanticContext],
|
||||
) -> None:
|
||||
"""
|
||||
Set identity and semantic enforcement filters in the retriever.
|
||||
"""
|
||||
if auth_context is not None:
|
||||
_set_identity_enforcement_filter(retriever, auth_context)
|
||||
if semantic_context is not None:
|
||||
_set_semantic_enforcement_filter(retriever, semantic_context)
|
||||
|
||||
|
||||
def _apply_qdrant_semantic_filter(
|
||||
search_kwargs: dict, semantic_context: Optional[SemanticContext]
|
||||
) -> None:
|
||||
"""
|
||||
Set semantic enforcement filter in search_kwargs for Qdrant vectorstore.
|
||||
"""
|
||||
try:
|
||||
from qdrant_client.http import models as rest
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
"Could not import `qdrant-client.http` python package. "
|
||||
"Please install it with `pip install qdrant-client`."
|
||||
) from e
|
||||
|
||||
# Create a semantic enforcement filter condition
|
||||
semantic_filters: List[
|
||||
Union[
|
||||
rest.FieldCondition,
|
||||
rest.IsEmptyCondition,
|
||||
rest.IsNullCondition,
|
||||
rest.HasIdCondition,
|
||||
rest.NestedCondition,
|
||||
rest.Filter,
|
||||
]
|
||||
] = []
|
||||
|
||||
if (
|
||||
semantic_context is not None
|
||||
and semantic_context.pebblo_semantic_topics is not None
|
||||
):
|
||||
semantic_topics_filter = rest.FieldCondition(
|
||||
key="metadata.pebblo_semantic_topics",
|
||||
match=rest.MatchAny(any=semantic_context.pebblo_semantic_topics.deny),
|
||||
)
|
||||
semantic_filters.append(semantic_topics_filter)
|
||||
if (
|
||||
semantic_context is not None
|
||||
and semantic_context.pebblo_semantic_entities is not None
|
||||
):
|
||||
semantic_entities_filter = rest.FieldCondition(
|
||||
key="metadata.pebblo_semantic_entities",
|
||||
match=rest.MatchAny(any=semantic_context.pebblo_semantic_entities.deny),
|
||||
)
|
||||
semantic_filters.append(semantic_entities_filter)
|
||||
|
||||
# If 'filter' already exists in search_kwargs
|
||||
if "filter" in search_kwargs:
|
||||
existing_filter: rest.Filter = search_kwargs["filter"]
|
||||
|
||||
# Check if existing_filter is a qdrant-client filter
|
||||
if isinstance(existing_filter, rest.Filter):
|
||||
# If 'must_not' condition exists in the existing filter
|
||||
if isinstance(existing_filter.must_not, list):
|
||||
# Warn if 'pebblo_semantic_topics' or 'pebblo_semantic_entities'
|
||||
# filter is overridden
|
||||
new_must_not_conditions: List[
|
||||
Union[
|
||||
rest.FieldCondition,
|
||||
rest.IsEmptyCondition,
|
||||
rest.IsNullCondition,
|
||||
rest.HasIdCondition,
|
||||
rest.NestedCondition,
|
||||
rest.Filter,
|
||||
]
|
||||
] = []
|
||||
# Drop semantic filter conditions if already present
|
||||
for condition in existing_filter.must_not:
|
||||
if hasattr(condition, "key"):
|
||||
if condition.key == "metadata.pebblo_semantic_topics":
|
||||
continue
|
||||
if condition.key == "metadata.pebblo_semantic_entities":
|
||||
continue
|
||||
new_must_not_conditions.append(condition)
|
||||
# Add semantic enforcement filters to 'must_not' conditions
|
||||
existing_filter.must_not = new_must_not_conditions
|
||||
existing_filter.must_not.extend(semantic_filters)
|
||||
else:
|
||||
# Set 'must_not' condition with semantic enforcement filters
|
||||
existing_filter.must_not = semantic_filters
|
||||
else:
|
||||
raise TypeError(
|
||||
"Using dict as a `filter` is deprecated. "
|
||||
"Please use qdrant-client filters directly: "
|
||||
"https://qdrant.tech/documentation/concepts/filtering/"
|
||||
)
|
||||
else:
|
||||
# If 'filter' does not exist in search_kwargs, create it
|
||||
search_kwargs["filter"] = rest.Filter(must_not=semantic_filters)
|
||||
|
||||
|
||||
def _apply_qdrant_authorization_filter(
|
||||
search_kwargs: dict, auth_context: Optional[AuthContext]
|
||||
) -> None:
|
||||
"""
|
||||
Set identity enforcement filter in search_kwargs for Qdrant vectorstore.
|
||||
"""
|
||||
try:
|
||||
from qdrant_client.http import models as rest
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
"Could not import `qdrant-client.http` python package. "
|
||||
"Please install it with `pip install qdrant-client`."
|
||||
) from e
|
||||
|
||||
if auth_context is not None:
|
||||
# Create a identity enforcement filter condition
|
||||
identity_enforcement_filter = rest.FieldCondition(
|
||||
key="metadata.authorized_identities",
|
||||
match=rest.MatchAny(any=auth_context.user_auth),
|
||||
)
|
||||
else:
|
||||
return
|
||||
|
||||
# If 'filter' already exists in search_kwargs
|
||||
if "filter" in search_kwargs:
|
||||
existing_filter: rest.Filter = search_kwargs["filter"]
|
||||
|
||||
# Check if existing_filter is a qdrant-client filter
|
||||
if isinstance(existing_filter, rest.Filter):
|
||||
# If 'must' exists in the existing filter
|
||||
if existing_filter.must:
|
||||
new_must_conditions: List[
|
||||
Union[
|
||||
rest.FieldCondition,
|
||||
rest.IsEmptyCondition,
|
||||
rest.IsNullCondition,
|
||||
rest.HasIdCondition,
|
||||
rest.NestedCondition,
|
||||
rest.Filter,
|
||||
]
|
||||
] = []
|
||||
# Drop 'authorized_identities' filter condition if already present
|
||||
for condition in existing_filter.must:
|
||||
if (
|
||||
hasattr(condition, "key")
|
||||
and condition.key == "metadata.authorized_identities"
|
||||
):
|
||||
continue
|
||||
new_must_conditions.append(condition)
|
||||
|
||||
# Add identity enforcement filter to 'must' conditions
|
||||
existing_filter.must = new_must_conditions
|
||||
existing_filter.must.append(identity_enforcement_filter)
|
||||
else:
|
||||
# Set 'must' condition with identity enforcement filter
|
||||
existing_filter.must = [identity_enforcement_filter]
|
||||
else:
|
||||
raise TypeError(
|
||||
"Using dict as a `filter` is deprecated. "
|
||||
"Please use qdrant-client filters directly: "
|
||||
"https://qdrant.tech/documentation/concepts/filtering/"
|
||||
)
|
||||
else:
|
||||
# If 'filter' does not exist in search_kwargs, create it
|
||||
search_kwargs["filter"] = rest.Filter(must=[identity_enforcement_filter])
|
||||
|
||||
|
||||
def _apply_pinecone_semantic_filter(
|
||||
search_kwargs: dict, semantic_context: Optional[SemanticContext]
|
||||
) -> None:
|
||||
"""
|
||||
Set semantic enforcement filter in search_kwargs for Pinecone vectorstore.
|
||||
"""
|
||||
# Check if semantic_context is provided
|
||||
semantic_context = semantic_context
|
||||
if semantic_context is not None:
|
||||
if semantic_context.pebblo_semantic_topics is not None:
|
||||
# Add pebblo_semantic_topics filter to search_kwargs
|
||||
search_kwargs.setdefault("filter", {})["pebblo_semantic_topics"] = {
|
||||
"$nin": semantic_context.pebblo_semantic_topics.deny
|
||||
}
|
||||
|
||||
if semantic_context.pebblo_semantic_entities is not None:
|
||||
# Add pebblo_semantic_entities filter to search_kwargs
|
||||
search_kwargs.setdefault("filter", {})["pebblo_semantic_entities"] = {
|
||||
"$nin": semantic_context.pebblo_semantic_entities.deny
|
||||
}
|
||||
|
||||
|
||||
def _apply_pinecone_authorization_filter(
|
||||
search_kwargs: dict, auth_context: Optional[AuthContext]
|
||||
) -> None:
|
||||
"""
|
||||
Set identity enforcement filter in search_kwargs for Pinecone vectorstore.
|
||||
"""
|
||||
if auth_context is not None:
|
||||
search_kwargs.setdefault("filter", {})["authorized_identities"] = {
|
||||
"$in": auth_context.user_auth
|
||||
}
|
||||
|
||||
|
||||
def _set_identity_enforcement_filter(
|
||||
retriever: VectorStoreRetriever, auth_context: Optional[AuthContext]
|
||||
) -> None:
|
||||
"""
|
||||
Set identity enforcement filter in search_kwargs.
|
||||
|
||||
This method sets the identity enforcement filter in the search_kwargs
|
||||
of the retriever based on the type of the vectorstore.
|
||||
"""
|
||||
search_kwargs = retriever.search_kwargs
|
||||
if isinstance(retriever.vectorstore, Pinecone):
|
||||
_apply_pinecone_authorization_filter(search_kwargs, auth_context)
|
||||
elif isinstance(retriever.vectorstore, Qdrant):
|
||||
_apply_qdrant_authorization_filter(search_kwargs, auth_context)
|
||||
|
||||
|
||||
def _set_semantic_enforcement_filter(
|
||||
retriever: VectorStoreRetriever, semantic_context: Optional[SemanticContext]
|
||||
) -> None:
|
||||
"""
|
||||
Set semantic enforcement filter in search_kwargs.
|
||||
|
||||
This method sets the semantic enforcement filter in the search_kwargs
|
||||
of the retriever based on the type of the vectorstore.
|
||||
"""
|
||||
search_kwargs = retriever.search_kwargs
|
||||
if isinstance(retriever.vectorstore, Pinecone):
|
||||
_apply_pinecone_semantic_filter(search_kwargs, semantic_context)
|
||||
elif isinstance(retriever.vectorstore, Qdrant):
|
||||
_apply_qdrant_semantic_filter(search_kwargs, semantic_context)
|
@ -0,0 +1,62 @@
|
||||
"""Models for the PebbloRetrievalQA chain."""
|
||||
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
|
||||
class AuthContext(BaseModel):
|
||||
"""Class for an authorization context."""
|
||||
|
||||
name: Optional[str] = None
|
||||
user_id: str
|
||||
user_auth: List[str]
|
||||
"""List of user authorizations, which may include their User ID and
|
||||
the groups they are part of"""
|
||||
|
||||
|
||||
class SemanticEntities(BaseModel):
|
||||
"""Class for a semantic entity filter."""
|
||||
|
||||
deny: List[str]
|
||||
|
||||
|
||||
class SemanticTopics(BaseModel):
|
||||
"""Class for a semantic topic filter."""
|
||||
|
||||
deny: List[str]
|
||||
|
||||
|
||||
class SemanticContext(BaseModel):
|
||||
"""Class for a semantic context."""
|
||||
|
||||
pebblo_semantic_entities: Optional[SemanticEntities] = None
|
||||
pebblo_semantic_topics: Optional[SemanticTopics] = None
|
||||
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
|
||||
# Validate semantic_context
|
||||
if (
|
||||
self.pebblo_semantic_entities is None
|
||||
and self.pebblo_semantic_topics is None
|
||||
):
|
||||
raise ValueError(
|
||||
"semantic_context must contain 'pebblo_semantic_entities' or "
|
||||
"'pebblo_semantic_topics'"
|
||||
)
|
||||
|
||||
|
||||
class ChainInput(BaseModel):
|
||||
"""Input for PebbloRetrievalQA chain."""
|
||||
|
||||
query: str
|
||||
auth_context: Optional[AuthContext] = None
|
||||
semantic_context: Optional[SemanticContext] = None
|
||||
|
||||
def dict(self, **kwargs: Any) -> dict:
|
||||
base_dict = super().dict(**kwargs)
|
||||
# Keep auth_context and semantic_context as it is(Pydantic models)
|
||||
base_dict["auth_context"] = self.auth_context
|
||||
base_dict["semantic_context"] = self.semantic_context
|
||||
return base_dict
|
@ -0,0 +1,129 @@
|
||||
"""
|
||||
Unit tests for the PebbloRetrievalQA chain
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
|
||||
|
||||
from langchain_community.chains import PebbloRetrievalQA
|
||||
from langchain_community.chains.pebblo_retrieval.models import (
|
||||
AuthContext,
|
||||
ChainInput,
|
||||
SemanticContext,
|
||||
)
|
||||
from langchain_community.vectorstores.chroma import Chroma
|
||||
from langchain_community.vectorstores.pinecone import Pinecone
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
class FakeRetriever(VectorStoreRetriever):
|
||||
"""
|
||||
Test util that parrots the query back as documents
|
||||
"""
|
||||
|
||||
vectorstore: VectorStore = Mock()
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
return [Document(page_content=query)]
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
return [Document(page_content=query)]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def unsupported_retriever() -> FakeRetriever:
|
||||
"""
|
||||
Create a FakeRetriever instance
|
||||
"""
|
||||
retriever = FakeRetriever()
|
||||
retriever.search_kwargs = {}
|
||||
# Set the class of vectorstore to Chroma
|
||||
retriever.vectorstore.__class__ = Chroma
|
||||
return retriever
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def retriever() -> FakeRetriever:
|
||||
"""
|
||||
Create a FakeRetriever instance
|
||||
"""
|
||||
retriever = FakeRetriever()
|
||||
retriever.search_kwargs = {}
|
||||
# Set the class of vectorstore to Pinecone
|
||||
retriever.vectorstore.__class__ = Pinecone
|
||||
return retriever
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pebblo_retrieval_qa(retriever: FakeRetriever) -> PebbloRetrievalQA:
|
||||
"""
|
||||
Create a PebbloRetrievalQA instance
|
||||
"""
|
||||
pebblo_retrieval_qa = PebbloRetrievalQA.from_chain_type(
|
||||
llm=FakeLLM(), chain_type="stuff", retriever=retriever
|
||||
)
|
||||
|
||||
return pebblo_retrieval_qa
|
||||
|
||||
|
||||
def test_invoke(pebblo_retrieval_qa: PebbloRetrievalQA) -> None:
|
||||
"""
|
||||
Test that the invoke method returns a non-None result
|
||||
"""
|
||||
# Create a fake auth context and semantic context
|
||||
auth_context = AuthContext(
|
||||
user_id="fake_user@email.com",
|
||||
user_auth=["fake-group", "fake-group2"],
|
||||
)
|
||||
semantic_context_dict = {
|
||||
"pebblo_semantic_topics": {"deny": ["harmful-advice"]},
|
||||
"pebblo_semantic_entities": {"deny": ["credit-card"]},
|
||||
}
|
||||
semantic_context = SemanticContext(**semantic_context_dict)
|
||||
|
||||
question = "What is the meaning of life?"
|
||||
|
||||
chain_input_obj = ChainInput(
|
||||
query=question, auth_context=auth_context, semantic_context=semantic_context
|
||||
)
|
||||
response = pebblo_retrieval_qa.invoke(chain_input_obj.dict())
|
||||
assert response is not None
|
||||
|
||||
|
||||
def test_validate_vectorstore(
|
||||
retriever: FakeRetriever, unsupported_retriever: FakeRetriever
|
||||
) -> None:
|
||||
"""
|
||||
Test vectorstore validation
|
||||
"""
|
||||
|
||||
# No exception should be raised for supported vectorstores (Pinecone)
|
||||
_ = PebbloRetrievalQA.from_chain_type(
|
||||
llm=FakeLLM(),
|
||||
chain_type="stuff",
|
||||
retriever=retriever,
|
||||
)
|
||||
|
||||
# validate_vectorstore method should raise a ValueError for unsupported vectorstores
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
_ = PebbloRetrievalQA.from_chain_type(
|
||||
llm=FakeLLM(),
|
||||
chain_type="stuff",
|
||||
retriever=unsupported_retriever,
|
||||
)
|
||||
assert (
|
||||
"Vectorstore must be an instance of one of the supported vectorstores"
|
||||
in str(exc_info.value)
|
||||
)
|
Loading…
Reference in New Issue