cli[minor]: Add first version of migrate (#20902)

Adds a first version of the migrate script.
pull/20932/head
Eugene Yurtsev 2 weeks ago committed by GitHub
parent d95e9fb67f
commit 6598757037
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,3 +1,4 @@
import importlib
from typing import Optional
import typer
@ -22,6 +23,13 @@ app.add_typer(
)
# If libcst is installed, add the migrate namespace
if importlib.util.find_spec("libcst"):
from langchain_cli.namespaces.migrate import main as migrate_namespace
app.add_typer(migrate_namespace.app, name="migrate", help=migrate_namespace.__doc__)
def version_callback(show_version: bool) -> None:
if show_version:
typer.echo(f"langchain-cli {__version__}")

@ -0,0 +1,25 @@
from enum import Enum
from typing import List, Type
from libcst.codemod import ContextAwareTransformer
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor
from langchain_cli.namespaces.migrate.codemods.replace_imports import (
ReplaceImportsCodemod,
)
class Rule(str, Enum):
R001 = "R001"
"""Replace imports that have been moved."""
def gather_codemods(disabled: List[Rule]) -> List[Type[ContextAwareTransformer]]:
codemods: List[Type[ContextAwareTransformer]] = []
if Rule.R001 not in disabled:
codemods.append(ReplaceImportsCodemod)
# Those codemods need to be the last ones.
codemods.extend([RemoveImportsVisitor, AddImportsVisitor])
return codemods

@ -0,0 +1,14 @@
[
[
"langchain.chat_models.ChatOpenAI",
"langchain_openai.ChatOpenAI"
],
[
"langchain.chat_models.ChatOpenAI",
"langchain_openai.ChatOpenAI"
],
[
"langchain.chat_models.ChatAnthropic",
"langchain_anthropic.ChatAnthropic"
]
]

@ -0,0 +1,205 @@
"""
# Adapted from bump-pydantic
# https://github.com/pydantic/bump-pydantic
This codemod deals with the following cases:
1. `from pydantic import BaseSettings`
2. `from pydantic.settings import BaseSettings`
3. `from pydantic import BaseSettings as <name>`
4. `from pydantic.settings import BaseSettings as <name>` # TODO: This is not working.
5. `import pydantic` -> `pydantic.BaseSettings`
"""
from __future__ import annotations
import json
import os
from dataclasses import dataclass
from typing import Callable, Dict, Iterable, List, Sequence, Tuple, TypeVar
import libcst as cst
import libcst.matchers as m
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
from libcst.codemod.visitors import AddImportsVisitor
HERE = os.path.dirname(__file__)
def _load_migrations_by_file(path: str):
migrations_path = os.path.join(HERE, path)
with open(migrations_path, "r", encoding="utf-8") as f:
data = json.load(f)
return data
T = TypeVar("T")
def _deduplicate_in_order(
seq: Iterable[T], key: Callable[[T], str] = lambda x: x
) -> List[T]:
seen = set()
seen_add = seen.add
return [x for x in seq if not (key(x) in seen or seen_add(key(x)))]
def _load_migrations():
"""Load the migrations from the JSON file."""
# Later earlier ones have higher precedence.
paths = [
"migrations_v0.2_partner.json",
"migrations_v0.2.json",
]
data = []
for path in paths:
data.extend(_load_migrations_by_file(path))
data = _deduplicate_in_order(data, key=lambda x: x[0])
imports: Dict[str, Tuple[str, str]] = {}
for old_path, new_path in data:
# Parse the old parse which is of the format 'langchain.chat_models.ChatOpenAI'
# into the module and class name.
old_parts = old_path.split(".")
old_module = ".".join(old_parts[:-1])
old_class = old_parts[-1]
old_path_str = f"{old_module}:{old_class}"
# Parse the new parse which is of the format 'langchain.chat_models.ChatOpenAI'
# Into a 2-tuple of the module and class name.
new_parts = new_path.split(".")
new_module = ".".join(new_parts[:-1])
new_class = new_parts[-1]
new_path_str = (new_module, new_class)
imports[old_path_str] = new_path_str
return imports
IMPORTS = _load_migrations()
def resolve_module_parts(module_parts: list[str]) -> m.Attribute | m.Name:
"""Converts a list of module parts to a `Name` or `Attribute` node."""
if len(module_parts) == 1:
return m.Name(module_parts[0])
if len(module_parts) == 2:
first, last = module_parts
return m.Attribute(value=m.Name(first), attr=m.Name(last))
last_name = module_parts.pop()
attr = resolve_module_parts(module_parts)
return m.Attribute(value=attr, attr=m.Name(last_name))
def get_import_from_from_str(import_str: str) -> m.ImportFrom:
"""Converts a string like `pydantic:BaseSettings` to Examples:
>>> get_import_from_from_str("pydantic:BaseSettings")
ImportFrom(
module=Name("pydantic"),
names=[ImportAlias(name=Name("BaseSettings"))],
)
>>> get_import_from_from_str("pydantic.settings:BaseSettings")
ImportFrom(
module=Attribute(value=Name("pydantic"), attr=Name("settings")),
names=[ImportAlias(name=Name("BaseSettings"))],
)
>>> get_import_from_from_str("a.b.c:d")
ImportFrom(
module=Attribute(
value=Attribute(value=Name("a"), attr=Name("b")), attr=Name("c")
),
names=[ImportAlias(name=Name("d"))],
)
"""
module, name = import_str.split(":")
module_parts = module.split(".")
module_node = resolve_module_parts(module_parts)
return m.ImportFrom(
module=module_node,
names=[m.ZeroOrMore(), m.ImportAlias(name=m.Name(value=name)), m.ZeroOrMore()],
)
@dataclass
class ImportInfo:
import_from: m.ImportFrom
import_str: str
to_import_str: tuple[str, str]
IMPORT_INFOS = [
ImportInfo(
import_from=get_import_from_from_str(import_str),
import_str=import_str,
to_import_str=to_import_str,
)
for import_str, to_import_str in IMPORTS.items()
]
IMPORT_MATCH = m.OneOf(*[info.import_from for info in IMPORT_INFOS])
class ReplaceImportsCodemod(VisitorBasedCodemodCommand):
@m.leave(IMPORT_MATCH)
def leave_replace_import(
self, _: cst.ImportFrom, updated_node: cst.ImportFrom
) -> cst.ImportFrom:
for import_info in IMPORT_INFOS:
if m.matches(updated_node, import_info.import_from):
aliases: Sequence[cst.ImportAlias] = updated_node.names # type: ignore
# If multiple objects are imported in a single import statement,
# we need to remove only the one we're replacing.
AddImportsVisitor.add_needed_import(
self.context, *import_info.to_import_str
)
if len(updated_node.names) > 1: # type: ignore
names = [
alias
for alias in aliases
if alias.name.value != import_info.to_import_str[-1]
]
names[-1] = names[-1].with_changes(comma=cst.MaybeSentinel.DEFAULT)
updated_node = updated_node.with_changes(names=names)
else:
return cst.RemoveFromParent() # type: ignore[return-value]
return updated_node
if __name__ == "__main__":
import textwrap
from rich.console import Console
console = Console()
source = textwrap.dedent(
"""
from pydantic.settings import BaseSettings
from pydantic.color import Color
from pydantic.payment import PaymentCardNumber, PaymentCardBrand
from pydantic import Color
from pydantic import Color as Potato
class Potato(BaseSettings):
color: Color
payment: PaymentCardNumber
brand: PaymentCardBrand
potato: Potato
"""
)
console.print(source)
console.print("=" * 80)
mod = cst.parse_module(source)
context = CodemodContext(filename="main.py")
wrapper = cst.MetadataWrapper(mod)
command = ReplaceImportsCodemod(context=context)
mod = wrapper.visit(command)
wrapper = cst.MetadataWrapper(mod)
command = AddImportsVisitor(context=context) # type: ignore[assignment]
mod = wrapper.visit(command)
console.print(mod.code)

@ -0,0 +1,52 @@
# Adapted from bump-pydantic
# https://github.com/pydantic/bump-pydantic
import fnmatch
import re
from pathlib import Path
from typing import List
MATCH_SEP = r"(?:/|\\)"
MATCH_SEP_OR_END = r"(?:/|\\|\Z)"
MATCH_NON_RECURSIVE = r"[^/\\]*"
MATCH_RECURSIVE = r"(?:.*)"
def glob_to_re(pattern: str) -> str:
"""Translate a glob pattern to a regular expression for matching."""
fragments: List[str] = []
for segment in re.split(r"/|\\", pattern):
if segment == "":
continue
if segment == "**":
# Remove previous separator match, so the recursive match c
# can match zero or more segments.
if fragments and fragments[-1] == MATCH_SEP:
fragments.pop()
fragments.append(MATCH_RECURSIVE)
elif "**" in segment:
raise ValueError(
"invalid pattern: '**' can only be an entire path component"
)
else:
fragment = fnmatch.translate(segment)
fragment = fragment.replace(r"(?s:", r"(?:")
fragment = fragment.replace(r".*", MATCH_NON_RECURSIVE)
fragment = fragment.replace(r"\Z", r"")
fragments.append(fragment)
fragments.append(MATCH_SEP)
# Remove trailing MATCH_SEP, so it can be replaced with MATCH_SEP_OR_END.
if fragments and fragments[-1] == MATCH_SEP:
fragments.pop()
fragments.append(MATCH_SEP_OR_END)
return rf"(?s:{''.join(fragments)})"
def match_glob(path: Path, pattern: str) -> bool:
"""Check if a path matches a glob pattern.
If the pattern ends with a directory separator, the path must be a directory.
"""
match = bool(re.fullmatch(glob_to_re(pattern), str(path)))
if pattern.endswith("/") or pattern.endswith("\\"):
return match and path.is_dir()
return match

@ -0,0 +1,194 @@
"""Migrate LangChain to the most recent version."""
# Adapted from bump-pydantic
# https://github.com/pydantic/bump-pydantic
import difflib
import functools
import multiprocessing
import os
import time
import traceback
from pathlib import Path
from typing import Any, Dict, Iterable, List, Tuple, Type, TypeVar, Union
import libcst as cst
import rich
import typer
from libcst.codemod import CodemodContext, ContextAwareTransformer
from libcst.helpers import calculate_module_and_package
from libcst.metadata import FullRepoManager, FullyQualifiedNameProvider, ScopeProvider
from rich.console import Console
from rich.progress import Progress
from typer import Argument, Exit, Option, Typer
from typing_extensions import ParamSpec
from langchain_cli.namespaces.migrate.codemods import Rule, gather_codemods
from langchain_cli.namespaces.migrate.glob_helpers import match_glob
app = Typer(invoke_without_command=True, add_completion=False)
P = ParamSpec("P")
T = TypeVar("T")
DEFAULT_IGNORES = [".venv/**"]
@app.callback()
def main(
path: Path = Argument(..., exists=True, dir_okay=True, allow_dash=False),
disable: List[Rule] = Option(default=[], help="Disable a rule."),
diff: bool = Option(False, help="Show diff instead of applying changes."),
ignore: List[str] = Option(
default=DEFAULT_IGNORES, help="Ignore a path glob pattern."
),
log_file: Path = Option("log.txt", help="Log errors to this file."),
):
"""Migrate langchain to the most recent version."""
if not diff:
rich.print("[bold red]Alert![/ bold red] langchain-cli migrate", end=": ")
if not typer.confirm(
"The migration process will modify your files. "
"The migration is a `best-effort` process and is not expected to "
"be perfect. "
"Do you want to continue?"
):
raise Exit()
console = Console(log_time=True)
console.log("Start langchain-cli migrate")
# NOTE: LIBCST_PARSER_TYPE=native is required according to https://github.com/Instagram/LibCST/issues/487.
os.environ["LIBCST_PARSER_TYPE"] = "native"
if os.path.isfile(path):
package = path.parent
all_files = [path]
else:
package = path
all_files = sorted(package.glob("**/*.py"))
filtered_files = [
file
for file in all_files
if not any(match_glob(file, pattern) for pattern in ignore)
]
files = [str(file.relative_to(".")) for file in filtered_files]
if len(files) == 1:
console.log("Found 1 file to process.")
elif len(files) > 1:
console.log(f"Found {len(files)} files to process.")
else:
console.log("No files to process.")
raise Exit()
providers = {FullyQualifiedNameProvider, ScopeProvider}
metadata_manager = FullRepoManager(".", files, providers=providers) # type: ignore[arg-type]
metadata_manager.resolve_cache()
scratch: dict[str, Any] = {}
start_time = time.time()
codemods = gather_codemods(disabled=disable)
log_fp = log_file.open("a+", encoding="utf8")
partial_run_codemods = functools.partial(
run_codemods, codemods, metadata_manager, scratch, package, diff
)
with Progress(*Progress.get_default_columns(), transient=True) as progress:
task = progress.add_task(description="Executing codemods...", total=len(files))
count_errors = 0
difflines: List[List[str]] = []
with multiprocessing.Pool() as pool:
for error, _difflines in pool.imap_unordered(partial_run_codemods, files):
progress.advance(task)
if _difflines is not None:
difflines.append(_difflines)
if error is not None:
count_errors += 1
log_fp.writelines(error)
modified = [Path(f) for f in files if os.stat(f).st_mtime > start_time]
if not diff:
if modified:
console.log(f"Refactored {len(modified)} files.")
else:
console.log("No files were modified.")
for _difflines in difflines:
color_diff(console, _difflines)
if count_errors > 0:
console.log(f"Found {count_errors} errors. Please check the {log_file} file.")
else:
console.log("Run successfully!")
if difflines:
raise Exit(1)
def run_codemods(
codemods: List[Type[ContextAwareTransformer]],
metadata_manager: FullRepoManager,
scratch: Dict[str, Any],
package: Path,
diff: bool,
filename: str,
) -> Tuple[Union[str, None], Union[List[str], None]]:
try:
module_and_package = calculate_module_and_package(str(package), filename)
context = CodemodContext(
metadata_manager=metadata_manager,
filename=filename,
full_module_name=module_and_package.name,
full_package_name=module_and_package.package,
)
context.scratch.update(scratch)
file_path = Path(filename)
with file_path.open("r+", encoding="utf-8") as fp:
code = fp.read()
fp.seek(0)
input_tree = cst.parse_module(code)
for codemod in codemods:
transformer = codemod(context=context)
output_tree = transformer.transform_module(input_tree)
input_tree = output_tree
output_code = input_tree.code
if code != output_code:
if diff:
lines = difflib.unified_diff(
code.splitlines(keepends=True),
output_code.splitlines(keepends=True),
fromfile=filename,
tofile=filename,
)
return None, list(lines)
else:
fp.write(output_code)
fp.truncate()
return None, None
except cst.ParserSyntaxError as exc:
return (
f"A syntax error happened on {filename}. This file cannot be "
f"formatted.\n"
f"{exc}"
), None
except Exception:
return f"An error happened on {filename}.\n{traceback.format_exc()}", None
def color_diff(console: Console, lines: Iterable[str]) -> None:
for line in lines:
line = line.rstrip("\n")
if line.startswith("+"):
console.print(line, style="green")
elif line.startswith("-"):
console.print(line, style="red")
elif line.startswith("^"):
console.print(line, style="blue")
else:
console.print(line, style="white")

45
libs/cli/poetry.lock generated

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
[[package]]
name = "aiohttp"
@ -819,6 +819,46 @@ orjson = ">=3.9.14,<4.0.0"
pydantic = ">=1,<3"
requests = ">=2,<3"
[[package]]
name = "libcst"
version = "1.3.1"
description = "A concrete syntax tree with AST-like properties for Python 3.0 through 3.12 programs."
optional = false
python-versions = ">=3.9"
files = [
{file = "libcst-1.3.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:de93193cba6d54f2a4419e94ba2de642b111f52e4fa01bb6e2c655914585f65b"},
{file = "libcst-1.3.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a2d64d86dcd6c80a5dac2e243c5ed7a7a193242209ac33bad4b0639b24f6d131"},
{file = "libcst-1.3.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:db084f7bbf825c7bd5ed256290066d0336df6a7dc3a029c9870a64cd2298b87f"},
{file = "libcst-1.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:16880711be03a1f5da7028fe791ba5b482a50d830225a70272dc332dfd927652"},
{file = "libcst-1.3.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:189bb28c19c5dd3c64583f969b72f7732dbdb1dee9eca3acc85099e4cef9148b"},
{file = "libcst-1.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:181372386c986e3de07d7a93f269214cd825adc714f1f9da8252b44f05e181c4"},
{file = "libcst-1.3.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8c2020f7449270e3ff0bdc481ae244d812f2d9a8b7dbff0ea66b830f4b350f54"},
{file = "libcst-1.3.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:be3bf9aaafebda6a21e241e819f0ab67e186e898c3562704e41241cf8738353a"},
{file = "libcst-1.3.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a0d250fb6a2c1d158f30d25ba5e33e3ed3672d2700d480dd47beffd1431a008"},
{file = "libcst-1.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4ad5741b251d901f3da1819ac539192230cc6f8f81aaf04eb4ec0009c1c97285"},
{file = "libcst-1.3.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b740dc0c3d1adbd91442fb878343d5a11e23a0e3ac5484be301fd8d148bcb085"},
{file = "libcst-1.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:9e6bc95fa7dde79cde63a34a0412cd4a3d9fcc27be781a590f8c45c840c83658"},
{file = "libcst-1.3.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:4186076ce12141609ce950d61867b2a73ea199a7a9870dbafa76ad600e075b3c"},
{file = "libcst-1.3.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4ed52a1a2fe4d8603de51649db5e438317b8116ebb9fc09ec68703535fe6c1c8"},
{file = "libcst-1.3.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c0886a9963597367b227345f19b24931b3ed6a4703fff237760745f90f0e6a20"},
{file = "libcst-1.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:904c4cc5c801a5747e64b43e0accc87c67a4c804842d977ee215872c4cf8cf88"},
{file = "libcst-1.3.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7cdb7e8a118b60e064a02f6cbfa4d328212a3a115d125244495190f405709d5f"},
{file = "libcst-1.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:431badf0e544b79c0ac9682dbd291ff63ddbc3c3aca0d13d3cc7a10c3a9db8a2"},
{file = "libcst-1.3.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:701f5335e4fd566871497b9af1e871c98e1ef10c30b3b244f39343d709213401"},
{file = "libcst-1.3.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7c6e709623b68ca9148e8ecbdc145f7b83befb26032e4bf6a8122500ba558b17"},
{file = "libcst-1.3.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ede0f026a82b03b33a559ec566860085ece2e76d8f9bc21cb053eedf9cde8c79"},
{file = "libcst-1.3.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c12b7b01d8745f82dd86a82acd2a9f8e8e7d6c94ddcfda996896e83d1a8d5c42"},
{file = "libcst-1.3.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2995ca687118a9d3d41876f7270bc29305a2d402f4b8c81a3cff0aeee6d4c81"},
{file = "libcst-1.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:2dbac1ac0a9d59ea7bbc3f87cdcca5bfe98835e31c668e95cb6f3d907ffc53fc"},
{file = "libcst-1.3.1.tar.gz", hash = "sha256:03b1df1ae02456f1d465fcd5ead1d0d454bb483caefd8c8e6bde515ffdb53d1b"},
]
[package.dependencies]
pyyaml = ">=5.2"
[package.extras]
dev = ["Sphinx (>=5.1.1)", "black (==23.12.1)", "build (>=0.10.0)", "coverage (>=4.5.4)", "fixit (==2.1.0)", "flake8 (==7.0.0)", "hypothesis (>=4.36.0)", "hypothesmith (>=0.0.4)", "jinja2 (==3.1.3)", "jupyter (>=1.0.0)", "maturin (>=0.8.3,<1.5)", "nbsphinx (>=0.4.2)", "prompt-toolkit (>=2.0.9)", "pyre-check (==0.9.18)", "setuptools-rust (>=1.5.2)", "setuptools-scm (>=6.0.1)", "slotscheck (>=0.7.1)", "sphinx-rtd-theme (>=0.4.3)", "ufmt (==2.5.1)", "usort (==1.0.8.post1)"]
[[package]]
name = "markdown-it-py"
version = "3.0.0"
@ -1322,7 +1362,6 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@ -1850,4 +1889,4 @@ serve = []
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "8232264abd652b61dac3fd47833745c1df6c3418599dc14f9fe09773f3f80f13"
content-hash = "4576fb13ecd9e13bc6c85e4cd6f56520708c7c1468f4b81bc6a346b128c9f695"

@ -17,6 +17,7 @@ gitpython = "^3.1.40"
langserve = { extras = ["all"], version = ">=0.0.51" }
uvicorn = "^0.23.2"
tomlkit = "^0.12.2"
libcst = { version = "^1.3.1", python = "^3.9" }
[tool.poetry.scripts]
langchain = "langchain_cli.cli:app"

@ -0,0 +1,13 @@
from __future__ import annotations
from dataclasses import dataclass
from .file import File
from .folder import Folder
@dataclass
class Case:
source: Folder | File
expected: Folder | File
name: str

@ -0,0 +1,15 @@
from tests.unit_tests.migrate.integration.case import Case
from tests.unit_tests.migrate.integration.cases import imports
from tests.unit_tests.migrate.integration.file import File
from tests.unit_tests.migrate.integration.folder import Folder
cases = [
Case(
name="empty",
source=File("__init__.py", content=[]),
expected=File("__init__.py", content=[]),
),
*imports.cases,
]
before = Folder("project", *[case.source for case in cases])
expected = Folder("project", *[case.expected for case in cases])

@ -0,0 +1,32 @@
from tests.unit_tests.migrate.integration.case import Case
from tests.unit_tests.migrate.integration.file import File
cases = [
Case(
name="Imports",
source=File(
"app.py",
content=[
"from langchain.chat_models import ChatOpenAI",
"",
"",
"class foo:",
" a: int",
"",
"chain = ChatOpenAI()",
],
),
expected=File(
"app.py",
content=[
"from langchain_openai import ChatOpenAI",
"",
"",
"class foo:",
" a: int",
"",
"chain = ChatOpenAI()",
],
),
),
]

@ -0,0 +1,16 @@
from __future__ import annotations
class File:
def __init__(self, name: str, content: list[str] | None = None) -> None:
self.name = name
self.content = "\n".join(content or [])
def __eq__(self, __value: object) -> bool:
if not isinstance(__value, File):
return NotImplemented
if self.name != __value.name:
return False
return self.content == __value.content

@ -0,0 +1,59 @@
from __future__ import annotations
from pathlib import Path
from .file import File
class Folder:
def __init__(self, name: str, *files: Folder | File) -> None:
self.name = name
self._files = files
@property
def files(self) -> list[Folder | File]:
return sorted(self._files, key=lambda f: f.name)
def create_structure(self, root: Path) -> None:
path = root / self.name
path.mkdir()
for file in self.files:
if isinstance(file, Folder):
file.create_structure(path)
else:
(path / file.name).write_text(file.content, encoding="utf-8")
@classmethod
def from_structure(cls, root: Path) -> Folder:
name = root.name
files: list[File | Folder] = []
for path in root.iterdir():
if path.is_dir():
files.append(cls.from_structure(path))
else:
files.append(
File(path.name, path.read_text(encoding="utf-8").splitlines())
)
return Folder(name, *files)
def __eq__(self, __value: object) -> bool:
if isinstance(__value, File):
return False
if not isinstance(__value, Folder):
return NotImplemented
if self.name != __value.name:
return False
if len(self.files) != len(__value.files):
return False
for self_file, other_file in zip(self.files, __value.files):
if self_file != other_file:
return False
return True

@ -0,0 +1,55 @@
# ruff: noqa: E402
from __future__ import annotations
import pytest
pytest.importorskip("libcst")
import difflib
from pathlib import Path
from typer.testing import CliRunner
from langchain_cli.namespaces.migrate.main import app
from tests.unit_tests.migrate.cli_runner.cases import before, expected
from tests.unit_tests.migrate.cli_runner.folder import Folder
def find_issue(current: Folder, expected: Folder) -> str:
for current_file, expected_file in zip(current.files, expected.files):
if current_file != expected_file:
if current_file.name != expected_file.name:
return (
f"Files have "
f"different names: {current_file.name} != {expected_file.name}"
)
if isinstance(current_file, Folder) and isinstance(expected_file, Folder):
return find_issue(current_file, expected_file)
elif isinstance(current_file, Folder) or isinstance(expected_file, Folder):
return (
f"One of the files is a "
f"folder: {current_file.name} != {expected_file.name}"
)
return "\n".join(
difflib.unified_diff(
current_file.content.splitlines(),
expected_file.content.splitlines(),
fromfile=current_file.name,
tofile=expected_file.name,
)
)
return "Unknown"
def test_command_line(tmp_path: Path) -> None:
runner = CliRunner()
with runner.isolated_filesystem(temp_dir=tmp_path) as td:
before.create_structure(root=Path(td))
# The input is used to force through the confirmation.
result = runner.invoke(app, [before.name], input="y\n")
assert result.exit_code == 0, result.output
after = Folder.from_structure(Path(td) / before.name)
assert after == expected, find_issue(after, expected)

@ -0,0 +1,72 @@
from __future__ import annotations
from pathlib import Path
import pytest
from langchain_cli.namespaces.migrate.glob_helpers import glob_to_re, match_glob
class TestGlobHelpers:
match_glob_values: list[tuple[str, Path, bool]] = [
("foo", Path("foo"), True),
("foo", Path("bar"), False),
("foo", Path("foo/bar"), False),
("*", Path("foo"), True),
("*", Path("bar"), True),
("*", Path("foo/bar"), False),
("**", Path("foo"), True),
("**", Path("foo/bar"), True),
("**", Path("foo/bar/baz/qux"), True),
("foo/bar", Path("foo/bar"), True),
("foo/bar", Path("foo"), False),
("foo/bar", Path("far"), False),
("foo/bar", Path("foo/foo"), False),
("foo/*", Path("foo/bar"), True),
("foo/*", Path("foo/bar/baz"), False),
("foo/*", Path("foo"), False),
("foo/*", Path("bar"), False),
("foo/**", Path("foo/bar"), True),
("foo/**", Path("foo/bar/baz"), True),
("foo/**", Path("foo/bar/baz/qux"), True),
("foo/**", Path("foo"), True),
("foo/**", Path("bar"), False),
("foo/**/bar", Path("foo/bar"), True),
("foo/**/bar", Path("foo/baz/bar"), True),
("foo/**/bar", Path("foo/baz/qux/bar"), True),
("foo/**/bar", Path("foo/baz/qux"), False),
("foo/**/bar", Path("foo/bar/baz"), False),
("foo/**/bar", Path("foo/bar/bar"), True),
("foo/**/bar", Path("foo"), False),
("foo/**/bar", Path("bar"), False),
("foo/**/*/bar", Path("foo/bar"), False),
("foo/**/*/bar", Path("foo/baz/bar"), True),
("foo/**/*/bar", Path("foo/baz/qux/bar"), True),
("foo/**/*/bar", Path("foo/baz/qux"), False),
("foo/**/*/bar", Path("foo/bar/baz"), False),
("foo/**/*/bar", Path("foo/bar/bar"), True),
("foo/**/*/bar", Path("foo"), False),
("foo/**/*/bar", Path("bar"), False),
("foo/ba*", Path("foo/bar"), True),
("foo/ba*", Path("foo/baz"), True),
("foo/ba*", Path("foo/qux"), False),
("foo/ba*", Path("foo/baz/qux"), False),
("foo/ba*", Path("foo/bar/baz"), False),
("foo/ba*", Path("foo"), False),
("foo/ba*", Path("bar"), False),
("foo/**/ba*/*/qux", Path("foo/a/b/c/bar/a/qux"), True),
("foo/**/ba*/*/qux", Path("foo/a/b/c/baz/a/qux"), True),
("foo/**/ba*/*/qux", Path("foo/a/bar/a/qux"), True),
("foo/**/ba*/*/qux", Path("foo/baz/a/qux"), True),
("foo/**/ba*/*/qux", Path("foo/baz/qux"), False),
("foo/**/ba*/*/qux", Path("foo/a/b/c/qux/a/qux"), False),
("foo/**/ba*/*/qux", Path("foo"), False),
("foo/**/ba*/*/qux", Path("bar"), False),
]
@pytest.mark.parametrize(("pattern", "path", "expected"), match_glob_values)
def test_match_glob(self, pattern: str, path: Path, expected: bool):
expr = glob_to_re(pattern)
assert (
match_glob(path, pattern) == expected
), f"path: {path}, pattern: {pattern}, expr: {expr}"

@ -0,0 +1,41 @@
# ruff: noqa: E402
import pytest
pytest.importorskip("libcst")
from libcst.codemod import CodemodTest
from langchain_cli.namespaces.migrate.codemods.replace_imports import (
ReplaceImportsCodemod,
)
class TestReplaceImportsCommand(CodemodTest):
TRANSFORM = ReplaceImportsCodemod
def test_single_import(self) -> None:
before = """
from langchain.chat_models import ChatOpenAI
"""
after = """
from langchain_openai import ChatOpenAI
"""
self.assertCodemod(before, after)
def test_noop_import(self) -> None:
code = """
from foo import ChatOpenAI
"""
self.assertCodemod(code, code)
def test_mixed_imports(self) -> None:
before = """
from langchain.chat_models import ChatOpenAI, ChatAnthropic, foo
"""
after = """
from langchain.chat_models import foo
from langchain_anthropic import ChatAnthropic
from langchain_openai import ChatOpenAI
"""
self.assertCodemod(before, after)
Loading…
Cancel
Save