add enum output parser (#5165)

This commit is contained in:
Harrison Chase 2023-05-27 20:58:23 -07:00 committed by GitHub
parent 465a970724
commit 179ddbe88b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 237 additions and 0 deletions

View File

@ -0,0 +1,173 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "0360be02",
"metadata": {},
"source": [
"# Enum Output Parser\n",
"\n",
"This notebook shows how to use an Enum output parser"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "2f039b4b",
"metadata": {},
"outputs": [],
"source": [
"from langchain.output_parsers.enum import EnumOutputParser"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "9a35d1a7",
"metadata": {},
"outputs": [],
"source": [
"from enum import Enum\n",
"\n",
"class Colors(Enum):\n",
" RED = \"red\"\n",
" GREEN = \"green\"\n",
" BLUE = \"blue\""
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a90a66f5",
"metadata": {},
"outputs": [],
"source": [
"parser = EnumOutputParser(enum=Colors)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "c48b88cb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Colors.RED: 'red'>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"parser.parse(\"red\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "7d313e41",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Colors.GREEN: 'green'>"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Can handle spaces\n",
"parser.parse(\" green\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "976ae42d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Colors.BLUE: 'blue'>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# And new lines\n",
"parser.parse(\"blue\\n\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "636a48ab",
"metadata": {},
"outputs": [
{
"ename": "OutputParserException",
"evalue": "Response 'yellow' is not one of the expected values: ['red', 'green', 'blue']",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"File \u001b[0;32m~/workplace/langchain/langchain/output_parsers/enum.py:25\u001b[0m, in \u001b[0;36mEnumOutputParser.parse\u001b[0;34m(self, response)\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m---> 25\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menum\u001b[49m\u001b[43m(\u001b[49m\u001b[43mresponse\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstrip\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 26\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m:\n",
"File \u001b[0;32m~/.pyenv/versions/3.9.1/lib/python3.9/enum.py:315\u001b[0m, in \u001b[0;36mEnumMeta.__call__\u001b[0;34m(cls, value, names, module, qualname, type, start)\u001b[0m\n\u001b[1;32m 314\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m names \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m: \u001b[38;5;66;03m# simple value lookup\u001b[39;00m\n\u001b[0;32m--> 315\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__new__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 316\u001b[0m \u001b[38;5;66;03m# otherwise, functional API: we're creating a new Enum type\u001b[39;00m\n",
"File \u001b[0;32m~/.pyenv/versions/3.9.1/lib/python3.9/enum.py:611\u001b[0m, in \u001b[0;36mEnum.__new__\u001b[0;34m(cls, value)\u001b[0m\n\u001b[1;32m 610\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m result \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m exc \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 611\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m ve_exc\n\u001b[1;32m 612\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m exc \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
"\u001b[0;31mValueError\u001b[0m: 'yellow' is not a valid Colors",
"\nDuring handling of the above exception, another exception occurred:\n",
"\u001b[0;31mOutputParserException\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[8], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# And raises errors when appropriate\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[43mparser\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparse\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43myellow\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/workplace/langchain/langchain/output_parsers/enum.py:27\u001b[0m, in \u001b[0;36mEnumOutputParser.parse\u001b[0;34m(self, response)\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39menum(response\u001b[38;5;241m.\u001b[39mstrip())\n\u001b[1;32m 26\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m:\n\u001b[0;32m---> 27\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m OutputParserException(\n\u001b[1;32m 28\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mResponse \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresponse\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m is not one of the \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 29\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mexpected values: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_valid_values\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 30\u001b[0m )\n",
"\u001b[0;31mOutputParserException\u001b[0m: Response 'yellow' is not one of the expected values: ['red', 'green', 'blue']"
]
}
],
"source": [
"# And raises errors when appropriate\n",
"parser.parse(\"yellow\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c517f447",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,33 @@
from enum import Enum
from typing import Any, Dict, List, Type
from pydantic import root_validator
from langchain.schema import BaseOutputParser, OutputParserException
class EnumOutputParser(BaseOutputParser):
enum: Type[Enum]
@root_validator()
def raise_deprecation(cls, values: Dict) -> Dict:
enum = values["enum"]
if not all(isinstance(e.value, str) for e in enum):
raise ValueError("Enum values must be strings")
return values
@property
def _valid_values(self) -> List[str]:
return [e.value for e in self.enum]
def parse(self, response: str) -> Any:
try:
return self.enum(response.strip())
except ValueError:
raise OutputParserException(
f"Response '{response}' is not one of the "
f"expected values: {self._valid_values}"
)
def get_format_instructions(self) -> str:
return f"Select one of the following options: {', '.join(self._valid_values)}"

View File

@ -0,0 +1,31 @@
from enum import Enum
from langchain.output_parsers.enum import EnumOutputParser
from langchain.schema import OutputParserException
class Colors(Enum):
RED = "red"
GREEN = "green"
BLUE = "blue"
def test_enum_output_parser_parse() -> None:
parser = EnumOutputParser(enum=Colors)
# Test valid inputs
result = parser.parse("red")
assert result == Colors.RED
result = parser.parse("green")
assert result == Colors.GREEN
result = parser.parse("blue")
assert result == Colors.BLUE
# Test invalid input
try:
parser.parse("INVALID")
assert False, "Should have raised OutputParserException"
except OutputParserException:
pass