mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
32 lines
706 B
Python
32 lines
706 B
Python
|
from enum import Enum
|
||
|
|
||
|
from langchain.output_parsers.enum import EnumOutputParser
|
||
|
from langchain.schema import OutputParserException
|
||
|
|
||
|
|
||
|
class Colors(Enum):
|
||
|
RED = "red"
|
||
|
GREEN = "green"
|
||
|
BLUE = "blue"
|
||
|
|
||
|
|
||
|
def test_enum_output_parser_parse() -> None:
|
||
|
parser = EnumOutputParser(enum=Colors)
|
||
|
|
||
|
# Test valid inputs
|
||
|
result = parser.parse("red")
|
||
|
assert result == Colors.RED
|
||
|
|
||
|
result = parser.parse("green")
|
||
|
assert result == Colors.GREEN
|
||
|
|
||
|
result = parser.parse("blue")
|
||
|
assert result == Colors.BLUE
|
||
|
|
||
|
# Test invalid input
|
||
|
try:
|
||
|
parser.parse("INVALID")
|
||
|
assert False, "Should have raised OutputParserException"
|
||
|
except OutputParserException:
|
||
|
pass
|