From 5653f36adc8890e643995828743cbe22745e6bc8 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Fri, 26 Apr 2024 11:17:20 -0400 Subject: [PATCH] cli[minor]: Add script to generate migrations for partner packages (#20932) Add script to help generate migrations. This works well for partner packages. Migrations are generated based on run time rather than static analysis (much simpler to get the correct migrations implemented). The script for generating migrations from langchain to community still needs work. --- .../namespaces/migrate/generate/__init__.py | 0 .../namespaces/migrate/generate/langchain.py | 52 ++++++++ .../namespaces/migrate/generate/partner.py | 41 +++++++ .../namespaces/migrate/generate/utils.py | 111 ++++++++++++++++++ libs/cli/scripts/__init__.py | 0 libs/cli/scripts/generate_migrations.py | 48 ++++++++ libs/cli/scripts/migrations.json | 1 + .../unit_tests/migrate/generate/__init__.py | 0 .../generate/test_partner_migrations.py | 31 +++++ .../unit_tests/migrate/generate/test_utils.py | 5 + 10 files changed, 289 insertions(+) create mode 100644 libs/cli/langchain_cli/namespaces/migrate/generate/__init__.py create mode 100644 libs/cli/langchain_cli/namespaces/migrate/generate/langchain.py create mode 100644 libs/cli/langchain_cli/namespaces/migrate/generate/partner.py create mode 100644 libs/cli/langchain_cli/namespaces/migrate/generate/utils.py create mode 100644 libs/cli/scripts/__init__.py create mode 100644 libs/cli/scripts/generate_migrations.py create mode 100644 libs/cli/scripts/migrations.json create mode 100644 libs/cli/tests/unit_tests/migrate/generate/__init__.py create mode 100644 libs/cli/tests/unit_tests/migrate/generate/test_partner_migrations.py create mode 100644 libs/cli/tests/unit_tests/migrate/generate/test_utils.py diff --git a/libs/cli/langchain_cli/namespaces/migrate/generate/__init__.py b/libs/cli/langchain_cli/namespaces/migrate/generate/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/libs/cli/langchain_cli/namespaces/migrate/generate/langchain.py b/libs/cli/langchain_cli/namespaces/migrate/generate/langchain.py new file mode 100644 index 0000000000..3065e90e65 --- /dev/null +++ b/libs/cli/langchain_cli/namespaces/migrate/generate/langchain.py @@ -0,0 +1,52 @@ +"""Generate migrations from langchain to langchain-community or core packages.""" +import glob +from pathlib import Path +from typing import List, Tuple + +from langchain_cli.namespaces.migrate.generate.utils import ( + _get_current_module, + find_imports_from_package, +) + +HERE = Path(__file__).parent +PKGS_ROOT = HERE.parent.parent.parent +LANGCHAIN_PKG = PKGS_ROOT / "langchain" +COMMUNITY_PKG = PKGS_ROOT / "community" +PARTNER_PKGS = PKGS_ROOT / "partners" + + +def _generate_migrations_from_file( + source_module: str, code: str, *, from_package: str +) -> List[Tuple[str, str]]: + """Generate migrations""" + imports = find_imports_from_package(code, from_package=from_package) + return [ + # Rewrite in a list comprehension + (f"{source_module}.{item}", f"{new_path}.{item}") + for new_path, item in imports + ] + + +def _generate_migrations_from_file_in_pkg( + file: str, root_pkg: str +) -> List[Tuple[str, str]]: + """Generate migrations for a file that's relative to langchain pkg.""" + # Read the file. + with open(file, encoding="utf-8") as f: + code = f.read() + + module_name = _get_current_module(file, root_pkg) + return _generate_migrations_from_file( + module_name, code, from_package="langchain_community" + ) + + +def generate_migrations_from_langchain_to_community() -> List[Tuple[str, str]]: + """Generate migrations from langchain to langchain-community.""" + migrations = [] + # scanning files in pkg + for file_path in glob.glob(str(LANGCHAIN_PKG) + "**/*.py"): + migrations.extend( + _generate_migrations_from_file_in_pkg(file_path, str(LANGCHAIN_PKG)) + ) + return migrations diff --git a/libs/cli/langchain_cli/namespaces/migrate/generate/partner.py b/libs/cli/langchain_cli/namespaces/migrate/generate/partner.py new file mode 100644 index 0000000000..8bd473c724 --- /dev/null +++ b/libs/cli/langchain_cli/namespaces/migrate/generate/partner.py @@ -0,0 +1,41 @@ +"""Generate migrations for partner packages.""" +import importlib +from typing import List, Tuple + +from langchain_core.embeddings import Embeddings +from langchain_core.language_models import BaseLanguageModel +from langchain_core.retrievers import BaseRetriever +from langchain_core.vectorstores import VectorStore + +from langchain_cli.namespaces.migrate.generate.utils import ( + COMMUNITY_PKG, + find_subclasses_in_module, + list_classes_by_package, +) + +# PUBLIC API + + +def get_migrations_for_partner_package(pkg_name: str) -> List[Tuple[str, str]]: + """Generate migrations from community package to partner package. + + This code works + + Args: + pkg_name (str): The name of the partner package. + + Returns: + List of 2-tuples containing old and new import paths. + """ + package = importlib.import_module(pkg_name) + classes_ = find_subclasses_in_module( + package, [BaseLanguageModel, Embeddings, BaseRetriever, VectorStore] + ) + community_classes = list_classes_by_package(str(COMMUNITY_PKG)) + + migrations = [ + (f"{community_module}.{community_class}", f"{pkg_name}.{community_class}") + for community_module, community_class in community_classes + if community_class in classes_ + ] + return migrations diff --git a/libs/cli/langchain_cli/namespaces/migrate/generate/utils.py b/libs/cli/langchain_cli/namespaces/migrate/generate/utils.py new file mode 100644 index 0000000000..34ce40cc36 --- /dev/null +++ b/libs/cli/langchain_cli/namespaces/migrate/generate/utils.py @@ -0,0 +1,111 @@ +import ast +import inspect +import os +import pathlib +from pathlib import Path +from typing import Any, List, Tuple, Type + +HERE = Path(__file__).parent +# Should bring us to [root]/src +PKGS_ROOT = HERE.parent.parent.parent.parent.parent + +LANGCHAIN_PKG = PKGS_ROOT / "langchain" +COMMUNITY_PKG = PKGS_ROOT / "community" +PARTNER_PKGS = PKGS_ROOT / "partners" + + +class ImportExtractor(ast.NodeVisitor): + def __init__(self, *, from_package: str) -> None: + """Extract all imports from the given package.""" + self.imports = [] + self.package = from_package + + def visit_ImportFrom(self, node): + if node.module and str(node.module).startswith(self.package): + for alias in node.names: + self.imports.append((node.module, alias.name)) + self.generic_visit(node) + + +def _get_class_names(code: str) -> List[str]: + """Extract class names from a code string.""" + # Parse the content of the file into an AST + tree = ast.parse(code) + + # Initialize a list to hold all class names + class_names = [] + + # Define a node visitor class to collect class names + class ClassVisitor(ast.NodeVisitor): + def visit_ClassDef(self, node): + class_names.append(node.name) + self.generic_visit(node) + + # Create an instance of the visitor and visit the AST + visitor = ClassVisitor() + visitor.visit(tree) + return class_names + + +def is_subclass(class_obj: Any, classes_: List[Type]) -> bool: + """Check if the given class object is a subclass of any class in list classes.""" + return any( + issubclass(class_obj, kls) + for kls in classes_ + if inspect.isclass(class_obj) and inspect.isclass(kls) + ) + + +def find_subclasses_in_module(module, classes_: List[Type]) -> List[str]: + """Find all classes in the module that inherit from one of the classes.""" + subclasses = [] + # Iterate over all attributes of the module that are classes + for name, obj in inspect.getmembers(module, inspect.isclass): + if is_subclass(obj, classes_): + subclasses.append(obj.__name__) + return subclasses + + +def _get_all_classnames_from_file(file: str, pkg: str) -> List[Tuple[str, str]]: + """Extract all class names from a file.""" + with open(file, encoding="utf-8") as f: + code = f.read() + module_name = _get_current_module(file, pkg) + class_names = _get_class_names(code) + return [(module_name, class_name) for class_name in class_names] + + +def list_classes_by_package(pkg_root: str) -> List[Tuple[str, str]]: + """List all classes in a package.""" + module_classes = [] + files = list(Path(pkg_root).rglob("*.py")) + + for file in files: + rel_path = os.path.relpath(file, pkg_root) + if rel_path.startswith("tests"): + continue + module_classes.extend(_get_all_classnames_from_file(file, pkg_root)) + return module_classes + + +def find_imports_from_package(code: str, *, from_package: str) -> List[Tuple[str, str]]: + # Parse the code into an AST + tree = ast.parse(code) + # Create an instance of the visitor + extractor = ImportExtractor(from_package="langchain_community") + # Use the visitor to update the imports list + extractor.visit(tree) + return extractor.imports + + +def _get_current_module(path: str, pkg_root: str) -> str: + """Convert a path to a module name.""" + path_as_pathlib = pathlib.Path(os.path.abspath(path)) + relative_path = path_as_pathlib.relative_to(pkg_root).with_suffix("") + posix_path = relative_path.as_posix() + norm_path = os.path.normpath(str(posix_path)) + fully_qualified_module = norm_path.replace("/", ".") + # Strip __init__ if present + if fully_qualified_module.endswith(".__init__"): + return fully_qualified_module[:-9] + return fully_qualified_module diff --git a/libs/cli/scripts/__init__.py b/libs/cli/scripts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/libs/cli/scripts/generate_migrations.py b/libs/cli/scripts/generate_migrations.py new file mode 100644 index 0000000000..e428fc4bf1 --- /dev/null +++ b/libs/cli/scripts/generate_migrations.py @@ -0,0 +1,48 @@ +"""Script to generate migrations for the migration script.""" +import json + +import click + +from langchain_cli.namespaces.migrate.generate.langchain import ( + generate_migrations_from_langchain_to_community, +) +from langchain_cli.namespaces.migrate.generate.partner import ( + get_migrations_for_partner_package, +) + + +@click.group() +def cli(): + """Migration scripts management.""" + pass + + +@cli.command() +@click.option( + "--output", + default="langchain_migrations.json", + help="Output file for the migration script.", +) +def langchain(output: str) -> None: + """Generate a migration script.""" + click.echo("Migration script generated.") + migrations = generate_migrations_from_langchain_to_community() + with open(output, "w") as f: + f.write(json.dumps(migrations)) + + +@cli.command() +@click.argument("pkg") +@click.option("--output", default=None, help="Output file for the migration script.") +def partner(pkg: str, output: str) -> None: + """Generate migration scripts specifically for LangChain modules.""" + click.echo("Migration script for LangChain generated.") + migrations = get_migrations_for_partner_package(pkg) + output_name = f"partner_{pkg}.json" if output is None else output + with open(output_name, "w") as f: + f.write(json.dumps(migrations, indent=2, sort_keys=True)) + click.secho(f"LangChain migration script saved to {output_name}") + + +if __name__ == "__main__": + cli() diff --git a/libs/cli/scripts/migrations.json b/libs/cli/scripts/migrations.json new file mode 100644 index 0000000000..6263e4ed06 --- /dev/null +++ b/libs/cli/scripts/migrations.json @@ -0,0 +1 @@ +[["langchain_community.embeddings.openai.OpenAIEmbeddings", "langchain_openai.embeddings.base.OpenAIEmbeddings"], ["langchain_community.embeddings.azure_openai.AzureOpenAIEmbeddings", "langchain_openai.embeddings.azure.AzureOpenAIEmbeddings"], ["langchain_community.chat_models.openai.ChatOpenAI", "langchain_openai.chat_models.base.ChatOpenAI"], ["langchain_community.chat_models.azure_openai.AzureChatOpenAI", "langchain_openai.chat_models.azure.AzureChatOpenAI"]] \ No newline at end of file diff --git a/libs/cli/tests/unit_tests/migrate/generate/__init__.py b/libs/cli/tests/unit_tests/migrate/generate/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/libs/cli/tests/unit_tests/migrate/generate/test_partner_migrations.py b/libs/cli/tests/unit_tests/migrate/generate/test_partner_migrations.py new file mode 100644 index 0000000000..0754e8eef6 --- /dev/null +++ b/libs/cli/tests/unit_tests/migrate/generate/test_partner_migrations.py @@ -0,0 +1,31 @@ +import pytest + +from langchain_cli.namespaces.migrate.generate.partner import ( + get_migrations_for_partner_package, +) + +pytest.importorskip(modname="langchain_openai") + + +def test_generate_migrations() -> None: + migrations = get_migrations_for_partner_package("langchain_openai") + assert migrations == [ + ("langchain_community.llms.openai.OpenAI", "langchain_openai.OpenAI"), + ("langchain_community.llms.openai.AzureOpenAI", "langchain_openai.AzureOpenAI"), + ( + "langchain_community.embeddings.openai.OpenAIEmbeddings", + "langchain_openai.OpenAIEmbeddings", + ), + ( + "langchain_community.embeddings.azure_openai.AzureOpenAIEmbeddings", + "langchain_openai.AzureOpenAIEmbeddings", + ), + ( + "langchain_community.chat_models.openai.ChatOpenAI", + "langchain_openai.ChatOpenAI", + ), + ( + "langchain_community.chat_models.azure_openai.AzureChatOpenAI", + "langchain_openai.AzureChatOpenAI", + ), + ] diff --git a/libs/cli/tests/unit_tests/migrate/generate/test_utils.py b/libs/cli/tests/unit_tests/migrate/generate/test_utils.py new file mode 100644 index 0000000000..38974ff79a --- /dev/null +++ b/libs/cli/tests/unit_tests/migrate/generate/test_utils.py @@ -0,0 +1,5 @@ +from langchain_cli.namespaces.migrate.generate.utils import PKGS_ROOT + + +def test_root() -> None: + assert PKGS_ROOT.name == "libs"