@ -2,11 +2,14 @@
from __future__ import annotations
from __future__ import annotations
import json
import json
from typing import Any , Callable , List , Optional , Sequence
from typing import Any , Callable , List , Optional , Sequence , Tuple , Union , cast
from langchain . chains . llm import LLMChain
from langchain . chains . llm import LLMChain
from langchain . chains . query_constructor . ir import (
from langchain . chains . query_constructor . ir import (
Comparator ,
Comparator ,
Comparison ,
FilterDirective ,
Operation ,
Operator ,
Operator ,
StructuredQuery ,
StructuredQuery ,
)
)
@ -14,17 +17,21 @@ from langchain.chains.query_constructor.parser import get_parser
from langchain . chains . query_constructor . prompt import (
from langchain . chains . query_constructor . prompt import (
DEFAULT_EXAMPLES ,
DEFAULT_EXAMPLES ,
DEFAULT_PREFIX ,
DEFAULT_PREFIX ,
DEFAULT_SCHEMA ,
DEFAULT_SCHEMA _PROMPT ,
DEFAULT_SUFFIX ,
DEFAULT_SUFFIX ,
EXAMPLE_PROMPT ,
EXAMPLE_PROMPT ,
EXAMPLES_WITH_LIMIT ,
EXAMPLES_WITH_LIMIT ,
SCHEMA_WITH_LIMIT ,
PREFIX_WITH_DATA_SOURCE ,
SCHEMA_WITH_LIMIT_PROMPT ,
SUFFIX_WITHOUT_DATA_SOURCE ,
USER_SPECIFIED_EXAMPLE_PROMPT ,
)
)
from langchain . chains . query_constructor . schema import AttributeInfo
from langchain . chains . query_constructor . schema import AttributeInfo
from langchain . output_parsers . json import parse_and_check_json_markdown
from langchain . output_parsers . json import parse_and_check_json_markdown
from langchain . prompts . few_shot import FewShotPromptTemplate
from langchain . prompts . few_shot import FewShotPromptTemplate
from langchain . schema import BaseOutputParser , BasePromptTemplate , OutputParserException
from langchain . schema import BaseOutputParser , BasePromptTemplate , OutputParserException
from langchain . schema . language_model import BaseLanguageModel
from langchain . schema . language_model import BaseLanguageModel
from langchain . schema . runnable import Runnable
class StructuredQueryOutputParser ( BaseOutputParser [ StructuredQuery ] ) :
class StructuredQueryOutputParser ( BaseOutputParser [ StructuredQuery ] ) :
@ -59,6 +66,8 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
cls ,
cls ,
allowed_comparators : Optional [ Sequence [ Comparator ] ] = None ,
allowed_comparators : Optional [ Sequence [ Comparator ] ] = None ,
allowed_operators : Optional [ Sequence [ Operator ] ] = None ,
allowed_operators : Optional [ Sequence [ Operator ] ] = None ,
allowed_attributes : Optional [ Sequence [ str ] ] = None ,
fix_invalid : bool = False ,
) - > StructuredQueryOutputParser :
) - > StructuredQueryOutputParser :
"""
"""
Create a structured query output parser from components .
Create a structured query output parser from components .
@ -70,13 +79,73 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
Returns :
Returns :
a structured query output parser
a structured query output parser
"""
"""
ast_parser = get_parser (
ast_parse : Callable
allowed_comparators = allowed_comparators , allowed_operators = allowed_operators
if fix_invalid :
def ast_parse ( raw_filter : str ) - > Optional [ FilterDirective ] :
filter = cast ( Optional [ FilterDirective ] , get_parser ( ) . parse ( raw_filter ) )
fixed = fix_filter_directive (
filter ,
allowed_comparators = allowed_comparators ,
allowed_operators = allowed_operators ,
allowed_attributes = allowed_attributes ,
)
return fixed
else :
ast_parse = get_parser (
allowed_comparators = allowed_comparators ,
allowed_operators = allowed_operators ,
allowed_attributes = allowed_attributes ,
) . parse
return cls ( ast_parse = ast_parse )
def fix_filter_directive (
filter : Optional [ FilterDirective ] ,
* ,
allowed_comparators : Optional [ Sequence [ Comparator ] ] = None ,
allowed_operators : Optional [ Sequence [ Operator ] ] = None ,
allowed_attributes : Optional [ Sequence [ str ] ] = None ,
) - > Optional [ FilterDirective ] :
if (
not ( allowed_comparators or allowed_operators or allowed_attributes )
) or not filter :
return filter
elif isinstance ( filter , Comparison ) :
if allowed_comparators and filter . comparator not in allowed_comparators :
return None
if allowed_attributes and filter . attribute not in allowed_attributes :
return None
return filter
elif isinstance ( filter , Operation ) :
if allowed_operators and filter . operator not in allowed_operators :
return None
args = [
fix_filter_directive (
arg ,
allowed_comparators = allowed_comparators ,
allowed_operators = allowed_operators ,
allowed_attributes = allowed_attributes ,
)
for arg in filter . arguments
]
args = [ arg for arg in args if arg is not None ]
if not args :
return None
elif len ( args ) == 1 and filter . operator in ( Operator . AND , Operator . OR ) :
return args [ 0 ]
else :
return Operation (
operator = filter . operator ,
arguments = args ,
)
)
return cls ( ast_parse = ast_parser . parse )
else :
return filter
def _format_attribute_info ( info : Sequence [ AttributeInfo ] ) - > str :
def _format_attribute_info ( info : Sequence [ Union[ AttributeInfo, dict ] ] ) - > str :
info_dicts = { }
info_dicts = { }
for i in info :
for i in info :
i_dict = dict ( i )
i_dict = dict ( i )
@ -84,56 +153,90 @@ def _format_attribute_info(info: Sequence[AttributeInfo]) -> str:
return json . dumps ( info_dicts , indent = 4 ) . replace ( " { " , " {{ " ) . replace ( " } " , " }} " )
return json . dumps ( info_dicts , indent = 4 ) . replace ( " { " , " {{ " ) . replace ( " } " , " }} " )
def _get_prompt (
def construct_examples ( input_output_pairs : Sequence [ Tuple [ str , dict ] ] ) - > List [ dict ] :
examples = [ ]
for i , ( _input , output ) in enumerate ( input_output_pairs ) :
structured_request = (
json . dumps ( output , indent = 4 ) . replace ( " { " , " {{ " ) . replace ( " } " , " }} " )
)
example = {
" i " : i + 1 ,
" user_query " : _input ,
" structured_request " : structured_request ,
}
examples . append ( example )
return examples
def get_query_constructor_prompt (
document_contents : str ,
document_contents : str ,
attribute_info : Sequence [ AttributeInfo ] ,
attribute_info : Sequence [ Union [ AttributeInfo , dict ] ] ,
examples : Optional [ List ] = None ,
* ,
allowed_comparators : Optional [ Sequence [ Comparator ] ] = None ,
examples : Optional [ Sequence ] = None ,
allowed_operators : Optional [ Sequence [ Operator ] ] = None ,
allowed_comparators : Sequence [ Comparator ] = tuple ( Comparator ) ,
allowed_operators : Sequence [ Operator ] = tuple ( Operator ) ,
enable_limit : bool = False ,
enable_limit : bool = False ,
schema_prompt : Optional [ BasePromptTemplate ] = None ,
* * kwargs : Any ,
) - > BasePromptTemplate :
) - > BasePromptTemplate :
""" Create query construction prompt.
Args :
document_contents : The contents of the document to be queried .
attribute_info : A list of AttributeInfo objects describing
the attributes of the document .
examples : Optional list of examples to use for the chain .
allowed_comparators : Sequence of allowed comparators .
allowed_operators : Sequence of allowed operators .
enable_limit : Whether to enable the limit operator . Defaults to False .
schema_prompt : Prompt for describing query schema . Should have string input
variables allowed_comparators and allowed_operators .
* * kwargs : Additional named params to pass to FewShotPromptTemplate init .
"""
default_schema_prompt = (
SCHEMA_WITH_LIMIT_PROMPT if enable_limit else DEFAULT_SCHEMA_PROMPT
)
schema_prompt = schema_prompt or default_schema_prompt
attribute_str = _format_attribute_info ( attribute_info )
attribute_str = _format_attribute_info ( attribute_info )
allowed_comparators = allowed_comparators or list ( Comparator )
schema = schema_prompt . format (
allowed_operators = allowed_operators or list ( Operator )
if enable_limit :
schema = SCHEMA_WITH_LIMIT . format (
allowed_comparators = " | " . join ( allowed_comparators ) ,
allowed_comparators = " | " . join ( allowed_comparators ) ,
allowed_operators = " | " . join ( allowed_operators ) ,
allowed_operators = " | " . join ( allowed_operators ) ,
)
)
if examples and isinstance ( examples [ 0 ] , tuple ) :
examples = examples or EXAMPLES_WITH_LIMIT
examples = construct_examples ( examples )
example_prompt = USER_SPECIFIED_EXAMPLE_PROMPT
prefix = PREFIX_WITH_DATA_SOURCE . format (
schema = schema , content = document_contents , attributes = attribute_str
)
suffix = SUFFIX_WITHOUT_DATA_SOURCE . format ( i = len ( examples ) + 1 )
else :
else :
schema = DEFAULT_SCHEMA . format (
examples = examples or (
allowed_comparators = " | " . join ( allowed_comparators ) ,
EXAMPLES_WITH_LIMIT if enable_limit else DEFAULT_EXAMPLES
allowed_operators = " | " . join ( allowed_operators ) ,
)
)
example_prompt = EXAMPLE_PROMPT
examples = examples or DEFAULT_EXAMPLES
prefix = DEFAULT_PREFIX . format ( schema = schema )
prefix = DEFAULT_PREFIX . format ( schema = schema )
suffix = DEFAULT_SUFFIX . format (
suffix = DEFAULT_SUFFIX . format (
i = len ( examples ) + 1 , content = document_contents , attributes = attribute_str
i = len ( examples ) + 1 , content = document_contents , attributes = attribute_str
)
)
output_parser = StructuredQueryOutputParser . from_components (
allowed_comparators = allowed_comparators , allowed_operators = allowed_operators
)
return FewShotPromptTemplate (
return FewShotPromptTemplate (
examples = examples ,
examples = list ( examples ) ,
example_prompt = EXAMPLE_PROMPT ,
example_prompt = example_prompt ,
input_variables = [ " query " ] ,
input_variables = [ " query " ] ,
suffix = suffix ,
suffix = suffix ,
prefix = prefix ,
prefix = prefix ,
output_parser = output_parser ,
* * kwargs ,
)
)
def load_query_constructor_chain (
def load_query_constructor_chain (
llm : BaseLanguageModel ,
llm : BaseLanguageModel ,
document_contents : str ,
document_contents : str ,
attribute_info : List[ AttributeInfo ] ,
attribute_info : Sequence[ Union [ AttributeInfo , dict ] ] ,
examples : Optional [ List ] = None ,
examples : Optional [ List ] = None ,
allowed_comparators : Optional[ Sequence [ Comparator ] ] = None ,
allowed_comparators : Sequence[ Comparator ] = tuple ( Comparator ) ,
allowed_operators : Optional[ Sequence [ Operator ] ] = None ,
allowed_operators : Sequence[ Operator ] = tuple ( Operator ) ,
enable_limit : bool = False ,
enable_limit : bool = False ,
schema_prompt : Optional [ BasePromptTemplate ] = None ,
* * kwargs : Any ,
* * kwargs : Any ,
) - > LLMChain :
) - > LLMChain :
""" Load a query constructor chain.
""" Load a query constructor chain.
@ -141,25 +244,95 @@ def load_query_constructor_chain(
Args :
Args :
llm : BaseLanguageModel to use for the chain .
llm : BaseLanguageModel to use for the chain .
document_contents : The contents of the document to be queried .
document_contents : The contents of the document to be queried .
attribute_info : A list of AttributeInfo objects describing
attribute_info : Sequence of attributes in the document .
the attributes of the document .
examples : Optional list of examples to use for the chain .
examples : Optional list of examples to use for the chain .
allowed_comparators : An optional list of allowed comparators .
allowed_comparators : Sequence of allowed comparators . Defaults to all
allowed_operators : An optional list of allowed operators .
Comparators .
allowed_operators : Sequence of allowed operators . Defaults to all Operators .
enable_limit : Whether to enable the limit operator . Defaults to False .
enable_limit : Whether to enable the limit operator . Defaults to False .
* * kwargs :
schema_prompt : Prompt for describing query schema . Should have string input
variables allowed_comparators and allowed_operators .
* * kwargs : Arbitrary named params to pass to LLMChain .
Returns :
Returns :
A LLMChain that can be used to construct queries .
A LLMChain that can be used to construct queries .
"""
"""
prompt = _ get_prompt(
prompt = get_query_constructor _prompt(
document_contents ,
document_contents ,
attribute_info ,
attribute_info ,
examples = examples ,
examples = examples ,
allowed_comparators = allowed_comparators ,
allowed_comparators = allowed_comparators ,
allowed_operators = allowed_operators ,
allowed_operators = allowed_operators ,
enable_limit = enable_limit ,
enable_limit = enable_limit ,
schema_prompt = schema_prompt ,
)
allowed_attributes = [ ]
for ainfo in attribute_info :
allowed_attributes . append (
ainfo . name if isinstance ( ainfo , AttributeInfo ) else ainfo [ " name " ]
)
output_parser = StructuredQueryOutputParser . from_components (
allowed_comparators = allowed_comparators ,
allowed_operators = allowed_operators ,
allowed_attributes = allowed_attributes ,
)
)
return LLMChain (
# For backwards compatibility.
llm = llm , prompt = prompt , output_parser = prompt . output_parser , * * kwargs
prompt . output_parser = output_parser
return LLMChain ( llm = llm , prompt = prompt , output_parser = output_parser , * * kwargs )
def load_query_constructor_runnable (
llm : BaseLanguageModel ,
document_contents : str ,
attribute_info : Sequence [ Union [ AttributeInfo , dict ] ] ,
* ,
examples : Optional [ Sequence ] = None ,
allowed_comparators : Sequence [ Comparator ] = tuple ( Comparator ) ,
allowed_operators : Sequence [ Operator ] = tuple ( Operator ) ,
enable_limit : bool = False ,
schema_prompt : Optional [ BasePromptTemplate ] = None ,
fix_invalid : bool = False ,
* * kwargs : Any ,
) - > Runnable :
""" Load a query constructor runnable chain.
Args :
llm : BaseLanguageModel to use for the chain .
document_contents : The contents of the document to be queried .
attribute_info : Sequence of attributes in the document .
examples : Optional list of examples to use for the chain .
allowed_comparators : Sequence of allowed comparators . Defaults to all
Comparators .
allowed_operators : Sequence of allowed operators . Defaults to all Operators .
enable_limit : Whether to enable the limit operator . Defaults to False .
schema_prompt : Prompt for describing query schema . Should have string input
variables allowed_comparators and allowed_operators .
fix_invalid : Whether to fix invalid filter directives by ignoring invalid
operators , comparators and attributes .
* * kwargs : Additional named params to pass to FewShotPromptTemplate init .
Returns :
A Runnable that can be used to construct queries .
"""
prompt = get_query_constructor_prompt (
document_contents ,
attribute_info ,
examples = examples ,
allowed_comparators = allowed_comparators ,
allowed_operators = allowed_operators ,
enable_limit = enable_limit ,
schema_prompt = schema_prompt ,
* * kwargs ,
)
allowed_attributes = [ ]
for ainfo in attribute_info :
allowed_attributes . append (
ainfo . name if isinstance ( ainfo , AttributeInfo ) else ainfo [ " name " ]
)
output_parser = StructuredQueryOutputParser . from_components (
allowed_comparators = allowed_comparators ,
allowed_operators = allowed_operators ,
allowed_attributes = allowed_attributes ,
fix_invalid = fix_invalid ,
)
)
return prompt | llm | output_parser