diff --git a/docs/modules/prompts/examples/output_parsers.ipynb b/docs/modules/prompts/examples/output_parsers.ipynb index 76b71df7..d8fc9206 100644 --- a/docs/modules/prompts/examples/output_parsers.ipynb +++ b/docs/modules/prompts/examples/output_parsers.ipynb @@ -17,36 +17,175 @@ "Below we go over some examples of output parsers." ] }, + { + "cell_type": "code", + "execution_count": 1, + "id": "5f0c8a33", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.prompts import PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate\n", + "from langchain.llms import OpenAI\n", + "from langchain.chat_models import ChatOpenAI" + ] + }, { "cell_type": "markdown", - "id": "91871002", + "id": "a1ae632a", "metadata": {}, "source": [ - "## Structured Output Parser\n", + "## PydanticOutputParser\n", + "This output parser allows users to specify an arbitrary JSON schema and query LLMs for JSON outputs that conform to that schema.\n", "\n", - "This output parser can be used when you want to return multiple fields." + "Keep in mind that large language models are leaky abstractions! You'll have to use an LLM with sufficient capacity to generate well-formed JSON. In the OpenAI family, DaVinci can do reliably but Curie's ability already drops off dramatically. \n", + "\n", + "Use Pydantic to declare your data model. Pydantic's BaseModel like a Python dataclass, but with actual type checking + coercion." ] }, { "cell_type": "code", - "execution_count": 1, - "id": "b492997a", + "execution_count": 2, + "id": "cba6d8e3", "metadata": {}, "outputs": [], "source": [ - "from langchain.output_parsers import StructuredOutputParser, ResponseSchema" + "from langchain.output_parsers import PydanticOutputParser\n", + "from pydantic import BaseModel, Field, validator\n", + "from typing import List" ] }, { "cell_type": "code", - "execution_count": 2, - "id": "ffb7fc57", + "execution_count": 3, + "id": "0a203100", "metadata": {}, "outputs": [], "source": [ - "from langchain.prompts import PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate\n", - "from langchain.llms import OpenAI\n", - "from langchain.chat_models import ChatOpenAI" + "model_name = 'text-davinci-003'\n", + "temperature = 0.0\n", + "model = OpenAI(model_name=model_name, temperature=temperature)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "b3f16168", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Joke(setup='Why did the chicken cross the playground?', punchline='To get to the other slide!')" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Define your desired data structure.\n", + "class Joke(BaseModel):\n", + " setup: str = Field(description=\"question to set up a joke\")\n", + " punchline: str = Field(description=\"answer to resolve the joke\")\n", + " \n", + " # You can add custom validation logic easily with Pydantic.\n", + " @validator('setup')\n", + " def question_ends_with_question_mark(cls, field):\n", + " if field[-1] != '?':\n", + " raise ValueError(\"Badly formed question!\")\n", + " return field\n", + "\n", + "# And a query intented to prompt a language model to populate the data structure.\n", + "joke_query = \"Tell me a joke.\"\n", + "\n", + "# Set up a parser + inject instructions into the prompt template.\n", + "parser = PydanticOutputParser(pydantic_object=Joke)\n", + "\n", + "prompt = PromptTemplate(\n", + " template=\"Answer the user query.\\n{format_instructions}\\n{query}\\n\",\n", + " input_variables=[\"query\"],\n", + " partial_variables={\"format_instructions\": parser.get_format_instructions()}\n", + ")\n", + "\n", + "_input = prompt.format_prompt(query=joke_query)\n", + "\n", + "output = model(_input.to_string())\n", + "\n", + "parser.parse(output)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "03049f88", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Actor(name='Tom Hanks', film_names=['Forrest Gump', 'Saving Private Ryan', 'The Green Mile', 'Cast Away', 'Toy Story', 'A League of Their Own'])" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Here's another example, but with a compound typed field.\n", + "class Actor(BaseModel):\n", + " name: str = Field(description=\"name of an actor\")\n", + " film_names: List[str] = Field(description=\"list of names of films they starred in\")\n", + " \n", + "actor_query = \"Generate the filmography for a random actor.\"\n", + "\n", + "parser = PydanticOutputParser(pydantic_object=Actor)\n", + "\n", + "prompt = PromptTemplate(\n", + " template=\"Answer the user query.\\n{format_instructions}\\n{query}\\n\",\n", + " input_variables=[\"query\"],\n", + " partial_variables={\"format_instructions\": parser.get_format_instructions()}\n", + ")\n", + "\n", + "_input = prompt.format_prompt(query=actor_query)\n", + "\n", + "output = model(_input.to_string())\n", + "\n", + "parser.parse(output)" + ] + }, + { + "cell_type": "markdown", + "id": "61f67890", + "metadata": {}, + "source": [ + "
\n", + "
\n", + "
\n", + "
\n", + "\n", + "---" + ] + }, + { + "cell_type": "markdown", + "id": "91871002", + "metadata": {}, + "source": [ + "## Structured Output Parser\n", + "\n", + "While the Pydantic/JSON parser is more powerful, we initially experimented data structures having text fields only." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "b492997a", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.output_parsers import StructuredOutputParser, ResponseSchema" ] }, { @@ -59,7 +198,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 7, "id": "432ac44a", "metadata": {}, "outputs": [], @@ -81,7 +220,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 8, "id": "593cfc25", "metadata": {}, "outputs": [], @@ -104,7 +243,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 9, "id": "106f1ba6", "metadata": {}, "outputs": [], @@ -114,7 +253,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 10, "id": "86d9d24f", "metadata": {}, "outputs": [], @@ -125,7 +264,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 11, "id": "956bdc99", "metadata": {}, "outputs": [ @@ -135,7 +274,7 @@ "{'answer': 'Paris', 'source': 'https://en.wikipedia.org/wiki/Paris'}" ] }, - "execution_count": 7, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -154,7 +293,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 12, "id": "8f483d7d", "metadata": {}, "outputs": [], @@ -164,7 +303,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 13, "id": "f761cbf1", "metadata": {}, "outputs": [], @@ -180,7 +319,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 14, "id": "edd73ae3", "metadata": {}, "outputs": [], @@ -191,7 +330,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 15, "id": "a3c8b91e", "metadata": {}, "outputs": [ @@ -201,7 +340,7 @@ "{'answer': 'Paris', 'source': 'https://en.wikipedia.org/wiki/Paris'}" ] }, - "execution_count": 11, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -217,12 +356,12 @@ "source": [ "## CommaSeparatedListOutputParser\n", "\n", - "This output parser can be used to get a list of items as output." + "Here's another parser strictly less powerful than Pydantic/JSON parsing." ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 16, "id": "872246d7", "metadata": {}, "outputs": [], @@ -232,7 +371,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 17, "id": "c3f9aee6", "metadata": {}, "outputs": [], @@ -242,7 +381,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 18, "id": "e77871b7", "metadata": {}, "outputs": [], @@ -257,7 +396,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 19, "id": "a71cb5d3", "metadata": {}, "outputs": [], @@ -267,7 +406,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 20, "id": "783d7d98", "metadata": {}, "outputs": [], @@ -278,7 +417,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 21, "id": "fcb81344", "metadata": {}, "outputs": [ @@ -292,7 +431,7 @@ " 'Cookies and Cream']" ] }, - "execution_count": 17, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -300,14 +439,6 @@ "source": [ "output_parser.parse(output)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cba6d8e3", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -326,7 +457,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.1" + "version": "3.9.0" } }, "nbformat": 4, diff --git a/langchain/output_parsers/__init__.py b/langchain/output_parsers/__init__.py index 268f62a5..fd2f0f24 100644 --- a/langchain/output_parsers/__init__.py +++ b/langchain/output_parsers/__init__.py @@ -3,6 +3,7 @@ from langchain.output_parsers.list import ( CommaSeparatedListOutputParser, ListOutputParser, ) +from langchain.output_parsers.pydantic import PydanticOutputParser from langchain.output_parsers.rail_parser import GuardrailsOutputParser from langchain.output_parsers.regex import RegexParser from langchain.output_parsers.regex_dict import RegexDictParser @@ -17,4 +18,5 @@ __all__ = [ "StructuredOutputParser", "ResponseSchema", "GuardrailsOutputParser", + "PydanticOutputParser", ] diff --git a/langchain/output_parsers/format_instructions.py b/langchain/output_parsers/format_instructions.py index 3653d477..1c6639a9 100644 --- a/langchain/output_parsers/format_instructions.py +++ b/langchain/output_parsers/format_instructions.py @@ -7,3 +7,10 @@ STRUCTURED_FORMAT_INSTRUCTIONS = """The output should be a markdown code snippet {format} }} ```""" + +PYDANTIC_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below. For example, the object {{"foo": ["bar", "baz"]}} conforms to the schema {{"foo": {{"description": "a list of strings field", "type": "string"}}}}. + +Here is the output schema: +``` +{schema} +```""" diff --git a/langchain/output_parsers/pydantic.py b/langchain/output_parsers/pydantic.py new file mode 100644 index 00000000..e441509b --- /dev/null +++ b/langchain/output_parsers/pydantic.py @@ -0,0 +1,40 @@ +import json +import re +from typing import Any + +from pydantic import BaseModel, ValidationError + +from langchain.output_parsers.base import BaseOutputParser +from langchain.output_parsers.format_instructions import PYDANTIC_FORMAT_INSTRUCTIONS + + +class PydanticOutputParser(BaseOutputParser): + pydantic_object: Any + + def parse(self, text: str) -> BaseModel: + try: + # Greedy search for 1st json candidate. + match = re.search("\{.*\}", text.strip()) + json_str = "" + if match: + json_str = match.group() + json_object = json.loads(json_str) + return self.pydantic_object.parse_obj(json_object) + + except (json.JSONDecodeError, ValidationError) as e: + name = self.pydantic_object.__name__ + msg = f"Failed to parse {name} from completion {text}. Got: {e}" + raise ValueError(msg) + + def get_format_instructions(self) -> str: + schema = self.pydantic_object.schema() + + # Remove extraneous fields. + reduced_schema = { + prop: {"description": data["description"], "type": data["type"]} + for prop, data in schema["properties"].items() + } + # Ensure json in context is well-formed with double quotes. + schema = json.dumps(reduced_schema) + + return PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema)