mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
115 lines
3.7 KiB
Python
115 lines
3.7 KiB
Python
import importlib
|
|
import inspect
|
|
import pkgutil
|
|
from types import ModuleType
|
|
|
|
from langchain_core.load.mapping import SERIALIZABLE_MAPPING
|
|
|
|
|
|
def import_all_modules(package_name: str) -> dict:
|
|
package = importlib.import_module(package_name)
|
|
classes: dict = {}
|
|
|
|
def _handle_module(module: ModuleType) -> None:
|
|
# Iterate over all members of the module
|
|
|
|
names = dir(module)
|
|
|
|
if hasattr(module, "__all__"):
|
|
names += list(module.__all__)
|
|
|
|
names = sorted(set(names))
|
|
|
|
for name in names:
|
|
# Check if it's a class or function
|
|
attr = getattr(module, name)
|
|
|
|
if not inspect.isclass(attr):
|
|
continue
|
|
|
|
if not hasattr(attr, "is_lc_serializable") or not isinstance(attr, type):
|
|
continue
|
|
|
|
if (
|
|
isinstance(attr.is_lc_serializable(), bool) # type: ignore
|
|
and attr.is_lc_serializable() # type: ignore
|
|
):
|
|
key = tuple(attr.lc_id()) # type: ignore
|
|
value = tuple(attr.__module__.split(".") + [attr.__name__])
|
|
if key in classes and classes[key] != value:
|
|
raise ValueError
|
|
classes[key] = value
|
|
|
|
_handle_module(package)
|
|
|
|
for importer, modname, ispkg in pkgutil.walk_packages(
|
|
package.__path__, package.__name__ + "."
|
|
):
|
|
try:
|
|
module = importlib.import_module(modname)
|
|
except ModuleNotFoundError:
|
|
continue
|
|
_handle_module(module)
|
|
|
|
return classes
|
|
|
|
|
|
def test_import_all_modules() -> None:
|
|
"""Test import all modules works as expected"""
|
|
all_modules = import_all_modules("langchain")
|
|
filtered_modules = [
|
|
k
|
|
for k in all_modules
|
|
if len(k) == 4 and tuple(k[:2]) == ("langchain", "chat_models")
|
|
]
|
|
# This test will need to be updated if new serializable classes are added
|
|
# to community
|
|
assert sorted(filtered_modules) == sorted(
|
|
[
|
|
("langchain", "chat_models", "azure_openai", "AzureChatOpenAI"),
|
|
("langchain", "chat_models", "bedrock", "BedrockChat"),
|
|
("langchain", "chat_models", "anthropic", "ChatAnthropic"),
|
|
("langchain", "chat_models", "fireworks", "ChatFireworks"),
|
|
("langchain", "chat_models", "google_palm", "ChatGooglePalm"),
|
|
("langchain", "chat_models", "openai", "ChatOpenAI"),
|
|
("langchain", "chat_models", "vertexai", "ChatVertexAI"),
|
|
]
|
|
)
|
|
|
|
|
|
def test_serializable_mapping() -> None:
|
|
to_skip = {
|
|
# This should have had a different namespace, as it was never
|
|
# exported from the langchain module, but we keep for whoever has
|
|
# already serialized it.
|
|
("langchain", "prompts", "image", "ImagePromptTemplate"): (
|
|
"langchain_core",
|
|
"prompts",
|
|
"image",
|
|
"ImagePromptTemplate",
|
|
),
|
|
# This is not exported from langchain, only langchain_core
|
|
("langchain_core", "prompts", "structured", "StructuredPrompt"): (
|
|
"langchain_core",
|
|
"prompts",
|
|
"structured",
|
|
"StructuredPrompt",
|
|
),
|
|
}
|
|
serializable_modules = import_all_modules("langchain")
|
|
|
|
missing = set(SERIALIZABLE_MAPPING).difference(
|
|
set(serializable_modules).union(to_skip)
|
|
)
|
|
assert missing == set()
|
|
extra = set(serializable_modules).difference(SERIALIZABLE_MAPPING)
|
|
assert extra == set()
|
|
|
|
for k, import_path in serializable_modules.items():
|
|
import_dir, import_obj = import_path[:-1], import_path[-1]
|
|
# Import module
|
|
mod = importlib.import_module(".".join(import_dir))
|
|
# Import class
|
|
cls = getattr(mod, import_obj)
|
|
assert list(k) == cls.lc_id()
|