Harrison/guarded output parser (#1804)

Co-authored-by: jerwelborn <jeremy.welborn@gmail.com>
This commit is contained in:
Harrison Chase 2023-03-21 22:07:23 -07:00 committed by GitHub
parent 8fa1764c60
commit ce5d97bcb3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 567 additions and 75 deletions

View File

@ -14,6 +14,11 @@
"- `get_format_instructions() -> str`: A method which returns a string containing instructions for how the output of a language model should be formatted.\n", "- `get_format_instructions() -> str`: A method which returns a string containing instructions for how the output of a language model should be formatted.\n",
"- `parse(str) -> Any`: A method which takes in a string (assumed to be the response from a language model) and parses it into some structure.\n", "- `parse(str) -> Any`: A method which takes in a string (assumed to be the response from a language model) and parses it into some structure.\n",
"\n", "\n",
"And then one optional one:\n",
"\n",
"- `parse_with_prompt(str) -> Any`: A method which takes in a string (assumed to be the response from a language model) and a prompt (assumed to the prompt that generated such a response) and parses it into some structure. The prompt is largely provided in the event the OutputParser wants to retry or fix the output in some way, and needs information from the prompt to do so.\n",
"\n",
"\n",
"Below we go over some examples of output parsers." "Below we go over some examples of output parsers."
] ]
}, },
@ -75,7 +80,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"Joke(setup='Why did the chicken cross the playground?', punchline='To get to the other slide!')" "Joke(setup='Why did the chicken cross the road?', punchline='To get to the other side!')"
] ]
}, },
"execution_count": 4, "execution_count": 4,
@ -124,7 +129,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"Actor(name='Tom Hanks', film_names=['Forrest Gump', 'Saving Private Ryan', 'The Green Mile', 'Cast Away', 'Toy Story', 'A League of Their Own'])" "Actor(name='Tom Hanks', film_names=['Forrest Gump', 'Saving Private Ryan', 'The Green Mile', 'Cast Away', 'Toy Story'])"
] ]
}, },
"execution_count": 5, "execution_count": 5,
@ -155,11 +160,297 @@
"parser.parse(output)" "parser.parse(output)"
] ]
}, },
{
"cell_type": "markdown",
"id": "4d6c0c86",
"metadata": {},
"source": [
"## Fixing Output Parsing Mistakes\n",
"\n",
"The above guardrail simply tries to parse the LLM response. If it does not parse correctly, then it errors.\n",
"\n",
"But we can do other things besides throw errors. Specifically, we can pass the misformatted output, along with the formatted instructions, to the model and ask it to fix it.\n",
"\n",
"For this example, we'll use the above OutputParser. Here's what happens if we pass it a result that does not comply with the schema:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "73beb20d",
"metadata": {},
"outputs": [],
"source": [
"misformatted = \"{'name': 'Tom Hanks', 'film_names': ['Forrest Gump']}\""
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "f0e5ba80",
"metadata": {},
"outputs": [
{
"ename": "OutputParserException",
"evalue": "Failed to parse Actor from completion {'name': 'Tom Hanks', 'film_names': ['Forrest Gump']}. Got: Expecting property name enclosed in double quotes: line 1 column 2 (char 1)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mJSONDecodeError\u001b[0m Traceback (most recent call last)",
"File \u001b[0;32m~/workplace/langchain/langchain/output_parsers/pydantic.py:23\u001b[0m, in \u001b[0;36mPydanticOutputParser.parse\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m 22\u001b[0m json_str \u001b[38;5;241m=\u001b[39m match\u001b[38;5;241m.\u001b[39mgroup()\n\u001b[0;32m---> 23\u001b[0m json_object \u001b[38;5;241m=\u001b[39m \u001b[43mjson\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mloads\u001b[49m\u001b[43m(\u001b[49m\u001b[43mjson_str\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpydantic_object\u001b[38;5;241m.\u001b[39mparse_obj(json_object)\n",
"File \u001b[0;32m~/.pyenv/versions/3.9.1/lib/python3.9/json/__init__.py:346\u001b[0m, in \u001b[0;36mloads\u001b[0;34m(s, cls, object_hook, parse_float, parse_int, parse_constant, object_pairs_hook, **kw)\u001b[0m\n\u001b[1;32m 343\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\u001b[38;5;28mcls\u001b[39m \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m object_hook \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m\n\u001b[1;32m 344\u001b[0m parse_int \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m parse_float \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m\n\u001b[1;32m 345\u001b[0m parse_constant \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m object_pairs_hook \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m kw):\n\u001b[0;32m--> 346\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_default_decoder\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecode\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 347\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mcls\u001b[39m \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
"File \u001b[0;32m~/.pyenv/versions/3.9.1/lib/python3.9/json/decoder.py:337\u001b[0m, in \u001b[0;36mJSONDecoder.decode\u001b[0;34m(self, s, _w)\u001b[0m\n\u001b[1;32m 333\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Return the Python representation of ``s`` (a ``str`` instance\u001b[39;00m\n\u001b[1;32m 334\u001b[0m \u001b[38;5;124;03mcontaining a JSON document).\u001b[39;00m\n\u001b[1;32m 335\u001b[0m \n\u001b[1;32m 336\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m--> 337\u001b[0m obj, end \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mraw_decode\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_w\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mend\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 338\u001b[0m end \u001b[38;5;241m=\u001b[39m _w(s, end)\u001b[38;5;241m.\u001b[39mend()\n",
"File \u001b[0;32m~/.pyenv/versions/3.9.1/lib/python3.9/json/decoder.py:353\u001b[0m, in \u001b[0;36mJSONDecoder.raw_decode\u001b[0;34m(self, s, idx)\u001b[0m\n\u001b[1;32m 352\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 353\u001b[0m obj, end \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscan_once\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 354\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n",
"\u001b[0;31mJSONDecodeError\u001b[0m: Expecting property name enclosed in double quotes: line 1 column 2 (char 1)",
"\nDuring handling of the above exception, another exception occurred:\n",
"\u001b[0;31mOutputParserException\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[7], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mparser\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparse\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmisformatted\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/workplace/langchain/langchain/output_parsers/pydantic.py:29\u001b[0m, in \u001b[0;36mPydanticOutputParser.parse\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m 27\u001b[0m name \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpydantic_object\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\n\u001b[1;32m 28\u001b[0m msg \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFailed to parse \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m from completion \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtext\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. Got: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00me\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m---> 29\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m OutputParserException(msg)\n",
"\u001b[0;31mOutputParserException\u001b[0m: Failed to parse Actor from completion {'name': 'Tom Hanks', 'film_names': ['Forrest Gump']}. Got: Expecting property name enclosed in double quotes: line 1 column 2 (char 1)"
]
}
],
"source": [
"parser.parse(misformatted)"
]
},
{
"cell_type": "markdown",
"id": "6c7c82b6",
"metadata": {},
"source": [
"Now we can construct and use a `OutputFixingParser`. This output parser takes as an argument another output parser but also an LLM with which to try to correct any formatting mistakes."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "39b1a5ce",
"metadata": {},
"outputs": [],
"source": [
"from langchain.output_parsers import OutputFixingParser\n",
"\n",
"new_parser = OutputFixingParser.from_llm(parser=parser, llm=ChatOpenAI())"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "0fd96d68",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Actor(name='Tom Hanks', film_names=['Forrest Gump'])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"new_parser.parse(misformatted)"
]
},
{
"cell_type": "markdown",
"id": "ea34eeaa",
"metadata": {},
"source": [
"## Fixing Output Parsing Mistakes with the original prompt\n",
"\n",
"While in some cases it is possible to fix any parsing mistakes by only looking at the output, in other cases it can't. An example of this is when the output is not just in the incorrect format, but is partially complete. Consider the below example."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "67c5e1ac",
"metadata": {},
"outputs": [],
"source": [
"template = \"\"\"Based on the user question, provide an Action and Action Input for what step should be taken.\n",
"{format_instructions}\n",
"Question: {query}\n",
"Response:\"\"\"\n",
"class Action(BaseModel):\n",
" action: str = Field(description=\"action to take\")\n",
" action_input: str = Field(description=\"input to the action\")\n",
" \n",
"parser = PydanticOutputParser(pydantic_object=Action)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "007aa87f",
"metadata": {},
"outputs": [],
"source": [
"prompt = PromptTemplate(\n",
" template=\"Answer the user query.\\n{format_instructions}\\n{query}\\n\",\n",
" input_variables=[\"query\"],\n",
" partial_variables={\"format_instructions\": parser.get_format_instructions()}\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "10d207ff",
"metadata": {},
"outputs": [],
"source": [
"prompt_value = prompt.format_prompt(query=\"who is leo di caprios gf?\")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "68622837",
"metadata": {},
"outputs": [],
"source": [
"bad_response = '{\"action\": \"search\"}'"
]
},
{
"cell_type": "markdown",
"id": "25631465",
"metadata": {},
"source": [
"If we try to parse this response as is, we will get an error"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "894967c1",
"metadata": {},
"outputs": [
{
"ename": "OutputParserException",
"evalue": "Failed to parse Action from completion {\"action\": \"search\"}. Got: 1 validation error for Action\naction_input\n field required (type=value_error.missing)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValidationError\u001b[0m Traceback (most recent call last)",
"File \u001b[0;32m~/workplace/langchain/langchain/output_parsers/pydantic.py:24\u001b[0m, in \u001b[0;36mPydanticOutputParser.parse\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m 23\u001b[0m json_object \u001b[38;5;241m=\u001b[39m json\u001b[38;5;241m.\u001b[39mloads(json_str)\n\u001b[0;32m---> 24\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpydantic_object\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparse_obj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mjson_object\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 26\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (json\u001b[38;5;241m.\u001b[39mJSONDecodeError, ValidationError) \u001b[38;5;28;01mas\u001b[39;00m e:\n",
"File \u001b[0;32m~/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/pydantic/main.py:527\u001b[0m, in \u001b[0;36mpydantic.main.BaseModel.parse_obj\u001b[0;34m()\u001b[0m\n",
"File \u001b[0;32m~/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/pydantic/main.py:342\u001b[0m, in \u001b[0;36mpydantic.main.BaseModel.__init__\u001b[0;34m()\u001b[0m\n",
"\u001b[0;31mValidationError\u001b[0m: 1 validation error for Action\naction_input\n field required (type=value_error.missing)",
"\nDuring handling of the above exception, another exception occurred:\n",
"\u001b[0;31mOutputParserException\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[15], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mparser\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparse\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbad_response\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/workplace/langchain/langchain/output_parsers/pydantic.py:29\u001b[0m, in \u001b[0;36mPydanticOutputParser.parse\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m 27\u001b[0m name \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpydantic_object\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\n\u001b[1;32m 28\u001b[0m msg \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFailed to parse \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m from completion \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtext\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. Got: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00me\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m---> 29\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m OutputParserException(msg)\n",
"\u001b[0;31mOutputParserException\u001b[0m: Failed to parse Action from completion {\"action\": \"search\"}. Got: 1 validation error for Action\naction_input\n field required (type=value_error.missing)"
]
}
],
"source": [
"parser.parse(bad_response)"
]
},
{
"cell_type": "markdown",
"id": "f6b64696",
"metadata": {},
"source": [
"If we try to use the `OutputFixingParser` to fix this error, it will be confused - namely, it doesn't know what to actually put for action input."
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "78b2b40d",
"metadata": {},
"outputs": [],
"source": [
"fix_parser = OutputFixingParser.from_llm(parser=parser, llm=ChatOpenAI())"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "4fe1301d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Action(action='search', action_input='keyword')"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fix_parser.parse(bad_response)"
]
},
{
"cell_type": "markdown",
"id": "9bd9ea7d",
"metadata": {},
"source": [
"Instead, we can use the RetryOutputParser, which passes in the prompt (as well as the original output) to try again to get a better response."
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "7e8a8a28",
"metadata": {},
"outputs": [],
"source": [
"from langchain.output_parsers import RetryWithErrorOutputParser"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "5c86e141",
"metadata": {},
"outputs": [],
"source": [
"retry_parser = RetryWithErrorOutputParser.from_llm(parser=parser, llm=ChatOpenAI())"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "9c04f731",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Action(action='search', action_input='leo di caprios girlfriend')"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"retry_parser.parse_with_prompt(bad_response, prompt_value)"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "61f67890", "id": "61f67890",
"metadata": {}, "metadata": {},
"source": [ "source": [
"<br>\n",
"<br>\n",
"<br>\n",
"<br>\n",
"<br>\n", "<br>\n",
"<br>\n", "<br>\n",
"<br>\n", "<br>\n",
@ -168,6 +459,14 @@
"---" "---"
] ]
}, },
{
"cell_type": "markdown",
"id": "64bf525a",
"metadata": {},
"source": [
"# Older, less powerful parsers"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "91871002", "id": "91871002",
@ -180,7 +479,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 16,
"id": "b492997a", "id": "b492997a",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -198,7 +497,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 17,
"id": "432ac44a", "id": "432ac44a",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -220,7 +519,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 18,
"id": "593cfc25", "id": "593cfc25",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -243,7 +542,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 19,
"id": "106f1ba6", "id": "106f1ba6",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -253,7 +552,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 20,
"id": "86d9d24f", "id": "86d9d24f",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -264,7 +563,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": 21,
"id": "956bdc99", "id": "956bdc99",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -274,7 +573,7 @@
"{'answer': 'Paris', 'source': 'https://en.wikipedia.org/wiki/Paris'}" "{'answer': 'Paris', 'source': 'https://en.wikipedia.org/wiki/Paris'}"
] ]
}, },
"execution_count": 11, "execution_count": 21,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -293,7 +592,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": 22,
"id": "8f483d7d", "id": "8f483d7d",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -303,7 +602,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": 23,
"id": "f761cbf1", "id": "f761cbf1",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -319,7 +618,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14, "execution_count": 24,
"id": "edd73ae3", "id": "edd73ae3",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -330,7 +629,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 15, "execution_count": 25,
"id": "a3c8b91e", "id": "a3c8b91e",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -340,7 +639,7 @@
"{'answer': 'Paris', 'source': 'https://en.wikipedia.org/wiki/Paris'}" "{'answer': 'Paris', 'source': 'https://en.wikipedia.org/wiki/Paris'}"
] ]
}, },
"execution_count": 15, "execution_count": 25,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -361,7 +660,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 16, "execution_count": 26,
"id": "872246d7", "id": "872246d7",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -371,7 +670,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 17, "execution_count": 27,
"id": "c3f9aee6", "id": "c3f9aee6",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -381,7 +680,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 18, "execution_count": 28,
"id": "e77871b7", "id": "e77871b7",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -396,7 +695,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 19, "execution_count": 29,
"id": "a71cb5d3", "id": "a71cb5d3",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -406,7 +705,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 20, "execution_count": 30,
"id": "783d7d98", "id": "783d7d98",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -417,7 +716,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 21, "execution_count": 31,
"id": "fcb81344", "id": "fcb81344",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -431,7 +730,7 @@
" 'Cookies and Cream']" " 'Cookies and Cream']"
] ]
}, },
"execution_count": 21, "execution_count": 31,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -457,7 +756,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.0" "version": "3.9.1"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -13,7 +13,6 @@ from langchain.agents.conversational_chat.prompt import (
) )
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.chains import LLMChain from langchain.chains import LLMChain
from langchain.output_parsers.base import BaseOutputParser
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.chat import ( from langchain.prompts.chat import (
ChatPromptTemplate, ChatPromptTemplate,
@ -26,6 +25,7 @@ from langchain.schema import (
AIMessage, AIMessage,
BaseLanguageModel, BaseLanguageModel,
BaseMessage, BaseMessage,
BaseOutputParser,
HumanMessage, HumanMessage,
) )
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool

View File

@ -1,4 +1,4 @@
from langchain.output_parsers.base import BaseOutputParser from langchain.output_parsers.fix import OutputFixingParser
from langchain.output_parsers.list import ( from langchain.output_parsers.list import (
CommaSeparatedListOutputParser, CommaSeparatedListOutputParser,
ListOutputParser, ListOutputParser,
@ -7,6 +7,7 @@ from langchain.output_parsers.pydantic import PydanticOutputParser
from langchain.output_parsers.rail_parser import GuardrailsOutputParser from langchain.output_parsers.rail_parser import GuardrailsOutputParser
from langchain.output_parsers.regex import RegexParser from langchain.output_parsers.regex import RegexParser
from langchain.output_parsers.regex_dict import RegexDictParser from langchain.output_parsers.regex_dict import RegexDictParser
from langchain.output_parsers.retry import RetryOutputParser, RetryWithErrorOutputParser
from langchain.output_parsers.structured import ResponseSchema, StructuredOutputParser from langchain.output_parsers.structured import ResponseSchema, StructuredOutputParser
__all__ = [ __all__ = [
@ -14,9 +15,11 @@ __all__ = [
"RegexDictParser", "RegexDictParser",
"ListOutputParser", "ListOutputParser",
"CommaSeparatedListOutputParser", "CommaSeparatedListOutputParser",
"BaseOutputParser",
"StructuredOutputParser", "StructuredOutputParser",
"ResponseSchema", "ResponseSchema",
"GuardrailsOutputParser", "GuardrailsOutputParser",
"PydanticOutputParser", "PydanticOutputParser",
"RetryOutputParser",
"RetryWithErrorOutputParser",
"OutputFixingParser",
] ]

View File

@ -1,28 +0,0 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict
from pydantic import BaseModel
class BaseOutputParser(BaseModel, ABC):
"""Class to parse the output of an LLM call."""
@abstractmethod
def parse(self, text: str) -> Any:
"""Parse the output of an LLM call."""
def get_format_instructions(self) -> str:
raise NotImplementedError
@property
def _type(self) -> str:
"""Return the type key."""
raise NotImplementedError
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of output parser."""
output_parser_dict = super().dict()
output_parser_dict["_type"] = self._type
return output_parser_dict

View File

@ -0,0 +1,41 @@
from __future__ import annotations
from typing import Any
from langchain.chains.llm import LLMChain
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel, BaseOutputParser, OutputParserException
class OutputFixingParser(BaseOutputParser):
"""Wraps a parser and tries to fix parsing errors."""
parser: BaseOutputParser
retry_chain: LLMChain
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
parser: BaseOutputParser,
prompt: BasePromptTemplate = NAIVE_FIX_PROMPT,
) -> OutputFixingParser:
chain = LLMChain(llm=llm, prompt=prompt)
return cls(parser=parser, retry_chain=chain)
def parse(self, completion: str) -> Any:
try:
parsed_completion = self.parser.parse(completion)
except OutputParserException as e:
new_completion = self.retry_chain.run(
instructions=self.parser.get_format_instructions(),
completion=completion,
error=repr(e),
)
parsed_completion = self.parser.parse(new_completion)
return parsed_completion
def get_format_instructions(self) -> str:
return self.parser.get_format_instructions()

View File

@ -8,7 +8,10 @@ STRUCTURED_FORMAT_INSTRUCTIONS = """The output should be a markdown code snippet
}} }}
```""" ```"""
PYDANTIC_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below. For example, the object {{"foo": ["bar", "baz"]}} conforms to the schema {{"foo": {{"description": "a list of strings field", "type": "string"}}}}. PYDANTIC_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below.
As an example, for the schema {{"properties": {{"foo": {{"title": "Foo", "description": "a list of strings", "type": "array", "items": {{"type": "string"}}}}}}, "required": ["foo"]}}}}
the object {{"foo": ["bar", "baz"]}} is a well-formatted instance of the schema. The object {{"properties": {{"foo": ["bar", "baz"]}}}} is not well-formatted.
Here is the output schema: Here is the output schema:
``` ```

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from abc import abstractmethod from abc import abstractmethod
from typing import List from typing import List
from langchain.output_parsers.base import BaseOutputParser from langchain.schema import BaseOutputParser
class ListOutputParser(BaseOutputParser): class ListOutputParser(BaseOutputParser):

View File

@ -0,0 +1,22 @@
# flake8: noqa
from langchain.prompts.prompt import PromptTemplate
NAIVE_FIX = """Instructions:
--------------
{instructions}
--------------
Completion:
--------------
{completion}
--------------
Above, the Completion did not satisfy the constraints given in the Instructions.
Error:
--------------
{error}
--------------
Please try again. Please only respond with an answer that satisfies the constraints laid out in the Instructions:"""
NAIVE_FIX_PROMPT = PromptTemplate.from_template(NAIVE_FIX)

View File

@ -4,8 +4,8 @@ from typing import Any
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from langchain.output_parsers.base import BaseOutputParser
from langchain.output_parsers.format_instructions import PYDANTIC_FORMAT_INSTRUCTIONS from langchain.output_parsers.format_instructions import PYDANTIC_FORMAT_INSTRUCTIONS
from langchain.schema import BaseOutputParser, OutputParserException
class PydanticOutputParser(BaseOutputParser): class PydanticOutputParser(BaseOutputParser):
@ -14,7 +14,9 @@ class PydanticOutputParser(BaseOutputParser):
def parse(self, text: str) -> BaseModel: def parse(self, text: str) -> BaseModel:
try: try:
# Greedy search for 1st json candidate. # Greedy search for 1st json candidate.
match = re.search("\{.*\}", text.strip()) match = re.search(
"\{.*\}", text.strip(), re.MULTILINE | re.IGNORECASE | re.DOTALL
)
json_str = "" json_str = ""
if match: if match:
json_str = match.group() json_str = match.group()
@ -24,16 +26,17 @@ class PydanticOutputParser(BaseOutputParser):
except (json.JSONDecodeError, ValidationError) as e: except (json.JSONDecodeError, ValidationError) as e:
name = self.pydantic_object.__name__ name = self.pydantic_object.__name__
msg = f"Failed to parse {name} from completion {text}. Got: {e}" msg = f"Failed to parse {name} from completion {text}. Got: {e}"
raise ValueError(msg) raise OutputParserException(msg)
def get_format_instructions(self) -> str: def get_format_instructions(self) -> str:
schema = self.pydantic_object.schema() schema = self.pydantic_object.schema()
# Remove extraneous fields. # Remove extraneous fields.
reduced_schema = { reduced_schema = schema
prop: {"description": data["description"], "type": data["type"]} if "title" in reduced_schema:
for prop, data in schema["properties"].items() del reduced_schema["title"]
} if "type" in reduced_schema:
del reduced_schema["type"]
# Ensure json in context is well-formed with double quotes. # Ensure json in context is well-formed with double quotes.
schema = json.dumps(reduced_schema) schema = json.dumps(reduced_schema)

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Any, Dict from typing import Any, Dict
from langchain.output_parsers.base import BaseOutputParser from langchain.schema import BaseOutputParser
class GuardrailsOutputParser(BaseOutputParser): class GuardrailsOutputParser(BaseOutputParser):

View File

@ -5,7 +5,7 @@ from typing import Dict, List, Optional
from pydantic import BaseModel from pydantic import BaseModel
from langchain.output_parsers.base import BaseOutputParser from langchain.schema import BaseOutputParser
class RegexParser(BaseOutputParser, BaseModel): class RegexParser(BaseOutputParser, BaseModel):

View File

@ -5,7 +5,7 @@ from typing import Dict, Optional
from pydantic import BaseModel from pydantic import BaseModel
from langchain.output_parsers.base import BaseOutputParser from langchain.schema import BaseOutputParser
class RegexDictParser(BaseOutputParser, BaseModel): class RegexDictParser(BaseOutputParser, BaseModel):

View File

@ -0,0 +1,118 @@
from __future__ import annotations
from typing import Any
from langchain.chains.llm import LLMChain
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import (
BaseLanguageModel,
BaseOutputParser,
OutputParserException,
PromptValue,
)
NAIVE_COMPLETION_RETRY = """Prompt:
{prompt}
Completion:
{completion}
Above, the Completion did not satisfy the constraints given in the Prompt.
Please try again:"""
NAIVE_COMPLETION_RETRY_WITH_ERROR = """Prompt:
{prompt}
Completion:
{completion}
Above, the Completion did not satisfy the constraints given in the Prompt.
Details: {error}
Please try again:"""
NAIVE_RETRY_PROMPT = PromptTemplate.from_template(NAIVE_COMPLETION_RETRY)
NAIVE_RETRY_WITH_ERROR_PROMPT = PromptTemplate.from_template(
NAIVE_COMPLETION_RETRY_WITH_ERROR
)
class RetryOutputParser(BaseOutputParser):
"""Wraps a parser and tries to fix parsing errors.
Does this by passing the original prompt and the completion to another
LLM, and telling it the completion did not satisfy criteria in the prompt.
"""
parser: BaseOutputParser
retry_chain: LLMChain
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
parser: BaseOutputParser,
prompt: BasePromptTemplate = NAIVE_RETRY_PROMPT,
) -> RetryOutputParser:
chain = LLMChain(llm=llm, prompt=prompt)
return cls(parser=parser, retry_chain=chain)
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> Any:
try:
parsed_completion = self.parser.parse(completion)
except OutputParserException:
new_completion = self.retry_chain.run(
prompt=prompt_value.to_string(), completion=completion
)
parsed_completion = self.parser.parse(new_completion)
return parsed_completion
def parse(self, completion: str) -> Any:
raise NotImplementedError(
"This OutputParser can only be called by the `parse_with_prompt` method."
)
def get_format_instructions(self) -> str:
return self.parser.get_format_instructions()
class RetryWithErrorOutputParser(BaseOutputParser):
"""Wraps a parser and tries to fix parsing errors.
Does this by passing the original prompt, the completion, AND the error
that was raised to another language and telling it that the completion
did not work, and raised the given error. Differs from RetryOutputParser
in that this implementation provides the error that was raised back to the
LLM, which in theory should give it more information on how to fix it.
"""
parser: BaseOutputParser
retry_chain: LLMChain
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
parser: BaseOutputParser,
prompt: BasePromptTemplate = NAIVE_RETRY_WITH_ERROR_PROMPT,
) -> RetryWithErrorOutputParser:
chain = LLMChain(llm=llm, prompt=prompt)
return cls(parser=parser, retry_chain=chain)
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> Any:
try:
parsed_completion = self.parser.parse(completion)
except OutputParserException as e:
new_completion = self.retry_chain.run(
prompt=prompt_value.to_string(), completion=completion, error=repr(e)
)
parsed_completion = self.parser.parse(new_completion)
return parsed_completion
def parse(self, completion: str) -> Any:
raise NotImplementedError(
"This OutputParser can only be called by the `parse_with_prompt` method."
)
def get_format_instructions(self) -> str:
return self.parser.get_format_instructions()

View File

@ -5,8 +5,8 @@ from typing import List
from pydantic import BaseModel from pydantic import BaseModel
from langchain.output_parsers.base import BaseOutputParser
from langchain.output_parsers.format_instructions import STRUCTURED_FORMAT_INSTRUCTIONS from langchain.output_parsers.format_instructions import STRUCTURED_FORMAT_INSTRUCTIONS
from langchain.schema import BaseOutputParser, OutputParserException
line_template = '\t"{name}": {type} // {description}' line_template = '\t"{name}": {type} // {description}'
@ -42,7 +42,7 @@ class StructuredOutputParser(BaseOutputParser):
json_obj = json.loads(json_string) json_obj = json.loads(json_string)
for schema in self.response_schemas: for schema in self.response_schemas:
if schema.name not in json_obj: if schema.name not in json_obj:
raise ValueError( raise OutputParserException(
f"Got invalid return object. Expected key `{schema.name}` " f"Got invalid return object. Expected key `{schema.name}` "
f"to be present, but got {json_obj}" f"to be present, but got {json_obj}"
) )

View File

@ -10,13 +10,7 @@ import yaml
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import BaseModel, Extra, Field, root_validator
from langchain.formatting import formatter from langchain.formatting import formatter
from langchain.output_parsers.base import BaseOutputParser from langchain.schema import BaseMessage, BaseOutputParser, HumanMessage, PromptValue
from langchain.output_parsers.list import ( # noqa: F401
CommaSeparatedListOutputParser,
ListOutputParser,
)
from langchain.output_parsers.regex import RegexParser # noqa: F401
from langchain.schema import BaseMessage, HumanMessage, PromptValue
def jinja2_formatter(template: str, **kwargs: Any) -> str: def jinja2_formatter(template: str, **kwargs: Any) -> str:

View File

@ -259,3 +259,40 @@ class BaseMemory(BaseModel, ABC):
Memory = BaseMemory Memory = BaseMemory
class BaseOutputParser(BaseModel, ABC):
"""Class to parse the output of an LLM call."""
@abstractmethod
def parse(self, text: str) -> Any:
"""Parse the output of an LLM call."""
def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any:
return self.parse(completion)
def get_format_instructions(self) -> str:
raise NotImplementedError
@property
def _type(self) -> str:
"""Return the type key."""
raise NotImplementedError
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of output parser."""
output_parser_dict = super().dict()
output_parser_dict["_type"] = self._type
return output_parser_dict
class OutputParserException(Exception):
"""Exception that output parsers should raise to signify a parsing error.
This exists to differentiate parsing errors from other code or execution errors
that also may arise inside the output parser. OutputParserExceptions will be
available to catch and handle in ways to fix the parsing error, while other
errors will be raised.
"""
pass

View File

@ -7,8 +7,8 @@ import pytest
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.loading import load_chain from langchain.chains.loading import load_chain
from langchain.output_parsers.base import BaseOutputParser
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseOutputParser
from tests.unit_tests.llms.fake_llm import FakeLLM from tests.unit_tests.llms.fake_llm import FakeLLM