mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
fix openai structured chain with pydantic (#7622)
should return pydantic class
This commit is contained in:
parent
ee70d4a0cd
commit
1d4db1327a
@ -9,6 +9,7 @@ from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.output_parsers.openai_functions import (
|
||||
JsonOutputFunctionsParser,
|
||||
PydanticAttrOutputFunctionsParser,
|
||||
PydanticOutputFunctionsParser,
|
||||
)
|
||||
from langchain.prompts import BasePromptTemplate
|
||||
@ -318,17 +319,26 @@ def create_structured_output_chain(
|
||||
chain.run("Harry was a chubby brown beagle who loved chicken")
|
||||
# -> Dog(name="Harry", color="brown", fav_food="chicken")
|
||||
""" # noqa: E501
|
||||
function: Dict = {
|
||||
if isinstance(output_schema, dict):
|
||||
function: Any = {
|
||||
"name": "output_formatter",
|
||||
"description": (
|
||||
"Output formatter. Should always be used to format your response to the"
|
||||
" user."
|
||||
),
|
||||
"parameters": output_schema,
|
||||
}
|
||||
parameters = (
|
||||
output_schema if isinstance(output_schema, dict) else output_schema.schema()
|
||||
else:
|
||||
|
||||
class _OutputFormatter(BaseModel):
|
||||
"""Output formatter. Should always be used to format your response to the user.""" # noqa: E501
|
||||
|
||||
output: output_schema # type: ignore
|
||||
|
||||
function = _OutputFormatter
|
||||
output_parser = output_parser or PydanticAttrOutputFunctionsParser(
|
||||
pydantic_schema=_OutputFormatter, attr_name="output"
|
||||
)
|
||||
function["parameters"] = parameters
|
||||
return create_openai_fn_chain(
|
||||
[function], llm, prompt, output_parser=output_parser, **kwargs
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user