openai fn update nb (#7352)

pull/7361/head
Bagatur 1 year ago committed by GitHub
parent 0ed2da7020
commit d1c7237034
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -24,7 +24,9 @@
"from langchain.chains.openai_functions import (\n",
" create_openai_fn_chain, create_structured_output_chain\n",
")\n",
"from langchain.prompts import ChatPromptTemplate"
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate\n",
"from langchain.schema import HumanMessage, SystemMessage"
]
},
{
@ -33,7 +35,7 @@
"metadata": {},
"source": [
"## Getting structured outputs\n",
"We can take advantage of OpenAI functions to try and force the model to return a particular kind of structured output. We'll use the `create_structured_output_chain` to create our chain, which takes the desired structured output either as a Pydantic object or as JsonSchema.\n",
"We can take advantage of OpenAI functions to try and force the model to return a particular kind of structured output. We'll use the `create_structured_output_chain` to create our chain, which takes the desired structured output either as a Pydantic class or as JsonSchema.\n",
"\n",
"See here for relevant [reference docs](https://api.python.langchain.com/en/latest/chains/langchain.chains.openai_functions.base.create_structured_output_chain.html)."
]
@ -43,13 +45,29 @@
"id": "e052faae",
"metadata": {},
"source": [
"### Using Pydantic objects\n",
"When passing in Pydantic objects to structure our text, we need to make sure to have a docstring description for the class. It also helps to have descriptions for each of the object attributes."
"### Using Pydantic classes\n",
"When passing in Pydantic classes to structure our text, we need to make sure to have a docstring description for the class. It also helps to have descriptions for each of the classes attributes."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "0e085c99",
"metadata": {},
"outputs": [],
"source": [
"from pydantic import BaseModel, Field\n",
"\n",
"class Person(BaseModel):\n",
" \"\"\"Identifying information about a person.\"\"\"\n",
" name: str = Field(..., description=\"The person's name\")\n",
" age: int = Field(..., description=\"The person's age\")\n",
" fav_food: Optional[str] = Field(None, description=\"The person's favorite food\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "b459a33e",
"metadata": {},
"outputs": [
@ -61,7 +79,10 @@
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mHuman: Sally is 13\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mSystem: You are a world class algorithm for extracting information in structured formats.\n",
"Human: Use the given format to extract information from the following input:\n",
"Human: Sally is 13\n",
"Human: Tips: Make sure to answer in the correct format\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@ -72,21 +93,26 @@
"{'name': 'Sally', 'age': 13}"
]
},
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from pydantic import BaseModel, Field\n",
"# If we pass in a model explicitly, we need to make sure it supports the OpenAI function-calling API.\n",
"llm = ChatOpenAI(model=\"gpt-3.5-turbo-0613\", temperature=0)\n",
"\n",
"class Person(BaseModel):\n",
" \"\"\"Identifying information about a person.\"\"\"\n",
" name: str = Field(..., description=\"The person's name\")\n",
" age: int = Field(..., description=\"The person's age\")\n",
" fav_food: Optional[str] = Field(None, description=\"The person's favorite food\")\n",
" \n",
"chain = create_structured_output_chain(Person, verbose=True)\n",
"prompt_msgs = [\n",
" SystemMessage(\n",
" content=\"You are a world class algorithm for extracting information in structured formats.\"\n",
" ),\n",
" HumanMessage(content=\"Use the given format to extract information from the following input:\"),\n",
" HumanMessagePromptTemplate.from_template(\"{input}\"),\n",
" HumanMessage(content=\"Tips: Make sure to answer in the correct format\"),\n",
" ]\n",
"prompt = ChatPromptTemplate(messages=prompt_msgs)\n",
"\n",
"chain = create_structured_output_chain(Person, llm, prompt, verbose=True)\n",
"chain.run(\"Sally is 13\")"
]
},
@ -95,12 +121,12 @@
"id": "e3539936",
"metadata": {},
"source": [
"To extract arbitrarily many structured outputs of a given format, we can just create a wrapper Pydantic object that takes a sequence of the original object."
"To extract arbitrarily many structured outputs of a given format, we can just create a wrapper Pydantic class that takes a sequence of the original class."
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"id": "4d8ea815",
"metadata": {},
"outputs": [
@ -112,7 +138,10 @@
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mHuman: Sally is 13, Joey just turned 12 and loves spinach. Caroline is 10 years older than Sally, so she's 23.\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mSystem: You are a world class algorithm for extracting information in structured formats.\n",
"Human: Use the given format to extract information from the following input:\n",
"Human: Sally is 13, Joey just turned 12 and loves spinach. Caroline is 10 years older than Sally, so she's 23.\n",
"Human: Tips: Make sure to answer in the correct format\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@ -125,7 +154,7 @@
" {'name': 'Caroline', 'age': 23, 'fav_food': ''}]}"
]
},
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@ -137,7 +166,7 @@
" \"\"\"Identifying information about all people in a text.\"\"\"\n",
" people: Sequence[Person] = Field(..., description=\"The people in the text\")\n",
" \n",
"chain = create_structured_output_chain(People, verbose=True)\n",
"chain = create_structured_output_chain(People, llm, prompt, verbose=True)\n",
"chain.run(\"Sally is 13, Joey just turned 12 and loves spinach. Caroline is 10 years older than Sally, so she's 23.\")"
]
},
@ -148,12 +177,12 @@
"source": [
"### Using JsonSchema\n",
"\n",
"We can also pass in JsonSchema instead of Pydantic objects to specify the desired structure. When we do this, our chain will output json corresponding to the properties described in the JsonSchema, instead of a Pydantic object."
"We can also pass in JsonSchema instead of Pydantic classes to specify the desired structure. When we do this, our chain will output json corresponding to the properties described in the JsonSchema, instead of a Pydantic class."
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"id": "3484415e",
"metadata": {},
"outputs": [],
@ -188,7 +217,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"id": "be9b76b3",
"metadata": {},
"outputs": [
@ -200,7 +229,10 @@
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mHuman: Sally is 13\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mSystem: You are a world class algorithm for extracting information in structured formats.\n",
"Human: Use the given format to extract information from the following input:\n",
"Human: Sally is 13\n",
"Human: Tips: Make sure to answer in the correct format\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@ -211,13 +243,13 @@
"{'name': 'Sally', 'age': 13}"
]
},
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain = create_structured_output_chain(json_schema, verbose=True)\n",
"chain = create_structured_output_chain(json_schema, llm, prompt, verbose=True)\n",
"chain.run(\"Sally is 13\")"
]
},
@ -231,7 +263,7 @@
"\n",
"Functions can be passed in as:\n",
"- dicts conforming to OpenAI functions spec,\n",
"- Pydantic objects, in which case they should have docstring descriptions of the function they represent and descriptions for each of the parameters,\n",
"- Pydantic classes, in which case they should have docstring descriptions of the function they represent and descriptions for each of the parameters,\n",
"- Python functions, in which case they should have docstring descriptions of the function and args, along with type hints.\n",
"\n",
"See here for relevant [reference docs](https://api.python.langchain.com/en/latest/chains/langchain.chains.openai_functions.base.create_openai_fn_chain.html)."
@ -242,12 +274,33 @@
"id": "ff19be25",
"metadata": {},
"source": [
"### Using Pydantic objects"
"### Using Pydantic classes"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"id": "17f52508",
"metadata": {},
"outputs": [],
"source": [
"class RecordPerson(BaseModel):\n",
" \"\"\"Record some identifying information about a pe.\"\"\"\n",
" name: str = Field(..., description=\"The person's name\")\n",
" age: int = Field(..., description=\"The person's age\")\n",
" fav_food: Optional[str] = Field(None, description=\"The person's favorite food\")\n",
"\n",
" \n",
"class RecordDog(BaseModel):\n",
" \"\"\"Record some identifying information about a dog.\"\"\"\n",
" name: str = Field(..., description=\"The dog's name\")\n",
" color: str = Field(..., description=\"The dog's color\")\n",
" fav_food: Optional[str] = Field(None, description=\"The dog's favorite food\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "a4658ad8",
"metadata": {},
"outputs": [
@ -259,7 +312,10 @@
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mHuman: Harry was a chubby brown beagle who loved chicken\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mSystem: You are a world class algorithm for recording entities\n",
"Human: Make calls to the relevant function to record the entities in the following input:\n",
"Human: Harry was a chubby brown beagle who loved chicken\n",
"Human: Tips: Make sure to answer in the correct format\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@ -270,25 +326,23 @@
"RecordDog(name='Harry', color='brown', fav_food='chicken')"
]
},
"execution_count": 6,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class RecordPerson(BaseModel):\n",
" \"\"\"Record some identifying information about a pe.\"\"\"\n",
" name: str = Field(..., description=\"The person's name\")\n",
" age: int = Field(..., description=\"The person's age\")\n",
" fav_food: Optional[str] = Field(None, description=\"The person's favorite food\")\n",
"prompt_msgs = [\n",
" SystemMessage(\n",
" content=\"You are a world class algorithm for recording entities\"\n",
" ),\n",
" HumanMessage(content=\"Make calls to the relevant function to record the entities in the following input:\"),\n",
" HumanMessagePromptTemplate.from_template(\"{input}\"),\n",
" HumanMessage(content=\"Tips: Make sure to answer in the correct format\"),\n",
"]\n",
"prompt = ChatPromptTemplate(messages=prompt_msgs)\n",
"\n",
"class RecordDog(BaseModel):\n",
" \"\"\"Record some identifying information about a dog.\"\"\"\n",
" name: str = Field(..., description=\"The dog's name\")\n",
" color: str = Field(..., description=\"The dog's color\")\n",
" fav_food: Optional[str] = Field(None, description=\"The dog's favorite food\")\n",
"\n",
"chain = create_openai_fn_chain([RecordPerson, RecordDog], verbose=True)\n",
"chain = create_openai_fn_chain([RecordPerson, RecordDog], llm, prompt, verbose=True)\n",
"chain.run(\"Harry was a chubby brown beagle who loved chicken\")"
]
},
@ -298,14 +352,14 @@
"metadata": {},
"source": [
"### Using Python functions\n",
"We can pass in functions as Pydantic objects, directly as OpenAI function dicts, or Python functions. To pass Python function in directly, we'll want to make sure our parameters have type hints, we have a docstring, and we use [Google Python style docstrings](https://google.github.io/styleguide/pyguide.html#doc-function-args) to describe the parameters.\n",
"We can pass in functions as Pydantic classes, directly as OpenAI function dicts, or Python functions. To pass Python function in directly, we'll want to make sure our parameters have type hints, we have a docstring, and we use [Google Python style docstrings](https://google.github.io/styleguide/pyguide.html#doc-function-args) to describe the parameters.\n",
"\n",
"**NOTE**: To use Python functions, make sure the function arguments are of primitive types (str, float, int, bool) or that they are Pydantic objects."
]
},
{
"cell_type": "code",
"execution_count": 41,
"execution_count": 9,
"id": "95ac5825",
"metadata": {},
"outputs": [
@ -317,7 +371,10 @@
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mHuman: The most important thing to remember about Tommy, my 12 year old, is that he'll do anything for apple pie.\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mSystem: You are a world class algorithm for recording entities\n",
"Human: Make calls to the relevant function to record the entities in the following input:\n",
"Human: The most important thing to remember about Tommy, my 12 year old, is that he'll do anything for apple pie.\n",
"Human: Tips: Make sure to answer in the correct format\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@ -328,7 +385,7 @@
"{'name': 'Tommy', 'age': 12, 'fav_food': {'food': 'apple pie'}}"
]
},
"execution_count": 41,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@ -349,7 +406,7 @@
" return f\"Recording person {name} of age {age} with favorite food {fav_food.food}!\"\n",
"\n",
" \n",
"chain = create_openai_fn_chain([record_person], verbose=True)\n",
"chain = create_openai_fn_chain([record_person], llm, prompt, verbose=True)\n",
"chain.run(\"The most important thing to remember about Tommy, my 12 year old, is that he'll do anything for apple pie.\")"
]
},
@ -366,7 +423,7 @@
},
{
"cell_type": "code",
"execution_count": 42,
"execution_count": 10,
"id": "8b0d11de",
"metadata": {},
"outputs": [
@ -378,7 +435,10 @@
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mHuman: I can't find my dog Henry anywhere, he's a small brown beagle. Could you send a message about him?\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mSystem: You are a world class algorithm for recording entities\n",
"Human: Make calls to the relevant function to record the entities in the following input:\n",
"Human: I can't find my dog Henry anywhere, he's a small brown beagle. Could you send a message about him?\n",
"Human: Tips: Make sure to answer in the correct format\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@ -386,11 +446,11 @@
{
"data": {
"text/plain": [
"{'name': 'report_dog',\n",
"{'name': 'record_dog',\n",
" 'arguments': {'name': 'Henry', 'color': 'brown', 'fav_food': {'food': None}}}"
]
},
"execution_count": 42,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
@ -407,130 +467,10 @@
" return f\"Recording dog {name} of color {color} with favorite food {fav_food}!\"\n",
"\n",
"\n",
"chain = create_openai_fn_chain([record_person, report_dog], verbose=True)\n",
"chain = create_openai_fn_chain([record_person, record_dog], llm, prompt, verbose=True)\n",
"chain.run(\"I can't find my dog Henry anywhere, he's a small brown beagle. Could you send a message about him?\")"
]
},
{
"cell_type": "markdown",
"id": "4535ce33",
"metadata": {},
"source": [
"## Creating a Chain that runs the chosen function\n",
"We can go one step further and create a chain that actually executes the function chosen by the model."
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "43b0dfe0",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import inspect\n",
"from typing import Any, Callable, Dict, List, Optional\n",
"\n",
"from langchain.callbacks.manager import CallbackManagerForChainRun\n",
"from langchain.chains.base import Chain\n",
"from langchain.input import get_colored_text\n",
"\n",
"\n",
"class FunctionExecutorChain(Chain):\n",
" functions: Dict[str, Callable]\n",
" output_key: str = \"output\"\n",
" input_key: str = \"function\"\n",
"\n",
" @property\n",
" def input_keys(self) -> List[str]:\n",
" return [self.input_key]\n",
"\n",
" @property\n",
" def output_keys(self) -> List[str]:\n",
" return [self.output_key]\n",
"\n",
" def _call(\n",
" self,\n",
" inputs: Dict[str, Any],\n",
" run_manager: Optional[CallbackManagerForChainRun] = None,\n",
" ) -> Dict[str, Any]:\n",
" \"\"\"Run the logic of this chain and return the output.\"\"\"\n",
" _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()\n",
" name = inputs[\"function\"].pop(\"name\")\n",
" args = inputs[\"function\"].pop(\"arguments\")\n",
" _pretty_name = get_colored_text(name, \"green\")\n",
" _pretty_args = get_colored_text(json.dumps(args, indent=2), \"green\")\n",
" _text = f\"Calling function {_pretty_name} with arguments:\\n\" + _pretty_args\n",
" _run_manager.on_text(_text)\n",
" _args = {}\n",
" function = self.functions[name]\n",
" for arg_name, arg_type in inspect.getfullargspec(function).annotations.items():\n",
" if isinstance(arg_type, type) and issubclass(arg_type, BaseModel):\n",
" args[arg_name] = arg_type.parse_obj(args[arg_name])\n",
" output = function(**args)\n",
" return {self.output_key: output}"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "b8391857",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"Calling function \u001b[32;1m\u001b[1;3mrecord_person\u001b[0m with arguments:\n",
"\u001b[32;1m\u001b[1;3m{\n",
" \"name\": \"Tommy\",\n",
" \"age\": 12,\n",
" \"fav_food\": {\n",
" \"food\": \"apple pie\"\n",
" }\n",
"}\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'Recording person Tommy of age 12 with favorite food apple pie!'"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain.chains import SequentialChain\n",
"from langchain.chains.openai_functions.base import convert_to_openai_function\n",
"\n",
"functions = [record_person, record_dog]\n",
"openai_functions = [convert_to_openai_function(f) for f in functions]\n",
"fn_map = {\n",
" openai_fn[\"name\"]: fn for openai_fn, fn in zip(openai_functions, functions)\n",
"}\n",
"llm_chain = create_openai_fn_chain(functions)\n",
"exec_chain = FunctionExecutorChain(functions=fn_map, verbose=True)\n",
"chain = SequentialChain(\n",
" chains=[llm_chain, exec_chain],\n",
" input_variables=llm_chain.input_keys,\n",
" output_variables=[\"output\"],\n",
" verbose=True\n",
")\n",
"chain.run(\"The most important thing to remember about Tommy, my 12 year old, is that he'll do anything for apple pie.\")"
]
},
{
"cell_type": "markdown",
"id": "5f93686b",

@ -7,12 +7,11 @@ from pydantic import BaseModel
from langchain.base_language import BaseLanguageModel
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.output_parsers.openai_functions import (
JsonOutputFunctionsParser,
PydanticOutputFunctionsParser,
)
from langchain.prompts import BasePromptTemplate, ChatPromptTemplate
from langchain.prompts import BasePromptTemplate
from langchain.schema import BaseLLMOutputParser
PYTHON_TO_JSON_TYPES = {
@ -117,13 +116,13 @@ def convert_python_function_to_openai_function(function: Callable) -> Dict[str,
def convert_to_openai_function(
function: Union[Dict[str, Any], BaseModel, Callable]
function: Union[Dict[str, Any], Type[BaseModel], Callable]
) -> Dict[str, Any]:
"""Convert a raw function/class to an OpenAI function.
Args:
function: Either a dictionary, a pydantic.BaseModel, or a Python function. If
a dictionary is passed in, it is assumed to already be a valid OpenAI
function: Either a dictionary, a pydantic.BaseModel class, or a Python function.
If a dictionary is passed in, it is assumed to already be a valid OpenAI
function.
Returns:
@ -150,7 +149,7 @@ def convert_to_openai_function(
def _get_openai_output_parser(
functions: Sequence[Union[Dict[str, Any], BaseModel, Callable]],
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],
function_names: Sequence[str],
) -> BaseLLMOutputParser:
"""Get the appropriate function output parser given the user functions."""
@ -170,16 +169,17 @@ def _get_openai_output_parser(
def create_openai_fn_chain(
functions: Sequence[Union[Dict[str, Any], BaseModel, Callable]],
llm: Optional[BaseLanguageModel] = None,
prompt: Optional[BasePromptTemplate] = None,
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],
llm: BaseLanguageModel,
prompt: BasePromptTemplate,
*,
output_parser: Optional[BaseLLMOutputParser] = None,
**kwargs: Any,
) -> LLMChain:
"""Create an LLM chain that uses OpenAI functions.
Args:
functions: A sequence of either dictionaries, pydantic.BaseModels, or
functions: A sequence of either dictionaries, pydantic.BaseModels classes, or
Python functions. If dictionaries are passed in, they are assumed to
already be a valid OpenAI functions. If only a single
function is passed in, then it will be enforced that the model use that
@ -190,9 +190,7 @@ def create_openai_fn_chain(
Python functions should only use primitive types (str, int, float, bool) or
pydantic.BaseModels for arguments.
llm: Language model to use, assumed to support the OpenAI function-calling API.
Defaults to ChatOpenAI using model gpt-3.5-turbo-0613.
prompt: BasePromptTemplate to pass to the model. Defaults to a prompt that just
passes user input directly to model.
prompt: BasePromptTemplate to pass to the model.
output_parser: BaseLLMOutputParser to use for parsing model outputs. By default
will be inferred from the function types. If pydantic.BaseModels are passed
in, then the OutputParser will try to parse outputs using those. Otherwise
@ -208,6 +206,8 @@ def create_openai_fn_chain(
.. code-block:: python
from langchain.chains.openai_functions import create_openai_fn_chain
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from pydantic import BaseModel, Field
@ -228,6 +228,16 @@ def create_openai_fn_chain(
fav_food: Optional[str] = Field(None, description="The dog's favorite food")
llm = ChatOpenAI(model="gpt-3.5-turbo-0613", temperature=0)
prompt_msgs = [
SystemMessage(
content="You are a world class algorithm for recording entities"
),
HumanMessage(content="Make calls to the relevant function to record the entities in the following input:"),
HumanMessagePromptTemplate.from_template("{input}"),
HumanMessage(content="Tips: Make sure to answer in the correct format"),
]
prompt = ChatPromptTemplate(messages=prompt_msgs)
chain = create_openai_fn_chain([RecordPerson, RecordDog])
chain.run("Harry was a chubby brown beagle who loved chicken")
# -> RecordDog(name="Harry", color="brown", fav_food="chicken")
@ -235,8 +245,6 @@ def create_openai_fn_chain(
if not functions:
raise ValueError("Need to pass in at least one function. Received zero.")
openai_functions = [convert_to_openai_function(f) for f in functions]
llm = llm or ChatOpenAI(model="gpt-3.5-turbo-0613", temperature=0)
prompt = prompt or ChatPromptTemplate.from_template("{input}")
fn_names = [oai_fn["name"] for oai_fn in openai_functions]
output_parser = output_parser or _get_openai_output_parser(functions, fn_names)
llm_kwargs: Dict[str, Any] = {
@ -256,23 +264,22 @@ def create_openai_fn_chain(
def create_structured_output_chain(
output_schema: Union[Dict[str, Any], BaseModel],
llm: Optional[BaseLanguageModel] = None,
prompt: Optional[BasePromptTemplate] = None,
output_schema: Union[Dict[str, Any], Type[BaseModel]],
llm: BaseLanguageModel,
prompt: BasePromptTemplate,
*,
output_parser: Optional[BaseLLMOutputParser] = None,
**kwargs: Any,
) -> LLMChain:
"""Create an LLMChain that uses an OpenAI function to get a structured output.
Args:
output_schema: Either a dictionary or pydantic.BaseModel. If a dictionary is
passed in, it's assumed to already be a valid JsonSchema.
output_schema: Either a dictionary or pydantic.BaseModel class. If a dictionary
is passed in, it's assumed to already be a valid JsonSchema.
For best results, pydantic.BaseModels should have docstrings describing what
the schema represents and descriptions for the parameters.
llm: Language model to use, assumed to support the OpenAI function-calling API.
Defaults to ChatOpenAI using model gpt-3.5-turbo-0613.
prompt: BasePromptTemplate to pass to the model. Defaults to a prompt that just
passes user input directly to model.
prompt: BasePromptTemplate to pass to the model.
output_parser: BaseLLMOutputParser to use for parsing model outputs. By default
will be inferred from the function types. If pydantic.BaseModels are passed
in, then the OutputParser will try to parse outputs using those. Otherwise
@ -285,6 +292,8 @@ def create_structured_output_chain(
.. code-block:: python
from langchain.chains.openai_functions import create_structured_output_chain
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from pydantic import BaseModel, Field
@ -295,7 +304,17 @@ def create_structured_output_chain(
color: str = Field(..., description="The dog's color")
fav_food: Optional[str] = Field(None, description="The dog's favorite food")
chain = create_structured_output_chain([Dog])
llm = ChatOpenAI(model="gpt-3.5-turbo-0613", temperature=0)
prompt_msgs = [
SystemMessage(
content="You are a world class algorithm for extracting information in structured formats."
),
HumanMessage(content="Use the given format to extract information from the following input:"),
HumanMessagePromptTemplate.from_template("{input}"),
HumanMessage(content="Tips: Make sure to answer in the correct format"),
]
prompt = ChatPromptTemplate(messages=prompt_msgs)
chain = create_structured_output_chain(Dog, llm, prompt)
chain.run("Harry was a chubby brown beagle who loved chicken")
# -> Dog(name="Harry", color="brown", fav_food="chicken")
""" # noqa: E501
@ -311,5 +330,5 @@ def create_structured_output_chain(
)
function["parameters"] = parameters
return create_openai_fn_chain(
[function], llm=llm, prompt=prompt, output_parser=output_parser, **kwargs
[function], llm, prompt, output_parser=output_parser, **kwargs
)

Loading…
Cancel
Save