cli[minor]: Add script to generate migrations for partner packages (#20932)

Add script to help generate migrations.

This works well for partner packages. Migrations are generated based on run time rather than static analysis (much simpler to get the correct migrations implemented).

The script for generating migrations from langchain to community still needs work.
pull/20951/head
Eugene Yurtsev 3 weeks ago committed by GitHub
parent fe1304afc4
commit 5653f36adc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,52 @@
"""Generate migrations from langchain to langchain-community or core packages."""
import glob
from pathlib import Path
from typing import List, Tuple
from langchain_cli.namespaces.migrate.generate.utils import (
_get_current_module,
find_imports_from_package,
)
HERE = Path(__file__).parent
PKGS_ROOT = HERE.parent.parent.parent
LANGCHAIN_PKG = PKGS_ROOT / "langchain"
COMMUNITY_PKG = PKGS_ROOT / "community"
PARTNER_PKGS = PKGS_ROOT / "partners"
def _generate_migrations_from_file(
source_module: str, code: str, *, from_package: str
) -> List[Tuple[str, str]]:
"""Generate migrations"""
imports = find_imports_from_package(code, from_package=from_package)
return [
# Rewrite in a list comprehension
(f"{source_module}.{item}", f"{new_path}.{item}")
for new_path, item in imports
]
def _generate_migrations_from_file_in_pkg(
file: str, root_pkg: str
) -> List[Tuple[str, str]]:
"""Generate migrations for a file that's relative to langchain pkg."""
# Read the file.
with open(file, encoding="utf-8") as f:
code = f.read()
module_name = _get_current_module(file, root_pkg)
return _generate_migrations_from_file(
module_name, code, from_package="langchain_community"
)
def generate_migrations_from_langchain_to_community() -> List[Tuple[str, str]]:
"""Generate migrations from langchain to langchain-community."""
migrations = []
# scanning files in pkg
for file_path in glob.glob(str(LANGCHAIN_PKG) + "**/*.py"):
migrations.extend(
_generate_migrations_from_file_in_pkg(file_path, str(LANGCHAIN_PKG))
)
return migrations

@ -0,0 +1,41 @@
"""Generate migrations for partner packages."""
import importlib
from typing import List, Tuple
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLanguageModel
from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStore
from langchain_cli.namespaces.migrate.generate.utils import (
COMMUNITY_PKG,
find_subclasses_in_module,
list_classes_by_package,
)
# PUBLIC API
def get_migrations_for_partner_package(pkg_name: str) -> List[Tuple[str, str]]:
"""Generate migrations from community package to partner package.
This code works
Args:
pkg_name (str): The name of the partner package.
Returns:
List of 2-tuples containing old and new import paths.
"""
package = importlib.import_module(pkg_name)
classes_ = find_subclasses_in_module(
package, [BaseLanguageModel, Embeddings, BaseRetriever, VectorStore]
)
community_classes = list_classes_by_package(str(COMMUNITY_PKG))
migrations = [
(f"{community_module}.{community_class}", f"{pkg_name}.{community_class}")
for community_module, community_class in community_classes
if community_class in classes_
]
return migrations

@ -0,0 +1,111 @@
import ast
import inspect
import os
import pathlib
from pathlib import Path
from typing import Any, List, Tuple, Type
HERE = Path(__file__).parent
# Should bring us to [root]/src
PKGS_ROOT = HERE.parent.parent.parent.parent.parent
LANGCHAIN_PKG = PKGS_ROOT / "langchain"
COMMUNITY_PKG = PKGS_ROOT / "community"
PARTNER_PKGS = PKGS_ROOT / "partners"
class ImportExtractor(ast.NodeVisitor):
def __init__(self, *, from_package: str) -> None:
"""Extract all imports from the given package."""
self.imports = []
self.package = from_package
def visit_ImportFrom(self, node):
if node.module and str(node.module).startswith(self.package):
for alias in node.names:
self.imports.append((node.module, alias.name))
self.generic_visit(node)
def _get_class_names(code: str) -> List[str]:
"""Extract class names from a code string."""
# Parse the content of the file into an AST
tree = ast.parse(code)
# Initialize a list to hold all class names
class_names = []
# Define a node visitor class to collect class names
class ClassVisitor(ast.NodeVisitor):
def visit_ClassDef(self, node):
class_names.append(node.name)
self.generic_visit(node)
# Create an instance of the visitor and visit the AST
visitor = ClassVisitor()
visitor.visit(tree)
return class_names
def is_subclass(class_obj: Any, classes_: List[Type]) -> bool:
"""Check if the given class object is a subclass of any class in list classes."""
return any(
issubclass(class_obj, kls)
for kls in classes_
if inspect.isclass(class_obj) and inspect.isclass(kls)
)
def find_subclasses_in_module(module, classes_: List[Type]) -> List[str]:
"""Find all classes in the module that inherit from one of the classes."""
subclasses = []
# Iterate over all attributes of the module that are classes
for name, obj in inspect.getmembers(module, inspect.isclass):
if is_subclass(obj, classes_):
subclasses.append(obj.__name__)
return subclasses
def _get_all_classnames_from_file(file: str, pkg: str) -> List[Tuple[str, str]]:
"""Extract all class names from a file."""
with open(file, encoding="utf-8") as f:
code = f.read()
module_name = _get_current_module(file, pkg)
class_names = _get_class_names(code)
return [(module_name, class_name) for class_name in class_names]
def list_classes_by_package(pkg_root: str) -> List[Tuple[str, str]]:
"""List all classes in a package."""
module_classes = []
files = list(Path(pkg_root).rglob("*.py"))
for file in files:
rel_path = os.path.relpath(file, pkg_root)
if rel_path.startswith("tests"):
continue
module_classes.extend(_get_all_classnames_from_file(file, pkg_root))
return module_classes
def find_imports_from_package(code: str, *, from_package: str) -> List[Tuple[str, str]]:
# Parse the code into an AST
tree = ast.parse(code)
# Create an instance of the visitor
extractor = ImportExtractor(from_package="langchain_community")
# Use the visitor to update the imports list
extractor.visit(tree)
return extractor.imports
def _get_current_module(path: str, pkg_root: str) -> str:
"""Convert a path to a module name."""
path_as_pathlib = pathlib.Path(os.path.abspath(path))
relative_path = path_as_pathlib.relative_to(pkg_root).with_suffix("")
posix_path = relative_path.as_posix()
norm_path = os.path.normpath(str(posix_path))
fully_qualified_module = norm_path.replace("/", ".")
# Strip __init__ if present
if fully_qualified_module.endswith(".__init__"):
return fully_qualified_module[:-9]
return fully_qualified_module

@ -0,0 +1,48 @@
"""Script to generate migrations for the migration script."""
import json
import click
from langchain_cli.namespaces.migrate.generate.langchain import (
generate_migrations_from_langchain_to_community,
)
from langchain_cli.namespaces.migrate.generate.partner import (
get_migrations_for_partner_package,
)
@click.group()
def cli():
"""Migration scripts management."""
pass
@cli.command()
@click.option(
"--output",
default="langchain_migrations.json",
help="Output file for the migration script.",
)
def langchain(output: str) -> None:
"""Generate a migration script."""
click.echo("Migration script generated.")
migrations = generate_migrations_from_langchain_to_community()
with open(output, "w") as f:
f.write(json.dumps(migrations))
@cli.command()
@click.argument("pkg")
@click.option("--output", default=None, help="Output file for the migration script.")
def partner(pkg: str, output: str) -> None:
"""Generate migration scripts specifically for LangChain modules."""
click.echo("Migration script for LangChain generated.")
migrations = get_migrations_for_partner_package(pkg)
output_name = f"partner_{pkg}.json" if output is None else output
with open(output_name, "w") as f:
f.write(json.dumps(migrations, indent=2, sort_keys=True))
click.secho(f"LangChain migration script saved to {output_name}")
if __name__ == "__main__":
cli()

@ -0,0 +1 @@
[["langchain_community.embeddings.openai.OpenAIEmbeddings", "langchain_openai.embeddings.base.OpenAIEmbeddings"], ["langchain_community.embeddings.azure_openai.AzureOpenAIEmbeddings", "langchain_openai.embeddings.azure.AzureOpenAIEmbeddings"], ["langchain_community.chat_models.openai.ChatOpenAI", "langchain_openai.chat_models.base.ChatOpenAI"], ["langchain_community.chat_models.azure_openai.AzureChatOpenAI", "langchain_openai.chat_models.azure.AzureChatOpenAI"]]

@ -0,0 +1,31 @@
import pytest
from langchain_cli.namespaces.migrate.generate.partner import (
get_migrations_for_partner_package,
)
pytest.importorskip(modname="langchain_openai")
def test_generate_migrations() -> None:
migrations = get_migrations_for_partner_package("langchain_openai")
assert migrations == [
("langchain_community.llms.openai.OpenAI", "langchain_openai.OpenAI"),
("langchain_community.llms.openai.AzureOpenAI", "langchain_openai.AzureOpenAI"),
(
"langchain_community.embeddings.openai.OpenAIEmbeddings",
"langchain_openai.OpenAIEmbeddings",
),
(
"langchain_community.embeddings.azure_openai.AzureOpenAIEmbeddings",
"langchain_openai.AzureOpenAIEmbeddings",
),
(
"langchain_community.chat_models.openai.ChatOpenAI",
"langchain_openai.ChatOpenAI",
),
(
"langchain_community.chat_models.azure_openai.AzureChatOpenAI",
"langchain_openai.AzureChatOpenAI",
),
]

@ -0,0 +1,5 @@
from langchain_cli.namespaces.migrate.generate.utils import PKGS_ROOT
def test_root() -> None:
assert PKGS_ROOT.name == "libs"
Loading…
Cancel
Save