|
|
|
@ -9,6 +9,7 @@ from langchain.base_language import BaseLanguageModel
|
|
|
|
|
from langchain.chains import LLMChain
|
|
|
|
|
from langchain.output_parsers.openai_functions import (
|
|
|
|
|
JsonOutputFunctionsParser,
|
|
|
|
|
PydanticAttrOutputFunctionsParser,
|
|
|
|
|
PydanticOutputFunctionsParser,
|
|
|
|
|
)
|
|
|
|
|
from langchain.prompts import BasePromptTemplate
|
|
|
|
@ -318,17 +319,26 @@ def create_structured_output_chain(
|
|
|
|
|
chain.run("Harry was a chubby brown beagle who loved chicken")
|
|
|
|
|
# -> Dog(name="Harry", color="brown", fav_food="chicken")
|
|
|
|
|
""" # noqa: E501
|
|
|
|
|
function: Dict = {
|
|
|
|
|
"name": "output_formatter",
|
|
|
|
|
"description": (
|
|
|
|
|
"Output formatter. Should always be used to format your response to the"
|
|
|
|
|
" user."
|
|
|
|
|
),
|
|
|
|
|
}
|
|
|
|
|
parameters = (
|
|
|
|
|
output_schema if isinstance(output_schema, dict) else output_schema.schema()
|
|
|
|
|
)
|
|
|
|
|
function["parameters"] = parameters
|
|
|
|
|
if isinstance(output_schema, dict):
|
|
|
|
|
function: Any = {
|
|
|
|
|
"name": "output_formatter",
|
|
|
|
|
"description": (
|
|
|
|
|
"Output formatter. Should always be used to format your response to the"
|
|
|
|
|
" user."
|
|
|
|
|
),
|
|
|
|
|
"parameters": output_schema,
|
|
|
|
|
}
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
|
|
class _OutputFormatter(BaseModel):
|
|
|
|
|
"""Output formatter. Should always be used to format your response to the user.""" # noqa: E501
|
|
|
|
|
|
|
|
|
|
output: output_schema # type: ignore
|
|
|
|
|
|
|
|
|
|
function = _OutputFormatter
|
|
|
|
|
output_parser = output_parser or PydanticAttrOutputFunctionsParser(
|
|
|
|
|
pydantic_schema=_OutputFormatter, attr_name="output"
|
|
|
|
|
)
|
|
|
|
|
return create_openai_fn_chain(
|
|
|
|
|
[function], llm, prompt, output_parser=output_parser, **kwargs
|
|
|
|
|
)
|
|
|
|
|