mirror of https://github.com/hwchase17/langchain
cli[minor]: Add first version of migrate (#20902)
Adds a first version of the migrate script.pull/20932/head
parent
d95e9fb67f
commit
6598757037
@ -0,0 +1,25 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from typing import List, Type
|
||||||
|
|
||||||
|
from libcst.codemod import ContextAwareTransformer
|
||||||
|
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor
|
||||||
|
|
||||||
|
from langchain_cli.namespaces.migrate.codemods.replace_imports import (
|
||||||
|
ReplaceImportsCodemod,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Rule(str, Enum):
|
||||||
|
R001 = "R001"
|
||||||
|
"""Replace imports that have been moved."""
|
||||||
|
|
||||||
|
|
||||||
|
def gather_codemods(disabled: List[Rule]) -> List[Type[ContextAwareTransformer]]:
|
||||||
|
codemods: List[Type[ContextAwareTransformer]] = []
|
||||||
|
|
||||||
|
if Rule.R001 not in disabled:
|
||||||
|
codemods.append(ReplaceImportsCodemod)
|
||||||
|
|
||||||
|
# Those codemods need to be the last ones.
|
||||||
|
codemods.extend([RemoveImportsVisitor, AddImportsVisitor])
|
||||||
|
return codemods
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,14 @@
|
|||||||
|
[
|
||||||
|
[
|
||||||
|
"langchain.chat_models.ChatOpenAI",
|
||||||
|
"langchain_openai.ChatOpenAI"
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"langchain.chat_models.ChatOpenAI",
|
||||||
|
"langchain_openai.ChatOpenAI"
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"langchain.chat_models.ChatAnthropic",
|
||||||
|
"langchain_anthropic.ChatAnthropic"
|
||||||
|
]
|
||||||
|
]
|
@ -0,0 +1,205 @@
|
|||||||
|
"""
|
||||||
|
# Adapted from bump-pydantic
|
||||||
|
# https://github.com/pydantic/bump-pydantic
|
||||||
|
|
||||||
|
This codemod deals with the following cases:
|
||||||
|
|
||||||
|
1. `from pydantic import BaseSettings`
|
||||||
|
2. `from pydantic.settings import BaseSettings`
|
||||||
|
3. `from pydantic import BaseSettings as <name>`
|
||||||
|
4. `from pydantic.settings import BaseSettings as <name>` # TODO: This is not working.
|
||||||
|
5. `import pydantic` -> `pydantic.BaseSettings`
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable, Dict, Iterable, List, Sequence, Tuple, TypeVar
|
||||||
|
|
||||||
|
import libcst as cst
|
||||||
|
import libcst.matchers as m
|
||||||
|
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
|
||||||
|
from libcst.codemod.visitors import AddImportsVisitor
|
||||||
|
|
||||||
|
HERE = os.path.dirname(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_migrations_by_file(path: str):
|
||||||
|
migrations_path = os.path.join(HERE, path)
|
||||||
|
with open(migrations_path, "r", encoding="utf-8") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def _deduplicate_in_order(
|
||||||
|
seq: Iterable[T], key: Callable[[T], str] = lambda x: x
|
||||||
|
) -> List[T]:
|
||||||
|
seen = set()
|
||||||
|
seen_add = seen.add
|
||||||
|
return [x for x in seq if not (key(x) in seen or seen_add(key(x)))]
|
||||||
|
|
||||||
|
|
||||||
|
def _load_migrations():
|
||||||
|
"""Load the migrations from the JSON file."""
|
||||||
|
# Later earlier ones have higher precedence.
|
||||||
|
paths = [
|
||||||
|
"migrations_v0.2_partner.json",
|
||||||
|
"migrations_v0.2.json",
|
||||||
|
]
|
||||||
|
|
||||||
|
data = []
|
||||||
|
for path in paths:
|
||||||
|
data.extend(_load_migrations_by_file(path))
|
||||||
|
|
||||||
|
data = _deduplicate_in_order(data, key=lambda x: x[0])
|
||||||
|
|
||||||
|
imports: Dict[str, Tuple[str, str]] = {}
|
||||||
|
|
||||||
|
for old_path, new_path in data:
|
||||||
|
# Parse the old parse which is of the format 'langchain.chat_models.ChatOpenAI'
|
||||||
|
# into the module and class name.
|
||||||
|
old_parts = old_path.split(".")
|
||||||
|
old_module = ".".join(old_parts[:-1])
|
||||||
|
old_class = old_parts[-1]
|
||||||
|
old_path_str = f"{old_module}:{old_class}"
|
||||||
|
|
||||||
|
# Parse the new parse which is of the format 'langchain.chat_models.ChatOpenAI'
|
||||||
|
# Into a 2-tuple of the module and class name.
|
||||||
|
new_parts = new_path.split(".")
|
||||||
|
new_module = ".".join(new_parts[:-1])
|
||||||
|
new_class = new_parts[-1]
|
||||||
|
new_path_str = (new_module, new_class)
|
||||||
|
|
||||||
|
imports[old_path_str] = new_path_str
|
||||||
|
|
||||||
|
return imports
|
||||||
|
|
||||||
|
|
||||||
|
IMPORTS = _load_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_module_parts(module_parts: list[str]) -> m.Attribute | m.Name:
|
||||||
|
"""Converts a list of module parts to a `Name` or `Attribute` node."""
|
||||||
|
if len(module_parts) == 1:
|
||||||
|
return m.Name(module_parts[0])
|
||||||
|
if len(module_parts) == 2:
|
||||||
|
first, last = module_parts
|
||||||
|
return m.Attribute(value=m.Name(first), attr=m.Name(last))
|
||||||
|
last_name = module_parts.pop()
|
||||||
|
attr = resolve_module_parts(module_parts)
|
||||||
|
return m.Attribute(value=attr, attr=m.Name(last_name))
|
||||||
|
|
||||||
|
|
||||||
|
def get_import_from_from_str(import_str: str) -> m.ImportFrom:
|
||||||
|
"""Converts a string like `pydantic:BaseSettings` to Examples:
|
||||||
|
>>> get_import_from_from_str("pydantic:BaseSettings")
|
||||||
|
ImportFrom(
|
||||||
|
module=Name("pydantic"),
|
||||||
|
names=[ImportAlias(name=Name("BaseSettings"))],
|
||||||
|
)
|
||||||
|
>>> get_import_from_from_str("pydantic.settings:BaseSettings")
|
||||||
|
ImportFrom(
|
||||||
|
module=Attribute(value=Name("pydantic"), attr=Name("settings")),
|
||||||
|
names=[ImportAlias(name=Name("BaseSettings"))],
|
||||||
|
)
|
||||||
|
>>> get_import_from_from_str("a.b.c:d")
|
||||||
|
ImportFrom(
|
||||||
|
module=Attribute(
|
||||||
|
value=Attribute(value=Name("a"), attr=Name("b")), attr=Name("c")
|
||||||
|
),
|
||||||
|
names=[ImportAlias(name=Name("d"))],
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
module, name = import_str.split(":")
|
||||||
|
module_parts = module.split(".")
|
||||||
|
module_node = resolve_module_parts(module_parts)
|
||||||
|
return m.ImportFrom(
|
||||||
|
module=module_node,
|
||||||
|
names=[m.ZeroOrMore(), m.ImportAlias(name=m.Name(value=name)), m.ZeroOrMore()],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ImportInfo:
|
||||||
|
import_from: m.ImportFrom
|
||||||
|
import_str: str
|
||||||
|
to_import_str: tuple[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
IMPORT_INFOS = [
|
||||||
|
ImportInfo(
|
||||||
|
import_from=get_import_from_from_str(import_str),
|
||||||
|
import_str=import_str,
|
||||||
|
to_import_str=to_import_str,
|
||||||
|
)
|
||||||
|
for import_str, to_import_str in IMPORTS.items()
|
||||||
|
]
|
||||||
|
IMPORT_MATCH = m.OneOf(*[info.import_from for info in IMPORT_INFOS])
|
||||||
|
|
||||||
|
|
||||||
|
class ReplaceImportsCodemod(VisitorBasedCodemodCommand):
|
||||||
|
@m.leave(IMPORT_MATCH)
|
||||||
|
def leave_replace_import(
|
||||||
|
self, _: cst.ImportFrom, updated_node: cst.ImportFrom
|
||||||
|
) -> cst.ImportFrom:
|
||||||
|
for import_info in IMPORT_INFOS:
|
||||||
|
if m.matches(updated_node, import_info.import_from):
|
||||||
|
aliases: Sequence[cst.ImportAlias] = updated_node.names # type: ignore
|
||||||
|
# If multiple objects are imported in a single import statement,
|
||||||
|
# we need to remove only the one we're replacing.
|
||||||
|
AddImportsVisitor.add_needed_import(
|
||||||
|
self.context, *import_info.to_import_str
|
||||||
|
)
|
||||||
|
if len(updated_node.names) > 1: # type: ignore
|
||||||
|
names = [
|
||||||
|
alias
|
||||||
|
for alias in aliases
|
||||||
|
if alias.name.value != import_info.to_import_str[-1]
|
||||||
|
]
|
||||||
|
names[-1] = names[-1].with_changes(comma=cst.MaybeSentinel.DEFAULT)
|
||||||
|
updated_node = updated_node.with_changes(names=names)
|
||||||
|
else:
|
||||||
|
return cst.RemoveFromParent() # type: ignore[return-value]
|
||||||
|
return updated_node
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import textwrap
|
||||||
|
|
||||||
|
from rich.console import Console
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
source = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
from pydantic.settings import BaseSettings
|
||||||
|
from pydantic.color import Color
|
||||||
|
from pydantic.payment import PaymentCardNumber, PaymentCardBrand
|
||||||
|
from pydantic import Color
|
||||||
|
from pydantic import Color as Potato
|
||||||
|
|
||||||
|
|
||||||
|
class Potato(BaseSettings):
|
||||||
|
color: Color
|
||||||
|
payment: PaymentCardNumber
|
||||||
|
brand: PaymentCardBrand
|
||||||
|
potato: Potato
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
console.print(source)
|
||||||
|
console.print("=" * 80)
|
||||||
|
|
||||||
|
mod = cst.parse_module(source)
|
||||||
|
context = CodemodContext(filename="main.py")
|
||||||
|
wrapper = cst.MetadataWrapper(mod)
|
||||||
|
command = ReplaceImportsCodemod(context=context)
|
||||||
|
|
||||||
|
mod = wrapper.visit(command)
|
||||||
|
wrapper = cst.MetadataWrapper(mod)
|
||||||
|
command = AddImportsVisitor(context=context) # type: ignore[assignment]
|
||||||
|
mod = wrapper.visit(command)
|
||||||
|
console.print(mod.code)
|
@ -0,0 +1,52 @@
|
|||||||
|
# Adapted from bump-pydantic
|
||||||
|
# https://github.com/pydantic/bump-pydantic
|
||||||
|
import fnmatch
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
MATCH_SEP = r"(?:/|\\)"
|
||||||
|
MATCH_SEP_OR_END = r"(?:/|\\|\Z)"
|
||||||
|
MATCH_NON_RECURSIVE = r"[^/\\]*"
|
||||||
|
MATCH_RECURSIVE = r"(?:.*)"
|
||||||
|
|
||||||
|
|
||||||
|
def glob_to_re(pattern: str) -> str:
|
||||||
|
"""Translate a glob pattern to a regular expression for matching."""
|
||||||
|
fragments: List[str] = []
|
||||||
|
for segment in re.split(r"/|\\", pattern):
|
||||||
|
if segment == "":
|
||||||
|
continue
|
||||||
|
if segment == "**":
|
||||||
|
# Remove previous separator match, so the recursive match c
|
||||||
|
# can match zero or more segments.
|
||||||
|
if fragments and fragments[-1] == MATCH_SEP:
|
||||||
|
fragments.pop()
|
||||||
|
fragments.append(MATCH_RECURSIVE)
|
||||||
|
elif "**" in segment:
|
||||||
|
raise ValueError(
|
||||||
|
"invalid pattern: '**' can only be an entire path component"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
fragment = fnmatch.translate(segment)
|
||||||
|
fragment = fragment.replace(r"(?s:", r"(?:")
|
||||||
|
fragment = fragment.replace(r".*", MATCH_NON_RECURSIVE)
|
||||||
|
fragment = fragment.replace(r"\Z", r"")
|
||||||
|
fragments.append(fragment)
|
||||||
|
fragments.append(MATCH_SEP)
|
||||||
|
# Remove trailing MATCH_SEP, so it can be replaced with MATCH_SEP_OR_END.
|
||||||
|
if fragments and fragments[-1] == MATCH_SEP:
|
||||||
|
fragments.pop()
|
||||||
|
fragments.append(MATCH_SEP_OR_END)
|
||||||
|
return rf"(?s:{''.join(fragments)})"
|
||||||
|
|
||||||
|
|
||||||
|
def match_glob(path: Path, pattern: str) -> bool:
|
||||||
|
"""Check if a path matches a glob pattern.
|
||||||
|
|
||||||
|
If the pattern ends with a directory separator, the path must be a directory.
|
||||||
|
"""
|
||||||
|
match = bool(re.fullmatch(glob_to_re(pattern), str(path)))
|
||||||
|
if pattern.endswith("/") or pattern.endswith("\\"):
|
||||||
|
return match and path.is_dir()
|
||||||
|
return match
|
@ -0,0 +1,194 @@
|
|||||||
|
"""Migrate LangChain to the most recent version."""
|
||||||
|
# Adapted from bump-pydantic
|
||||||
|
# https://github.com/pydantic/bump-pydantic
|
||||||
|
import difflib
|
||||||
|
import functools
|
||||||
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Iterable, List, Tuple, Type, TypeVar, Union
|
||||||
|
|
||||||
|
import libcst as cst
|
||||||
|
import rich
|
||||||
|
import typer
|
||||||
|
from libcst.codemod import CodemodContext, ContextAwareTransformer
|
||||||
|
from libcst.helpers import calculate_module_and_package
|
||||||
|
from libcst.metadata import FullRepoManager, FullyQualifiedNameProvider, ScopeProvider
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.progress import Progress
|
||||||
|
from typer import Argument, Exit, Option, Typer
|
||||||
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
|
from langchain_cli.namespaces.migrate.codemods import Rule, gather_codemods
|
||||||
|
from langchain_cli.namespaces.migrate.glob_helpers import match_glob
|
||||||
|
|
||||||
|
app = Typer(invoke_without_command=True, add_completion=False)
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
DEFAULT_IGNORES = [".venv/**"]
|
||||||
|
|
||||||
|
|
||||||
|
@app.callback()
|
||||||
|
def main(
|
||||||
|
path: Path = Argument(..., exists=True, dir_okay=True, allow_dash=False),
|
||||||
|
disable: List[Rule] = Option(default=[], help="Disable a rule."),
|
||||||
|
diff: bool = Option(False, help="Show diff instead of applying changes."),
|
||||||
|
ignore: List[str] = Option(
|
||||||
|
default=DEFAULT_IGNORES, help="Ignore a path glob pattern."
|
||||||
|
),
|
||||||
|
log_file: Path = Option("log.txt", help="Log errors to this file."),
|
||||||
|
):
|
||||||
|
"""Migrate langchain to the most recent version."""
|
||||||
|
if not diff:
|
||||||
|
rich.print("[bold red]Alert![/ bold red] langchain-cli migrate", end=": ")
|
||||||
|
if not typer.confirm(
|
||||||
|
"The migration process will modify your files. "
|
||||||
|
"The migration is a `best-effort` process and is not expected to "
|
||||||
|
"be perfect. "
|
||||||
|
"Do you want to continue?"
|
||||||
|
):
|
||||||
|
raise Exit()
|
||||||
|
console = Console(log_time=True)
|
||||||
|
console.log("Start langchain-cli migrate")
|
||||||
|
# NOTE: LIBCST_PARSER_TYPE=native is required according to https://github.com/Instagram/LibCST/issues/487.
|
||||||
|
os.environ["LIBCST_PARSER_TYPE"] = "native"
|
||||||
|
|
||||||
|
if os.path.isfile(path):
|
||||||
|
package = path.parent
|
||||||
|
all_files = [path]
|
||||||
|
else:
|
||||||
|
package = path
|
||||||
|
all_files = sorted(package.glob("**/*.py"))
|
||||||
|
|
||||||
|
filtered_files = [
|
||||||
|
file
|
||||||
|
for file in all_files
|
||||||
|
if not any(match_glob(file, pattern) for pattern in ignore)
|
||||||
|
]
|
||||||
|
files = [str(file.relative_to(".")) for file in filtered_files]
|
||||||
|
|
||||||
|
if len(files) == 1:
|
||||||
|
console.log("Found 1 file to process.")
|
||||||
|
elif len(files) > 1:
|
||||||
|
console.log(f"Found {len(files)} files to process.")
|
||||||
|
else:
|
||||||
|
console.log("No files to process.")
|
||||||
|
raise Exit()
|
||||||
|
|
||||||
|
providers = {FullyQualifiedNameProvider, ScopeProvider}
|
||||||
|
metadata_manager = FullRepoManager(".", files, providers=providers) # type: ignore[arg-type]
|
||||||
|
metadata_manager.resolve_cache()
|
||||||
|
|
||||||
|
scratch: dict[str, Any] = {}
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
codemods = gather_codemods(disabled=disable)
|
||||||
|
|
||||||
|
log_fp = log_file.open("a+", encoding="utf8")
|
||||||
|
partial_run_codemods = functools.partial(
|
||||||
|
run_codemods, codemods, metadata_manager, scratch, package, diff
|
||||||
|
)
|
||||||
|
with Progress(*Progress.get_default_columns(), transient=True) as progress:
|
||||||
|
task = progress.add_task(description="Executing codemods...", total=len(files))
|
||||||
|
count_errors = 0
|
||||||
|
difflines: List[List[str]] = []
|
||||||
|
with multiprocessing.Pool() as pool:
|
||||||
|
for error, _difflines in pool.imap_unordered(partial_run_codemods, files):
|
||||||
|
progress.advance(task)
|
||||||
|
|
||||||
|
if _difflines is not None:
|
||||||
|
difflines.append(_difflines)
|
||||||
|
|
||||||
|
if error is not None:
|
||||||
|
count_errors += 1
|
||||||
|
log_fp.writelines(error)
|
||||||
|
|
||||||
|
modified = [Path(f) for f in files if os.stat(f).st_mtime > start_time]
|
||||||
|
|
||||||
|
if not diff:
|
||||||
|
if modified:
|
||||||
|
console.log(f"Refactored {len(modified)} files.")
|
||||||
|
else:
|
||||||
|
console.log("No files were modified.")
|
||||||
|
|
||||||
|
for _difflines in difflines:
|
||||||
|
color_diff(console, _difflines)
|
||||||
|
|
||||||
|
if count_errors > 0:
|
||||||
|
console.log(f"Found {count_errors} errors. Please check the {log_file} file.")
|
||||||
|
else:
|
||||||
|
console.log("Run successfully!")
|
||||||
|
|
||||||
|
if difflines:
|
||||||
|
raise Exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def run_codemods(
|
||||||
|
codemods: List[Type[ContextAwareTransformer]],
|
||||||
|
metadata_manager: FullRepoManager,
|
||||||
|
scratch: Dict[str, Any],
|
||||||
|
package: Path,
|
||||||
|
diff: bool,
|
||||||
|
filename: str,
|
||||||
|
) -> Tuple[Union[str, None], Union[List[str], None]]:
|
||||||
|
try:
|
||||||
|
module_and_package = calculate_module_and_package(str(package), filename)
|
||||||
|
context = CodemodContext(
|
||||||
|
metadata_manager=metadata_manager,
|
||||||
|
filename=filename,
|
||||||
|
full_module_name=module_and_package.name,
|
||||||
|
full_package_name=module_and_package.package,
|
||||||
|
)
|
||||||
|
context.scratch.update(scratch)
|
||||||
|
|
||||||
|
file_path = Path(filename)
|
||||||
|
with file_path.open("r+", encoding="utf-8") as fp:
|
||||||
|
code = fp.read()
|
||||||
|
fp.seek(0)
|
||||||
|
|
||||||
|
input_tree = cst.parse_module(code)
|
||||||
|
|
||||||
|
for codemod in codemods:
|
||||||
|
transformer = codemod(context=context)
|
||||||
|
output_tree = transformer.transform_module(input_tree)
|
||||||
|
input_tree = output_tree
|
||||||
|
|
||||||
|
output_code = input_tree.code
|
||||||
|
if code != output_code:
|
||||||
|
if diff:
|
||||||
|
lines = difflib.unified_diff(
|
||||||
|
code.splitlines(keepends=True),
|
||||||
|
output_code.splitlines(keepends=True),
|
||||||
|
fromfile=filename,
|
||||||
|
tofile=filename,
|
||||||
|
)
|
||||||
|
return None, list(lines)
|
||||||
|
else:
|
||||||
|
fp.write(output_code)
|
||||||
|
fp.truncate()
|
||||||
|
return None, None
|
||||||
|
except cst.ParserSyntaxError as exc:
|
||||||
|
return (
|
||||||
|
f"A syntax error happened on {filename}. This file cannot be "
|
||||||
|
f"formatted.\n"
|
||||||
|
f"{exc}"
|
||||||
|
), None
|
||||||
|
except Exception:
|
||||||
|
return f"An error happened on {filename}.\n{traceback.format_exc()}", None
|
||||||
|
|
||||||
|
|
||||||
|
def color_diff(console: Console, lines: Iterable[str]) -> None:
|
||||||
|
for line in lines:
|
||||||
|
line = line.rstrip("\n")
|
||||||
|
if line.startswith("+"):
|
||||||
|
console.print(line, style="green")
|
||||||
|
elif line.startswith("-"):
|
||||||
|
console.print(line, style="red")
|
||||||
|
elif line.startswith("^"):
|
||||||
|
console.print(line, style="blue")
|
||||||
|
else:
|
||||||
|
console.print(line, style="white")
|
@ -0,0 +1,13 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from .file import File
|
||||||
|
from .folder import Folder
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Case:
|
||||||
|
source: Folder | File
|
||||||
|
expected: Folder | File
|
||||||
|
name: str
|
@ -0,0 +1,15 @@
|
|||||||
|
from tests.unit_tests.migrate.integration.case import Case
|
||||||
|
from tests.unit_tests.migrate.integration.cases import imports
|
||||||
|
from tests.unit_tests.migrate.integration.file import File
|
||||||
|
from tests.unit_tests.migrate.integration.folder import Folder
|
||||||
|
|
||||||
|
cases = [
|
||||||
|
Case(
|
||||||
|
name="empty",
|
||||||
|
source=File("__init__.py", content=[]),
|
||||||
|
expected=File("__init__.py", content=[]),
|
||||||
|
),
|
||||||
|
*imports.cases,
|
||||||
|
]
|
||||||
|
before = Folder("project", *[case.source for case in cases])
|
||||||
|
expected = Folder("project", *[case.expected for case in cases])
|
@ -0,0 +1,32 @@
|
|||||||
|
from tests.unit_tests.migrate.integration.case import Case
|
||||||
|
from tests.unit_tests.migrate.integration.file import File
|
||||||
|
|
||||||
|
cases = [
|
||||||
|
Case(
|
||||||
|
name="Imports",
|
||||||
|
source=File(
|
||||||
|
"app.py",
|
||||||
|
content=[
|
||||||
|
"from langchain.chat_models import ChatOpenAI",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"class foo:",
|
||||||
|
" a: int",
|
||||||
|
"",
|
||||||
|
"chain = ChatOpenAI()",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
expected=File(
|
||||||
|
"app.py",
|
||||||
|
content=[
|
||||||
|
"from langchain_openai import ChatOpenAI",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"class foo:",
|
||||||
|
" a: int",
|
||||||
|
"",
|
||||||
|
"chain = ChatOpenAI()",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
@ -0,0 +1,16 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
|
class File:
|
||||||
|
def __init__(self, name: str, content: list[str] | None = None) -> None:
|
||||||
|
self.name = name
|
||||||
|
self.content = "\n".join(content or [])
|
||||||
|
|
||||||
|
def __eq__(self, __value: object) -> bool:
|
||||||
|
if not isinstance(__value, File):
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
if self.name != __value.name:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return self.content == __value.content
|
@ -0,0 +1,59 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from .file import File
|
||||||
|
|
||||||
|
|
||||||
|
class Folder:
|
||||||
|
def __init__(self, name: str, *files: Folder | File) -> None:
|
||||||
|
self.name = name
|
||||||
|
self._files = files
|
||||||
|
|
||||||
|
@property
|
||||||
|
def files(self) -> list[Folder | File]:
|
||||||
|
return sorted(self._files, key=lambda f: f.name)
|
||||||
|
|
||||||
|
def create_structure(self, root: Path) -> None:
|
||||||
|
path = root / self.name
|
||||||
|
path.mkdir()
|
||||||
|
|
||||||
|
for file in self.files:
|
||||||
|
if isinstance(file, Folder):
|
||||||
|
file.create_structure(path)
|
||||||
|
else:
|
||||||
|
(path / file.name).write_text(file.content, encoding="utf-8")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_structure(cls, root: Path) -> Folder:
|
||||||
|
name = root.name
|
||||||
|
files: list[File | Folder] = []
|
||||||
|
|
||||||
|
for path in root.iterdir():
|
||||||
|
if path.is_dir():
|
||||||
|
files.append(cls.from_structure(path))
|
||||||
|
else:
|
||||||
|
files.append(
|
||||||
|
File(path.name, path.read_text(encoding="utf-8").splitlines())
|
||||||
|
)
|
||||||
|
|
||||||
|
return Folder(name, *files)
|
||||||
|
|
||||||
|
def __eq__(self, __value: object) -> bool:
|
||||||
|
if isinstance(__value, File):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not isinstance(__value, Folder):
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
if self.name != __value.name:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if len(self.files) != len(__value.files):
|
||||||
|
return False
|
||||||
|
|
||||||
|
for self_file, other_file in zip(self.files, __value.files):
|
||||||
|
if self_file != other_file:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
@ -0,0 +1,55 @@
|
|||||||
|
# ruff: noqa: E402
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
pytest.importorskip("libcst")
|
||||||
|
|
||||||
|
import difflib
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from typer.testing import CliRunner
|
||||||
|
|
||||||
|
from langchain_cli.namespaces.migrate.main import app
|
||||||
|
from tests.unit_tests.migrate.cli_runner.cases import before, expected
|
||||||
|
from tests.unit_tests.migrate.cli_runner.folder import Folder
|
||||||
|
|
||||||
|
|
||||||
|
def find_issue(current: Folder, expected: Folder) -> str:
|
||||||
|
for current_file, expected_file in zip(current.files, expected.files):
|
||||||
|
if current_file != expected_file:
|
||||||
|
if current_file.name != expected_file.name:
|
||||||
|
return (
|
||||||
|
f"Files have "
|
||||||
|
f"different names: {current_file.name} != {expected_file.name}"
|
||||||
|
)
|
||||||
|
if isinstance(current_file, Folder) and isinstance(expected_file, Folder):
|
||||||
|
return find_issue(current_file, expected_file)
|
||||||
|
elif isinstance(current_file, Folder) or isinstance(expected_file, Folder):
|
||||||
|
return (
|
||||||
|
f"One of the files is a "
|
||||||
|
f"folder: {current_file.name} != {expected_file.name}"
|
||||||
|
)
|
||||||
|
return "\n".join(
|
||||||
|
difflib.unified_diff(
|
||||||
|
current_file.content.splitlines(),
|
||||||
|
expected_file.content.splitlines(),
|
||||||
|
fromfile=current_file.name,
|
||||||
|
tofile=expected_file.name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return "Unknown"
|
||||||
|
|
||||||
|
|
||||||
|
def test_command_line(tmp_path: Path) -> None:
|
||||||
|
runner = CliRunner()
|
||||||
|
|
||||||
|
with runner.isolated_filesystem(temp_dir=tmp_path) as td:
|
||||||
|
before.create_structure(root=Path(td))
|
||||||
|
# The input is used to force through the confirmation.
|
||||||
|
result = runner.invoke(app, [before.name], input="y\n")
|
||||||
|
assert result.exit_code == 0, result.output
|
||||||
|
|
||||||
|
after = Folder.from_structure(Path(td) / before.name)
|
||||||
|
|
||||||
|
assert after == expected, find_issue(after, expected)
|
@ -0,0 +1,72 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain_cli.namespaces.migrate.glob_helpers import glob_to_re, match_glob
|
||||||
|
|
||||||
|
|
||||||
|
class TestGlobHelpers:
|
||||||
|
match_glob_values: list[tuple[str, Path, bool]] = [
|
||||||
|
("foo", Path("foo"), True),
|
||||||
|
("foo", Path("bar"), False),
|
||||||
|
("foo", Path("foo/bar"), False),
|
||||||
|
("*", Path("foo"), True),
|
||||||
|
("*", Path("bar"), True),
|
||||||
|
("*", Path("foo/bar"), False),
|
||||||
|
("**", Path("foo"), True),
|
||||||
|
("**", Path("foo/bar"), True),
|
||||||
|
("**", Path("foo/bar/baz/qux"), True),
|
||||||
|
("foo/bar", Path("foo/bar"), True),
|
||||||
|
("foo/bar", Path("foo"), False),
|
||||||
|
("foo/bar", Path("far"), False),
|
||||||
|
("foo/bar", Path("foo/foo"), False),
|
||||||
|
("foo/*", Path("foo/bar"), True),
|
||||||
|
("foo/*", Path("foo/bar/baz"), False),
|
||||||
|
("foo/*", Path("foo"), False),
|
||||||
|
("foo/*", Path("bar"), False),
|
||||||
|
("foo/**", Path("foo/bar"), True),
|
||||||
|
("foo/**", Path("foo/bar/baz"), True),
|
||||||
|
("foo/**", Path("foo/bar/baz/qux"), True),
|
||||||
|
("foo/**", Path("foo"), True),
|
||||||
|
("foo/**", Path("bar"), False),
|
||||||
|
("foo/**/bar", Path("foo/bar"), True),
|
||||||
|
("foo/**/bar", Path("foo/baz/bar"), True),
|
||||||
|
("foo/**/bar", Path("foo/baz/qux/bar"), True),
|
||||||
|
("foo/**/bar", Path("foo/baz/qux"), False),
|
||||||
|
("foo/**/bar", Path("foo/bar/baz"), False),
|
||||||
|
("foo/**/bar", Path("foo/bar/bar"), True),
|
||||||
|
("foo/**/bar", Path("foo"), False),
|
||||||
|
("foo/**/bar", Path("bar"), False),
|
||||||
|
("foo/**/*/bar", Path("foo/bar"), False),
|
||||||
|
("foo/**/*/bar", Path("foo/baz/bar"), True),
|
||||||
|
("foo/**/*/bar", Path("foo/baz/qux/bar"), True),
|
||||||
|
("foo/**/*/bar", Path("foo/baz/qux"), False),
|
||||||
|
("foo/**/*/bar", Path("foo/bar/baz"), False),
|
||||||
|
("foo/**/*/bar", Path("foo/bar/bar"), True),
|
||||||
|
("foo/**/*/bar", Path("foo"), False),
|
||||||
|
("foo/**/*/bar", Path("bar"), False),
|
||||||
|
("foo/ba*", Path("foo/bar"), True),
|
||||||
|
("foo/ba*", Path("foo/baz"), True),
|
||||||
|
("foo/ba*", Path("foo/qux"), False),
|
||||||
|
("foo/ba*", Path("foo/baz/qux"), False),
|
||||||
|
("foo/ba*", Path("foo/bar/baz"), False),
|
||||||
|
("foo/ba*", Path("foo"), False),
|
||||||
|
("foo/ba*", Path("bar"), False),
|
||||||
|
("foo/**/ba*/*/qux", Path("foo/a/b/c/bar/a/qux"), True),
|
||||||
|
("foo/**/ba*/*/qux", Path("foo/a/b/c/baz/a/qux"), True),
|
||||||
|
("foo/**/ba*/*/qux", Path("foo/a/bar/a/qux"), True),
|
||||||
|
("foo/**/ba*/*/qux", Path("foo/baz/a/qux"), True),
|
||||||
|
("foo/**/ba*/*/qux", Path("foo/baz/qux"), False),
|
||||||
|
("foo/**/ba*/*/qux", Path("foo/a/b/c/qux/a/qux"), False),
|
||||||
|
("foo/**/ba*/*/qux", Path("foo"), False),
|
||||||
|
("foo/**/ba*/*/qux", Path("bar"), False),
|
||||||
|
]
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(("pattern", "path", "expected"), match_glob_values)
|
||||||
|
def test_match_glob(self, pattern: str, path: Path, expected: bool):
|
||||||
|
expr = glob_to_re(pattern)
|
||||||
|
assert (
|
||||||
|
match_glob(path, pattern) == expected
|
||||||
|
), f"path: {path}, pattern: {pattern}, expr: {expr}"
|
@ -0,0 +1,41 @@
|
|||||||
|
# ruff: noqa: E402
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
pytest.importorskip("libcst")
|
||||||
|
|
||||||
|
|
||||||
|
from libcst.codemod import CodemodTest
|
||||||
|
|
||||||
|
from langchain_cli.namespaces.migrate.codemods.replace_imports import (
|
||||||
|
ReplaceImportsCodemod,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestReplaceImportsCommand(CodemodTest):
|
||||||
|
TRANSFORM = ReplaceImportsCodemod
|
||||||
|
|
||||||
|
def test_single_import(self) -> None:
|
||||||
|
before = """
|
||||||
|
from langchain.chat_models import ChatOpenAI
|
||||||
|
"""
|
||||||
|
after = """
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
"""
|
||||||
|
self.assertCodemod(before, after)
|
||||||
|
|
||||||
|
def test_noop_import(self) -> None:
|
||||||
|
code = """
|
||||||
|
from foo import ChatOpenAI
|
||||||
|
"""
|
||||||
|
self.assertCodemod(code, code)
|
||||||
|
|
||||||
|
def test_mixed_imports(self) -> None:
|
||||||
|
before = """
|
||||||
|
from langchain.chat_models import ChatOpenAI, ChatAnthropic, foo
|
||||||
|
"""
|
||||||
|
after = """
|
||||||
|
from langchain.chat_models import foo
|
||||||
|
from langchain_anthropic import ChatAnthropic
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
"""
|
||||||
|
self.assertCodemod(before, after)
|
Loading…
Reference in New Issue