|
|
|
@ -18,6 +18,8 @@ from langchain.chains.query_constructor.prompt import (
|
|
|
|
|
DEFAULT_SCHEMA,
|
|
|
|
|
DEFAULT_SUFFIX,
|
|
|
|
|
EXAMPLE_PROMPT,
|
|
|
|
|
EXAMPLES_WITH_LIMIT,
|
|
|
|
|
SCHEMA_WITH_LIMIT,
|
|
|
|
|
)
|
|
|
|
|
from langchain.chains.query_constructor.schema import AttributeInfo
|
|
|
|
|
from langchain.output_parsers.structured import parse_json_markdown
|
|
|
|
@ -38,7 +40,11 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
|
|
|
|
|
parsed["filter"] = None
|
|
|
|
|
else:
|
|
|
|
|
parsed["filter"] = self.ast_parse(parsed["filter"])
|
|
|
|
|
return StructuredQuery(query=parsed["query"], filter=parsed["filter"])
|
|
|
|
|
return StructuredQuery(
|
|
|
|
|
query=parsed["query"],
|
|
|
|
|
filter=parsed["filter"],
|
|
|
|
|
limit=parsed.get("limit"),
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
raise OutputParserException(
|
|
|
|
|
f"Parsing text\n{text}\n raised following error:\n{e}"
|
|
|
|
@ -70,15 +76,25 @@ def _get_prompt(
|
|
|
|
|
examples: Optional[List] = None,
|
|
|
|
|
allowed_comparators: Optional[Sequence[Comparator]] = None,
|
|
|
|
|
allowed_operators: Optional[Sequence[Operator]] = None,
|
|
|
|
|
enable_limit: bool = False,
|
|
|
|
|
) -> BasePromptTemplate:
|
|
|
|
|
attribute_str = _format_attribute_info(attribute_info)
|
|
|
|
|
examples = examples or DEFAULT_EXAMPLES
|
|
|
|
|
allowed_comparators = allowed_comparators or list(Comparator)
|
|
|
|
|
allowed_operators = allowed_operators or list(Operator)
|
|
|
|
|
if enable_limit:
|
|
|
|
|
schema = SCHEMA_WITH_LIMIT.format(
|
|
|
|
|
allowed_comparators=" | ".join(allowed_comparators),
|
|
|
|
|
allowed_operators=" | ".join(allowed_operators),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
examples = examples or EXAMPLES_WITH_LIMIT
|
|
|
|
|
else:
|
|
|
|
|
schema = DEFAULT_SCHEMA.format(
|
|
|
|
|
allowed_comparators=" | ".join(allowed_comparators),
|
|
|
|
|
allowed_operators=" | ".join(allowed_operators),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
examples = examples or DEFAULT_EXAMPLES
|
|
|
|
|
prefix = DEFAULT_PREFIX.format(schema=schema)
|
|
|
|
|
suffix = DEFAULT_SUFFIX.format(
|
|
|
|
|
i=len(examples) + 1, content=document_contents, attributes=attribute_str
|
|
|
|
@ -87,7 +103,7 @@ def _get_prompt(
|
|
|
|
|
allowed_comparators=allowed_comparators, allowed_operators=allowed_operators
|
|
|
|
|
)
|
|
|
|
|
return FewShotPromptTemplate(
|
|
|
|
|
examples=DEFAULT_EXAMPLES,
|
|
|
|
|
examples=examples,
|
|
|
|
|
example_prompt=EXAMPLE_PROMPT,
|
|
|
|
|
input_variables=["query"],
|
|
|
|
|
suffix=suffix,
|
|
|
|
@ -103,6 +119,7 @@ def load_query_constructor_chain(
|
|
|
|
|
examples: Optional[List] = None,
|
|
|
|
|
allowed_comparators: Optional[Sequence[Comparator]] = None,
|
|
|
|
|
allowed_operators: Optional[Sequence[Operator]] = None,
|
|
|
|
|
enable_limit: bool = False,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> LLMChain:
|
|
|
|
|
prompt = _get_prompt(
|
|
|
|
@ -111,5 +128,6 @@ def load_query_constructor_chain(
|
|
|
|
|
examples=examples,
|
|
|
|
|
allowed_comparators=allowed_comparators,
|
|
|
|
|
allowed_operators=allowed_operators,
|
|
|
|
|
enable_limit=enable_limit,
|
|
|
|
|
)
|
|
|
|
|
return LLMChain(llm=llm, prompt=prompt, **kwargs)
|
|
|
|
|