QoL improvements to query constructor (#11504)

updating query constructor and self query retriever to
- make it easier to pass in examples
- validate attributes used in query
- remove invalid parts of query
- make it easier to get + edit prompt
- make query constructor a runnable
- make self query retriever use as runnable
pull/11555/head
Bagatur 9 months ago committed by GitHub
parent eec53fa294
commit e7a0def1bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,11 +2,14 @@
from __future__ import annotations
import json
from typing import Any, Callable, List, Optional, Sequence
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast
from langchain.chains.llm import LLMChain
from langchain.chains.query_constructor.ir import (
Comparator,
Comparison,
FilterDirective,
Operation,
Operator,
StructuredQuery,
)
@ -14,17 +17,21 @@ from langchain.chains.query_constructor.parser import get_parser
from langchain.chains.query_constructor.prompt import (
DEFAULT_EXAMPLES,
DEFAULT_PREFIX,
DEFAULT_SCHEMA,
DEFAULT_SCHEMA_PROMPT,
DEFAULT_SUFFIX,
EXAMPLE_PROMPT,
EXAMPLES_WITH_LIMIT,
SCHEMA_WITH_LIMIT,
PREFIX_WITH_DATA_SOURCE,
SCHEMA_WITH_LIMIT_PROMPT,
SUFFIX_WITHOUT_DATA_SOURCE,
USER_SPECIFIED_EXAMPLE_PROMPT,
)
from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.output_parsers.json import parse_and_check_json_markdown
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.schema import BaseOutputParser, BasePromptTemplate, OutputParserException
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.runnable import Runnable
class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
@ -59,6 +66,8 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
cls,
allowed_comparators: Optional[Sequence[Comparator]] = None,
allowed_operators: Optional[Sequence[Operator]] = None,
allowed_attributes: Optional[Sequence[str]] = None,
fix_invalid: bool = False,
) -> StructuredQueryOutputParser:
"""
Create a structured query output parser from components.
@ -70,13 +79,73 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
Returns:
a structured query output parser
"""
ast_parser = get_parser(
allowed_comparators=allowed_comparators, allowed_operators=allowed_operators
)
return cls(ast_parse=ast_parser.parse)
ast_parse: Callable
if fix_invalid:
def ast_parse(raw_filter: str) -> Optional[FilterDirective]:
filter = cast(Optional[FilterDirective], get_parser().parse(raw_filter))
fixed = fix_filter_directive(
filter,
allowed_comparators=allowed_comparators,
allowed_operators=allowed_operators,
allowed_attributes=allowed_attributes,
)
return fixed
else:
ast_parse = get_parser(
allowed_comparators=allowed_comparators,
allowed_operators=allowed_operators,
allowed_attributes=allowed_attributes,
).parse
return cls(ast_parse=ast_parse)
def fix_filter_directive(
filter: Optional[FilterDirective],
*,
allowed_comparators: Optional[Sequence[Comparator]] = None,
allowed_operators: Optional[Sequence[Operator]] = None,
allowed_attributes: Optional[Sequence[str]] = None,
) -> Optional[FilterDirective]:
if (
not (allowed_comparators or allowed_operators or allowed_attributes)
) or not filter:
return filter
elif isinstance(filter, Comparison):
if allowed_comparators and filter.comparator not in allowed_comparators:
return None
if allowed_attributes and filter.attribute not in allowed_attributes:
return None
return filter
elif isinstance(filter, Operation):
if allowed_operators and filter.operator not in allowed_operators:
return None
args = [
fix_filter_directive(
arg,
allowed_comparators=allowed_comparators,
allowed_operators=allowed_operators,
allowed_attributes=allowed_attributes,
)
for arg in filter.arguments
]
args = [arg for arg in args if arg is not None]
if not args:
return None
elif len(args) == 1 and filter.operator in (Operator.AND, Operator.OR):
return args[0]
else:
return Operation(
operator=filter.operator,
arguments=args,
)
else:
return filter
def _format_attribute_info(info: Sequence[AttributeInfo]) -> str:
def _format_attribute_info(info: Sequence[Union[AttributeInfo, dict]]) -> str:
info_dicts = {}
for i in info:
i_dict = dict(i)
@ -84,56 +153,90 @@ def _format_attribute_info(info: Sequence[AttributeInfo]) -> str:
return json.dumps(info_dicts, indent=4).replace("{", "{{").replace("}", "}}")
def _get_prompt(
def construct_examples(input_output_pairs: Sequence[Tuple[str, dict]]) -> List[dict]:
examples = []
for i, (_input, output) in enumerate(input_output_pairs):
structured_request = (
json.dumps(output, indent=4).replace("{", "{{").replace("}", "}}")
)
example = {
"i": i + 1,
"user_query": _input,
"structured_request": structured_request,
}
examples.append(example)
return examples
def get_query_constructor_prompt(
document_contents: str,
attribute_info: Sequence[AttributeInfo],
examples: Optional[List] = None,
allowed_comparators: Optional[Sequence[Comparator]] = None,
allowed_operators: Optional[Sequence[Operator]] = None,
attribute_info: Sequence[Union[AttributeInfo, dict]],
*,
examples: Optional[Sequence] = None,
allowed_comparators: Sequence[Comparator] = tuple(Comparator),
allowed_operators: Sequence[Operator] = tuple(Operator),
enable_limit: bool = False,
schema_prompt: Optional[BasePromptTemplate] = None,
**kwargs: Any,
) -> BasePromptTemplate:
"""Create query construction prompt.
Args:
document_contents: The contents of the document to be queried.
attribute_info: A list of AttributeInfo objects describing
the attributes of the document.
examples: Optional list of examples to use for the chain.
allowed_comparators: Sequence of allowed comparators.
allowed_operators: Sequence of allowed operators.
enable_limit: Whether to enable the limit operator. Defaults to False.
schema_prompt: Prompt for describing query schema. Should have string input
variables allowed_comparators and allowed_operators.
**kwargs: Additional named params to pass to FewShotPromptTemplate init.
"""
default_schema_prompt = (
SCHEMA_WITH_LIMIT_PROMPT if enable_limit else DEFAULT_SCHEMA_PROMPT
)
schema_prompt = schema_prompt or default_schema_prompt
attribute_str = _format_attribute_info(attribute_info)
allowed_comparators = allowed_comparators or list(Comparator)
allowed_operators = allowed_operators or list(Operator)
if enable_limit:
schema = SCHEMA_WITH_LIMIT.format(
allowed_comparators=" | ".join(allowed_comparators),
allowed_operators=" | ".join(allowed_operators),
schema = schema_prompt.format(
allowed_comparators=" | ".join(allowed_comparators),
allowed_operators=" | ".join(allowed_operators),
)
if examples and isinstance(examples[0], tuple):
examples = construct_examples(examples)
example_prompt = USER_SPECIFIED_EXAMPLE_PROMPT
prefix = PREFIX_WITH_DATA_SOURCE.format(
schema=schema, content=document_contents, attributes=attribute_str
)
examples = examples or EXAMPLES_WITH_LIMIT
suffix = SUFFIX_WITHOUT_DATA_SOURCE.format(i=len(examples) + 1)
else:
schema = DEFAULT_SCHEMA.format(
allowed_comparators=" | ".join(allowed_comparators),
allowed_operators=" | ".join(allowed_operators),
examples = examples or (
EXAMPLES_WITH_LIMIT if enable_limit else DEFAULT_EXAMPLES
)
example_prompt = EXAMPLE_PROMPT
prefix = DEFAULT_PREFIX.format(schema=schema)
suffix = DEFAULT_SUFFIX.format(
i=len(examples) + 1, content=document_contents, attributes=attribute_str
)
examples = examples or DEFAULT_EXAMPLES
prefix = DEFAULT_PREFIX.format(schema=schema)
suffix = DEFAULT_SUFFIX.format(
i=len(examples) + 1, content=document_contents, attributes=attribute_str
)
output_parser = StructuredQueryOutputParser.from_components(
allowed_comparators=allowed_comparators, allowed_operators=allowed_operators
)
return FewShotPromptTemplate(
examples=examples,
example_prompt=EXAMPLE_PROMPT,
examples=list(examples),
example_prompt=example_prompt,
input_variables=["query"],
suffix=suffix,
prefix=prefix,
output_parser=output_parser,
**kwargs,
)
def load_query_constructor_chain(
llm: BaseLanguageModel,
document_contents: str,
attribute_info: List[AttributeInfo],
attribute_info: Sequence[Union[AttributeInfo, dict]],
examples: Optional[List] = None,
allowed_comparators: Optional[Sequence[Comparator]] = None,
allowed_operators: Optional[Sequence[Operator]] = None,
allowed_comparators: Sequence[Comparator] = tuple(Comparator),
allowed_operators: Sequence[Operator] = tuple(Operator),
enable_limit: bool = False,
schema_prompt: Optional[BasePromptTemplate] = None,
**kwargs: Any,
) -> LLMChain:
"""Load a query constructor chain.
@ -141,25 +244,95 @@ def load_query_constructor_chain(
Args:
llm: BaseLanguageModel to use for the chain.
document_contents: The contents of the document to be queried.
attribute_info: A list of AttributeInfo objects describing
the attributes of the document.
attribute_info: Sequence of attributes in the document.
examples: Optional list of examples to use for the chain.
allowed_comparators: An optional list of allowed comparators.
allowed_operators: An optional list of allowed operators.
allowed_comparators: Sequence of allowed comparators. Defaults to all
Comparators.
allowed_operators: Sequence of allowed operators. Defaults to all Operators.
enable_limit: Whether to enable the limit operator. Defaults to False.
**kwargs:
schema_prompt: Prompt for describing query schema. Should have string input
variables allowed_comparators and allowed_operators.
**kwargs: Arbitrary named params to pass to LLMChain.
Returns:
A LLMChain that can be used to construct queries.
"""
prompt = _get_prompt(
prompt = get_query_constructor_prompt(
document_contents,
attribute_info,
examples=examples,
allowed_comparators=allowed_comparators,
allowed_operators=allowed_operators,
enable_limit=enable_limit,
schema_prompt=schema_prompt,
)
allowed_attributes = []
for ainfo in attribute_info:
allowed_attributes.append(
ainfo.name if isinstance(ainfo, AttributeInfo) else ainfo["name"]
)
output_parser = StructuredQueryOutputParser.from_components(
allowed_comparators=allowed_comparators,
allowed_operators=allowed_operators,
allowed_attributes=allowed_attributes,
)
return LLMChain(
llm=llm, prompt=prompt, output_parser=prompt.output_parser, **kwargs
# For backwards compatibility.
prompt.output_parser = output_parser
return LLMChain(llm=llm, prompt=prompt, output_parser=output_parser, **kwargs)
def load_query_constructor_runnable(
llm: BaseLanguageModel,
document_contents: str,
attribute_info: Sequence[Union[AttributeInfo, dict]],
*,
examples: Optional[Sequence] = None,
allowed_comparators: Sequence[Comparator] = tuple(Comparator),
allowed_operators: Sequence[Operator] = tuple(Operator),
enable_limit: bool = False,
schema_prompt: Optional[BasePromptTemplate] = None,
fix_invalid: bool = False,
**kwargs: Any,
) -> Runnable:
"""Load a query constructor runnable chain.
Args:
llm: BaseLanguageModel to use for the chain.
document_contents: The contents of the document to be queried.
attribute_info: Sequence of attributes in the document.
examples: Optional list of examples to use for the chain.
allowed_comparators: Sequence of allowed comparators. Defaults to all
Comparators.
allowed_operators: Sequence of allowed operators. Defaults to all Operators.
enable_limit: Whether to enable the limit operator. Defaults to False.
schema_prompt: Prompt for describing query schema. Should have string input
variables allowed_comparators and allowed_operators.
fix_invalid: Whether to fix invalid filter directives by ignoring invalid
operators, comparators and attributes.
**kwargs: Additional named params to pass to FewShotPromptTemplate init.
Returns:
A Runnable that can be used to construct queries.
"""
prompt = get_query_constructor_prompt(
document_contents,
attribute_info,
examples=examples,
allowed_comparators=allowed_comparators,
allowed_operators=allowed_operators,
enable_limit=enable_limit,
schema_prompt=schema_prompt,
**kwargs,
)
allowed_attributes = []
for ainfo in attribute_info:
allowed_attributes.append(
ainfo.name if isinstance(ainfo, AttributeInfo) else ainfo["name"]
)
output_parser = StructuredQueryOutputParser.from_components(
allowed_comparators=allowed_comparators,
allowed_operators=allowed_operators,
allowed_attributes=allowed_attributes,
fix_invalid=fix_invalid,
)
return prompt | llm | output_parser

@ -61,11 +61,13 @@ class QueryTransformer(Transformer):
*args: Any,
allowed_comparators: Optional[Sequence[Comparator]] = None,
allowed_operators: Optional[Sequence[Operator]] = None,
allowed_attributes: Optional[Sequence[str]] = None,
**kwargs: Any,
):
super().__init__(*args, **kwargs)
self.allowed_comparators = allowed_comparators
self.allowed_operators = allowed_operators
self.allowed_attributes = allowed_attributes
def program(self, *items: Any) -> tuple:
return items
@ -73,6 +75,11 @@ class QueryTransformer(Transformer):
def func_call(self, func_name: Any, args: list) -> FilterDirective:
func = self._match_func_name(str(func_name))
if isinstance(func, Comparator):
if self.allowed_attributes and args[0] not in self.allowed_attributes:
raise ValueError(
f"Received invalid attributes {args[0]}. Allowed attributes are "
f"{self.allowed_attributes}"
)
return Comparison(comparator=func, attribute=args[0], value=args[1])
elif len(args) == 1 and func in (Operator.AND, Operator.OR):
return args[0]
@ -134,6 +141,7 @@ class QueryTransformer(Transformer):
def get_parser(
allowed_comparators: Optional[Sequence[Comparator]] = None,
allowed_operators: Optional[Sequence[Operator]] = None,
allowed_attributes: Optional[Sequence[str]] = None,
) -> Lark:
"""
Returns a parser for the query language.
@ -151,6 +159,8 @@ def get_parser(
"Cannot import lark, please install it with 'pip install lark'."
)
transformer = QueryTransformer(
allowed_comparators=allowed_comparators, allowed_operators=allowed_operators
allowed_comparators=allowed_comparators,
allowed_operators=allowed_operators,
allowed_attributes=allowed_attributes,
)
return Lark(GRAMMAR, parser="lalr", transformer=transformer, start="program")

@ -3,36 +3,31 @@ from langchain.prompts import PromptTemplate
SONG_DATA_SOURCE = """\
```json
{
{{
"content": "Lyrics of a song",
"attributes": {
"artist": {
"attributes": {{
"artist": {{
"type": "string",
"description": "Name of the song artist"
},
"length": {
}},
"length": {{
"type": "integer",
"description": "Length of the song in seconds"
},
"genre": {
}},
"genre": {{
"type": "string",
"description": "The song genre, one of \"pop\", \"rock\" or \"rap\""
}
}
}
}}
}}
}}
```\
""".replace(
"{", "{{"
).replace(
"}", "}}"
)
"""
FULL_ANSWER = """\
```json
{{
"query": "teenager love",
"filter": "and(or(eq(\\"artist\\", \\"Taylor Swift\\"), eq(\\"artist\\", \\"Katy Perry\\")), \
lt(\\"length\\", 180), eq(\\"genre\\", \\"pop\\"))"
"filter": "and(or(eq(\\"artist\\", \\"Taylor Swift\\"), eq(\\"artist\\", \\"Katy Perry\\")), lt(\\"length\\", 180), eq(\\"genre\\", \\"pop\\"))"
}}
```\
"""
@ -104,16 +99,24 @@ Structured Request:
{structured_request}
"""
EXAMPLE_PROMPT = PromptTemplate(
input_variables=["i", "data_source", "user_query", "structured_request"],
template=EXAMPLE_PROMPT_TEMPLATE,
)
EXAMPLE_PROMPT = PromptTemplate.from_template(EXAMPLE_PROMPT_TEMPLATE)
USER_SPECIFIED_EXAMPLE_PROMPT = PromptTemplate.from_template(
"""\
<< Example {i}. >>
User Query:
{user_query}
Structured Request:
```json
{structured_request}
```
"""
)
DEFAULT_SCHEMA = """\
<< Structured Request Schema >>
When responding use a markdown code snippet with a JSON object formatted in the \
following schema:
When responding use a markdown code snippet with a JSON object formatted in the following schema:
```json
{{{{
@ -122,11 +125,9 @@ following schema:
}}}}
```
The query string should contain only text that is expected to match the contents of \
documents. Any conditions in the filter should not be mentioned in the query as well.
The query string should contain only text that is expected to match the contents of documents. Any conditions in the filter should not be mentioned in the query as well.
A logical condition statement is composed of one or more comparison and logical \
operation statements.
A logical condition statement is composed of one or more comparison and logical operation statements.
A comparison statement takes the form: `comp(attr, val)`:
- `comp` ({allowed_comparators}): comparator
@ -135,24 +136,20 @@ A comparison statement takes the form: `comp(attr, val)`:
A logical operation statement takes the form `op(statement1, statement2, ...)`:
- `op` ({allowed_operators}): logical operator
- `statement1`, `statement2`, ... (comparison statements or logical operation \
statements): one or more statements to apply the operation to
- `statement1`, `statement2`, ... (comparison statements or logical operation statements): one or more statements to apply the operation to
Make sure that you only use the comparators and logical operators listed above and \
no others.
Make sure that you only use the comparators and logical operators listed above and no others.
Make sure that filters only refer to attributes that exist in the data source.
Make sure that filters only use the attributed names with its function names if there are functions applied on them.
Make sure that filters only use format `YYYY-MM-DD` when handling timestamp data typed values.
Make sure that filters take into account the descriptions of attributes and only make \
comparisons that are feasible given the type of data being stored.
Make sure that filters are only used as needed. If there are no filters that should be \
applied return "NO_FILTER" for the filter value.\
Make sure that filters take into account the descriptions of attributes and only make comparisons that are feasible given the type of data being stored.
Make sure that filters are only used as needed. If there are no filters that should be applied return "NO_FILTER" for the filter value.\
"""
DEFAULT_SCHEMA_PROMPT = PromptTemplate.from_template(DEFAULT_SCHEMA)
SCHEMA_WITH_LIMIT = """\
<< Structured Request Schema >>
When responding use a markdown code snippet with a JSON object formatted in the \
following schema:
When responding use a markdown code snippet with a JSON object formatted in the following schema:
```json
{{{{
@ -162,11 +159,9 @@ following schema:
}}}}
```
The query string should contain only text that is expected to match the contents of \
documents. Any conditions in the filter should not be mentioned in the query as well.
The query string should contain only text that is expected to match the contents of documents. Any conditions in the filter should not be mentioned in the query as well.
A logical condition statement is composed of one or more comparison and logical \
operation statements.
A logical condition statement is composed of one or more comparison and logical operation statements.
A comparison statement takes the form: `comp(attr, val)`:
- `comp` ({allowed_comparators}): comparator
@ -175,20 +170,17 @@ A comparison statement takes the form: `comp(attr, val)`:
A logical operation statement takes the form `op(statement1, statement2, ...)`:
- `op` ({allowed_operators}): logical operator
- `statement1`, `statement2`, ... (comparison statements or logical operation \
statements): one or more statements to apply the operation to
- `statement1`, `statement2`, ... (comparison statements or logical operation statements): one or more statements to apply the operation to
Make sure that you only use the comparators and logical operators listed above and \
no others.
Make sure that you only use the comparators and logical operators listed above and no others.
Make sure that filters only refer to attributes that exist in the data source.
Make sure that filters only use the attributed names with its function names if there are functions applied on them.
Make sure that filters only use format `YYYY-MM-DD` when handling timestamp data typed values.
Make sure that filters take into account the descriptions of attributes and only make \
comparisons that are feasible given the type of data being stored.
Make sure that filters are only used as needed. If there are no filters that should be \
applied return "NO_FILTER" for the filter value.
Make sure the `limit` is always an int value. It is an optional parameter so leave it blank if it is does not make sense.
Make sure that filters take into account the descriptions of attributes and only make comparisons that are feasible given the type of data being stored.
Make sure that filters are only used as needed. If there are no filters that should be applied return "NO_FILTER" for the filter value.
Make sure the `limit` is always an int value. It is an optional parameter so leave it blank if it does not make sense.
"""
SCHEMA_WITH_LIMIT_PROMPT = PromptTemplate.from_template(SCHEMA_WITH_LIMIT)
DEFAULT_PREFIX = """\
Your goal is to structure the user's query to match the request schema provided below.
@ -196,6 +188,20 @@ Your goal is to structure the user's query to match the request schema provided
{schema}\
"""
PREFIX_WITH_DATA_SOURCE = (
DEFAULT_PREFIX
+ """
<< Data Source >>
```json
{{{{
"content": "{content}",
"attributes": {attributes}
}}}}
```
"""
)
DEFAULT_SUFFIX = """\
<< Example {i}. >>
Data Source:
@ -211,3 +217,11 @@ User Query:
Structured Request:
"""
SUFFIX_WITHOUT_DATA_SOURCE = """\
<< Example {i}. >>
User Query:
{{query}}
Structured Request:
"""

@ -1,13 +1,12 @@
"""Retriever that generates and executes structured queries over its own data source."""
import logging
from typing import Any, Dict, List, Optional, Tuple, Type, cast
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.chains import LLMChain
from langchain.chains.query_constructor.base import load_query_constructor_chain
from langchain.chains.query_constructor.base import load_query_constructor_runnable
from langchain.chains.query_constructor.ir import StructuredQuery, Visitor
from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.pydantic_v1 import BaseModel, Field, root_validator
@ -27,6 +26,7 @@ from langchain.retrievers.self_query.vectara import VectaraTranslator
from langchain.retrievers.self_query.weaviate import WeaviateTranslator
from langchain.schema import BaseRetriever, Document
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.runnable import Runnable
from langchain.vectorstores import (
Chroma,
DashVector,
@ -86,8 +86,10 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
vectorstore: VectorStore
"""The underlying vector store from which documents will be retrieved."""
llm_chain: LLMChain
"""The LLMChain for generating the vector store queries."""
query_constructor: Runnable[dict, StructuredQuery] = Field(alias="llm_chain")
"""The query constructor chain for generating the vector store queries.
llm_chain is legacy name kept for backwards compatibility."""
search_type: str = "similarity"
"""The search type to perform on the vector store."""
search_kwargs: dict = Field(default_factory=dict)
@ -103,6 +105,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
allow_population_by_field_name = True
@root_validator(pre=True)
def validate_translator(cls, values: Dict) -> Dict:
@ -113,23 +116,10 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
)
return values
def _get_structured_query(
self, inputs: Dict[str, Any], run_manager: CallbackManagerForRetrieverRun
) -> StructuredQuery:
structured_query = cast(
StructuredQuery,
self.llm_chain.predict(callbacks=run_manager.get_child(), **inputs),
)
return structured_query
async def _aget_structured_query(
self, inputs: Dict[str, Any], run_manager: AsyncCallbackManagerForRetrieverRun
) -> StructuredQuery:
structured_query = cast(
StructuredQuery,
await self.llm_chain.apredict(callbacks=run_manager.get_child(), **inputs),
)
return structured_query
@property
def llm_chain(self) -> Runnable:
"""llm_chain is legacy name kept for backwards compatibility."""
return self.query_constructor
def _prepare_query(
self, query: str, structured_query: StructuredQuery
@ -167,8 +157,9 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
Returns:
List of relevant documents
"""
inputs = self.llm_chain.prep_inputs({"query": query})
structured_query = self._get_structured_query(inputs, run_manager)
structured_query = self.query_constructor.invoke(
{"query": query}, config={"callbacks": run_manager.get_child()}
)
if self.verbose:
logger.info(f"Generated Query: {structured_query}")
new_query, search_kwargs = self._prepare_query(query, structured_query)
@ -186,8 +177,9 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
Returns:
List of relevant documents
"""
inputs = self.llm_chain.prep_inputs({"query": query})
structured_query = await self._aget_structured_query(inputs, run_manager)
structured_query = await self.query_constructor.ainvoke(
{"query": query}, config={"callbacks": run_manager.get_child()}
)
if self.verbose:
logger.info(f"Generated Query: {structured_query}")
new_query, search_kwargs = self._prepare_query(query, structured_query)
@ -200,7 +192,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
llm: BaseLanguageModel,
vectorstore: VectorStore,
document_contents: str,
metadata_field_info: List[AttributeInfo],
metadata_field_info: Sequence[Union[AttributeInfo, dict]],
structured_query_translator: Optional[Visitor] = None,
chain_kwargs: Optional[Dict] = None,
enable_limit: bool = False,
@ -219,7 +211,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
chain_kwargs[
"allowed_operators"
] = structured_query_translator.allowed_operators
llm_chain = load_query_constructor_chain(
query_constructor = load_query_constructor_runnable(
llm,
document_contents,
metadata_field_info,
@ -227,7 +219,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
**chain_kwargs,
)
return cls(
llm_chain=llm_chain,
query_constructor=query_constructor,
vectorstore=vectorstore,
use_original_query=use_original_query,
structured_query_translator=structured_query_translator,

Loading…
Cancel
Save