fix openai structured chain with pydantic (#7622)

should return pydantic class
This commit is contained in:
Bagatur 2023-07-12 23:46:13 -04:00 committed by GitHub
parent ee70d4a0cd
commit 1d4db1327a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -9,6 +9,7 @@ from langchain.base_language import BaseLanguageModel
from langchain.chains import LLMChain
from langchain.output_parsers.openai_functions import (
JsonOutputFunctionsParser,
PydanticAttrOutputFunctionsParser,
PydanticOutputFunctionsParser,
)
from langchain.prompts import BasePromptTemplate
@ -318,17 +319,26 @@ def create_structured_output_chain(
chain.run("Harry was a chubby brown beagle who loved chicken")
# -> Dog(name="Harry", color="brown", fav_food="chicken")
""" # noqa: E501
function: Dict = {
if isinstance(output_schema, dict):
function: Any = {
"name": "output_formatter",
"description": (
"Output formatter. Should always be used to format your response to the"
" user."
),
"parameters": output_schema,
}
parameters = (
output_schema if isinstance(output_schema, dict) else output_schema.schema()
else:
class _OutputFormatter(BaseModel):
"""Output formatter. Should always be used to format your response to the user.""" # noqa: E501
output: output_schema # type: ignore
function = _OutputFormatter
output_parser = output_parser or PydanticAttrOutputFunctionsParser(
pydantic_schema=_OutputFormatter, attr_name="output"
)
function["parameters"] = parameters
return create_openai_fn_chain(
[function], llm, prompt, output_parser=output_parser, **kwargs
)