You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/cli/langchain_cli/namespaces/migrate/generate/partner.py

55 lines
1.5 KiB
Python

"""Generate migrations for partner packages."""
import importlib
from typing import List, Tuple
from langchain_core.documents import BaseDocumentCompressor, BaseDocumentTransformer
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLanguageModel
from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStore
from langchain_cli.namespaces.migrate.generate.utils import (
COMMUNITY_PKG,
find_subclasses_in_module,
list_classes_by_package,
list_init_imports_by_package,
)
# PUBLIC API
def get_migrations_for_partner_package(pkg_name: str) -> List[Tuple[str, str]]:
"""Generate migrations from community package to partner package.
This code works
Args:
pkg_name (str): The name of the partner package.
Returns:
List of 2-tuples containing old and new import paths.
"""
package = importlib.import_module(pkg_name)
classes_ = find_subclasses_in_module(
package,
[
BaseLanguageModel,
Embeddings,
BaseRetriever,
VectorStore,
BaseDocumentTransformer,
BaseDocumentCompressor,
],
)
community_classes = list_classes_by_package(str(COMMUNITY_PKG))
imports_for_pkg = list_init_imports_by_package(str(COMMUNITY_PKG))
old_paths = community_classes + imports_for_pkg
migrations = [
(f"{module}.{item}", f"{pkg_name}.{item}")
for module, item in old_paths
if item in classes_
]
return migrations