diff --git a/docs/modules/prompts/examples/output_parsers.ipynb b/docs/modules/prompts/examples/output_parsers.ipynb index d8fc9206..449d0f4f 100644 --- a/docs/modules/prompts/examples/output_parsers.ipynb +++ b/docs/modules/prompts/examples/output_parsers.ipynb @@ -14,6 +14,11 @@ "- `get_format_instructions() -> str`: A method which returns a string containing instructions for how the output of a language model should be formatted.\n", "- `parse(str) -> Any`: A method which takes in a string (assumed to be the response from a language model) and parses it into some structure.\n", "\n", + "And then one optional one:\n", + "\n", + "- `parse_with_prompt(str) -> Any`: A method which takes in a string (assumed to be the response from a language model) and a prompt (assumed to the prompt that generated such a response) and parses it into some structure. The prompt is largely provided in the event the OutputParser wants to retry or fix the output in some way, and needs information from the prompt to do so.\n", + "\n", + "\n", "Below we go over some examples of output parsers." ] }, @@ -75,7 +80,7 @@ { "data": { "text/plain": [ - "Joke(setup='Why did the chicken cross the playground?', punchline='To get to the other slide!')" + "Joke(setup='Why did the chicken cross the road?', punchline='To get to the other side!')" ] }, "execution_count": 4, @@ -124,7 +129,7 @@ { "data": { "text/plain": [ - "Actor(name='Tom Hanks', film_names=['Forrest Gump', 'Saving Private Ryan', 'The Green Mile', 'Cast Away', 'Toy Story', 'A League of Their Own'])" + "Actor(name='Tom Hanks', film_names=['Forrest Gump', 'Saving Private Ryan', 'The Green Mile', 'Cast Away', 'Toy Story'])" ] }, "execution_count": 5, @@ -155,11 +160,297 @@ "parser.parse(output)" ] }, + { + "cell_type": "markdown", + "id": "4d6c0c86", + "metadata": {}, + "source": [ + "## Fixing Output Parsing Mistakes\n", + "\n", + "The above guardrail simply tries to parse the LLM response. If it does not parse correctly, then it errors.\n", + "\n", + "But we can do other things besides throw errors. Specifically, we can pass the misformatted output, along with the formatted instructions, to the model and ask it to fix it.\n", + "\n", + "For this example, we'll use the above OutputParser. Here's what happens if we pass it a result that does not comply with the schema:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "73beb20d", + "metadata": {}, + "outputs": [], + "source": [ + "misformatted = \"{'name': 'Tom Hanks', 'film_names': ['Forrest Gump']}\"" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "f0e5ba80", + "metadata": {}, + "outputs": [ + { + "ename": "OutputParserException", + "evalue": "Failed to parse Actor from completion {'name': 'Tom Hanks', 'film_names': ['Forrest Gump']}. Got: Expecting property name enclosed in double quotes: line 1 column 2 (char 1)", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mJSONDecodeError\u001b[0m Traceback (most recent call last)", + "File \u001b[0;32m~/workplace/langchain/langchain/output_parsers/pydantic.py:23\u001b[0m, in \u001b[0;36mPydanticOutputParser.parse\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m 22\u001b[0m json_str \u001b[38;5;241m=\u001b[39m match\u001b[38;5;241m.\u001b[39mgroup()\n\u001b[0;32m---> 23\u001b[0m json_object \u001b[38;5;241m=\u001b[39m \u001b[43mjson\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mloads\u001b[49m\u001b[43m(\u001b[49m\u001b[43mjson_str\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpydantic_object\u001b[38;5;241m.\u001b[39mparse_obj(json_object)\n", + "File \u001b[0;32m~/.pyenv/versions/3.9.1/lib/python3.9/json/__init__.py:346\u001b[0m, in \u001b[0;36mloads\u001b[0;34m(s, cls, object_hook, parse_float, parse_int, parse_constant, object_pairs_hook, **kw)\u001b[0m\n\u001b[1;32m 343\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\u001b[38;5;28mcls\u001b[39m \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m object_hook \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m\n\u001b[1;32m 344\u001b[0m parse_int \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m parse_float \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m\n\u001b[1;32m 345\u001b[0m parse_constant \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m object_pairs_hook \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m kw):\n\u001b[0;32m--> 346\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_default_decoder\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecode\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 347\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mcls\u001b[39m \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", + "File \u001b[0;32m~/.pyenv/versions/3.9.1/lib/python3.9/json/decoder.py:337\u001b[0m, in \u001b[0;36mJSONDecoder.decode\u001b[0;34m(self, s, _w)\u001b[0m\n\u001b[1;32m 333\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Return the Python representation of ``s`` (a ``str`` instance\u001b[39;00m\n\u001b[1;32m 334\u001b[0m \u001b[38;5;124;03mcontaining a JSON document).\u001b[39;00m\n\u001b[1;32m 335\u001b[0m \n\u001b[1;32m 336\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m--> 337\u001b[0m obj, end \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mraw_decode\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_w\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mend\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 338\u001b[0m end \u001b[38;5;241m=\u001b[39m _w(s, end)\u001b[38;5;241m.\u001b[39mend()\n", + "File \u001b[0;32m~/.pyenv/versions/3.9.1/lib/python3.9/json/decoder.py:353\u001b[0m, in \u001b[0;36mJSONDecoder.raw_decode\u001b[0;34m(self, s, idx)\u001b[0m\n\u001b[1;32m 352\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 353\u001b[0m obj, end \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscan_once\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 354\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n", + "\u001b[0;31mJSONDecodeError\u001b[0m: Expecting property name enclosed in double quotes: line 1 column 2 (char 1)", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[0;31mOutputParserException\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[7], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mparser\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparse\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmisformatted\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/workplace/langchain/langchain/output_parsers/pydantic.py:29\u001b[0m, in \u001b[0;36mPydanticOutputParser.parse\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m 27\u001b[0m name \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpydantic_object\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\n\u001b[1;32m 28\u001b[0m msg \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFailed to parse \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m from completion \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtext\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. Got: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00me\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m---> 29\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m OutputParserException(msg)\n", + "\u001b[0;31mOutputParserException\u001b[0m: Failed to parse Actor from completion {'name': 'Tom Hanks', 'film_names': ['Forrest Gump']}. Got: Expecting property name enclosed in double quotes: line 1 column 2 (char 1)" + ] + } + ], + "source": [ + "parser.parse(misformatted)" + ] + }, + { + "cell_type": "markdown", + "id": "6c7c82b6", + "metadata": {}, + "source": [ + "Now we can construct and use a `OutputFixingParser`. This output parser takes as an argument another output parser but also an LLM with which to try to correct any formatting mistakes." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "39b1a5ce", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.output_parsers import OutputFixingParser\n", + "\n", + "new_parser = OutputFixingParser.from_llm(parser=parser, llm=ChatOpenAI())" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "0fd96d68", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Actor(name='Tom Hanks', film_names=['Forrest Gump'])" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "new_parser.parse(misformatted)" + ] + }, + { + "cell_type": "markdown", + "id": "ea34eeaa", + "metadata": {}, + "source": [ + "## Fixing Output Parsing Mistakes with the original prompt\n", + "\n", + "While in some cases it is possible to fix any parsing mistakes by only looking at the output, in other cases it can't. An example of this is when the output is not just in the incorrect format, but is partially complete. Consider the below example." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "67c5e1ac", + "metadata": {}, + "outputs": [], + "source": [ + "template = \"\"\"Based on the user question, provide an Action and Action Input for what step should be taken.\n", + "{format_instructions}\n", + "Question: {query}\n", + "Response:\"\"\"\n", + "class Action(BaseModel):\n", + " action: str = Field(description=\"action to take\")\n", + " action_input: str = Field(description=\"input to the action\")\n", + " \n", + "parser = PydanticOutputParser(pydantic_object=Action)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "007aa87f", + "metadata": {}, + "outputs": [], + "source": [ + "prompt = PromptTemplate(\n", + " template=\"Answer the user query.\\n{format_instructions}\\n{query}\\n\",\n", + " input_variables=[\"query\"],\n", + " partial_variables={\"format_instructions\": parser.get_format_instructions()}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "10d207ff", + "metadata": {}, + "outputs": [], + "source": [ + "prompt_value = prompt.format_prompt(query=\"who is leo di caprios gf?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "68622837", + "metadata": {}, + "outputs": [], + "source": [ + "bad_response = '{\"action\": \"search\"}'" + ] + }, + { + "cell_type": "markdown", + "id": "25631465", + "metadata": {}, + "source": [ + "If we try to parse this response as is, we will get an error" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "894967c1", + "metadata": {}, + "outputs": [ + { + "ename": "OutputParserException", + "evalue": "Failed to parse Action from completion {\"action\": \"search\"}. Got: 1 validation error for Action\naction_input\n field required (type=value_error.missing)", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValidationError\u001b[0m Traceback (most recent call last)", + "File \u001b[0;32m~/workplace/langchain/langchain/output_parsers/pydantic.py:24\u001b[0m, in \u001b[0;36mPydanticOutputParser.parse\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m 23\u001b[0m json_object \u001b[38;5;241m=\u001b[39m json\u001b[38;5;241m.\u001b[39mloads(json_str)\n\u001b[0;32m---> 24\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpydantic_object\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparse_obj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mjson_object\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 26\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (json\u001b[38;5;241m.\u001b[39mJSONDecodeError, ValidationError) \u001b[38;5;28;01mas\u001b[39;00m e:\n", + "File \u001b[0;32m~/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/pydantic/main.py:527\u001b[0m, in \u001b[0;36mpydantic.main.BaseModel.parse_obj\u001b[0;34m()\u001b[0m\n", + "File \u001b[0;32m~/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/pydantic/main.py:342\u001b[0m, in \u001b[0;36mpydantic.main.BaseModel.__init__\u001b[0;34m()\u001b[0m\n", + "\u001b[0;31mValidationError\u001b[0m: 1 validation error for Action\naction_input\n field required (type=value_error.missing)", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[0;31mOutputParserException\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[15], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mparser\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparse\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbad_response\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/workplace/langchain/langchain/output_parsers/pydantic.py:29\u001b[0m, in \u001b[0;36mPydanticOutputParser.parse\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m 27\u001b[0m name \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpydantic_object\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\n\u001b[1;32m 28\u001b[0m msg \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFailed to parse \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m from completion \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtext\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. Got: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00me\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m---> 29\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m OutputParserException(msg)\n", + "\u001b[0;31mOutputParserException\u001b[0m: Failed to parse Action from completion {\"action\": \"search\"}. Got: 1 validation error for Action\naction_input\n field required (type=value_error.missing)" + ] + } + ], + "source": [ + "parser.parse(bad_response)" + ] + }, + { + "cell_type": "markdown", + "id": "f6b64696", + "metadata": {}, + "source": [ + "If we try to use the `OutputFixingParser` to fix this error, it will be confused - namely, it doesn't know what to actually put for action input." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "78b2b40d", + "metadata": {}, + "outputs": [], + "source": [ + "fix_parser = OutputFixingParser.from_llm(parser=parser, llm=ChatOpenAI())" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "4fe1301d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Action(action='search', action_input='keyword')" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fix_parser.parse(bad_response)" + ] + }, + { + "cell_type": "markdown", + "id": "9bd9ea7d", + "metadata": {}, + "source": [ + "Instead, we can use the RetryOutputParser, which passes in the prompt (as well as the original output) to try again to get a better response." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "7e8a8a28", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.output_parsers import RetryWithErrorOutputParser" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "5c86e141", + "metadata": {}, + "outputs": [], + "source": [ + "retry_parser = RetryWithErrorOutputParser.from_llm(parser=parser, llm=ChatOpenAI())" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "9c04f731", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Action(action='search', action_input='leo di caprios girlfriend')" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "retry_parser.parse_with_prompt(bad_response, prompt_value)" + ] + }, { "cell_type": "markdown", "id": "61f67890", "metadata": {}, "source": [ + "
\n", + "
\n", + "
\n", + "
\n", "
\n", "
\n", "
\n", @@ -168,6 +459,14 @@ "---" ] }, + { + "cell_type": "markdown", + "id": "64bf525a", + "metadata": {}, + "source": [ + "# Older, less powerful parsers" + ] + }, { "cell_type": "markdown", "id": "91871002", @@ -180,7 +479,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 16, "id": "b492997a", "metadata": {}, "outputs": [], @@ -198,7 +497,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 17, "id": "432ac44a", "metadata": {}, "outputs": [], @@ -220,7 +519,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 18, "id": "593cfc25", "metadata": {}, "outputs": [], @@ -243,7 +542,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 19, "id": "106f1ba6", "metadata": {}, "outputs": [], @@ -253,7 +552,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 20, "id": "86d9d24f", "metadata": {}, "outputs": [], @@ -264,7 +563,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 21, "id": "956bdc99", "metadata": {}, "outputs": [ @@ -274,7 +573,7 @@ "{'answer': 'Paris', 'source': 'https://en.wikipedia.org/wiki/Paris'}" ] }, - "execution_count": 11, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -293,7 +592,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 22, "id": "8f483d7d", "metadata": {}, "outputs": [], @@ -303,7 +602,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 23, "id": "f761cbf1", "metadata": {}, "outputs": [], @@ -319,7 +618,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 24, "id": "edd73ae3", "metadata": {}, "outputs": [], @@ -330,7 +629,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 25, "id": "a3c8b91e", "metadata": {}, "outputs": [ @@ -340,7 +639,7 @@ "{'answer': 'Paris', 'source': 'https://en.wikipedia.org/wiki/Paris'}" ] }, - "execution_count": 15, + "execution_count": 25, "metadata": {}, "output_type": "execute_result" } @@ -361,7 +660,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 26, "id": "872246d7", "metadata": {}, "outputs": [], @@ -371,7 +670,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 27, "id": "c3f9aee6", "metadata": {}, "outputs": [], @@ -381,7 +680,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 28, "id": "e77871b7", "metadata": {}, "outputs": [], @@ -396,7 +695,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 29, "id": "a71cb5d3", "metadata": {}, "outputs": [], @@ -406,7 +705,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 30, "id": "783d7d98", "metadata": {}, "outputs": [], @@ -417,7 +716,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 31, "id": "fcb81344", "metadata": {}, "outputs": [ @@ -431,7 +730,7 @@ " 'Cookies and Cream']" ] }, - "execution_count": 21, + "execution_count": 31, "metadata": {}, "output_type": "execute_result" } @@ -457,7 +756,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.0" + "version": "3.9.1" } }, "nbformat": 4, diff --git a/langchain/agents/conversational_chat/base.py b/langchain/agents/conversational_chat/base.py index e3cc9494..2a1c8b7f 100644 --- a/langchain/agents/conversational_chat/base.py +++ b/langchain/agents/conversational_chat/base.py @@ -13,7 +13,6 @@ from langchain.agents.conversational_chat.prompt import ( ) from langchain.callbacks.base import BaseCallbackManager from langchain.chains import LLMChain -from langchain.output_parsers.base import BaseOutputParser from langchain.prompts.base import BasePromptTemplate from langchain.prompts.chat import ( ChatPromptTemplate, @@ -26,6 +25,7 @@ from langchain.schema import ( AIMessage, BaseLanguageModel, BaseMessage, + BaseOutputParser, HumanMessage, ) from langchain.tools.base import BaseTool diff --git a/langchain/output_parsers/__init__.py b/langchain/output_parsers/__init__.py index fd2f0f24..4133e4a1 100644 --- a/langchain/output_parsers/__init__.py +++ b/langchain/output_parsers/__init__.py @@ -1,4 +1,4 @@ -from langchain.output_parsers.base import BaseOutputParser +from langchain.output_parsers.fix import OutputFixingParser from langchain.output_parsers.list import ( CommaSeparatedListOutputParser, ListOutputParser, @@ -7,6 +7,7 @@ from langchain.output_parsers.pydantic import PydanticOutputParser from langchain.output_parsers.rail_parser import GuardrailsOutputParser from langchain.output_parsers.regex import RegexParser from langchain.output_parsers.regex_dict import RegexDictParser +from langchain.output_parsers.retry import RetryOutputParser, RetryWithErrorOutputParser from langchain.output_parsers.structured import ResponseSchema, StructuredOutputParser __all__ = [ @@ -14,9 +15,11 @@ __all__ = [ "RegexDictParser", "ListOutputParser", "CommaSeparatedListOutputParser", - "BaseOutputParser", "StructuredOutputParser", "ResponseSchema", "GuardrailsOutputParser", "PydanticOutputParser", + "RetryOutputParser", + "RetryWithErrorOutputParser", + "OutputFixingParser", ] diff --git a/langchain/output_parsers/base.py b/langchain/output_parsers/base.py deleted file mode 100644 index f3598416..00000000 --- a/langchain/output_parsers/base.py +++ /dev/null @@ -1,28 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import Any, Dict - -from pydantic import BaseModel - - -class BaseOutputParser(BaseModel, ABC): - """Class to parse the output of an LLM call.""" - - @abstractmethod - def parse(self, text: str) -> Any: - """Parse the output of an LLM call.""" - - def get_format_instructions(self) -> str: - raise NotImplementedError - - @property - def _type(self) -> str: - """Return the type key.""" - raise NotImplementedError - - def dict(self, **kwargs: Any) -> Dict: - """Return dictionary representation of output parser.""" - output_parser_dict = super().dict() - output_parser_dict["_type"] = self._type - return output_parser_dict diff --git a/langchain/output_parsers/fix.py b/langchain/output_parsers/fix.py new file mode 100644 index 00000000..2654948a --- /dev/null +++ b/langchain/output_parsers/fix.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from typing import Any + +from langchain.chains.llm import LLMChain +from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT +from langchain.prompts.base import BasePromptTemplate +from langchain.schema import BaseLanguageModel, BaseOutputParser, OutputParserException + + +class OutputFixingParser(BaseOutputParser): + """Wraps a parser and tries to fix parsing errors.""" + + parser: BaseOutputParser + retry_chain: LLMChain + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + parser: BaseOutputParser, + prompt: BasePromptTemplate = NAIVE_FIX_PROMPT, + ) -> OutputFixingParser: + chain = LLMChain(llm=llm, prompt=prompt) + return cls(parser=parser, retry_chain=chain) + + def parse(self, completion: str) -> Any: + try: + parsed_completion = self.parser.parse(completion) + except OutputParserException as e: + new_completion = self.retry_chain.run( + instructions=self.parser.get_format_instructions(), + completion=completion, + error=repr(e), + ) + parsed_completion = self.parser.parse(new_completion) + + return parsed_completion + + def get_format_instructions(self) -> str: + return self.parser.get_format_instructions() diff --git a/langchain/output_parsers/format_instructions.py b/langchain/output_parsers/format_instructions.py index 1c6639a9..c91d6ca7 100644 --- a/langchain/output_parsers/format_instructions.py +++ b/langchain/output_parsers/format_instructions.py @@ -8,7 +8,10 @@ STRUCTURED_FORMAT_INSTRUCTIONS = """The output should be a markdown code snippet }} ```""" -PYDANTIC_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below. For example, the object {{"foo": ["bar", "baz"]}} conforms to the schema {{"foo": {{"description": "a list of strings field", "type": "string"}}}}. +PYDANTIC_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below. + +As an example, for the schema {{"properties": {{"foo": {{"title": "Foo", "description": "a list of strings", "type": "array", "items": {{"type": "string"}}}}}}, "required": ["foo"]}}}} +the object {{"foo": ["bar", "baz"]}} is a well-formatted instance of the schema. The object {{"properties": {{"foo": ["bar", "baz"]}}}} is not well-formatted. Here is the output schema: ``` diff --git a/langchain/output_parsers/list.py b/langchain/output_parsers/list.py index aebbf409..32b26742 100644 --- a/langchain/output_parsers/list.py +++ b/langchain/output_parsers/list.py @@ -3,7 +3,7 @@ from __future__ import annotations from abc import abstractmethod from typing import List -from langchain.output_parsers.base import BaseOutputParser +from langchain.schema import BaseOutputParser class ListOutputParser(BaseOutputParser): diff --git a/langchain/output_parsers/prompts.py b/langchain/output_parsers/prompts.py new file mode 100644 index 00000000..5ea37b24 --- /dev/null +++ b/langchain/output_parsers/prompts.py @@ -0,0 +1,22 @@ +# flake8: noqa +from langchain.prompts.prompt import PromptTemplate + +NAIVE_FIX = """Instructions: +-------------- +{instructions} +-------------- +Completion: +-------------- +{completion} +-------------- + +Above, the Completion did not satisfy the constraints given in the Instructions. +Error: +-------------- +{error} +-------------- + +Please try again. Please only respond with an answer that satisfies the constraints laid out in the Instructions:""" + + +NAIVE_FIX_PROMPT = PromptTemplate.from_template(NAIVE_FIX) diff --git a/langchain/output_parsers/pydantic.py b/langchain/output_parsers/pydantic.py index e441509b..9f818cd9 100644 --- a/langchain/output_parsers/pydantic.py +++ b/langchain/output_parsers/pydantic.py @@ -4,8 +4,8 @@ from typing import Any from pydantic import BaseModel, ValidationError -from langchain.output_parsers.base import BaseOutputParser from langchain.output_parsers.format_instructions import PYDANTIC_FORMAT_INSTRUCTIONS +from langchain.schema import BaseOutputParser, OutputParserException class PydanticOutputParser(BaseOutputParser): @@ -14,7 +14,9 @@ class PydanticOutputParser(BaseOutputParser): def parse(self, text: str) -> BaseModel: try: # Greedy search for 1st json candidate. - match = re.search("\{.*\}", text.strip()) + match = re.search( + "\{.*\}", text.strip(), re.MULTILINE | re.IGNORECASE | re.DOTALL + ) json_str = "" if match: json_str = match.group() @@ -24,16 +26,17 @@ class PydanticOutputParser(BaseOutputParser): except (json.JSONDecodeError, ValidationError) as e: name = self.pydantic_object.__name__ msg = f"Failed to parse {name} from completion {text}. Got: {e}" - raise ValueError(msg) + raise OutputParserException(msg) def get_format_instructions(self) -> str: schema = self.pydantic_object.schema() # Remove extraneous fields. - reduced_schema = { - prop: {"description": data["description"], "type": data["type"]} - for prop, data in schema["properties"].items() - } + reduced_schema = schema + if "title" in reduced_schema: + del reduced_schema["title"] + if "type" in reduced_schema: + del reduced_schema["type"] # Ensure json in context is well-formed with double quotes. schema = json.dumps(reduced_schema) diff --git a/langchain/output_parsers/rail_parser.py b/langchain/output_parsers/rail_parser.py index 432ea851..0dab50d9 100644 --- a/langchain/output_parsers/rail_parser.py +++ b/langchain/output_parsers/rail_parser.py @@ -2,7 +2,7 @@ from __future__ import annotations from typing import Any, Dict -from langchain.output_parsers.base import BaseOutputParser +from langchain.schema import BaseOutputParser class GuardrailsOutputParser(BaseOutputParser): diff --git a/langchain/output_parsers/regex.py b/langchain/output_parsers/regex.py index b58137a7..dd03556b 100644 --- a/langchain/output_parsers/regex.py +++ b/langchain/output_parsers/regex.py @@ -5,7 +5,7 @@ from typing import Dict, List, Optional from pydantic import BaseModel -from langchain.output_parsers.base import BaseOutputParser +from langchain.schema import BaseOutputParser class RegexParser(BaseOutputParser, BaseModel): diff --git a/langchain/output_parsers/regex_dict.py b/langchain/output_parsers/regex_dict.py index f17e8203..d37f2564 100644 --- a/langchain/output_parsers/regex_dict.py +++ b/langchain/output_parsers/regex_dict.py @@ -5,7 +5,7 @@ from typing import Dict, Optional from pydantic import BaseModel -from langchain.output_parsers.base import BaseOutputParser +from langchain.schema import BaseOutputParser class RegexDictParser(BaseOutputParser, BaseModel): diff --git a/langchain/output_parsers/retry.py b/langchain/output_parsers/retry.py new file mode 100644 index 00000000..7c6760ea --- /dev/null +++ b/langchain/output_parsers/retry.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +from typing import Any + +from langchain.chains.llm import LLMChain +from langchain.prompts.base import BasePromptTemplate +from langchain.prompts.prompt import PromptTemplate +from langchain.schema import ( + BaseLanguageModel, + BaseOutputParser, + OutputParserException, + PromptValue, +) + +NAIVE_COMPLETION_RETRY = """Prompt: +{prompt} +Completion: +{completion} + +Above, the Completion did not satisfy the constraints given in the Prompt. +Please try again:""" + +NAIVE_COMPLETION_RETRY_WITH_ERROR = """Prompt: +{prompt} +Completion: +{completion} + +Above, the Completion did not satisfy the constraints given in the Prompt. +Details: {error} +Please try again:""" + +NAIVE_RETRY_PROMPT = PromptTemplate.from_template(NAIVE_COMPLETION_RETRY) +NAIVE_RETRY_WITH_ERROR_PROMPT = PromptTemplate.from_template( + NAIVE_COMPLETION_RETRY_WITH_ERROR +) + + +class RetryOutputParser(BaseOutputParser): + """Wraps a parser and tries to fix parsing errors. + + Does this by passing the original prompt and the completion to another + LLM, and telling it the completion did not satisfy criteria in the prompt. + """ + + parser: BaseOutputParser + retry_chain: LLMChain + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + parser: BaseOutputParser, + prompt: BasePromptTemplate = NAIVE_RETRY_PROMPT, + ) -> RetryOutputParser: + chain = LLMChain(llm=llm, prompt=prompt) + return cls(parser=parser, retry_chain=chain) + + def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> Any: + try: + parsed_completion = self.parser.parse(completion) + except OutputParserException: + new_completion = self.retry_chain.run( + prompt=prompt_value.to_string(), completion=completion + ) + parsed_completion = self.parser.parse(new_completion) + + return parsed_completion + + def parse(self, completion: str) -> Any: + raise NotImplementedError( + "This OutputParser can only be called by the `parse_with_prompt` method." + ) + + def get_format_instructions(self) -> str: + return self.parser.get_format_instructions() + + +class RetryWithErrorOutputParser(BaseOutputParser): + """Wraps a parser and tries to fix parsing errors. + + Does this by passing the original prompt, the completion, AND the error + that was raised to another language and telling it that the completion + did not work, and raised the given error. Differs from RetryOutputParser + in that this implementation provides the error that was raised back to the + LLM, which in theory should give it more information on how to fix it. + """ + + parser: BaseOutputParser + retry_chain: LLMChain + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + parser: BaseOutputParser, + prompt: BasePromptTemplate = NAIVE_RETRY_WITH_ERROR_PROMPT, + ) -> RetryWithErrorOutputParser: + chain = LLMChain(llm=llm, prompt=prompt) + return cls(parser=parser, retry_chain=chain) + + def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> Any: + try: + parsed_completion = self.parser.parse(completion) + except OutputParserException as e: + new_completion = self.retry_chain.run( + prompt=prompt_value.to_string(), completion=completion, error=repr(e) + ) + parsed_completion = self.parser.parse(new_completion) + + return parsed_completion + + def parse(self, completion: str) -> Any: + raise NotImplementedError( + "This OutputParser can only be called by the `parse_with_prompt` method." + ) + + def get_format_instructions(self) -> str: + return self.parser.get_format_instructions() diff --git a/langchain/output_parsers/structured.py b/langchain/output_parsers/structured.py index 4a32d23f..9921e7c6 100644 --- a/langchain/output_parsers/structured.py +++ b/langchain/output_parsers/structured.py @@ -5,8 +5,8 @@ from typing import List from pydantic import BaseModel -from langchain.output_parsers.base import BaseOutputParser from langchain.output_parsers.format_instructions import STRUCTURED_FORMAT_INSTRUCTIONS +from langchain.schema import BaseOutputParser, OutputParserException line_template = '\t"{name}": {type} // {description}' @@ -42,7 +42,7 @@ class StructuredOutputParser(BaseOutputParser): json_obj = json.loads(json_string) for schema in self.response_schemas: if schema.name not in json_obj: - raise ValueError( + raise OutputParserException( f"Got invalid return object. Expected key `{schema.name}` " f"to be present, but got {json_obj}" ) diff --git a/langchain/prompts/base.py b/langchain/prompts/base.py index af431373..166ac4d5 100644 --- a/langchain/prompts/base.py +++ b/langchain/prompts/base.py @@ -10,13 +10,7 @@ import yaml from pydantic import BaseModel, Extra, Field, root_validator from langchain.formatting import formatter -from langchain.output_parsers.base import BaseOutputParser -from langchain.output_parsers.list import ( # noqa: F401 - CommaSeparatedListOutputParser, - ListOutputParser, -) -from langchain.output_parsers.regex import RegexParser # noqa: F401 -from langchain.schema import BaseMessage, HumanMessage, PromptValue +from langchain.schema import BaseMessage, BaseOutputParser, HumanMessage, PromptValue def jinja2_formatter(template: str, **kwargs: Any) -> str: diff --git a/langchain/schema.py b/langchain/schema.py index cb2d7403..84840bcd 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -259,3 +259,40 @@ class BaseMemory(BaseModel, ABC): Memory = BaseMemory + + +class BaseOutputParser(BaseModel, ABC): + """Class to parse the output of an LLM call.""" + + @abstractmethod + def parse(self, text: str) -> Any: + """Parse the output of an LLM call.""" + + def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any: + return self.parse(completion) + + def get_format_instructions(self) -> str: + raise NotImplementedError + + @property + def _type(self) -> str: + """Return the type key.""" + raise NotImplementedError + + def dict(self, **kwargs: Any) -> Dict: + """Return dictionary representation of output parser.""" + output_parser_dict = super().dict() + output_parser_dict["_type"] = self._type + return output_parser_dict + + +class OutputParserException(Exception): + """Exception that output parsers should raise to signify a parsing error. + + This exists to differentiate parsing errors from other code or execution errors + that also may arise inside the output parser. OutputParserExceptions will be + available to catch and handle in ways to fix the parsing error, while other + errors will be raised. + """ + + pass diff --git a/tests/unit_tests/chains/test_llm.py b/tests/unit_tests/chains/test_llm.py index d01e14e5..66b42e70 100644 --- a/tests/unit_tests/chains/test_llm.py +++ b/tests/unit_tests/chains/test_llm.py @@ -7,8 +7,8 @@ import pytest from langchain.chains.llm import LLMChain from langchain.chains.loading import load_chain -from langchain.output_parsers.base import BaseOutputParser from langchain.prompts.prompt import PromptTemplate +from langchain.schema import BaseOutputParser from tests.unit_tests.llms.fake_llm import FakeLLM