cli[minor]: update code to generate migrations from langchain to community (#20946)

Updates code that generates migrations from langchain to community
pull/20949/head
Eugene Yurtsev 1 month ago committed by GitHub
parent 078c5d9bc6
commit 2fa0ff1a2d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,52 +1,129 @@
"""Generate migrations from langchain to langchain-community or core packages.""" """Generate migrations from langchain to langchain-community or core packages."""
import glob import importlib
from pathlib import Path import inspect
import pkgutil
from typing import List, Tuple from typing import List, Tuple
from langchain_cli.namespaces.migrate.generate.utils import (
_get_current_module, def generate_raw_migrations_to_community() -> List[Tuple[str, str]]:
find_imports_from_package, """Scan the `langchain` package and generate migrations for all modules."""
) import langchain as package
HERE = Path(__file__).parent to_package = "langchain_community"
PKGS_ROOT = HERE.parent.parent.parent
LANGCHAIN_PKG = PKGS_ROOT / "langchain" items = []
COMMUNITY_PKG = PKGS_ROOT / "community" for importer, modname, ispkg in pkgutil.walk_packages(
PARTNER_PKGS = PKGS_ROOT / "partners" package.__path__, package.__name__ + "."
):
try:
def _generate_migrations_from_file( module = importlib.import_module(modname)
source_module: str, code: str, *, from_package: str except ModuleNotFoundError:
) -> List[Tuple[str, str]]: continue
"""Generate migrations"""
imports = find_imports_from_package(code, from_package=from_package) # Check if the module is an __init__ file and evaluate __all__
return [ try:
# Rewrite in a list comprehension has_all = hasattr(module, "__all__")
(f"{source_module}.{item}", f"{new_path}.{item}") except ImportError:
for new_path, item in imports has_all = False
]
if has_all:
all_objects = module.__all__
def _generate_migrations_from_file_in_pkg( for name in all_objects:
file: str, root_pkg: str # Attempt to fetch each object declared in __all__
) -> List[Tuple[str, str]]: try:
"""Generate migrations for a file that's relative to langchain pkg.""" obj = getattr(module, name, None)
# Read the file. except ImportError:
with open(file, encoding="utf-8") as f: continue
code = f.read() if obj and (inspect.isclass(obj) or inspect.isfunction(obj)):
if obj.__module__.startswith(to_package):
module_name = _get_current_module(file, root_pkg) items.append((f"{modname}.{name}", f"{obj.__module__}.{name}"))
return _generate_migrations_from_file(
module_name, code, from_package="langchain_community" # Iterate over all members of the module
) for name, obj in inspect.getmembers(module):
# Check if it's a class or function
if inspect.isclass(obj) or inspect.isfunction(obj):
def generate_migrations_from_langchain_to_community() -> List[Tuple[str, str]]: # Check if the module name of the obj starts with 'langchain_community'
"""Generate migrations from langchain to langchain-community.""" if obj.__module__.startswith(to_package):
migrations = [] items.append((f"{modname}.{name}", f"{obj.__module__}.{name}"))
# scanning files in pkg
for file_path in glob.glob(str(LANGCHAIN_PKG) + "**/*.py"): return items
migrations.extend(
_generate_migrations_from_file_in_pkg(file_path, str(LANGCHAIN_PKG))
) def generate_top_level_imports_community() -> List[Tuple[str, str]]:
return migrations """This code will look at all the top level modules in langchain_community.
It'll attempt to import everything from each __init__ file
for example,
langchain_community/
chat_models/
__init__.py # <-- import everything from here
llm/
__init__.py # <-- import everything from here
It'll collect all the imports, import the classes / functions it can find
there. It'll return a list of 2-tuples
Each tuple will contain the fully qualified path of the class / function to where
its logic is defined
(e.g., langchain_community.chat_models.xyz_implementation.ver2.XYZ)
and the second tuple will contain the path
to importing it from the top level namespaces
(e.g., langchain_community.chat_models.XYZ)
"""
import langchain_community as package
items = []
# Only iterate through top-level modules/packages
for finder, modname, ispkg in pkgutil.iter_modules(
package.__path__, package.__name__ + "."
):
if ispkg:
try:
module = importlib.import_module(modname)
except ModuleNotFoundError:
continue
if hasattr(module, "__all__"):
all_objects = getattr(module, "__all__")
for name in all_objects:
# Attempt to fetch each object declared in __all__
obj = getattr(module, name, None)
if obj and (inspect.isclass(obj) or inspect.isfunction(obj)):
# Capture the fully qualified name of the object
original_module = obj.__module__
original_name = obj.__name__
# Form the new import path from the top-level namespace
top_level_import = f"{modname}.{name}"
# Append the tuple with original and top-level paths
items.append(
(f"{original_module}.{original_name}", top_level_import)
)
return items
def generate_simplified_migrations() -> List[Tuple[str, str]]:
"""Get all the raw migrations, then simplify them if possible."""
raw_migrations = generate_raw_migrations_to_community()
top_level_simplifications = generate_top_level_imports_community()
top_level_dict = {full: top_level for full, top_level in top_level_simplifications}
simple_migrations = []
for migration in raw_migrations:
original, new = migration
replacement = top_level_dict.get(new, new)
simple_migrations.append((original, replacement))
# Now let's deduplicate the list based on the original path (which is
# the 1st element of the tuple)
deduped_migrations = []
seen = set()
for migration in simple_migrations:
original = migration[0]
if original not in seen:
deduped_migrations.append(migration)
seen.add(original)
return deduped_migrations

@ -5,7 +5,7 @@ import pkgutil
import click import click
from langchain_cli.namespaces.migrate.generate.langchain import ( from langchain_cli.namespaces.migrate.generate.langchain import (
generate_migrations_from_langchain_to_community, generate_simplified_migrations,
) )
from langchain_cli.namespaces.migrate.generate.partner import ( from langchain_cli.namespaces.migrate.generate.partner import (
get_migrations_for_partner_package, get_migrations_for_partner_package,
@ -27,9 +27,9 @@ def cli():
def langchain(output: str) -> None: def langchain(output: str) -> None:
"""Generate a migration script.""" """Generate a migration script."""
click.echo("Migration script generated.") click.echo("Migration script generated.")
migrations = generate_migrations_from_langchain_to_community() migrations = generate_simplified_migrations()
with open(output, "w") as f: with open(output, "w") as f:
f.write(json.dumps(migrations)) f.write(json.dumps(migrations, indent=2, sort_keys=True))
@cli.command() @cli.command()

@ -0,0 +1,25 @@
from langchain_cli.namespaces.migrate.generate.langchain import (
generate_simplified_migrations,
)
def test_create_json_agent_migration() -> None:
"""Test the migration of create_json_agent from langchain to langchain_community."""
raw_migrations = generate_simplified_migrations()
json_agent_migrations = [
migration for migration in raw_migrations if "create_json_agent" in migration[0]
]
assert json_agent_migrations == [
(
"langchain.agents.create_json_agent",
"langchain_community.agent_toolkits.create_json_agent",
),
(
"langchain.agents.agent_toolkits.create_json_agent",
"langchain_community.agent_toolkits.create_json_agent",
),
(
"langchain.agents.agent_toolkits.json.base.create_json_agent",
"langchain_community.agent_toolkits.create_json_agent",
),
]
Loading…
Cancel
Save