expose output key to create_openai_fn_chain (#9155)

I quick change to allow the output key of create_openai_fn_chain to
optionally be changed.

@baskaryan

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/9287/head
Kenny 1 year ago committed by GitHub
parent b9ca5cc5ea
commit 74a64cfbab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -192,6 +192,7 @@ def create_openai_fn_chain(
llm: BaseLanguageModel, llm: BaseLanguageModel,
prompt: BasePromptTemplate, prompt: BasePromptTemplate,
*, *,
output_key: str = "function",
output_parser: Optional[BaseLLMOutputParser] = None, output_parser: Optional[BaseLLMOutputParser] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMChain: ) -> LLMChain:
@ -210,6 +211,7 @@ def create_openai_fn_chain(
pydantic.BaseModels for arguments. pydantic.BaseModels for arguments.
llm: Language model to use, assumed to support the OpenAI function-calling API. llm: Language model to use, assumed to support the OpenAI function-calling API.
prompt: BasePromptTemplate to pass to the model. prompt: BasePromptTemplate to pass to the model.
output_key: The key to use when returning the output in LLMChain.__call__.
output_parser: BaseLLMOutputParser to use for parsing model outputs. By default output_parser: BaseLLMOutputParser to use for parsing model outputs. By default
will be inferred from the function types. If pydantic.BaseModels are passed will be inferred from the function types. If pydantic.BaseModels are passed
in, then the OutputParser will try to parse outputs using those. Otherwise in, then the OutputParser will try to parse outputs using those. Otherwise
@ -274,7 +276,7 @@ def create_openai_fn_chain(
prompt=prompt, prompt=prompt,
output_parser=output_parser, output_parser=output_parser,
llm_kwargs=llm_kwargs, llm_kwargs=llm_kwargs,
output_key="function", output_key=output_key,
**kwargs, **kwargs,
) )
return llm_chain return llm_chain
@ -285,6 +287,7 @@ def create_structured_output_chain(
llm: BaseLanguageModel, llm: BaseLanguageModel,
prompt: BasePromptTemplate, prompt: BasePromptTemplate,
*, *,
output_key: str = "function",
output_parser: Optional[BaseLLMOutputParser] = None, output_parser: Optional[BaseLLMOutputParser] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMChain: ) -> LLMChain:
@ -297,6 +300,7 @@ def create_structured_output_chain(
the schema represents and descriptions for the parameters. the schema represents and descriptions for the parameters.
llm: Language model to use, assumed to support the OpenAI function-calling API. llm: Language model to use, assumed to support the OpenAI function-calling API.
prompt: BasePromptTemplate to pass to the model. prompt: BasePromptTemplate to pass to the model.
output_key: The key to use when returning the output in LLMChain.__call__.
output_parser: BaseLLMOutputParser to use for parsing model outputs. By default output_parser: BaseLLMOutputParser to use for parsing model outputs. By default
will be inferred from the function types. If pydantic.BaseModels are passed will be inferred from the function types. If pydantic.BaseModels are passed
in, then the OutputParser will try to parse outputs using those. Otherwise in, then the OutputParser will try to parse outputs using those. Otherwise
@ -354,5 +358,10 @@ def create_structured_output_chain(
pydantic_schema=_OutputFormatter, attr_name="output" pydantic_schema=_OutputFormatter, attr_name="output"
) )
return create_openai_fn_chain( return create_openai_fn_chain(
[function], llm, prompt, output_parser=output_parser, **kwargs [function],
llm,
prompt,
output_key=output_key,
output_parser=output_parser,
**kwargs,
) )

Loading…
Cancel
Save