Add async support to SelfQueryRetriever (#10175)

### Description

SelfQueryRetriever is missing async support, so I am adding it.
I also removed deprecated predict_and_parse method usage here, and added
some tests.

### Issue
N/A

### Tag maintainer
Not yet

### Twitter handle
N/A
pull/11461/head
Viktor Zhemchuzhnikov 10 months ago committed by GitHub
parent 35297ca0d3
commit fd9da60aea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -160,4 +160,6 @@ def load_query_constructor_chain(
allowed_operators=allowed_operators,
enable_limit=enable_limit,
)
return LLMChain(llm=llm, prompt=prompt, **kwargs)
return LLMChain(
llm=llm, prompt=prompt, output_parser=prompt.output_parser, **kwargs
)

@ -1,8 +1,11 @@
"""Retriever that generates and executes structured queries over its own data source."""
import logging
from typing import Any, Dict, List, Optional, Tuple, Type, cast
from typing import Any, Dict, List, Optional, Type, cast
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.chains import LLMChain
from langchain.chains.query_constructor.base import load_query_constructor_chain
from langchain.chains.query_constructor.ir import StructuredQuery, Visitor
@ -42,6 +45,8 @@ from langchain.vectorstores import (
Weaviate,
)
logger = logging.getLogger(__name__)
def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
"""Get the translator class corresponding to the vector store class."""
@ -108,6 +113,49 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
)
return values
def _get_structured_query(
self, inputs: Dict[str, Any], run_manager: CallbackManagerForRetrieverRun
) -> StructuredQuery:
structured_query = cast(
StructuredQuery,
self.llm_chain.predict(callbacks=run_manager.get_child(), **inputs),
)
return structured_query
async def _aget_structured_query(
self, inputs: Dict[str, Any], run_manager: AsyncCallbackManagerForRetrieverRun
) -> StructuredQuery:
structured_query = cast(
StructuredQuery,
await self.llm_chain.apredict(callbacks=run_manager.get_child(), **inputs),
)
return structured_query
def _prepare_query(
self, query: str, structured_query: StructuredQuery
) -> Tuple[str, Dict[str, Any]]:
new_query, new_kwargs = self.structured_query_translator.visit_structured_query(
structured_query
)
if structured_query.limit is not None:
new_kwargs["k"] = structured_query.limit
if self.use_original_query:
new_query = query
search_kwargs = {**self.search_kwargs, **new_kwargs}
return new_query, search_kwargs
def _get_docs_with_query(
self, query: str, search_kwargs: Dict[str, Any]
) -> List[Document]:
docs = self.vectorstore.search(query, self.search_type, **search_kwargs)
return docs
async def _aget_docs_with_query(
self, query: str, search_kwargs: Dict[str, Any]
) -> List[Document]:
docs = await self.vectorstore.asearch(query, self.search_type, **search_kwargs)
return docs
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
@ -120,25 +168,30 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
List of relevant documents
"""
inputs = self.llm_chain.prep_inputs({"query": query})
structured_query = cast(
StructuredQuery,
self.llm_chain.predict_and_parse(
callbacks=run_manager.get_child(), **inputs
),
)
structured_query = self._get_structured_query(inputs, run_manager)
if self.verbose:
print(structured_query)
new_query, new_kwargs = self.structured_query_translator.visit_structured_query(
structured_query
)
if structured_query.limit is not None:
new_kwargs["k"] = structured_query.limit
logger.info(f"Generated Query: {structured_query}")
new_query, search_kwargs = self._prepare_query(query, structured_query)
docs = self._get_docs_with_query(new_query, search_kwargs)
return docs
if self.use_original_query:
new_query = query
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant for a query.
search_kwargs = {**self.search_kwargs, **new_kwargs}
docs = self.vectorstore.search(new_query, self.search_type, **search_kwargs)
Args:
query: string to find relevant documents for
Returns:
List of relevant documents
"""
inputs = self.llm_chain.prep_inputs({"query": query})
structured_query = await self._aget_structured_query(inputs, run_manager)
if self.verbose:
logger.info(f"Generated Query: {structured_query}")
new_query, search_kwargs = self._prepare_query(query, structured_query)
docs = await self._aget_docs_with_query(new_query, search_kwargs)
return docs
@classmethod

@ -0,0 +1,142 @@
from typing import Any, Dict, List, Tuple, Union
import pytest
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.chains.query_constructor.ir import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)
from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.retrievers import SelfQueryRetriever
from langchain.schema import Document
from tests.unit_tests.indexes.test_indexing import InMemoryVectorStore
from tests.unit_tests.llms.fake_llm import FakeLLM
class FakeTranslator(Visitor):
allowed_comparators = (
Comparator.EQ,
Comparator.NE,
Comparator.LT,
Comparator.LTE,
Comparator.GT,
Comparator.GTE,
Comparator.CONTAIN,
Comparator.LIKE,
)
allowed_operators = (Operator.AND, Operator.OR, Operator.NOT)
def _format_func(self, func: Union[Operator, Comparator]) -> str:
self._validate_func(func)
return f"${func.value}"
def visit_operation(self, operation: Operation) -> Dict:
args = [arg.accept(self) for arg in operation.arguments]
return {self._format_func(operation.operator): args}
def visit_comparison(self, comparison: Comparison) -> Dict:
return {
comparison.attribute: {
self._format_func(comparison.comparator): comparison.value
}
}
def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {"filter": structured_query.filter.accept(self)}
return structured_query.query, kwargs
class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
res = self.store.get(query)
if res is None:
return []
return [res]
@pytest.fixture()
def fake_llm() -> FakeLLM:
return FakeLLM(
queries={
"1": """```json
{
"query": "test",
"filter": null
}
```""",
"bar": "baz",
},
sequential_responses=True,
)
@pytest.fixture()
def fake_vectorstore() -> InMemoryVectorstoreWithSearch:
vectorstore = InMemoryVectorstoreWithSearch()
vectorstore.add_documents(
[
Document(
page_content="test",
metadata={
"foo": "bar",
},
),
],
ids=["test"],
)
return vectorstore
@pytest.fixture()
def fake_self_query_retriever(
fake_llm: FakeLLM, fake_vectorstore: InMemoryVectorstoreWithSearch
) -> SelfQueryRetriever:
return SelfQueryRetriever.from_llm(
llm=fake_llm,
vectorstore=fake_vectorstore,
document_contents="test",
metadata_field_info=[
AttributeInfo(
name="foo",
type="string",
description="test",
),
],
structured_query_translator=FakeTranslator(),
)
def test__get_relevant_documents(fake_self_query_retriever: SelfQueryRetriever) -> None:
relevant_documents = fake_self_query_retriever._get_relevant_documents(
"foo",
run_manager=CallbackManagerForRetrieverRun.get_noop_manager(),
)
assert len(relevant_documents) == 1
assert relevant_documents[0].metadata["foo"] == "bar"
@pytest.mark.asyncio
async def test__aget_relevant_documents(
fake_self_query_retriever: SelfQueryRetriever,
) -> None:
relevant_documents = await fake_self_query_retriever._aget_relevant_documents(
"foo",
run_manager=AsyncCallbackManagerForRetrieverRun.get_noop_manager(),
)
assert len(relevant_documents) == 1
assert relevant_documents[0].metadata["foo"] == "bar"
Loading…
Cancel
Save