diff --git a/libs/cli/langchain_cli/namespaces/migrate/generate/__init__.py b/libs/cli/langchain_cli/namespaces/migrate/generate/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/libs/cli/langchain_cli/namespaces/migrate/generate/langchain.py b/libs/cli/langchain_cli/namespaces/migrate/generate/langchain.py new file mode 100644 index 0000000000..3065e90e65 --- /dev/null +++ b/libs/cli/langchain_cli/namespaces/migrate/generate/langchain.py @@ -0,0 +1,52 @@ +"""Generate migrations from langchain to langchain-community or core packages.""" +import glob +from pathlib import Path +from typing import List, Tuple + +from langchain_cli.namespaces.migrate.generate.utils import ( + _get_current_module, + find_imports_from_package, +) + +HERE = Path(__file__).parent +PKGS_ROOT = HERE.parent.parent.parent +LANGCHAIN_PKG = PKGS_ROOT / "langchain" +COMMUNITY_PKG = PKGS_ROOT / "community" +PARTNER_PKGS = PKGS_ROOT / "partners" + + +def _generate_migrations_from_file( + source_module: str, code: str, *, from_package: str +) -> List[Tuple[str, str]]: + """Generate migrations""" + imports = find_imports_from_package(code, from_package=from_package) + return [ + # Rewrite in a list comprehension + (f"{source_module}.{item}", f"{new_path}.{item}") + for new_path, item in imports + ] + + +def _generate_migrations_from_file_in_pkg( + file: str, root_pkg: str +) -> List[Tuple[str, str]]: + """Generate migrations for a file that's relative to langchain pkg.""" + # Read the file. + with open(file, encoding="utf-8") as f: + code = f.read() + + module_name = _get_current_module(file, root_pkg) + return _generate_migrations_from_file( + module_name, code, from_package="langchain_community" + ) + + +def generate_migrations_from_langchain_to_community() -> List[Tuple[str, str]]: + """Generate migrations from langchain to langchain-community.""" + migrations = [] + # scanning files in pkg + for file_path in glob.glob(str(LANGCHAIN_PKG) + "**/*.py"): + migrations.extend( + _generate_migrations_from_file_in_pkg(file_path, str(LANGCHAIN_PKG)) + ) + return migrations diff --git a/libs/cli/langchain_cli/namespaces/migrate/generate/partner.py b/libs/cli/langchain_cli/namespaces/migrate/generate/partner.py new file mode 100644 index 0000000000..8bd473c724 --- /dev/null +++ b/libs/cli/langchain_cli/namespaces/migrate/generate/partner.py @@ -0,0 +1,41 @@ +"""Generate migrations for partner packages.""" +import importlib +from typing import List, Tuple + +from langchain_core.embeddings import Embeddings +from langchain_core.language_models import BaseLanguageModel +from langchain_core.retrievers import BaseRetriever +from langchain_core.vectorstores import VectorStore + +from langchain_cli.namespaces.migrate.generate.utils import ( + COMMUNITY_PKG, + find_subclasses_in_module, + list_classes_by_package, +) + +# PUBLIC API + + +def get_migrations_for_partner_package(pkg_name: str) -> List[Tuple[str, str]]: + """Generate migrations from community package to partner package. + + This code works + + Args: + pkg_name (str): The name of the partner package. + + Returns: + List of 2-tuples containing old and new import paths. + """ + package = importlib.import_module(pkg_name) + classes_ = find_subclasses_in_module( + package, [BaseLanguageModel, Embeddings, BaseRetriever, VectorStore] + ) + community_classes = list_classes_by_package(str(COMMUNITY_PKG)) + + migrations = [ + (f"{community_module}.{community_class}", f"{pkg_name}.{community_class}") + for community_module, community_class in community_classes + if community_class in classes_ + ] + return migrations diff --git a/libs/cli/langchain_cli/namespaces/migrate/generate/utils.py b/libs/cli/langchain_cli/namespaces/migrate/generate/utils.py new file mode 100644 index 0000000000..34ce40cc36 --- /dev/null +++ b/libs/cli/langchain_cli/namespaces/migrate/generate/utils.py @@ -0,0 +1,111 @@ +import ast +import inspect +import os +import pathlib +from pathlib import Path +from typing import Any, List, Tuple, Type + +HERE = Path(__file__).parent +# Should bring us to [root]/src +PKGS_ROOT = HERE.parent.parent.parent.parent.parent + +LANGCHAIN_PKG = PKGS_ROOT / "langchain" +COMMUNITY_PKG = PKGS_ROOT / "community" +PARTNER_PKGS = PKGS_ROOT / "partners" + + +class ImportExtractor(ast.NodeVisitor): + def __init__(self, *, from_package: str) -> None: + """Extract all imports from the given package.""" + self.imports = [] + self.package = from_package + + def visit_ImportFrom(self, node): + if node.module and str(node.module).startswith(self.package): + for alias in node.names: + self.imports.append((node.module, alias.name)) + self.generic_visit(node) + + +def _get_class_names(code: str) -> List[str]: + """Extract class names from a code string.""" + # Parse the content of the file into an AST + tree = ast.parse(code) + + # Initialize a list to hold all class names + class_names = [] + + # Define a node visitor class to collect class names + class ClassVisitor(ast.NodeVisitor): + def visit_ClassDef(self, node): + class_names.append(node.name) + self.generic_visit(node) + + # Create an instance of the visitor and visit the AST + visitor = ClassVisitor() + visitor.visit(tree) + return class_names + + +def is_subclass(class_obj: Any, classes_: List[Type]) -> bool: + """Check if the given class object is a subclass of any class in list classes.""" + return any( + issubclass(class_obj, kls) + for kls in classes_ + if inspect.isclass(class_obj) and inspect.isclass(kls) + ) + + +def find_subclasses_in_module(module, classes_: List[Type]) -> List[str]: + """Find all classes in the module that inherit from one of the classes.""" + subclasses = [] + # Iterate over all attributes of the module that are classes + for name, obj in inspect.getmembers(module, inspect.isclass): + if is_subclass(obj, classes_): + subclasses.append(obj.__name__) + return subclasses + + +def _get_all_classnames_from_file(file: str, pkg: str) -> List[Tuple[str, str]]: + """Extract all class names from a file.""" + with open(file, encoding="utf-8") as f: + code = f.read() + module_name = _get_current_module(file, pkg) + class_names = _get_class_names(code) + return [(module_name, class_name) for class_name in class_names] + + +def list_classes_by_package(pkg_root: str) -> List[Tuple[str, str]]: + """List all classes in a package.""" + module_classes = [] + files = list(Path(pkg_root).rglob("*.py")) + + for file in files: + rel_path = os.path.relpath(file, pkg_root) + if rel_path.startswith("tests"): + continue + module_classes.extend(_get_all_classnames_from_file(file, pkg_root)) + return module_classes + + +def find_imports_from_package(code: str, *, from_package: str) -> List[Tuple[str, str]]: + # Parse the code into an AST + tree = ast.parse(code) + # Create an instance of the visitor + extractor = ImportExtractor(from_package="langchain_community") + # Use the visitor to update the imports list + extractor.visit(tree) + return extractor.imports + + +def _get_current_module(path: str, pkg_root: str) -> str: + """Convert a path to a module name.""" + path_as_pathlib = pathlib.Path(os.path.abspath(path)) + relative_path = path_as_pathlib.relative_to(pkg_root).with_suffix("") + posix_path = relative_path.as_posix() + norm_path = os.path.normpath(str(posix_path)) + fully_qualified_module = norm_path.replace("/", ".") + # Strip __init__ if present + if fully_qualified_module.endswith(".__init__"): + return fully_qualified_module[:-9] + return fully_qualified_module diff --git a/libs/cli/scripts/__init__.py b/libs/cli/scripts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/libs/cli/scripts/generate_migrations.py b/libs/cli/scripts/generate_migrations.py new file mode 100644 index 0000000000..e428fc4bf1 --- /dev/null +++ b/libs/cli/scripts/generate_migrations.py @@ -0,0 +1,48 @@ +"""Script to generate migrations for the migration script.""" +import json + +import click + +from langchain_cli.namespaces.migrate.generate.langchain import ( + generate_migrations_from_langchain_to_community, +) +from langchain_cli.namespaces.migrate.generate.partner import ( + get_migrations_for_partner_package, +) + + +@click.group() +def cli(): + """Migration scripts management.""" + pass + + +@cli.command() +@click.option( + "--output", + default="langchain_migrations.json", + help="Output file for the migration script.", +) +def langchain(output: str) -> None: + """Generate a migration script.""" + click.echo("Migration script generated.") + migrations = generate_migrations_from_langchain_to_community() + with open(output, "w") as f: + f.write(json.dumps(migrations)) + + +@cli.command() +@click.argument("pkg") +@click.option("--output", default=None, help="Output file for the migration script.") +def partner(pkg: str, output: str) -> None: + """Generate migration scripts specifically for LangChain modules.""" + click.echo("Migration script for LangChain generated.") + migrations = get_migrations_for_partner_package(pkg) + output_name = f"partner_{pkg}.json" if output is None else output + with open(output_name, "w") as f: + f.write(json.dumps(migrations, indent=2, sort_keys=True)) + click.secho(f"LangChain migration script saved to {output_name}") + + +if __name__ == "__main__": + cli() diff --git a/libs/cli/scripts/migrations.json b/libs/cli/scripts/migrations.json new file mode 100644 index 0000000000..6263e4ed06 --- /dev/null +++ b/libs/cli/scripts/migrations.json @@ -0,0 +1 @@ +[["langchain_community.embeddings.openai.OpenAIEmbeddings", "langchain_openai.embeddings.base.OpenAIEmbeddings"], ["langchain_community.embeddings.azure_openai.AzureOpenAIEmbeddings", "langchain_openai.embeddings.azure.AzureOpenAIEmbeddings"], ["langchain_community.chat_models.openai.ChatOpenAI", "langchain_openai.chat_models.base.ChatOpenAI"], ["langchain_community.chat_models.azure_openai.AzureChatOpenAI", "langchain_openai.chat_models.azure.AzureChatOpenAI"]] \ No newline at end of file diff --git a/libs/cli/tests/unit_tests/migrate/generate/__init__.py b/libs/cli/tests/unit_tests/migrate/generate/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/libs/cli/tests/unit_tests/migrate/generate/test_partner_migrations.py b/libs/cli/tests/unit_tests/migrate/generate/test_partner_migrations.py new file mode 100644 index 0000000000..0754e8eef6 --- /dev/null +++ b/libs/cli/tests/unit_tests/migrate/generate/test_partner_migrations.py @@ -0,0 +1,31 @@ +import pytest + +from langchain_cli.namespaces.migrate.generate.partner import ( + get_migrations_for_partner_package, +) + +pytest.importorskip(modname="langchain_openai") + + +def test_generate_migrations() -> None: + migrations = get_migrations_for_partner_package("langchain_openai") + assert migrations == [ + ("langchain_community.llms.openai.OpenAI", "langchain_openai.OpenAI"), + ("langchain_community.llms.openai.AzureOpenAI", "langchain_openai.AzureOpenAI"), + ( + "langchain_community.embeddings.openai.OpenAIEmbeddings", + "langchain_openai.OpenAIEmbeddings", + ), + ( + "langchain_community.embeddings.azure_openai.AzureOpenAIEmbeddings", + "langchain_openai.AzureOpenAIEmbeddings", + ), + ( + "langchain_community.chat_models.openai.ChatOpenAI", + "langchain_openai.ChatOpenAI", + ), + ( + "langchain_community.chat_models.azure_openai.AzureChatOpenAI", + "langchain_openai.AzureChatOpenAI", + ), + ] diff --git a/libs/cli/tests/unit_tests/migrate/generate/test_utils.py b/libs/cli/tests/unit_tests/migrate/generate/test_utils.py new file mode 100644 index 0000000000..38974ff79a --- /dev/null +++ b/libs/cli/tests/unit_tests/migrate/generate/test_utils.py @@ -0,0 +1,5 @@ +from langchain_cli.namespaces.migrate.generate.utils import PKGS_ROOT + + +def test_root() -> None: + assert PKGS_ROOT.name == "libs"