mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
77ad857934
Description: Enable app discovery and Prompt/Response apis in PebbloSafeRetrieval Documentation: NA Unit test: N/A --------- Signed-off-by: Rahul Tripathi <rauhl.psit.ec@gmail.com> Co-authored-by: Rahul Tripathi <rauhl.psit.ec@gmail.com>
141 lines
3.9 KiB
Python
141 lines
3.9 KiB
Python
"""
|
|
Unit tests for the PebbloRetrievalQA chain
|
|
"""
|
|
|
|
from typing import List
|
|
from unittest.mock import Mock
|
|
|
|
import pytest
|
|
from langchain_core.callbacks import (
|
|
AsyncCallbackManagerForRetrieverRun,
|
|
CallbackManagerForRetrieverRun,
|
|
)
|
|
from langchain_core.documents import Document
|
|
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
|
|
|
|
from langchain_community.chains import PebbloRetrievalQA
|
|
from langchain_community.chains.pebblo_retrieval.models import (
|
|
AuthContext,
|
|
ChainInput,
|
|
SemanticContext,
|
|
)
|
|
from langchain_community.vectorstores.chroma import Chroma
|
|
from langchain_community.vectorstores.pinecone import Pinecone
|
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
|
|
|
|
|
class FakeRetriever(VectorStoreRetriever):
|
|
"""
|
|
Test util that parrots the query back as documents
|
|
"""
|
|
|
|
vectorstore: VectorStore = Mock()
|
|
|
|
def _get_relevant_documents(
|
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
|
) -> List[Document]:
|
|
return [Document(page_content=query)]
|
|
|
|
async def _aget_relevant_documents(
|
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
|
) -> List[Document]:
|
|
return [Document(page_content=query)]
|
|
|
|
|
|
@pytest.fixture
|
|
def unsupported_retriever() -> FakeRetriever:
|
|
"""
|
|
Create a FakeRetriever instance
|
|
"""
|
|
retriever = FakeRetriever()
|
|
retriever.search_kwargs = {}
|
|
# Set the class of vectorstore to Chroma
|
|
retriever.vectorstore.__class__ = Chroma
|
|
return retriever
|
|
|
|
|
|
@pytest.fixture
|
|
def retriever() -> FakeRetriever:
|
|
"""
|
|
Create a FakeRetriever instance
|
|
"""
|
|
retriever = FakeRetriever()
|
|
retriever.search_kwargs = {}
|
|
# Set the class of vectorstore to Pinecone
|
|
retriever.vectorstore.__class__ = Pinecone
|
|
return retriever
|
|
|
|
|
|
@pytest.fixture
|
|
def pebblo_retrieval_qa(retriever: FakeRetriever) -> PebbloRetrievalQA:
|
|
"""
|
|
Create a PebbloRetrievalQA instance
|
|
"""
|
|
pebblo_retrieval_qa = PebbloRetrievalQA.from_chain_type(
|
|
llm=FakeLLM(),
|
|
chain_type="stuff",
|
|
retriever=retriever,
|
|
owner="owner",
|
|
description="description",
|
|
app_name="app_name",
|
|
)
|
|
|
|
return pebblo_retrieval_qa
|
|
|
|
|
|
def test_invoke(pebblo_retrieval_qa: PebbloRetrievalQA) -> None:
|
|
"""
|
|
Test that the invoke method returns a non-None result
|
|
"""
|
|
# Create a fake auth context and semantic context
|
|
auth_context = AuthContext(
|
|
user_id="fake_user@email.com",
|
|
user_auth=["fake-group", "fake-group2"],
|
|
)
|
|
semantic_context_dict = {
|
|
"pebblo_semantic_topics": {"deny": ["harmful-advice"]},
|
|
"pebblo_semantic_entities": {"deny": ["credit-card"]},
|
|
}
|
|
semantic_context = SemanticContext(**semantic_context_dict)
|
|
|
|
question = "What is the meaning of life?"
|
|
|
|
chain_input_obj = ChainInput(
|
|
query=question, auth_context=auth_context, semantic_context=semantic_context
|
|
)
|
|
response = pebblo_retrieval_qa.invoke(chain_input_obj.dict())
|
|
assert response is not None
|
|
|
|
|
|
def test_validate_vectorstore(
|
|
retriever: FakeRetriever, unsupported_retriever: FakeRetriever
|
|
) -> None:
|
|
"""
|
|
Test vectorstore validation
|
|
"""
|
|
|
|
# No exception should be raised for supported vectorstores (Pinecone)
|
|
_ = PebbloRetrievalQA.from_chain_type(
|
|
llm=FakeLLM(),
|
|
chain_type="stuff",
|
|
retriever=retriever,
|
|
owner="owner",
|
|
description="description",
|
|
app_name="app_name",
|
|
)
|
|
|
|
# validate_vectorstore method should raise a ValueError for unsupported vectorstores
|
|
with pytest.raises(ValueError) as exc_info:
|
|
_ = PebbloRetrievalQA.from_chain_type(
|
|
llm=FakeLLM(),
|
|
chain_type="stuff",
|
|
retriever=unsupported_retriever,
|
|
owner="owner",
|
|
description="description",
|
|
app_name="app_name",
|
|
)
|
|
assert (
|
|
"Vectorstore must be an instance of one of the supported vectorstores"
|
|
in str(exc_info.value)
|
|
)
|