diff --git a/docs/modules/chains/examples/extraction.ipynb b/docs/modules/chains/examples/extraction.ipynb new file mode 100644 index 0000000000..45124d9608 --- /dev/null +++ b/docs/modules/chains/examples/extraction.ipynb @@ -0,0 +1,240 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "6605e7f7", + "metadata": {}, + "source": [ + "# Extraction\n", + "\n", + "The extraction chain uses the OpenAI `functions` parameter to specify a schema to extract entities from a document. This helps us make sure that the model outputs exactly the schema of entities and properties that we want, with their appropriate types.\n", + "\n", + "The extraction chain is to be used when we want to extract several entities with their properties from the same passage (i.e. what people were mentioned in this passage?)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "34f04daf", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.chat_models import ChatOpenAI\n", + "from langchain.chains import create_extraction_chain, create_extraction_chain_pydantic\n", + "from langchain.prompts import ChatPromptTemplate" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "a2648974", + "metadata": {}, + "outputs": [], + "source": [ + "llm = ChatOpenAI(temperature=0, \n", + " model=\"gpt-3.5-turbo-0613\")" + ] + }, + { + "cell_type": "markdown", + "id": "5ef034ce", + "metadata": {}, + "source": [ + "## Extracting entities" + ] + }, + { + "cell_type": "markdown", + "id": "78ff9df9", + "metadata": {}, + "source": [ + "To extract entities, we need to create a schema like the following, were we specify all the properties we want to find and the type we expect them to have. We can also specify which of these properties are required and which are optional." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "4ac43eba", + "metadata": {}, + "outputs": [], + "source": [ + "schema = {\n", + " \"properties\": {\n", + " \"person_name\": {\"type\": \"string\"}, \n", + " \"person_height\":{\"type\": \"integer\"},\n", + " \"person_hair_color\": {\"type\": \"string\"},\n", + " \"dog_name\": {\"type\": \"string\"},\n", + " \"dog_breed\": {\"type\": \"string\"}\n", + " },\n", + " \"required\": [\"person_name\", \"person_height\"]\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "640bd005", + "metadata": {}, + "outputs": [], + "source": [ + "inp = \"\"\"\n", + "Alex is 5 feet tall. Claudia is 4 feet taller Alex and jumps higher than him. Claudia is a brunette and Alex is blonde.\n", + "Alex's dog Frosty is a labrador and likes to play hide and seek.\n", + " \"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "64313214", + "metadata": {}, + "outputs": [], + "source": [ + "chain = create_extraction_chain(schema, llm)" + ] + }, + { + "cell_type": "markdown", + "id": "17c48adb", + "metadata": {}, + "source": [ + "As we can see, we extracted the required entities and their properties in the required format:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "cc5436ed", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'person_name': 'Alex',\n", + " 'person_height': 5,\n", + " 'person_hair_color': 'blonde',\n", + " 'dog_name': 'Frosty',\n", + " 'dog_breed': 'labrador'},\n", + " {'person_name': 'Claudia',\n", + " 'person_height': 9,\n", + " 'person_hair_color': 'brunette',\n", + " 'dog_name': '',\n", + " 'dog_breed': ''}]" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain.run(inp)" + ] + }, + { + "cell_type": "markdown", + "id": "698b4c4d", + "metadata": {}, + "source": [ + "## Pydantic example" + ] + }, + { + "cell_type": "markdown", + "id": "6504a6d9", + "metadata": {}, + "source": [ + "We can also use a Pydantic schema to choose the required properties and types and we will set as 'Optional' those that are not strictly required.\n", + "\n", + "By using the `create_extraction_chain_pydantic` function, we can send a Pydantic schema as input and the output will be an instantiated object that respects our desired schema. \n", + "\n", + "In this way, we can specify our schema in the same manner that we would a new class or function in Python - with purely Pythonic types." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "6792866b", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Optional, List\n", + "from pydantic import BaseModel, Field" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "36a63761", + "metadata": {}, + "outputs": [], + "source": [ + "class Properties(BaseModel):\n", + " person_name: str\n", + " person_height: int\n", + " person_hair_color: str\n", + " dog_breed: Optional[str]\n", + " dog_name: Optional[str]" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "8ffd1e57", + "metadata": {}, + "outputs": [], + "source": [ + "chain = create_extraction_chain_pydantic(pydantic_schema=Properties, llm=llm)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "24baa954", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[Properties(person_name='Alex', person_height=5, person_hair_color='blonde', dog_breed='labrador', dog_name='Frosty'),\n", + " Properties(person_name='Claudia', person_height=9, person_hair_color='brunette', dog_breed=None, dog_name=None)]" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "inp = \"\"\"\n", + "Alex is 5 feet tall. Claudia is 4 feet taller Alex and jumps higher than him. Claudia is a brunette and Alex is blonde.\n", + "Alex's dog Frosty is a labrador and likes to play hide and seek.\n", + " \"\"\"\n", + "chain.run(inp)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "general", + "language": "python", + "name": "general" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/modules/chains/examples/tagging.ipynb b/docs/modules/chains/examples/tagging.ipynb new file mode 100644 index 0000000000..d513963d41 --- /dev/null +++ b/docs/modules/chains/examples/tagging.ipynb @@ -0,0 +1,389 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a13ea924", + "metadata": {}, + "source": [ + "# Tagging\n", + "\n", + "The tagging chain uses the OpenAI `functions` parameter to specify a schema to tag a document with. This helps us make sure that the model outputs exactly tags that we want, with their appropriate types.\n", + "\n", + "The tagging chain is to be used when we want to tag a passage with a specific attribute (i.e. what is the sentiment of this message?)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "bafb496a", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.chat_models import ChatOpenAI\n", + "from langchain.chains import create_tagging_chain, create_tagging_chain_pydantic\n", + "from langchain.prompts import ChatPromptTemplate" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "39f3ce3e", + "metadata": {}, + "outputs": [], + "source": [ + "llm = ChatOpenAI(\n", + " temperature=0, \n", + " model=\"gpt-3.5-turbo-0613\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "832ddcd9", + "metadata": {}, + "source": [ + "## Simplest approach, only specifying type" + ] + }, + { + "cell_type": "markdown", + "id": "4fc8d766", + "metadata": {}, + "source": [ + "We can start by specifying a few properties with their expected type in our schema" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "8329f943", + "metadata": {}, + "outputs": [], + "source": [ + "schema = {\n", + " \"properties\": {\n", + " \"sentiment\": {\"type\": \"string\"}, \n", + " \"aggressiveness\": {\"type\": \"integer\"},\n", + " \"language\": {\"type\": \"string\"},\n", + " }\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "6146ae70", + "metadata": {}, + "outputs": [], + "source": [ + "chain = create_tagging_chain(schema, llm)" + ] + }, + { + "cell_type": "markdown", + "id": "9e306ca3", + "metadata": {}, + "source": [ + "As we can see in the examples, it correctly interprets what we want but the results vary so that we get, for example, sentiments in different languages ('positive', 'enojado' etc.).\n", + "\n", + "We will see how to control these results in the next section." + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "id": "5509b6a6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'sentiment': 'positive', 'language': 'Spanish'}" + ] + }, + "execution_count": 59, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "inp = \"Estoy increiblemente contento de haberte conocido! Creo que seremos muy buenos amigos!\"\n", + "chain.run(inp)" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "id": "9154474c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'sentiment': 'enojado', 'aggressiveness': 1, 'language': 'Spanish'}" + ] + }, + "execution_count": 60, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "inp = \"Estoy muy enojado con vos! Te voy a dar tu merecido!\"\n", + "chain.run(inp)" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "id": "aae85b27", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'sentiment': 'positive', 'aggressiveness': 0, 'language': 'English'}" + ] + }, + "execution_count": 61, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "inp = \"Weather is ok here, I can go outside without much more than a coat\"\n", + "chain.run(inp)" + ] + }, + { + "cell_type": "markdown", + "id": "bebb2f83", + "metadata": {}, + "source": [ + "## More control\n", + "\n", + "By being smart about how we define our schema we can have more control over the model's output. Specifically we can define:\n", + "\n", + "- possible values for each property\n", + "- description to make sure that the model understands the property\n", + "- required properties to be returned" + ] + }, + { + "cell_type": "markdown", + "id": "69ef0b9a", + "metadata": {}, + "source": [ + "Following is an example of how we can use _enum_, _description_ and _required_ to control for each of the previously mentioned aspects:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "6a5f7961", + "metadata": {}, + "outputs": [], + "source": [ + "schema = {\n", + " \"properties\": {\n", + " \"sentiment\": {\"type\": \"string\", \"enum\": [\"happy\", \"neutral\", \"sad\"]}, \n", + " \"aggressiveness\": {\"type\": \"integer\", \"enum\": [1,2,3,4,5], \"description\": \"describes how aggressive the statement is, the higher the number the more aggressive\"},\n", + " \"language\": {\"type\": \"string\", \"enum\": [\"spanish\", \"english\", \"french\", \"german\", \"italian\"]},\n", + " },\n", + " \"required\": [\"language\", \"sentiment\", \"aggressiveness\"]\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "e5a5881f", + "metadata": {}, + "outputs": [], + "source": [ + "chain = create_tagging_chain(schema, llm)" + ] + }, + { + "cell_type": "markdown", + "id": "5ded2332", + "metadata": {}, + "source": [ + "Now the answers are much better!" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "d9b9d53d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'sentiment': 'happy', 'aggressiveness': 0, 'language': 'spanish'}" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "inp = \"Estoy increiblemente contento de haberte conocido! Creo que seremos muy buenos amigos!\"\n", + "chain.run(inp)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "1c12fa00", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'sentiment': 'sad', 'aggressiveness': 10, 'language': 'spanish'}" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "inp = \"Estoy muy enojado con vos! Te voy a dar tu merecido!\"\n", + "chain.run(inp)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "0bdfcb05", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'sentiment': 'neutral', 'aggressiveness': 0, 'language': 'english'}" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "inp = \"Weather is ok here, I can go outside without much more than a coat\"\n", + "chain.run(inp)" + ] + }, + { + "cell_type": "markdown", + "id": "e68ad17e", + "metadata": {}, + "source": [ + "## Specifying schema with Pydantic" + ] + }, + { + "cell_type": "markdown", + "id": "2f5970ec", + "metadata": {}, + "source": [ + "We can also use a Pydantic schema to specify the required properties and types. We can also send other arguments, such as 'enum' or 'description' as can be seen in the example below.\n", + "\n", + "By using the `create_tagging_chain_pydantic` function, we can send a Pydantic schema as input and the output will be an instantiated object that respects our desired schema. \n", + "\n", + "In this way, we can specify our schema in the same manner that we would a new class or function in Python - with purely Pythonic types." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "bf1f367e", + "metadata": {}, + "outputs": [], + "source": [ + "from enum import Enum\n", + "from pydantic import BaseModel, Field" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "83a2e826", + "metadata": {}, + "outputs": [], + "source": [ + "class Tags(BaseModel):\n", + " sentiment: str = Field(..., enum=[\"happy\", \"neutral\", \"sad\"])\n", + " aggressiveness: int = Field(..., description=\"describes how aggressive the statement is, the higher the number the more aggressive\", enum=[1, 2, 3, 4, 5])\n", + " language: str = Field(..., enum=[\"spanish\", \"english\", \"french\", \"german\", \"italian\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "6e404892", + "metadata": {}, + "outputs": [], + "source": [ + "chain = create_tagging_chain_pydantic(Tags, llm)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "b5fc43c4", + "metadata": {}, + "outputs": [], + "source": [ + "inp = \"Estoy muy enojado con vos! Te voy a dar tu merecido!\"\n", + "res = chain.run(inp)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "5074bcc3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Tags(sentiment='sad', aggressiveness=10, language='spanish')" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "res" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "general", + "language": "python", + "name": "general" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/modules/chains/generic/llm_chain.ipynb b/docs/modules/chains/generic/llm_chain.ipynb index 372c46ed74..fddaa171b5 100644 --- a/docs/modules/chains/generic/llm_chain.ipynb +++ b/docs/modules/chains/generic/llm_chain.ipynb @@ -137,7 +137,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "a178173b-b183-432a-a517-250fe3191173", "metadata": {}, @@ -352,7 +351,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.10" + "version": "3.9.1" } }, "nbformat": 4, diff --git a/langchain/chains/__init__.py b/langchain/chains/__init__.py index 72c46963d2..219c71ef3b 100644 --- a/langchain/chains/__init__.py +++ b/langchain/chains/__init__.py @@ -22,6 +22,12 @@ from langchain.chains.llm_summarization_checker.base import LLMSummarizationChec from langchain.chains.loading import load_chain from langchain.chains.mapreduce import MapReduceChain from langchain.chains.moderation import OpenAIModerationChain +from langchain.chains.openai_functions import ( + create_extraction_chain, + create_extraction_chain_pydantic, + create_tagging_chain, + create_tagging_chain_pydantic, +) from langchain.chains.pal.base import PALChain from langchain.chains.qa_generation.base import QAGenerationChain from langchain.chains.qa_with_sources.base import QAWithSourcesChain @@ -69,4 +75,8 @@ __all__ = [ "OpenAPIEndpointChain", "FlareChain", "NebulaGraphQAChain", + "create_extraction_chain", + "create_tagging_chain", + "create_extraction_chain_pydantic", + "create_tagging_chain_pydantic", ] diff --git a/langchain/chains/openai_functions.py b/langchain/chains/openai_functions.py new file mode 100644 index 0000000000..e558289cbc --- /dev/null +++ b/langchain/chains/openai_functions.py @@ -0,0 +1,221 @@ +import json +from functools import partial +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel + +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, +) +from langchain.chains.base import Chain +from langchain.chains.sequential import SimpleSequentialChain +from langchain.chains.transform import TransformChain +from langchain.prompts.base import BasePromptTemplate +from langchain.prompts.chat import ChatPromptTemplate + + +def _resolve_schema_references(schema: Any, definitions: Dict[str, Any]) -> Any: + """ + Resolves the $ref keys in a JSON schema object using the provided definitions. + """ + if isinstance(schema, list): + for i, item in enumerate(schema): + schema[i] = _resolve_schema_references(item, definitions) + elif isinstance(schema, dict): + if "$ref" in schema: + ref_key = schema.pop("$ref").split("/")[-1] + ref = definitions.get(ref_key, {}) + schema.update(ref) + else: + for key, value in schema.items(): + schema[key] = _resolve_schema_references(value, definitions) + return schema + + +def _get_function_arguments(inputs: dict) -> str: + message = inputs["input"] + try: + func_call = message.additional_kwargs["function_call"] + except ValueError as exc: + raise ValueError(f"Could not parse function call: {exc}") + + return func_call["arguments"] + + +def _parse_tag(inputs: dict) -> dict: + args = _get_function_arguments(inputs) + return {"output": json.loads(args)} + + +def _parse_tag_pydantic(inputs: dict, pydantic_schema: Any) -> dict: + args = _get_function_arguments(inputs) + args = pydantic_schema.parse_raw(args) + return {"output": args} + + +def _parse_entities(inputs: dict) -> dict: + args = _get_function_arguments(inputs) + return {"output": json.loads(args)["info"]} + + +def _parse_entities_pydantic(inputs: dict, pydantic_schema: Any) -> dict: + args = _get_function_arguments(inputs) + pydantic_args = pydantic_schema.parse_raw(args) + return {"output": pydantic_args.info} + + +class OpenAIFunctionsChain(Chain): + prompt: BasePromptTemplate + llm: BaseLanguageModel + functions: List[Dict] + + @property + def input_keys(self) -> List[str]: + return self.prompt.input_variables + + @property + def output_keys(self) -> List[str]: + return ["output"] + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + _inputs = {k: v for k, v in inputs.items() if k in self.prompt.input_variables} + prompt = self.prompt.format_prompt(**_inputs) + messages = prompt.to_messages() + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + callbacks = _run_manager.get_child() + predicted_message = self.llm.predict_messages( + messages, functions=self.functions, callbacks=callbacks + ) + return {"output": predicted_message} + + async def _acall( + self, + inputs: Dict[str, Any], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + _inputs = {k: v for k, v in inputs.items() if k in self.prompt.input_variables} + prompt = self.prompt.format_prompt(**_inputs) + messages = prompt.to_messages() + _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() + callbacks = _run_manager.get_child() + predicted_message = await self.llm.apredict_messages( + messages, functions=self.functions, callbacks=callbacks + ) + return {"output": predicted_message} + + +def _convert_schema(schema: dict) -> dict: + props = {k: {"title": k, **v} for k, v in schema["properties"].items()} + return { + "type": "object", + "properties": props, + "required": schema.get("required", []), + } + + +def _get_extraction_functions(entity_schema: dict) -> List[dict]: + return [ + { + "name": "information_extraction", + "description": "Extracts the relevant information from the passage.", + "parameters": { + "type": "object", + "properties": { + "info": {"type": "array", "items": _convert_schema(entity_schema)} + }, + "required": ["info"], + }, + } + ] + + +def _get_tagging_functions(schema: dict) -> List[dict]: + return [ + { + "name": "information_extraction", + "description": "Extracts the relevant information from the passage.", + "parameters": _convert_schema(schema), + } + ] + + +_EXTRACTION_TEMPLATE = """Extract and save the relevant entities mentioned\ + in the following passage together with their properties. + +Passage: +{input} +""" + + +def create_extraction_chain(schema: dict, llm: BaseLanguageModel) -> Chain: + functions = _get_extraction_functions(schema) + prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE) + chain = OpenAIFunctionsChain(llm=llm, prompt=prompt, functions=functions) + parsing_chain = TransformChain( + transform=_parse_entities, + input_variables=["input"], + output_variables=["output"], + ) + return SimpleSequentialChain(chains=[chain, parsing_chain]) + + +def create_extraction_chain_pydantic( + pydantic_schema: Any, llm: BaseLanguageModel +) -> Chain: + class PydanticSchema(BaseModel): + info: List[pydantic_schema] # type: ignore + + openai_schema = PydanticSchema.schema() + openai_schema = _resolve_schema_references( + openai_schema, openai_schema["definitions"] + ) + + functions = _get_extraction_functions(openai_schema) + prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE) + chain = OpenAIFunctionsChain(llm=llm, prompt=prompt, functions=functions) + pydantic_parsing_chain = TransformChain( + transform=partial(_parse_entities_pydantic, pydantic_schema=PydanticSchema), + input_variables=["input"], + output_variables=["output"], + ) + return SimpleSequentialChain(chains=[chain, pydantic_parsing_chain]) + + +_TAGGING_TEMPLATE = """Extract the desired information from the following passage. + +Passage: +{input} +""" + + +def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain: + functions = _get_tagging_functions(schema) + prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE) + chain = OpenAIFunctionsChain(llm=llm, prompt=prompt, functions=functions) + parsing_chain = TransformChain( + transform=_parse_tag, input_variables=["input"], output_variables=["output"] + ) + return SimpleSequentialChain(chains=[chain, parsing_chain]) + + +def create_tagging_chain_pydantic( + pydantic_schema: Any, llm: BaseLanguageModel +) -> Chain: + openai_schema = pydantic_schema.schema() + + functions = _get_tagging_functions(openai_schema) + prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE) + chain = OpenAIFunctionsChain(llm=llm, prompt=prompt, functions=functions) + pydantic_parsing_chain = TransformChain( + transform=partial(_parse_tag_pydantic, pydantic_schema=pydantic_schema), + input_variables=["input"], + output_variables=["output"], + ) + + return SimpleSequentialChain(chains=[chain, pydantic_parsing_chain])