From e7a0def1bc1d64209486e0f214f14e1faea5c480 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Mon, 9 Oct 2023 08:10:52 -0700 Subject: [PATCH] QoL improvements to query constructor (#11504) updating query constructor and self query retriever to - make it easier to pass in examples - validate attributes used in query - remove invalid parts of query - make it easier to get + edit prompt - make query constructor a runnable - make self query retriever use as runnable --- .../chains/query_constructor/base.py | 267 +++++++++++++++--- .../chains/query_constructor/parser.py | 12 +- .../chains/query_constructor/prompt.py | 114 ++++---- .../langchain/retrievers/self_query/base.py | 50 ++-- 4 files changed, 316 insertions(+), 127 deletions(-) diff --git a/libs/langchain/langchain/chains/query_constructor/base.py b/libs/langchain/langchain/chains/query_constructor/base.py index 266ea58bd2..5bd58811dd 100644 --- a/libs/langchain/langchain/chains/query_constructor/base.py +++ b/libs/langchain/langchain/chains/query_constructor/base.py @@ -2,11 +2,14 @@ from __future__ import annotations import json -from typing import Any, Callable, List, Optional, Sequence +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast from langchain.chains.llm import LLMChain from langchain.chains.query_constructor.ir import ( Comparator, + Comparison, + FilterDirective, + Operation, Operator, StructuredQuery, ) @@ -14,17 +17,21 @@ from langchain.chains.query_constructor.parser import get_parser from langchain.chains.query_constructor.prompt import ( DEFAULT_EXAMPLES, DEFAULT_PREFIX, - DEFAULT_SCHEMA, + DEFAULT_SCHEMA_PROMPT, DEFAULT_SUFFIX, EXAMPLE_PROMPT, EXAMPLES_WITH_LIMIT, - SCHEMA_WITH_LIMIT, + PREFIX_WITH_DATA_SOURCE, + SCHEMA_WITH_LIMIT_PROMPT, + SUFFIX_WITHOUT_DATA_SOURCE, + USER_SPECIFIED_EXAMPLE_PROMPT, ) from langchain.chains.query_constructor.schema import AttributeInfo from langchain.output_parsers.json import parse_and_check_json_markdown from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.schema import BaseOutputParser, BasePromptTemplate, OutputParserException from langchain.schema.language_model import BaseLanguageModel +from langchain.schema.runnable import Runnable class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]): @@ -59,6 +66,8 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]): cls, allowed_comparators: Optional[Sequence[Comparator]] = None, allowed_operators: Optional[Sequence[Operator]] = None, + allowed_attributes: Optional[Sequence[str]] = None, + fix_invalid: bool = False, ) -> StructuredQueryOutputParser: """ Create a structured query output parser from components. @@ -70,13 +79,73 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]): Returns: a structured query output parser """ - ast_parser = get_parser( - allowed_comparators=allowed_comparators, allowed_operators=allowed_operators - ) - return cls(ast_parse=ast_parser.parse) + ast_parse: Callable + if fix_invalid: + + def ast_parse(raw_filter: str) -> Optional[FilterDirective]: + filter = cast(Optional[FilterDirective], get_parser().parse(raw_filter)) + fixed = fix_filter_directive( + filter, + allowed_comparators=allowed_comparators, + allowed_operators=allowed_operators, + allowed_attributes=allowed_attributes, + ) + return fixed + + else: + ast_parse = get_parser( + allowed_comparators=allowed_comparators, + allowed_operators=allowed_operators, + allowed_attributes=allowed_attributes, + ).parse + return cls(ast_parse=ast_parse) + + +def fix_filter_directive( + filter: Optional[FilterDirective], + *, + allowed_comparators: Optional[Sequence[Comparator]] = None, + allowed_operators: Optional[Sequence[Operator]] = None, + allowed_attributes: Optional[Sequence[str]] = None, +) -> Optional[FilterDirective]: + if ( + not (allowed_comparators or allowed_operators or allowed_attributes) + ) or not filter: + return filter + + elif isinstance(filter, Comparison): + if allowed_comparators and filter.comparator not in allowed_comparators: + return None + if allowed_attributes and filter.attribute not in allowed_attributes: + return None + return filter + elif isinstance(filter, Operation): + if allowed_operators and filter.operator not in allowed_operators: + return None + args = [ + fix_filter_directive( + arg, + allowed_comparators=allowed_comparators, + allowed_operators=allowed_operators, + allowed_attributes=allowed_attributes, + ) + for arg in filter.arguments + ] + args = [arg for arg in args if arg is not None] + if not args: + return None + elif len(args) == 1 and filter.operator in (Operator.AND, Operator.OR): + return args[0] + else: + return Operation( + operator=filter.operator, + arguments=args, + ) + else: + return filter -def _format_attribute_info(info: Sequence[AttributeInfo]) -> str: +def _format_attribute_info(info: Sequence[Union[AttributeInfo, dict]]) -> str: info_dicts = {} for i in info: i_dict = dict(i) @@ -84,56 +153,90 @@ def _format_attribute_info(info: Sequence[AttributeInfo]) -> str: return json.dumps(info_dicts, indent=4).replace("{", "{{").replace("}", "}}") -def _get_prompt( +def construct_examples(input_output_pairs: Sequence[Tuple[str, dict]]) -> List[dict]: + examples = [] + for i, (_input, output) in enumerate(input_output_pairs): + structured_request = ( + json.dumps(output, indent=4).replace("{", "{{").replace("}", "}}") + ) + example = { + "i": i + 1, + "user_query": _input, + "structured_request": structured_request, + } + examples.append(example) + return examples + + +def get_query_constructor_prompt( document_contents: str, - attribute_info: Sequence[AttributeInfo], - examples: Optional[List] = None, - allowed_comparators: Optional[Sequence[Comparator]] = None, - allowed_operators: Optional[Sequence[Operator]] = None, + attribute_info: Sequence[Union[AttributeInfo, dict]], + *, + examples: Optional[Sequence] = None, + allowed_comparators: Sequence[Comparator] = tuple(Comparator), + allowed_operators: Sequence[Operator] = tuple(Operator), enable_limit: bool = False, + schema_prompt: Optional[BasePromptTemplate] = None, + **kwargs: Any, ) -> BasePromptTemplate: + """Create query construction prompt. + + Args: + document_contents: The contents of the document to be queried. + attribute_info: A list of AttributeInfo objects describing + the attributes of the document. + examples: Optional list of examples to use for the chain. + allowed_comparators: Sequence of allowed comparators. + allowed_operators: Sequence of allowed operators. + enable_limit: Whether to enable the limit operator. Defaults to False. + schema_prompt: Prompt for describing query schema. Should have string input + variables allowed_comparators and allowed_operators. + **kwargs: Additional named params to pass to FewShotPromptTemplate init. + """ + default_schema_prompt = ( + SCHEMA_WITH_LIMIT_PROMPT if enable_limit else DEFAULT_SCHEMA_PROMPT + ) + schema_prompt = schema_prompt or default_schema_prompt attribute_str = _format_attribute_info(attribute_info) - allowed_comparators = allowed_comparators or list(Comparator) - allowed_operators = allowed_operators or list(Operator) - if enable_limit: - schema = SCHEMA_WITH_LIMIT.format( - allowed_comparators=" | ".join(allowed_comparators), - allowed_operators=" | ".join(allowed_operators), + schema = schema_prompt.format( + allowed_comparators=" | ".join(allowed_comparators), + allowed_operators=" | ".join(allowed_operators), + ) + if examples and isinstance(examples[0], tuple): + examples = construct_examples(examples) + example_prompt = USER_SPECIFIED_EXAMPLE_PROMPT + prefix = PREFIX_WITH_DATA_SOURCE.format( + schema=schema, content=document_contents, attributes=attribute_str ) - - examples = examples or EXAMPLES_WITH_LIMIT + suffix = SUFFIX_WITHOUT_DATA_SOURCE.format(i=len(examples) + 1) else: - schema = DEFAULT_SCHEMA.format( - allowed_comparators=" | ".join(allowed_comparators), - allowed_operators=" | ".join(allowed_operators), + examples = examples or ( + EXAMPLES_WITH_LIMIT if enable_limit else DEFAULT_EXAMPLES + ) + example_prompt = EXAMPLE_PROMPT + prefix = DEFAULT_PREFIX.format(schema=schema) + suffix = DEFAULT_SUFFIX.format( + i=len(examples) + 1, content=document_contents, attributes=attribute_str ) - - examples = examples or DEFAULT_EXAMPLES - prefix = DEFAULT_PREFIX.format(schema=schema) - suffix = DEFAULT_SUFFIX.format( - i=len(examples) + 1, content=document_contents, attributes=attribute_str - ) - output_parser = StructuredQueryOutputParser.from_components( - allowed_comparators=allowed_comparators, allowed_operators=allowed_operators - ) return FewShotPromptTemplate( - examples=examples, - example_prompt=EXAMPLE_PROMPT, + examples=list(examples), + example_prompt=example_prompt, input_variables=["query"], suffix=suffix, prefix=prefix, - output_parser=output_parser, + **kwargs, ) def load_query_constructor_chain( llm: BaseLanguageModel, document_contents: str, - attribute_info: List[AttributeInfo], + attribute_info: Sequence[Union[AttributeInfo, dict]], examples: Optional[List] = None, - allowed_comparators: Optional[Sequence[Comparator]] = None, - allowed_operators: Optional[Sequence[Operator]] = None, + allowed_comparators: Sequence[Comparator] = tuple(Comparator), + allowed_operators: Sequence[Operator] = tuple(Operator), enable_limit: bool = False, + schema_prompt: Optional[BasePromptTemplate] = None, **kwargs: Any, ) -> LLMChain: """Load a query constructor chain. @@ -141,25 +244,95 @@ def load_query_constructor_chain( Args: llm: BaseLanguageModel to use for the chain. document_contents: The contents of the document to be queried. - attribute_info: A list of AttributeInfo objects describing - the attributes of the document. + attribute_info: Sequence of attributes in the document. examples: Optional list of examples to use for the chain. - allowed_comparators: An optional list of allowed comparators. - allowed_operators: An optional list of allowed operators. + allowed_comparators: Sequence of allowed comparators. Defaults to all + Comparators. + allowed_operators: Sequence of allowed operators. Defaults to all Operators. enable_limit: Whether to enable the limit operator. Defaults to False. - **kwargs: + schema_prompt: Prompt for describing query schema. Should have string input + variables allowed_comparators and allowed_operators. + **kwargs: Arbitrary named params to pass to LLMChain. Returns: A LLMChain that can be used to construct queries. """ - prompt = _get_prompt( + prompt = get_query_constructor_prompt( document_contents, attribute_info, examples=examples, allowed_comparators=allowed_comparators, allowed_operators=allowed_operators, enable_limit=enable_limit, + schema_prompt=schema_prompt, + ) + allowed_attributes = [] + for ainfo in attribute_info: + allowed_attributes.append( + ainfo.name if isinstance(ainfo, AttributeInfo) else ainfo["name"] + ) + output_parser = StructuredQueryOutputParser.from_components( + allowed_comparators=allowed_comparators, + allowed_operators=allowed_operators, + allowed_attributes=allowed_attributes, ) - return LLMChain( - llm=llm, prompt=prompt, output_parser=prompt.output_parser, **kwargs + # For backwards compatibility. + prompt.output_parser = output_parser + return LLMChain(llm=llm, prompt=prompt, output_parser=output_parser, **kwargs) + + +def load_query_constructor_runnable( + llm: BaseLanguageModel, + document_contents: str, + attribute_info: Sequence[Union[AttributeInfo, dict]], + *, + examples: Optional[Sequence] = None, + allowed_comparators: Sequence[Comparator] = tuple(Comparator), + allowed_operators: Sequence[Operator] = tuple(Operator), + enable_limit: bool = False, + schema_prompt: Optional[BasePromptTemplate] = None, + fix_invalid: bool = False, + **kwargs: Any, +) -> Runnable: + """Load a query constructor runnable chain. + + Args: + llm: BaseLanguageModel to use for the chain. + document_contents: The contents of the document to be queried. + attribute_info: Sequence of attributes in the document. + examples: Optional list of examples to use for the chain. + allowed_comparators: Sequence of allowed comparators. Defaults to all + Comparators. + allowed_operators: Sequence of allowed operators. Defaults to all Operators. + enable_limit: Whether to enable the limit operator. Defaults to False. + schema_prompt: Prompt for describing query schema. Should have string input + variables allowed_comparators and allowed_operators. + fix_invalid: Whether to fix invalid filter directives by ignoring invalid + operators, comparators and attributes. + **kwargs: Additional named params to pass to FewShotPromptTemplate init. + + Returns: + A Runnable that can be used to construct queries. + """ + prompt = get_query_constructor_prompt( + document_contents, + attribute_info, + examples=examples, + allowed_comparators=allowed_comparators, + allowed_operators=allowed_operators, + enable_limit=enable_limit, + schema_prompt=schema_prompt, + **kwargs, + ) + allowed_attributes = [] + for ainfo in attribute_info: + allowed_attributes.append( + ainfo.name if isinstance(ainfo, AttributeInfo) else ainfo["name"] + ) + output_parser = StructuredQueryOutputParser.from_components( + allowed_comparators=allowed_comparators, + allowed_operators=allowed_operators, + allowed_attributes=allowed_attributes, + fix_invalid=fix_invalid, ) + return prompt | llm | output_parser diff --git a/libs/langchain/langchain/chains/query_constructor/parser.py b/libs/langchain/langchain/chains/query_constructor/parser.py index 6cbc42b5e7..b7af3ae6c7 100644 --- a/libs/langchain/langchain/chains/query_constructor/parser.py +++ b/libs/langchain/langchain/chains/query_constructor/parser.py @@ -61,11 +61,13 @@ class QueryTransformer(Transformer): *args: Any, allowed_comparators: Optional[Sequence[Comparator]] = None, allowed_operators: Optional[Sequence[Operator]] = None, + allowed_attributes: Optional[Sequence[str]] = None, **kwargs: Any, ): super().__init__(*args, **kwargs) self.allowed_comparators = allowed_comparators self.allowed_operators = allowed_operators + self.allowed_attributes = allowed_attributes def program(self, *items: Any) -> tuple: return items @@ -73,6 +75,11 @@ class QueryTransformer(Transformer): def func_call(self, func_name: Any, args: list) -> FilterDirective: func = self._match_func_name(str(func_name)) if isinstance(func, Comparator): + if self.allowed_attributes and args[0] not in self.allowed_attributes: + raise ValueError( + f"Received invalid attributes {args[0]}. Allowed attributes are " + f"{self.allowed_attributes}" + ) return Comparison(comparator=func, attribute=args[0], value=args[1]) elif len(args) == 1 and func in (Operator.AND, Operator.OR): return args[0] @@ -134,6 +141,7 @@ class QueryTransformer(Transformer): def get_parser( allowed_comparators: Optional[Sequence[Comparator]] = None, allowed_operators: Optional[Sequence[Operator]] = None, + allowed_attributes: Optional[Sequence[str]] = None, ) -> Lark: """ Returns a parser for the query language. @@ -151,6 +159,8 @@ def get_parser( "Cannot import lark, please install it with 'pip install lark'." ) transformer = QueryTransformer( - allowed_comparators=allowed_comparators, allowed_operators=allowed_operators + allowed_comparators=allowed_comparators, + allowed_operators=allowed_operators, + allowed_attributes=allowed_attributes, ) return Lark(GRAMMAR, parser="lalr", transformer=transformer, start="program") diff --git a/libs/langchain/langchain/chains/query_constructor/prompt.py b/libs/langchain/langchain/chains/query_constructor/prompt.py index fb6a7901c5..ead4a5b54d 100644 --- a/libs/langchain/langchain/chains/query_constructor/prompt.py +++ b/libs/langchain/langchain/chains/query_constructor/prompt.py @@ -3,36 +3,31 @@ from langchain.prompts import PromptTemplate SONG_DATA_SOURCE = """\ ```json -{ +{{ "content": "Lyrics of a song", - "attributes": { - "artist": { + "attributes": {{ + "artist": {{ "type": "string", "description": "Name of the song artist" - }, - "length": { + }}, + "length": {{ "type": "integer", "description": "Length of the song in seconds" - }, - "genre": { + }}, + "genre": {{ "type": "string", "description": "The song genre, one of \"pop\", \"rock\" or \"rap\"" - } - } -} + }} + }} +}} ```\ -""".replace( - "{", "{{" -).replace( - "}", "}}" -) +""" FULL_ANSWER = """\ ```json {{ "query": "teenager love", - "filter": "and(or(eq(\\"artist\\", \\"Taylor Swift\\"), eq(\\"artist\\", \\"Katy Perry\\")), \ -lt(\\"length\\", 180), eq(\\"genre\\", \\"pop\\"))" + "filter": "and(or(eq(\\"artist\\", \\"Taylor Swift\\"), eq(\\"artist\\", \\"Katy Perry\\")), lt(\\"length\\", 180), eq(\\"genre\\", \\"pop\\"))" }} ```\ """ @@ -104,16 +99,24 @@ Structured Request: {structured_request} """ -EXAMPLE_PROMPT = PromptTemplate( - input_variables=["i", "data_source", "user_query", "structured_request"], - template=EXAMPLE_PROMPT_TEMPLATE, -) +EXAMPLE_PROMPT = PromptTemplate.from_template(EXAMPLE_PROMPT_TEMPLATE) + +USER_SPECIFIED_EXAMPLE_PROMPT = PromptTemplate.from_template( + """\ +<< Example {i}. >> +User Query: +{user_query} +Structured Request: +```json +{structured_request} +``` +""" +) DEFAULT_SCHEMA = """\ << Structured Request Schema >> -When responding use a markdown code snippet with a JSON object formatted in the \ -following schema: +When responding use a markdown code snippet with a JSON object formatted in the following schema: ```json {{{{ @@ -122,11 +125,9 @@ following schema: }}}} ``` -The query string should contain only text that is expected to match the contents of \ -documents. Any conditions in the filter should not be mentioned in the query as well. +The query string should contain only text that is expected to match the contents of documents. Any conditions in the filter should not be mentioned in the query as well. -A logical condition statement is composed of one or more comparison and logical \ -operation statements. +A logical condition statement is composed of one or more comparison and logical operation statements. A comparison statement takes the form: `comp(attr, val)`: - `comp` ({allowed_comparators}): comparator @@ -135,24 +136,20 @@ A comparison statement takes the form: `comp(attr, val)`: A logical operation statement takes the form `op(statement1, statement2, ...)`: - `op` ({allowed_operators}): logical operator -- `statement1`, `statement2`, ... (comparison statements or logical operation \ -statements): one or more statements to apply the operation to +- `statement1`, `statement2`, ... (comparison statements or logical operation statements): one or more statements to apply the operation to -Make sure that you only use the comparators and logical operators listed above and \ -no others. +Make sure that you only use the comparators and logical operators listed above and no others. Make sure that filters only refer to attributes that exist in the data source. Make sure that filters only use the attributed names with its function names if there are functions applied on them. Make sure that filters only use format `YYYY-MM-DD` when handling timestamp data typed values. -Make sure that filters take into account the descriptions of attributes and only make \ -comparisons that are feasible given the type of data being stored. -Make sure that filters are only used as needed. If there are no filters that should be \ -applied return "NO_FILTER" for the filter value.\ +Make sure that filters take into account the descriptions of attributes and only make comparisons that are feasible given the type of data being stored. +Make sure that filters are only used as needed. If there are no filters that should be applied return "NO_FILTER" for the filter value.\ """ +DEFAULT_SCHEMA_PROMPT = PromptTemplate.from_template(DEFAULT_SCHEMA) SCHEMA_WITH_LIMIT = """\ << Structured Request Schema >> -When responding use a markdown code snippet with a JSON object formatted in the \ -following schema: +When responding use a markdown code snippet with a JSON object formatted in the following schema: ```json {{{{ @@ -162,11 +159,9 @@ following schema: }}}} ``` -The query string should contain only text that is expected to match the contents of \ -documents. Any conditions in the filter should not be mentioned in the query as well. +The query string should contain only text that is expected to match the contents of documents. Any conditions in the filter should not be mentioned in the query as well. -A logical condition statement is composed of one or more comparison and logical \ -operation statements. +A logical condition statement is composed of one or more comparison and logical operation statements. A comparison statement takes the form: `comp(attr, val)`: - `comp` ({allowed_comparators}): comparator @@ -175,20 +170,17 @@ A comparison statement takes the form: `comp(attr, val)`: A logical operation statement takes the form `op(statement1, statement2, ...)`: - `op` ({allowed_operators}): logical operator -- `statement1`, `statement2`, ... (comparison statements or logical operation \ -statements): one or more statements to apply the operation to +- `statement1`, `statement2`, ... (comparison statements or logical operation statements): one or more statements to apply the operation to -Make sure that you only use the comparators and logical operators listed above and \ -no others. +Make sure that you only use the comparators and logical operators listed above and no others. Make sure that filters only refer to attributes that exist in the data source. Make sure that filters only use the attributed names with its function names if there are functions applied on them. Make sure that filters only use format `YYYY-MM-DD` when handling timestamp data typed values. -Make sure that filters take into account the descriptions of attributes and only make \ -comparisons that are feasible given the type of data being stored. -Make sure that filters are only used as needed. If there are no filters that should be \ -applied return "NO_FILTER" for the filter value. -Make sure the `limit` is always an int value. It is an optional parameter so leave it blank if it is does not make sense. +Make sure that filters take into account the descriptions of attributes and only make comparisons that are feasible given the type of data being stored. +Make sure that filters are only used as needed. If there are no filters that should be applied return "NO_FILTER" for the filter value. +Make sure the `limit` is always an int value. It is an optional parameter so leave it blank if it does not make sense. """ +SCHEMA_WITH_LIMIT_PROMPT = PromptTemplate.from_template(SCHEMA_WITH_LIMIT) DEFAULT_PREFIX = """\ Your goal is to structure the user's query to match the request schema provided below. @@ -196,6 +188,20 @@ Your goal is to structure the user's query to match the request schema provided {schema}\ """ +PREFIX_WITH_DATA_SOURCE = ( + DEFAULT_PREFIX + + """ + +<< Data Source >> +```json +{{{{ + "content": "{content}", + "attributes": {attributes} +}}}} +``` +""" +) + DEFAULT_SUFFIX = """\ << Example {i}. >> Data Source: @@ -211,3 +217,11 @@ User Query: Structured Request: """ + +SUFFIX_WITHOUT_DATA_SOURCE = """\ +<< Example {i}. >> +User Query: +{{query}} + +Structured Request: +""" diff --git a/libs/langchain/langchain/retrievers/self_query/base.py b/libs/langchain/langchain/retrievers/self_query/base.py index 0514f6c50e..adb67bda63 100644 --- a/libs/langchain/langchain/retrievers/self_query/base.py +++ b/libs/langchain/langchain/retrievers/self_query/base.py @@ -1,13 +1,12 @@ """Retriever that generates and executes structured queries over its own data source.""" import logging -from typing import Any, Dict, List, Optional, Tuple, Type, cast +from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, ) -from langchain.chains import LLMChain -from langchain.chains.query_constructor.base import load_query_constructor_chain +from langchain.chains.query_constructor.base import load_query_constructor_runnable from langchain.chains.query_constructor.ir import StructuredQuery, Visitor from langchain.chains.query_constructor.schema import AttributeInfo from langchain.pydantic_v1 import BaseModel, Field, root_validator @@ -27,6 +26,7 @@ from langchain.retrievers.self_query.vectara import VectaraTranslator from langchain.retrievers.self_query.weaviate import WeaviateTranslator from langchain.schema import BaseRetriever, Document from langchain.schema.language_model import BaseLanguageModel +from langchain.schema.runnable import Runnable from langchain.vectorstores import ( Chroma, DashVector, @@ -86,8 +86,10 @@ class SelfQueryRetriever(BaseRetriever, BaseModel): vectorstore: VectorStore """The underlying vector store from which documents will be retrieved.""" - llm_chain: LLMChain - """The LLMChain for generating the vector store queries.""" + query_constructor: Runnable[dict, StructuredQuery] = Field(alias="llm_chain") + """The query constructor chain for generating the vector store queries. + + llm_chain is legacy name kept for backwards compatibility.""" search_type: str = "similarity" """The search type to perform on the vector store.""" search_kwargs: dict = Field(default_factory=dict) @@ -103,6 +105,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel): """Configuration for this pydantic object.""" arbitrary_types_allowed = True + allow_population_by_field_name = True @root_validator(pre=True) def validate_translator(cls, values: Dict) -> Dict: @@ -113,23 +116,10 @@ class SelfQueryRetriever(BaseRetriever, BaseModel): ) return values - def _get_structured_query( - self, inputs: Dict[str, Any], run_manager: CallbackManagerForRetrieverRun - ) -> StructuredQuery: - structured_query = cast( - StructuredQuery, - self.llm_chain.predict(callbacks=run_manager.get_child(), **inputs), - ) - return structured_query - - async def _aget_structured_query( - self, inputs: Dict[str, Any], run_manager: AsyncCallbackManagerForRetrieverRun - ) -> StructuredQuery: - structured_query = cast( - StructuredQuery, - await self.llm_chain.apredict(callbacks=run_manager.get_child(), **inputs), - ) - return structured_query + @property + def llm_chain(self) -> Runnable: + """llm_chain is legacy name kept for backwards compatibility.""" + return self.query_constructor def _prepare_query( self, query: str, structured_query: StructuredQuery @@ -167,8 +157,9 @@ class SelfQueryRetriever(BaseRetriever, BaseModel): Returns: List of relevant documents """ - inputs = self.llm_chain.prep_inputs({"query": query}) - structured_query = self._get_structured_query(inputs, run_manager) + structured_query = self.query_constructor.invoke( + {"query": query}, config={"callbacks": run_manager.get_child()} + ) if self.verbose: logger.info(f"Generated Query: {structured_query}") new_query, search_kwargs = self._prepare_query(query, structured_query) @@ -186,8 +177,9 @@ class SelfQueryRetriever(BaseRetriever, BaseModel): Returns: List of relevant documents """ - inputs = self.llm_chain.prep_inputs({"query": query}) - structured_query = await self._aget_structured_query(inputs, run_manager) + structured_query = await self.query_constructor.ainvoke( + {"query": query}, config={"callbacks": run_manager.get_child()} + ) if self.verbose: logger.info(f"Generated Query: {structured_query}") new_query, search_kwargs = self._prepare_query(query, structured_query) @@ -200,7 +192,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel): llm: BaseLanguageModel, vectorstore: VectorStore, document_contents: str, - metadata_field_info: List[AttributeInfo], + metadata_field_info: Sequence[Union[AttributeInfo, dict]], structured_query_translator: Optional[Visitor] = None, chain_kwargs: Optional[Dict] = None, enable_limit: bool = False, @@ -219,7 +211,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel): chain_kwargs[ "allowed_operators" ] = structured_query_translator.allowed_operators - llm_chain = load_query_constructor_chain( + query_constructor = load_query_constructor_runnable( llm, document_contents, metadata_field_info, @@ -227,7 +219,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel): **chain_kwargs, ) return cls( - llm_chain=llm_chain, + query_constructor=query_constructor, vectorstore=vectorstore, use_original_query=use_original_query, structured_query_translator=structured_query_translator,