forked from Archives/langchain
Add _type for all parsers (#4189)
Used for serialization. Also add test that recurses through our subclasses to check they have them implemented Would fix https://github.com/hwchase17/langchain/issues/3217 Blocking: https://github.com/mlflow/mlflow/pull/8297 --------- Signed-off-by: Sunish Sheth <sunishsheth2009@gmail.com> Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>parallel_dir_loader
parent
b21d7c138c
commit
812e5f43f5
@ -0,0 +1,47 @@
|
||||
"""Test the BaseOutputParser class and its sub-classes."""
|
||||
from abc import ABC
|
||||
from typing import List, Optional, Set, Type
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.schema import BaseOutputParser
|
||||
|
||||
|
||||
def non_abstract_subclasses(
|
||||
cls: Type[ABC], to_skip: Optional[Set] = None
|
||||
) -> List[Type]:
|
||||
"""Recursively find all non-abstract subclasses of a class."""
|
||||
_to_skip = to_skip or set()
|
||||
subclasses = []
|
||||
for subclass in cls.__subclasses__():
|
||||
if not getattr(subclass, "__abstractmethods__", None):
|
||||
if subclass.__name__ not in _to_skip:
|
||||
subclasses.append(subclass)
|
||||
subclasses.extend(non_abstract_subclasses(subclass, to_skip=_to_skip))
|
||||
return subclasses
|
||||
|
||||
|
||||
_PARSERS_TO_SKIP = {"FakeOutputParser", "BaseOutputParser"}
|
||||
_NON_ABSTRACT_PARSERS = non_abstract_subclasses(
|
||||
BaseOutputParser, to_skip=_PARSERS_TO_SKIP
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cls", _NON_ABSTRACT_PARSERS)
|
||||
def test_subclass_implements_type(cls: Type[BaseOutputParser]) -> None:
|
||||
try:
|
||||
cls._type
|
||||
except NotImplementedError:
|
||||
pytest.fail(f"_type property is not implemented in class {cls.__name__}")
|
||||
|
||||
|
||||
def test_all_subclasses_implement_unique_type() -> None:
|
||||
types = []
|
||||
for cls in _NON_ABSTRACT_PARSERS:
|
||||
try:
|
||||
types.append(cls._type)
|
||||
except NotImplementedError:
|
||||
# This is handled in the previous test
|
||||
pass
|
||||
dups = set([t for t in types if types.count(t) > 1])
|
||||
assert not dups, f"Duplicate types: {dups}"
|
Loading…
Reference in New Issue