From 1c8cff32f19786f9bc721daa1d623d94ff94aa03 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Fri, 7 Jul 2023 05:44:53 -0400 Subject: [PATCH] Generic OpenAI fn chain (#7270) Add loading functions for openai function chains and add docs page --- .../chains/popular/openai_functions.ipynb | 578 ++++++++++++++++++ langchain/chains/openai_functions/__init__.py | 6 + langchain/chains/openai_functions/base.py | 315 ++++++++++ langchain/output_parsers/openai_functions.py | 40 +- 4 files changed, 932 insertions(+), 7 deletions(-) create mode 100644 docs/extras/modules/chains/popular/openai_functions.ipynb create mode 100644 langchain/chains/openai_functions/base.py diff --git a/docs/extras/modules/chains/popular/openai_functions.ipynb b/docs/extras/modules/chains/popular/openai_functions.ipynb new file mode 100644 index 0000000000..b21183c778 --- /dev/null +++ b/docs/extras/modules/chains/popular/openai_functions.ipynb @@ -0,0 +1,578 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "54ccb772", + "metadata": {}, + "source": [ + "# Using OpenAI functions\n", + "This walkthrough demonstrates how to incorporate OpenAI function-calling API's in a chain. We'll go over: \n", + "1. How to use functions to get structured outputs from ChatOpenAI\n", + "2. How to create a generic chain that uses (multiple) functions\n", + "3. How to create a chain that actually executes the chosen function" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "767ac575", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Optional\n", + "\n", + "from langchain.chains.openai_functions import (\n", + " create_openai_fn_chain, create_structured_output_chain\n", + ")\n", + "from langchain.prompts import ChatPromptTemplate" + ] + }, + { + "cell_type": "markdown", + "id": "976b6496", + "metadata": {}, + "source": [ + "## Getting structured outputs\n", + "We can take advantage of OpenAI functions to try and force the model to return a particular kind of structured output. We'll use the `create_structured_output_chain` to create our chain, which takes the desired structured output either as a Pydantic object or as JsonSchema.\n", + "\n", + "See here for relevant [reference docs](https://api.python.langchain.com/en/latest/chains/langchain.chains.openai_functions.base.create_structured_output_chain.html)." + ] + }, + { + "cell_type": "markdown", + "id": "e052faae", + "metadata": {}, + "source": [ + "### Using Pydantic objects\n", + "When passing in Pydantic objects to structure our text, we need to make sure to have a docstring description for the class. It also helps to have descriptions for each of the object attributes." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "b459a33e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "Prompt after formatting:\n", + "\u001b[32;1m\u001b[1;3mHuman: Sally is 13\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "{'name': 'Sally', 'age': 13}" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from pydantic import BaseModel, Field\n", + "\n", + "class Person(BaseModel):\n", + " \"\"\"Identifying information about a person.\"\"\"\n", + " name: str = Field(..., description=\"The person's name\")\n", + " age: int = Field(..., description=\"The person's age\")\n", + " fav_food: Optional[str] = Field(None, description=\"The person's favorite food\")\n", + " \n", + "chain = create_structured_output_chain(Person, verbose=True)\n", + "chain.run(\"Sally is 13\")" + ] + }, + { + "cell_type": "markdown", + "id": "e3539936", + "metadata": {}, + "source": [ + "To extract arbitrarily many structured outputs of a given format, we can just create a wrapper Pydantic object that takes a sequence of the original object." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "4d8ea815", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "Prompt after formatting:\n", + "\u001b[32;1m\u001b[1;3mHuman: Sally is 13, Joey just turned 12 and loves spinach. Caroline is 10 years older than Sally, so she's 23.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "{'people': [{'name': 'Sally', 'age': 13, 'fav_food': ''},\n", + " {'name': 'Joey', 'age': 12, 'fav_food': 'spinach'},\n", + " {'name': 'Caroline', 'age': 23, 'fav_food': ''}]}" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from typing import Sequence\n", + "\n", + "class People(BaseModel):\n", + " \"\"\"Identifying information about all people in a text.\"\"\"\n", + " people: Sequence[Person] = Field(..., description=\"The people in the text\")\n", + " \n", + "chain = create_structured_output_chain(People, verbose=True)\n", + "chain.run(\"Sally is 13, Joey just turned 12 and loves spinach. Caroline is 10 years older than Sally, so she's 23.\")" + ] + }, + { + "cell_type": "markdown", + "id": "ea66e10e", + "metadata": {}, + "source": [ + "### Using JsonSchema\n", + "\n", + "We can also pass in JsonSchema instead of Pydantic objects to specify the desired structure. When we do this, our chain will output json corresponding to the properties described in the JsonSchema, instead of a Pydantic object." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "3484415e", + "metadata": {}, + "outputs": [], + "source": [ + "json_schema = {\n", + " \"title\": \"Person\",\n", + " \"description\": \"Identifying information about a person.\",\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"name\": {\n", + " \"title\": \"Name\",\n", + " \"description\": \"The person's name\",\n", + " \"type\": \"string\"\n", + " },\n", + " \"age\": {\n", + " \"title\": \"Age\",\n", + " \"description\": \"The person's age\",\n", + " \"type\": \"integer\"\n", + " },\n", + " \"fav_food\": {\n", + " \"title\": \"Fav Food\",\n", + " \"description\": \"The person's favorite food\",\n", + " \"type\": \"string\"\n", + " }\n", + " },\n", + " \"required\": [\n", + " \"name\",\n", + " \"age\"\n", + " ]\n", + "}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "be9b76b3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "Prompt after formatting:\n", + "\u001b[32;1m\u001b[1;3mHuman: Sally is 13\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "{'name': 'Sally', 'age': 13}" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain = create_structured_output_chain(json_schema, verbose=True)\n", + "chain.run(\"Sally is 13\")" + ] + }, + { + "cell_type": "markdown", + "id": "12394696", + "metadata": {}, + "source": [ + "## Creating a generic OpenAI functions chain\n", + "To create a generic OpenAI functions chain, we can use the `create_openai_fn_chain` method. This is the same as `create_structured_output_chain` except that instead of taking a single output schema, it takes a sequence of function definitions.\n", + "\n", + "Functions can be passed in as:\n", + "- dicts conforming to OpenAI functions spec,\n", + "- Pydantic objects, in which case they should have docstring descriptions of the function they represent and descriptions for each of the parameters,\n", + "- Python functions, in which case they should have docstring descriptions of the function and args, along with type hints.\n", + "\n", + "See here for relevant [reference docs](https://api.python.langchain.com/en/latest/chains/langchain.chains.openai_functions.base.create_openai_fn_chain.html)." + ] + }, + { + "cell_type": "markdown", + "id": "ff19be25", + "metadata": {}, + "source": [ + "### Using Pydantic objects" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "a4658ad8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "Prompt after formatting:\n", + "\u001b[32;1m\u001b[1;3mHuman: Harry was a chubby brown beagle who loved chicken\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "RecordDog(name='Harry', color='brown', fav_food='chicken')" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class RecordPerson(BaseModel):\n", + " \"\"\"Record some identifying information about a pe.\"\"\"\n", + " name: str = Field(..., description=\"The person's name\")\n", + " age: int = Field(..., description=\"The person's age\")\n", + " fav_food: Optional[str] = Field(None, description=\"The person's favorite food\")\n", + "\n", + "class RecordDog(BaseModel):\n", + " \"\"\"Record some identifying information about a dog.\"\"\"\n", + " name: str = Field(..., description=\"The dog's name\")\n", + " color: str = Field(..., description=\"The dog's color\")\n", + " fav_food: Optional[str] = Field(None, description=\"The dog's favorite food\")\n", + "\n", + "chain = create_openai_fn_chain([RecordPerson, RecordDog], verbose=True)\n", + "chain.run(\"Harry was a chubby brown beagle who loved chicken\")" + ] + }, + { + "cell_type": "markdown", + "id": "df6d9147", + "metadata": {}, + "source": [ + "### Using Python functions\n", + "We can pass in functions as Pydantic objects, directly as OpenAI function dicts, or Python functions. To pass Python function in directly, we'll want to make sure our parameters have type hints, we have a docstring, and we use [Google Python style docstrings](https://google.github.io/styleguide/pyguide.html#doc-function-args) to describe the parameters.\n", + "\n", + "**NOTE**: To use Python functions, make sure the function arguments are of primitive types (str, float, int, bool) or that they are Pydantic objects." + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "95ac5825", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "Prompt after formatting:\n", + "\u001b[32;1m\u001b[1;3mHuman: The most important thing to remember about Tommy, my 12 year old, is that he'll do anything for apple pie.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "{'name': 'Tommy', 'age': 12, 'fav_food': {'food': 'apple pie'}}" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class OptionalFavFood(BaseModel):\n", + " \"\"\"Either a food or null.\"\"\"\n", + " food: Optional[str] = Field(None, description=\"Either the name of a food or null. Should be null if the food isn't known.\")\n", + "\n", + "def record_person(name: str, age: int, fav_food: OptionalFavFood) -> str:\n", + " \"\"\"Record some basic identifying information about a person.\n", + " \n", + " Args:\n", + " name: The person's name.\n", + " age: The person's age in years.\n", + " fav_food: An OptionalFavFood object that either contains the person's favorite food or a null value. Food should be null if it's not known.\n", + " \"\"\"\n", + " return f\"Recording person {name} of age {age} with favorite food {fav_food.food}!\"\n", + "\n", + " \n", + "chain = create_openai_fn_chain([record_person], verbose=True)\n", + "chain.run(\"The most important thing to remember about Tommy, my 12 year old, is that he'll do anything for apple pie.\")" + ] + }, + { + "cell_type": "markdown", + "id": "403ea5dd", + "metadata": {}, + "source": [ + "If we pass in multiple Python functions or OpenAI functions, then the returned output will be of the form\n", + "```python\n", + "{\"name\": \"<>\", \"arguments\": {<>}}\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "8b0d11de", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "Prompt after formatting:\n", + "\u001b[32;1m\u001b[1;3mHuman: I can't find my dog Henry anywhere, he's a small brown beagle. Could you send a message about him?\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "{'name': 'report_dog',\n", + " 'arguments': {'name': 'Henry', 'color': 'brown', 'fav_food': {'food': None}}}" + ] + }, + "execution_count": 42, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def record_dog(name: str, color: str, fav_food: OptionalFavFood) -> str:\n", + " \"\"\"Record some basic identifying information about a dog.\n", + " \n", + " Args:\n", + " name: The dog's name.\n", + " color: The dog's color.\n", + " fav_food: An OptionalFavFood object that either contains the dog's favorite food or a null value. Food should be null if it's not known.\n", + " \"\"\"\n", + " return f\"Recording dog {name} of color {color} with favorite food {fav_food}!\"\n", + "\n", + "\n", + "chain = create_openai_fn_chain([record_person, report_dog], verbose=True)\n", + "chain.run(\"I can't find my dog Henry anywhere, he's a small brown beagle. Could you send a message about him?\")" + ] + }, + { + "cell_type": "markdown", + "id": "4535ce33", + "metadata": {}, + "source": [ + "## Creating a Chain that runs the chosen function\n", + "We can go one step further and create a chain that actually executes the function chosen by the model." + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "43b0dfe0", + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import inspect\n", + "from typing import Any, Callable, Dict, List, Optional\n", + "\n", + "from langchain.callbacks.manager import CallbackManagerForChainRun\n", + "from langchain.chains.base import Chain\n", + "from langchain.input import get_colored_text\n", + "\n", + "\n", + "class FunctionExecutorChain(Chain):\n", + " functions: Dict[str, Callable]\n", + " output_key: str = \"output\"\n", + " input_key: str = \"function\"\n", + "\n", + " @property\n", + " def input_keys(self) -> List[str]:\n", + " return [self.input_key]\n", + "\n", + " @property\n", + " def output_keys(self) -> List[str]:\n", + " return [self.output_key]\n", + "\n", + " def _call(\n", + " self,\n", + " inputs: Dict[str, Any],\n", + " run_manager: Optional[CallbackManagerForChainRun] = None,\n", + " ) -> Dict[str, Any]:\n", + " \"\"\"Run the logic of this chain and return the output.\"\"\"\n", + " _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()\n", + " name = inputs[\"function\"].pop(\"name\")\n", + " args = inputs[\"function\"].pop(\"arguments\")\n", + " _pretty_name = get_colored_text(name, \"green\")\n", + " _pretty_args = get_colored_text(json.dumps(args, indent=2), \"green\")\n", + " _text = f\"Calling function {_pretty_name} with arguments:\\n\" + _pretty_args\n", + " _run_manager.on_text(_text)\n", + " _args = {}\n", + " function = self.functions[name]\n", + " for arg_name, arg_type in inspect.getfullargspec(function).annotations.items():\n", + " if isinstance(arg_type, type) and issubclass(arg_type, BaseModel):\n", + " args[arg_name] = arg_type.parse_obj(args[arg_name])\n", + " output = function(**args)\n", + " return {self.output_key: output}" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "b8391857", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "Calling function \u001b[32;1m\u001b[1;3mrecord_person\u001b[0m with arguments:\n", + "\u001b[32;1m\u001b[1;3m{\n", + " \"name\": \"Tommy\",\n", + " \"age\": 12,\n", + " \"fav_food\": {\n", + " \"food\": \"apple pie\"\n", + " }\n", + "}\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'Recording person Tommy of age 12 with favorite food apple pie!'" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain.chains import SequentialChain\n", + "from langchain.chains.openai_functions.base import convert_to_openai_function\n", + "\n", + "functions = [record_person, record_dog]\n", + "openai_functions = [convert_to_openai_function(f) for f in functions]\n", + "fn_map = {\n", + " openai_fn[\"name\"]: fn for openai_fn, fn in zip(openai_functions, functions)\n", + "}\n", + "llm_chain = create_openai_fn_chain(functions)\n", + "exec_chain = FunctionExecutorChain(functions=fn_map, verbose=True)\n", + "chain = SequentialChain(\n", + " chains=[llm_chain, exec_chain],\n", + " input_variables=llm_chain.input_keys,\n", + " output_variables=[\"output\"],\n", + " verbose=True\n", + ")\n", + "chain.run(\"The most important thing to remember about Tommy, my 12 year old, is that he'll do anything for apple pie.\")" + ] + }, + { + "cell_type": "markdown", + "id": "5f93686b", + "metadata": {}, + "source": [ + "## Other Chains using OpenAI functions\n", + "\n", + "There are a number of more specific chains that use OpenAI functions.\n", + "- [Extraction](/docs/modules/chains/additional/extraction): very similar to structured output chain, intended for information/entity extraction specifically.\n", + "- [Tagging](/docs/modules/chains/additional/tagging): tag inputs.\n", + "- [OpenAPI](/docs/modules/chains/additional/openapi_openai): take an OpenAPI spec and create + execute valid requests against the API, using OpenAI functions under the hood.\n", + "- [QA with citations](/docs/modules/chains/additional/qa_citations): use OpenAI functions ability to extract citations from text." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "93425c66", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "venv" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/chains/openai_functions/__init__.py b/langchain/chains/openai_functions/__init__.py index 55b38dbfc8..3ac797b849 100644 --- a/langchain/chains/openai_functions/__init__.py +++ b/langchain/chains/openai_functions/__init__.py @@ -1,3 +1,7 @@ +from langchain.chains.openai_functions.base import ( + create_openai_fn_chain, + create_structured_output_chain, +) from langchain.chains.openai_functions.citation_fuzzy_match import ( create_citation_fuzzy_match_chain, ) @@ -22,4 +26,6 @@ __all__ = [ "create_citation_fuzzy_match_chain", "create_qa_with_structure_chain", "create_qa_with_sources_chain", + "create_structured_output_chain", + "create_openai_fn_chain", ] diff --git a/langchain/chains/openai_functions/base.py b/langchain/chains/openai_functions/base.py new file mode 100644 index 0000000000..89f178f585 --- /dev/null +++ b/langchain/chains/openai_functions/base.py @@ -0,0 +1,315 @@ +"""Methods for creating chains that use OpenAI function-calling APIs.""" +import inspect +import re +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union + +from pydantic import BaseModel + +from langchain.base_language import BaseLanguageModel +from langchain.chains import LLMChain +from langchain.chat_models import ChatOpenAI +from langchain.output_parsers.openai_functions import ( + JsonOutputFunctionsParser, + PydanticOutputFunctionsParser, +) +from langchain.prompts import BasePromptTemplate, ChatPromptTemplate +from langchain.schema import BaseLLMOutputParser + +PYTHON_TO_JSON_TYPES = { + "str": "string", + "int": "number", + "float": "number", + "bool": "boolean", +} + + +def _get_python_function_name(function: Callable) -> str: + """Get the name of a Python function.""" + source = inspect.getsource(function) + return re.search(r"^def (.*)\(", source).groups()[0] # type: ignore + + +def _parse_python_function_docstring(function: Callable) -> Tuple[str, dict]: + """Parse the function and argument descriptions from the docstring of a function. + + Assumes the function docstring follows Google Python style guide. + """ + docstring = inspect.getdoc(function) + if docstring: + docstring_blocks = docstring.split("\n\n") + descriptors = [] + args_block = None + past_descriptors = False + for block in docstring_blocks: + if block.startswith("Args:"): + args_block = block + break + elif block.startswith("Returns:") or block.startswith("Example:"): + # Don't break in case Args come after + past_descriptors = True + elif not past_descriptors: + descriptors.append(block) + else: + continue + description = " ".join(descriptors) + else: + description = "" + args_block = None + arg_descriptions = {} + if args_block: + arg = None + for line in args_block.split("\n")[1:]: + if ":" in line: + arg, desc = line.split(":") + arg_descriptions[arg.strip()] = desc.strip() + elif arg: + arg_descriptions[arg.strip()] += " " + line.strip() + return description, arg_descriptions + + +def _get_python_function_arguments(function: Callable, arg_descriptions: dict) -> dict: + """Get JsonSchema describing a Python functions arguments. + + Assumes all function arguments are of primitive types (int, float, str, bool) or + are subclasses of pydantic.BaseModel. + """ + properties = {} + annotations = inspect.getfullargspec(function).annotations + for arg, arg_type in annotations.items(): + if arg == "return": + continue + if isinstance(arg_type, type) and issubclass(arg_type, BaseModel): + properties[arg] = arg_type.schema() + elif arg_type.__name__ in PYTHON_TO_JSON_TYPES: + properties[arg] = {"type": PYTHON_TO_JSON_TYPES[arg_type.__name__]} + if arg in arg_descriptions: + if arg not in properties: + properties[arg] = {} + properties[arg]["description"] = arg_descriptions[arg] + return properties + + +def _get_python_function_required_args(function: Callable) -> List[str]: + """Get the required arguments for a Python function.""" + spec = inspect.getfullargspec(function) + required = spec.args[: -len(spec.defaults)] if spec.defaults else spec.args + required += [k for k in spec.kwonlyargs if k not in (spec.kwonlydefaults or {})] + return required + + +def convert_python_function_to_openai_function(function: Callable) -> Dict[str, Any]: + """Convert a Python function to an OpenAI function-calling API compatible dict. + + Assumes the Python function has type hints and a docstring with a description. If + the docstring has Google Python style argument descriptions, these will be + included as well. + """ + description, arg_descriptions = _parse_python_function_docstring(function) + return { + "name": _get_python_function_name(function), + "description": description, + "parameters": { + "type": "object", + "properties": _get_python_function_arguments(function, arg_descriptions), + "required": _get_python_function_required_args(function), + }, + } + + +def convert_to_openai_function( + function: Union[Dict[str, Any], BaseModel, Callable] +) -> Dict[str, Any]: + """Convert a raw function/class to an OpenAI function. + + Args: + function: Either a dictionary, a pydantic.BaseModel, or a Python function. If + a dictionary is passed in, it is assumed to already be a valid OpenAI + function. + + Returns: + A dict version of the passed in function which is compatible with the + OpenAI function-calling API. + """ + if isinstance(function, dict): + return function + elif isinstance(function, type) and issubclass(function, BaseModel): + schema = function.schema() + return { + "name": schema["title"], + "description": schema["description"], + "parameters": schema, + } + elif callable(function): + return convert_python_function_to_openai_function(function) + + else: + raise ValueError( + f"Unsupported function type {type(function)}. Functions must be passed in" + f" as Dict, pydantic.BaseModel, or Callable." + ) + + +def _get_openai_output_parser( + functions: Sequence[Union[Dict[str, Any], BaseModel, Callable]], + function_names: Sequence[str], +) -> BaseLLMOutputParser: + """Get the appropriate function output parser given the user functions.""" + if isinstance(functions[0], type) and issubclass(functions[0], BaseModel): + if len(functions) > 1: + pydantic_schema: Union[Dict, Type[BaseModel]] = { + name: fn for name, fn in zip(function_names, functions) + } + else: + pydantic_schema = functions[0] + output_parser: BaseLLMOutputParser = PydanticOutputFunctionsParser( + pydantic_schema=pydantic_schema + ) + else: + output_parser = JsonOutputFunctionsParser(args_only=len(functions) <= 1) + return output_parser + + +def create_openai_fn_chain( + functions: Sequence[Union[Dict[str, Any], BaseModel, Callable]], + llm: Optional[BaseLanguageModel] = None, + prompt: Optional[BasePromptTemplate] = None, + output_parser: Optional[BaseLLMOutputParser] = None, + **kwargs: Any, +) -> LLMChain: + """Create an LLM chain that uses OpenAI functions. + + Args: + functions: A sequence of either dictionaries, pydantic.BaseModels, or + Python functions. If dictionaries are passed in, they are assumed to + already be a valid OpenAI functions. If only a single + function is passed in, then it will be enforced that the model use that + function. pydantic.BaseModels and Python functions should have docstrings + describing what the function does. For best results, pydantic.BaseModels + should have descriptions of the parameters and Python functions should have + Google Python style args descriptions in the docstring. Additionally, + Python functions should only use primitive types (str, int, float, bool) or + pydantic.BaseModels for arguments. + llm: Language model to use, assumed to support the OpenAI function-calling API. + Defaults to ChatOpenAI using model gpt-3.5-turbo-0613. + prompt: BasePromptTemplate to pass to the model. Defaults to a prompt that just + passes user input directly to model. + output_parser: BaseLLMOutputParser to use for parsing model outputs. By default + will be inferred from the function types. If pydantic.BaseModels are passed + in, then the OutputParser will try to parse outputs using those. Otherwise + model outputs will simply be parsed as JSON. If multiple functions are + passed in and they are not pydantic.BaseModels, the chain output will + include both the name of the function that was returned and the arguments + to pass to the function. + + Returns: + An LLMChain that will pass in the given functions to the model when run. + + Example: + .. code-block:: python + + from langchain.chains.openai_functions import create_openai_fn_chain + + from pydantic import BaseModel, Field + + + class RecordPerson(BaseModel): + \"\"\"Record some identifying information about a person.\"\"\" + + name: str = Field(..., description="The person's name") + age: int = Field(..., description="The person's age") + fav_food: Optional[str] = Field(None, description="The person's favorite food") + + + class RecordDog(BaseModel): + \"\"\"Record some identifying information about a dog.\"\"\" + + name: str = Field(..., description="The dog's name") + color: str = Field(..., description="The dog's color") + fav_food: Optional[str] = Field(None, description="The dog's favorite food") + + + chain = create_openai_fn_chain([RecordPerson, RecordDog]) + chain.run("Harry was a chubby brown beagle who loved chicken") + # -> RecordDog(name="Harry", color="brown", fav_food="chicken") + """ # noqa: E501 + if not functions: + raise ValueError("Need to pass in at least one function. Received zero.") + openai_functions = [convert_to_openai_function(f) for f in functions] + llm = llm or ChatOpenAI(model="gpt-3.5-turbo-0613", temperature=0) + prompt = prompt or ChatPromptTemplate.from_template("{input}") + fn_names = [oai_fn["name"] for oai_fn in openai_functions] + output_parser = output_parser or _get_openai_output_parser(functions, fn_names) + llm_kwargs: Dict[str, Any] = { + "functions": openai_functions, + } + if len(openai_functions) == 1: + llm_kwargs["function_call"] = {"name": openai_functions[0]["name"]} + llm_chain = LLMChain( + llm=llm, + prompt=prompt, + output_parser=output_parser, + llm_kwargs=llm_kwargs, + output_key="function", + **kwargs, + ) + return llm_chain + + +def create_structured_output_chain( + output_schema: Union[Dict[str, Any], BaseModel], + llm: Optional[BaseLanguageModel] = None, + prompt: Optional[BasePromptTemplate] = None, + output_parser: Optional[BaseLLMOutputParser] = None, + **kwargs: Any, +) -> LLMChain: + """Create an LLMChain that uses an OpenAI function to get a structured output. + + Args: + output_schema: Either a dictionary or pydantic.BaseModel. If a dictionary is + passed in, it's assumed to already be a valid JsonSchema. + For best results, pydantic.BaseModels should have docstrings describing what + the schema represents and descriptions for the parameters. + llm: Language model to use, assumed to support the OpenAI function-calling API. + Defaults to ChatOpenAI using model gpt-3.5-turbo-0613. + prompt: BasePromptTemplate to pass to the model. Defaults to a prompt that just + passes user input directly to model. + output_parser: BaseLLMOutputParser to use for parsing model outputs. By default + will be inferred from the function types. If pydantic.BaseModels are passed + in, then the OutputParser will try to parse outputs using those. Otherwise + model outputs will simply be parsed as JSON. + + Returns: + An LLMChain that will pass the given function to the model. + + Example: + .. code-block:: python + + from langchain.chains.openai_functions import create_structured_output_chain + + from pydantic import BaseModel, Field + + class Dog(BaseModel): + \"\"\"Identifying information about a dog.\"\"\" + + name: str = Field(..., description="The dog's name") + color: str = Field(..., description="The dog's color") + fav_food: Optional[str] = Field(None, description="The dog's favorite food") + + chain = create_structured_output_chain([Dog]) + chain.run("Harry was a chubby brown beagle who loved chicken") + # -> Dog(name="Harry", color="brown", fav_food="chicken") + """ # noqa: E501 + function: Dict = { + "name": "output_formatter", + "description": ( + "Output formatter. Should always be used to format your response to the" + " user." + ), + } + parameters = ( + output_schema if isinstance(output_schema, dict) else output_schema.schema() + ) + function["parameters"] = parameters + return create_openai_fn_chain( + [function], llm=llm, prompt=prompt, output_parser=output_parser, **kwargs + ) diff --git a/langchain/output_parsers/openai_functions.py b/langchain/output_parsers/openai_functions.py index 245aa107ea..7d9d8aed86 100644 --- a/langchain/output_parsers/openai_functions.py +++ b/langchain/output_parsers/openai_functions.py @@ -1,7 +1,14 @@ import json -from typing import Any, List +from typing import Any, Dict, List, Type, Union -from langchain.schema import BaseLLMOutputParser, ChatGeneration, Generation +from pydantic import BaseModel, root_validator + +from langchain.schema import ( + BaseLLMOutputParser, + ChatGeneration, + Generation, + OutputParserException, +) class OutputFunctionsParser(BaseLLMOutputParser[Any]): @@ -10,14 +17,14 @@ class OutputFunctionsParser(BaseLLMOutputParser[Any]): def parse_result(self, result: List[Generation]) -> Any: generation = result[0] if not isinstance(generation, ChatGeneration): - raise ValueError( + raise OutputParserException( "This output parser can only be used with a chat generation." ) message = generation.message try: func_call = message.additional_kwargs["function_call"] except ValueError as exc: - raise ValueError(f"Could not parse function call: {exc}") + raise OutputParserException(f"Could not parse function call: {exc}") if self.args_only: return func_call["arguments"] @@ -42,11 +49,30 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser): class PydanticOutputFunctionsParser(OutputFunctionsParser): - pydantic_schema: Any + pydantic_schema: Union[Type[BaseModel], Dict[str, Type[BaseModel]]] + + @root_validator(pre=True) + def validate_schema(cls, values: Dict) -> Dict: + schema = values["pydantic_schema"] + if "args_only" not in values: + values["args_only"] = isinstance(schema, type) and issubclass( + schema, BaseModel + ) + elif values["args_only"] and isinstance(schema, Dict): + raise ValueError( + "If multiple pydantic schemas are provided then args_only should be" + " False." + ) + return values def parse_result(self, result: List[Generation]) -> Any: - _args = super().parse_result(result) - pydantic_args = self.pydantic_schema.parse_raw(_args) + _result = super().parse_result(result) + if self.args_only: + pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore + else: + fn_name = _result["name"] + _args = _result["arguments"] + pydantic_args = self.pydantic_schema[fn_name].parse_raw(_args) # type: ignore # noqa: E501 return pydantic_args