mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
c7ca350cd3
In LangChain, all module classes are enumerated in the `__init__.py` file of the correspondent module. But some classes were missed and were not included in the module `__init__.py` This PR: - added the missed classes to the module `__init__.py` files - `__init__.py:__all_` variable value (a list of the class names) was sorted - `langchain.tools.sql_database.tool.QueryCheckerTool` was renamed into the `QuerySQLCheckerTool` because it conflicted with `langchain.tools.spark_sql.tool.QueryCheckerTool` - changes to `pyproject.toml`: - added `pgvector` to `pyproject.toml:extended_testing` - added `pandas` to `pyproject.toml:[tool.poetry.group.test.dependencies]` - commented out the `streamlit` from `collbacks/__init__.py`, It is because now the `streamlit` requires Python >=3.7, !=3.9.7 - fixed duplicate names in `tools` - fixed correspondent ut-s #### Who can review? @hwchase17 @dev2049
54 lines
1.6 KiB
Python
54 lines
1.6 KiB
Python
"""Test the BaseOutputParser class and its sub-classes."""
|
|
from abc import ABC
|
|
from typing import List, Optional, Set, Type
|
|
|
|
import pytest
|
|
|
|
from langchain.schema import BaseOutputParser
|
|
|
|
|
|
def non_abstract_subclasses(
|
|
cls: Type[ABC], to_skip: Optional[Set] = None
|
|
) -> List[Type]:
|
|
"""Recursively find all non-abstract subclasses of a class."""
|
|
_to_skip = to_skip or set()
|
|
subclasses = []
|
|
for subclass in cls.__subclasses__():
|
|
if not getattr(subclass, "__abstractmethods__", None):
|
|
if subclass.__name__ not in _to_skip:
|
|
subclasses.append(subclass)
|
|
subclasses.extend(non_abstract_subclasses(subclass, to_skip=_to_skip))
|
|
return subclasses
|
|
|
|
|
|
# parsers defined not in the output_parsers module:
|
|
_PARSERS_TO_SKIP = {
|
|
"FakeOutputParser",
|
|
"BaseOutputParser",
|
|
"FinishedOutputParser",
|
|
"RouterOutputParser",
|
|
}
|
|
_NON_ABSTRACT_PARSERS = non_abstract_subclasses(
|
|
BaseOutputParser, to_skip=_PARSERS_TO_SKIP
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("cls", _NON_ABSTRACT_PARSERS)
|
|
def test_subclass_implements_type(cls: Type[BaseOutputParser]) -> None:
|
|
try:
|
|
cls._type
|
|
except NotImplementedError:
|
|
pytest.fail(f"_type property is not implemented in class {cls.__name__}")
|
|
|
|
|
|
def test_all_subclasses_implement_unique_type() -> None:
|
|
types = []
|
|
for cls in _NON_ABSTRACT_PARSERS:
|
|
try:
|
|
types.append(cls._type)
|
|
except NotImplementedError:
|
|
# This is handled in the previous test
|
|
pass
|
|
dups = set([t for t in types if types.count(t) > 1])
|
|
assert not dups, f"Duplicate types: {dups}"
|