mirror of https://github.com/hwchase17/langchain
cli[minor]: Add first version of migrate (#20902)
Adds a first version of the migrate script.pull/20932/head
parent
d95e9fb67f
commit
6598757037
@ -0,0 +1,25 @@
|
||||
from enum import Enum
|
||||
from typing import List, Type
|
||||
|
||||
from libcst.codemod import ContextAwareTransformer
|
||||
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor
|
||||
|
||||
from langchain_cli.namespaces.migrate.codemods.replace_imports import (
|
||||
ReplaceImportsCodemod,
|
||||
)
|
||||
|
||||
|
||||
class Rule(str, Enum):
|
||||
R001 = "R001"
|
||||
"""Replace imports that have been moved."""
|
||||
|
||||
|
||||
def gather_codemods(disabled: List[Rule]) -> List[Type[ContextAwareTransformer]]:
|
||||
codemods: List[Type[ContextAwareTransformer]] = []
|
||||
|
||||
if Rule.R001 not in disabled:
|
||||
codemods.append(ReplaceImportsCodemod)
|
||||
|
||||
# Those codemods need to be the last ones.
|
||||
codemods.extend([RemoveImportsVisitor, AddImportsVisitor])
|
||||
return codemods
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,14 @@
|
||||
[
|
||||
[
|
||||
"langchain.chat_models.ChatOpenAI",
|
||||
"langchain_openai.ChatOpenAI"
|
||||
],
|
||||
[
|
||||
"langchain.chat_models.ChatOpenAI",
|
||||
"langchain_openai.ChatOpenAI"
|
||||
],
|
||||
[
|
||||
"langchain.chat_models.ChatAnthropic",
|
||||
"langchain_anthropic.ChatAnthropic"
|
||||
]
|
||||
]
|
@ -0,0 +1,205 @@
|
||||
"""
|
||||
# Adapted from bump-pydantic
|
||||
# https://github.com/pydantic/bump-pydantic
|
||||
|
||||
This codemod deals with the following cases:
|
||||
|
||||
1. `from pydantic import BaseSettings`
|
||||
2. `from pydantic.settings import BaseSettings`
|
||||
3. `from pydantic import BaseSettings as <name>`
|
||||
4. `from pydantic.settings import BaseSettings as <name>` # TODO: This is not working.
|
||||
5. `import pydantic` -> `pydantic.BaseSettings`
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, Iterable, List, Sequence, Tuple, TypeVar
|
||||
|
||||
import libcst as cst
|
||||
import libcst.matchers as m
|
||||
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
|
||||
from libcst.codemod.visitors import AddImportsVisitor
|
||||
|
||||
HERE = os.path.dirname(__file__)
|
||||
|
||||
|
||||
def _load_migrations_by_file(path: str):
|
||||
migrations_path = os.path.join(HERE, path)
|
||||
with open(migrations_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return data
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _deduplicate_in_order(
|
||||
seq: Iterable[T], key: Callable[[T], str] = lambda x: x
|
||||
) -> List[T]:
|
||||
seen = set()
|
||||
seen_add = seen.add
|
||||
return [x for x in seq if not (key(x) in seen or seen_add(key(x)))]
|
||||
|
||||
|
||||
def _load_migrations():
|
||||
"""Load the migrations from the JSON file."""
|
||||
# Later earlier ones have higher precedence.
|
||||
paths = [
|
||||
"migrations_v0.2_partner.json",
|
||||
"migrations_v0.2.json",
|
||||
]
|
||||
|
||||
data = []
|
||||
for path in paths:
|
||||
data.extend(_load_migrations_by_file(path))
|
||||
|
||||
data = _deduplicate_in_order(data, key=lambda x: x[0])
|
||||
|
||||
imports: Dict[str, Tuple[str, str]] = {}
|
||||
|
||||
for old_path, new_path in data:
|
||||
# Parse the old parse which is of the format 'langchain.chat_models.ChatOpenAI'
|
||||
# into the module and class name.
|
||||
old_parts = old_path.split(".")
|
||||
old_module = ".".join(old_parts[:-1])
|
||||
old_class = old_parts[-1]
|
||||
old_path_str = f"{old_module}:{old_class}"
|
||||
|
||||
# Parse the new parse which is of the format 'langchain.chat_models.ChatOpenAI'
|
||||
# Into a 2-tuple of the module and class name.
|
||||
new_parts = new_path.split(".")
|
||||
new_module = ".".join(new_parts[:-1])
|
||||
new_class = new_parts[-1]
|
||||
new_path_str = (new_module, new_class)
|
||||
|
||||
imports[old_path_str] = new_path_str
|
||||
|
||||
return imports
|
||||
|
||||
|
||||
IMPORTS = _load_migrations()
|
||||
|
||||
|
||||
def resolve_module_parts(module_parts: list[str]) -> m.Attribute | m.Name:
|
||||
"""Converts a list of module parts to a `Name` or `Attribute` node."""
|
||||
if len(module_parts) == 1:
|
||||
return m.Name(module_parts[0])
|
||||
if len(module_parts) == 2:
|
||||
first, last = module_parts
|
||||
return m.Attribute(value=m.Name(first), attr=m.Name(last))
|
||||
last_name = module_parts.pop()
|
||||
attr = resolve_module_parts(module_parts)
|
||||
return m.Attribute(value=attr, attr=m.Name(last_name))
|
||||
|
||||
|
||||
def get_import_from_from_str(import_str: str) -> m.ImportFrom:
|
||||
"""Converts a string like `pydantic:BaseSettings` to Examples:
|
||||
>>> get_import_from_from_str("pydantic:BaseSettings")
|
||||
ImportFrom(
|
||||
module=Name("pydantic"),
|
||||
names=[ImportAlias(name=Name("BaseSettings"))],
|
||||
)
|
||||
>>> get_import_from_from_str("pydantic.settings:BaseSettings")
|
||||
ImportFrom(
|
||||
module=Attribute(value=Name("pydantic"), attr=Name("settings")),
|
||||
names=[ImportAlias(name=Name("BaseSettings"))],
|
||||
)
|
||||
>>> get_import_from_from_str("a.b.c:d")
|
||||
ImportFrom(
|
||||
module=Attribute(
|
||||
value=Attribute(value=Name("a"), attr=Name("b")), attr=Name("c")
|
||||
),
|
||||
names=[ImportAlias(name=Name("d"))],
|
||||
)
|
||||
"""
|
||||
module, name = import_str.split(":")
|
||||
module_parts = module.split(".")
|
||||
module_node = resolve_module_parts(module_parts)
|
||||
return m.ImportFrom(
|
||||
module=module_node,
|
||||
names=[m.ZeroOrMore(), m.ImportAlias(name=m.Name(value=name)), m.ZeroOrMore()],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImportInfo:
|
||||
import_from: m.ImportFrom
|
||||
import_str: str
|
||||
to_import_str: tuple[str, str]
|
||||
|
||||
|
||||
IMPORT_INFOS = [
|
||||
ImportInfo(
|
||||
import_from=get_import_from_from_str(import_str),
|
||||
import_str=import_str,
|
||||
to_import_str=to_import_str,
|
||||
)
|
||||
for import_str, to_import_str in IMPORTS.items()
|
||||
]
|
||||
IMPORT_MATCH = m.OneOf(*[info.import_from for info in IMPORT_INFOS])
|
||||
|
||||
|
||||
class ReplaceImportsCodemod(VisitorBasedCodemodCommand):
|
||||
@m.leave(IMPORT_MATCH)
|
||||
def leave_replace_import(
|
||||
self, _: cst.ImportFrom, updated_node: cst.ImportFrom
|
||||
) -> cst.ImportFrom:
|
||||
for import_info in IMPORT_INFOS:
|
||||
if m.matches(updated_node, import_info.import_from):
|
||||
aliases: Sequence[cst.ImportAlias] = updated_node.names # type: ignore
|
||||
# If multiple objects are imported in a single import statement,
|
||||
# we need to remove only the one we're replacing.
|
||||
AddImportsVisitor.add_needed_import(
|
||||
self.context, *import_info.to_import_str
|
||||
)
|
||||
if len(updated_node.names) > 1: # type: ignore
|
||||
names = [
|
||||
alias
|
||||
for alias in aliases
|
||||
if alias.name.value != import_info.to_import_str[-1]
|
||||
]
|
||||
names[-1] = names[-1].with_changes(comma=cst.MaybeSentinel.DEFAULT)
|
||||
updated_node = updated_node.with_changes(names=names)
|
||||
else:
|
||||
return cst.RemoveFromParent() # type: ignore[return-value]
|
||||
return updated_node
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import textwrap
|
||||
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
|
||||
source = textwrap.dedent(
|
||||
"""
|
||||
from pydantic.settings import BaseSettings
|
||||
from pydantic.color import Color
|
||||
from pydantic.payment import PaymentCardNumber, PaymentCardBrand
|
||||
from pydantic import Color
|
||||
from pydantic import Color as Potato
|
||||
|
||||
|
||||
class Potato(BaseSettings):
|
||||
color: Color
|
||||
payment: PaymentCardNumber
|
||||
brand: PaymentCardBrand
|
||||
potato: Potato
|
||||
"""
|
||||
)
|
||||
console.print(source)
|
||||
console.print("=" * 80)
|
||||
|
||||
mod = cst.parse_module(source)
|
||||
context = CodemodContext(filename="main.py")
|
||||
wrapper = cst.MetadataWrapper(mod)
|
||||
command = ReplaceImportsCodemod(context=context)
|
||||
|
||||
mod = wrapper.visit(command)
|
||||
wrapper = cst.MetadataWrapper(mod)
|
||||
command = AddImportsVisitor(context=context) # type: ignore[assignment]
|
||||
mod = wrapper.visit(command)
|
||||
console.print(mod.code)
|
@ -0,0 +1,52 @@
|
||||
# Adapted from bump-pydantic
|
||||
# https://github.com/pydantic/bump-pydantic
|
||||
import fnmatch
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
MATCH_SEP = r"(?:/|\\)"
|
||||
MATCH_SEP_OR_END = r"(?:/|\\|\Z)"
|
||||
MATCH_NON_RECURSIVE = r"[^/\\]*"
|
||||
MATCH_RECURSIVE = r"(?:.*)"
|
||||
|
||||
|
||||
def glob_to_re(pattern: str) -> str:
|
||||
"""Translate a glob pattern to a regular expression for matching."""
|
||||
fragments: List[str] = []
|
||||
for segment in re.split(r"/|\\", pattern):
|
||||
if segment == "":
|
||||
continue
|
||||
if segment == "**":
|
||||
# Remove previous separator match, so the recursive match c
|
||||
# can match zero or more segments.
|
||||
if fragments and fragments[-1] == MATCH_SEP:
|
||||
fragments.pop()
|
||||
fragments.append(MATCH_RECURSIVE)
|
||||
elif "**" in segment:
|
||||
raise ValueError(
|
||||
"invalid pattern: '**' can only be an entire path component"
|
||||
)
|
||||
else:
|
||||
fragment = fnmatch.translate(segment)
|
||||
fragment = fragment.replace(r"(?s:", r"(?:")
|
||||
fragment = fragment.replace(r".*", MATCH_NON_RECURSIVE)
|
||||
fragment = fragment.replace(r"\Z", r"")
|
||||
fragments.append(fragment)
|
||||
fragments.append(MATCH_SEP)
|
||||
# Remove trailing MATCH_SEP, so it can be replaced with MATCH_SEP_OR_END.
|
||||
if fragments and fragments[-1] == MATCH_SEP:
|
||||
fragments.pop()
|
||||
fragments.append(MATCH_SEP_OR_END)
|
||||
return rf"(?s:{''.join(fragments)})"
|
||||
|
||||
|
||||
def match_glob(path: Path, pattern: str) -> bool:
|
||||
"""Check if a path matches a glob pattern.
|
||||
|
||||
If the pattern ends with a directory separator, the path must be a directory.
|
||||
"""
|
||||
match = bool(re.fullmatch(glob_to_re(pattern), str(path)))
|
||||
if pattern.endswith("/") or pattern.endswith("\\"):
|
||||
return match and path.is_dir()
|
||||
return match
|
@ -0,0 +1,194 @@
|
||||
"""Migrate LangChain to the most recent version."""
|
||||
# Adapted from bump-pydantic
|
||||
# https://github.com/pydantic/bump-pydantic
|
||||
import difflib
|
||||
import functools
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Tuple, Type, TypeVar, Union
|
||||
|
||||
import libcst as cst
|
||||
import rich
|
||||
import typer
|
||||
from libcst.codemod import CodemodContext, ContextAwareTransformer
|
||||
from libcst.helpers import calculate_module_and_package
|
||||
from libcst.metadata import FullRepoManager, FullyQualifiedNameProvider, ScopeProvider
|
||||
from rich.console import Console
|
||||
from rich.progress import Progress
|
||||
from typer import Argument, Exit, Option, Typer
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from langchain_cli.namespaces.migrate.codemods import Rule, gather_codemods
|
||||
from langchain_cli.namespaces.migrate.glob_helpers import match_glob
|
||||
|
||||
app = Typer(invoke_without_command=True, add_completion=False)
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
DEFAULT_IGNORES = [".venv/**"]
|
||||
|
||||
|
||||
@app.callback()
|
||||
def main(
|
||||
path: Path = Argument(..., exists=True, dir_okay=True, allow_dash=False),
|
||||
disable: List[Rule] = Option(default=[], help="Disable a rule."),
|
||||
diff: bool = Option(False, help="Show diff instead of applying changes."),
|
||||
ignore: List[str] = Option(
|
||||
default=DEFAULT_IGNORES, help="Ignore a path glob pattern."
|
||||
),
|
||||
log_file: Path = Option("log.txt", help="Log errors to this file."),
|
||||
):
|
||||
"""Migrate langchain to the most recent version."""
|
||||
if not diff:
|
||||
rich.print("[bold red]Alert![/ bold red] langchain-cli migrate", end=": ")
|
||||
if not typer.confirm(
|
||||
"The migration process will modify your files. "
|
||||
"The migration is a `best-effort` process and is not expected to "
|
||||
"be perfect. "
|
||||
"Do you want to continue?"
|
||||
):
|
||||
raise Exit()
|
||||
console = Console(log_time=True)
|
||||
console.log("Start langchain-cli migrate")
|
||||
# NOTE: LIBCST_PARSER_TYPE=native is required according to https://github.com/Instagram/LibCST/issues/487.
|
||||
os.environ["LIBCST_PARSER_TYPE"] = "native"
|
||||
|
||||
if os.path.isfile(path):
|
||||
package = path.parent
|
||||
all_files = [path]
|
||||
else:
|
||||
package = path
|
||||
all_files = sorted(package.glob("**/*.py"))
|
||||
|
||||
filtered_files = [
|
||||
file
|
||||
for file in all_files
|
||||
if not any(match_glob(file, pattern) for pattern in ignore)
|
||||
]
|
||||
files = [str(file.relative_to(".")) for file in filtered_files]
|
||||
|
||||
if len(files) == 1:
|
||||
console.log("Found 1 file to process.")
|
||||
elif len(files) > 1:
|
||||
console.log(f"Found {len(files)} files to process.")
|
||||
else:
|
||||
console.log("No files to process.")
|
||||
raise Exit()
|
||||
|
||||
providers = {FullyQualifiedNameProvider, ScopeProvider}
|
||||
metadata_manager = FullRepoManager(".", files, providers=providers) # type: ignore[arg-type]
|
||||
metadata_manager.resolve_cache()
|
||||
|
||||
scratch: dict[str, Any] = {}
|
||||
start_time = time.time()
|
||||
|
||||
codemods = gather_codemods(disabled=disable)
|
||||
|
||||
log_fp = log_file.open("a+", encoding="utf8")
|
||||
partial_run_codemods = functools.partial(
|
||||
run_codemods, codemods, metadata_manager, scratch, package, diff
|
||||
)
|
||||
with Progress(*Progress.get_default_columns(), transient=True) as progress:
|
||||
task = progress.add_task(description="Executing codemods...", total=len(files))
|
||||
count_errors = 0
|
||||
difflines: List[List[str]] = []
|
||||
with multiprocessing.Pool() as pool:
|
||||
for error, _difflines in pool.imap_unordered(partial_run_codemods, files):
|
||||
progress.advance(task)
|
||||
|
||||
if _difflines is not None:
|
||||
difflines.append(_difflines)
|
||||
|
||||
if error is not None:
|
||||
count_errors += 1
|
||||
log_fp.writelines(error)
|
||||
|
||||
modified = [Path(f) for f in files if os.stat(f).st_mtime > start_time]
|
||||
|
||||
if not diff:
|
||||
if modified:
|
||||
console.log(f"Refactored {len(modified)} files.")
|
||||
else:
|
||||
console.log("No files were modified.")
|
||||
|
||||
for _difflines in difflines:
|
||||
color_diff(console, _difflines)
|
||||
|
||||
if count_errors > 0:
|
||||
console.log(f"Found {count_errors} errors. Please check the {log_file} file.")
|
||||
else:
|
||||
console.log("Run successfully!")
|
||||
|
||||
if difflines:
|
||||
raise Exit(1)
|
||||
|
||||
|
||||
def run_codemods(
|
||||
codemods: List[Type[ContextAwareTransformer]],
|
||||
metadata_manager: FullRepoManager,
|
||||
scratch: Dict[str, Any],
|
||||
package: Path,
|
||||
diff: bool,
|
||||
filename: str,
|
||||
) -> Tuple[Union[str, None], Union[List[str], None]]:
|
||||
try:
|
||||
module_and_package = calculate_module_and_package(str(package), filename)
|
||||
context = CodemodContext(
|
||||
metadata_manager=metadata_manager,
|
||||
filename=filename,
|
||||
full_module_name=module_and_package.name,
|
||||
full_package_name=module_and_package.package,
|
||||
)
|
||||
context.scratch.update(scratch)
|
||||
|
||||
file_path = Path(filename)
|
||||
with file_path.open("r+", encoding="utf-8") as fp:
|
||||
code = fp.read()
|
||||
fp.seek(0)
|
||||
|
||||
input_tree = cst.parse_module(code)
|
||||
|
||||
for codemod in codemods:
|
||||
transformer = codemod(context=context)
|
||||
output_tree = transformer.transform_module(input_tree)
|
||||
input_tree = output_tree
|
||||
|
||||
output_code = input_tree.code
|
||||
if code != output_code:
|
||||
if diff:
|
||||
lines = difflib.unified_diff(
|
||||
code.splitlines(keepends=True),
|
||||
output_code.splitlines(keepends=True),
|
||||
fromfile=filename,
|
||||
tofile=filename,
|
||||
)
|
||||
return None, list(lines)
|
||||
else:
|
||||
fp.write(output_code)
|
||||
fp.truncate()
|
||||
return None, None
|
||||
except cst.ParserSyntaxError as exc:
|
||||
return (
|
||||
f"A syntax error happened on {filename}. This file cannot be "
|
||||
f"formatted.\n"
|
||||
f"{exc}"
|
||||
), None
|
||||
except Exception:
|
||||
return f"An error happened on {filename}.\n{traceback.format_exc()}", None
|
||||
|
||||
|
||||
def color_diff(console: Console, lines: Iterable[str]) -> None:
|
||||
for line in lines:
|
||||
line = line.rstrip("\n")
|
||||
if line.startswith("+"):
|
||||
console.print(line, style="green")
|
||||
elif line.startswith("-"):
|
||||
console.print(line, style="red")
|
||||
elif line.startswith("^"):
|
||||
console.print(line, style="blue")
|
||||
else:
|
||||
console.print(line, style="white")
|
@ -0,0 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .file import File
|
||||
from .folder import Folder
|
||||
|
||||
|
||||
@dataclass
|
||||
class Case:
|
||||
source: Folder | File
|
||||
expected: Folder | File
|
||||
name: str
|
@ -0,0 +1,15 @@
|
||||
from tests.unit_tests.migrate.integration.case import Case
|
||||
from tests.unit_tests.migrate.integration.cases import imports
|
||||
from tests.unit_tests.migrate.integration.file import File
|
||||
from tests.unit_tests.migrate.integration.folder import Folder
|
||||
|
||||
cases = [
|
||||
Case(
|
||||
name="empty",
|
||||
source=File("__init__.py", content=[]),
|
||||
expected=File("__init__.py", content=[]),
|
||||
),
|
||||
*imports.cases,
|
||||
]
|
||||
before = Folder("project", *[case.source for case in cases])
|
||||
expected = Folder("project", *[case.expected for case in cases])
|
@ -0,0 +1,32 @@
|
||||
from tests.unit_tests.migrate.integration.case import Case
|
||||
from tests.unit_tests.migrate.integration.file import File
|
||||
|
||||
cases = [
|
||||
Case(
|
||||
name="Imports",
|
||||
source=File(
|
||||
"app.py",
|
||||
content=[
|
||||
"from langchain.chat_models import ChatOpenAI",
|
||||
"",
|
||||
"",
|
||||
"class foo:",
|
||||
" a: int",
|
||||
"",
|
||||
"chain = ChatOpenAI()",
|
||||
],
|
||||
),
|
||||
expected=File(
|
||||
"app.py",
|
||||
content=[
|
||||
"from langchain_openai import ChatOpenAI",
|
||||
"",
|
||||
"",
|
||||
"class foo:",
|
||||
" a: int",
|
||||
"",
|
||||
"chain = ChatOpenAI()",
|
||||
],
|
||||
),
|
||||
),
|
||||
]
|
@ -0,0 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class File:
|
||||
def __init__(self, name: str, content: list[str] | None = None) -> None:
|
||||
self.name = name
|
||||
self.content = "\n".join(content or [])
|
||||
|
||||
def __eq__(self, __value: object) -> bool:
|
||||
if not isinstance(__value, File):
|
||||
return NotImplemented
|
||||
|
||||
if self.name != __value.name:
|
||||
return False
|
||||
|
||||
return self.content == __value.content
|
@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from .file import File
|
||||
|
||||
|
||||
class Folder:
|
||||
def __init__(self, name: str, *files: Folder | File) -> None:
|
||||
self.name = name
|
||||
self._files = files
|
||||
|
||||
@property
|
||||
def files(self) -> list[Folder | File]:
|
||||
return sorted(self._files, key=lambda f: f.name)
|
||||
|
||||
def create_structure(self, root: Path) -> None:
|
||||
path = root / self.name
|
||||
path.mkdir()
|
||||
|
||||
for file in self.files:
|
||||
if isinstance(file, Folder):
|
||||
file.create_structure(path)
|
||||
else:
|
||||
(path / file.name).write_text(file.content, encoding="utf-8")
|
||||
|
||||
@classmethod
|
||||
def from_structure(cls, root: Path) -> Folder:
|
||||
name = root.name
|
||||
files: list[File | Folder] = []
|
||||
|
||||
for path in root.iterdir():
|
||||
if path.is_dir():
|
||||
files.append(cls.from_structure(path))
|
||||
else:
|
||||
files.append(
|
||||
File(path.name, path.read_text(encoding="utf-8").splitlines())
|
||||
)
|
||||
|
||||
return Folder(name, *files)
|
||||
|
||||
def __eq__(self, __value: object) -> bool:
|
||||
if isinstance(__value, File):
|
||||
return False
|
||||
|
||||
if not isinstance(__value, Folder):
|
||||
return NotImplemented
|
||||
|
||||
if self.name != __value.name:
|
||||
return False
|
||||
|
||||
if len(self.files) != len(__value.files):
|
||||
return False
|
||||
|
||||
for self_file, other_file in zip(self.files, __value.files):
|
||||
if self_file != other_file:
|
||||
return False
|
||||
|
||||
return True
|
@ -0,0 +1,55 @@
|
||||
# ruff: noqa: E402
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("libcst")
|
||||
|
||||
import difflib
|
||||
from pathlib import Path
|
||||
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from langchain_cli.namespaces.migrate.main import app
|
||||
from tests.unit_tests.migrate.cli_runner.cases import before, expected
|
||||
from tests.unit_tests.migrate.cli_runner.folder import Folder
|
||||
|
||||
|
||||
def find_issue(current: Folder, expected: Folder) -> str:
|
||||
for current_file, expected_file in zip(current.files, expected.files):
|
||||
if current_file != expected_file:
|
||||
if current_file.name != expected_file.name:
|
||||
return (
|
||||
f"Files have "
|
||||
f"different names: {current_file.name} != {expected_file.name}"
|
||||
)
|
||||
if isinstance(current_file, Folder) and isinstance(expected_file, Folder):
|
||||
return find_issue(current_file, expected_file)
|
||||
elif isinstance(current_file, Folder) or isinstance(expected_file, Folder):
|
||||
return (
|
||||
f"One of the files is a "
|
||||
f"folder: {current_file.name} != {expected_file.name}"
|
||||
)
|
||||
return "\n".join(
|
||||
difflib.unified_diff(
|
||||
current_file.content.splitlines(),
|
||||
expected_file.content.splitlines(),
|
||||
fromfile=current_file.name,
|
||||
tofile=expected_file.name,
|
||||
)
|
||||
)
|
||||
return "Unknown"
|
||||
|
||||
|
||||
def test_command_line(tmp_path: Path) -> None:
|
||||
runner = CliRunner()
|
||||
|
||||
with runner.isolated_filesystem(temp_dir=tmp_path) as td:
|
||||
before.create_structure(root=Path(td))
|
||||
# The input is used to force through the confirmation.
|
||||
result = runner.invoke(app, [before.name], input="y\n")
|
||||
assert result.exit_code == 0, result.output
|
||||
|
||||
after = Folder.from_structure(Path(td) / before.name)
|
||||
|
||||
assert after == expected, find_issue(after, expected)
|
@ -0,0 +1,72 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_cli.namespaces.migrate.glob_helpers import glob_to_re, match_glob
|
||||
|
||||
|
||||
class TestGlobHelpers:
|
||||
match_glob_values: list[tuple[str, Path, bool]] = [
|
||||
("foo", Path("foo"), True),
|
||||
("foo", Path("bar"), False),
|
||||
("foo", Path("foo/bar"), False),
|
||||
("*", Path("foo"), True),
|
||||
("*", Path("bar"), True),
|
||||
("*", Path("foo/bar"), False),
|
||||
("**", Path("foo"), True),
|
||||
("**", Path("foo/bar"), True),
|
||||
("**", Path("foo/bar/baz/qux"), True),
|
||||
("foo/bar", Path("foo/bar"), True),
|
||||
("foo/bar", Path("foo"), False),
|
||||
("foo/bar", Path("far"), False),
|
||||
("foo/bar", Path("foo/foo"), False),
|
||||
("foo/*", Path("foo/bar"), True),
|
||||
("foo/*", Path("foo/bar/baz"), False),
|
||||
("foo/*", Path("foo"), False),
|
||||
("foo/*", Path("bar"), False),
|
||||
("foo/**", Path("foo/bar"), True),
|
||||
("foo/**", Path("foo/bar/baz"), True),
|
||||
("foo/**", Path("foo/bar/baz/qux"), True),
|
||||
("foo/**", Path("foo"), True),
|
||||
("foo/**", Path("bar"), False),
|
||||
("foo/**/bar", Path("foo/bar"), True),
|
||||
("foo/**/bar", Path("foo/baz/bar"), True),
|
||||
("foo/**/bar", Path("foo/baz/qux/bar"), True),
|
||||
("foo/**/bar", Path("foo/baz/qux"), False),
|
||||
("foo/**/bar", Path("foo/bar/baz"), False),
|
||||
("foo/**/bar", Path("foo/bar/bar"), True),
|
||||
("foo/**/bar", Path("foo"), False),
|
||||
("foo/**/bar", Path("bar"), False),
|
||||
("foo/**/*/bar", Path("foo/bar"), False),
|
||||
("foo/**/*/bar", Path("foo/baz/bar"), True),
|
||||
("foo/**/*/bar", Path("foo/baz/qux/bar"), True),
|
||||
("foo/**/*/bar", Path("foo/baz/qux"), False),
|
||||
("foo/**/*/bar", Path("foo/bar/baz"), False),
|
||||
("foo/**/*/bar", Path("foo/bar/bar"), True),
|
||||
("foo/**/*/bar", Path("foo"), False),
|
||||
("foo/**/*/bar", Path("bar"), False),
|
||||
("foo/ba*", Path("foo/bar"), True),
|
||||
("foo/ba*", Path("foo/baz"), True),
|
||||
("foo/ba*", Path("foo/qux"), False),
|
||||
("foo/ba*", Path("foo/baz/qux"), False),
|
||||
("foo/ba*", Path("foo/bar/baz"), False),
|
||||
("foo/ba*", Path("foo"), False),
|
||||
("foo/ba*", Path("bar"), False),
|
||||
("foo/**/ba*/*/qux", Path("foo/a/b/c/bar/a/qux"), True),
|
||||
("foo/**/ba*/*/qux", Path("foo/a/b/c/baz/a/qux"), True),
|
||||
("foo/**/ba*/*/qux", Path("foo/a/bar/a/qux"), True),
|
||||
("foo/**/ba*/*/qux", Path("foo/baz/a/qux"), True),
|
||||
("foo/**/ba*/*/qux", Path("foo/baz/qux"), False),
|
||||
("foo/**/ba*/*/qux", Path("foo/a/b/c/qux/a/qux"), False),
|
||||
("foo/**/ba*/*/qux", Path("foo"), False),
|
||||
("foo/**/ba*/*/qux", Path("bar"), False),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize(("pattern", "path", "expected"), match_glob_values)
|
||||
def test_match_glob(self, pattern: str, path: Path, expected: bool):
|
||||
expr = glob_to_re(pattern)
|
||||
assert (
|
||||
match_glob(path, pattern) == expected
|
||||
), f"path: {path}, pattern: {pattern}, expr: {expr}"
|
@ -0,0 +1,41 @@
|
||||
# ruff: noqa: E402
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("libcst")
|
||||
|
||||
|
||||
from libcst.codemod import CodemodTest
|
||||
|
||||
from langchain_cli.namespaces.migrate.codemods.replace_imports import (
|
||||
ReplaceImportsCodemod,
|
||||
)
|
||||
|
||||
|
||||
class TestReplaceImportsCommand(CodemodTest):
|
||||
TRANSFORM = ReplaceImportsCodemod
|
||||
|
||||
def test_single_import(self) -> None:
|
||||
before = """
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
"""
|
||||
after = """
|
||||
from langchain_openai import ChatOpenAI
|
||||
"""
|
||||
self.assertCodemod(before, after)
|
||||
|
||||
def test_noop_import(self) -> None:
|
||||
code = """
|
||||
from foo import ChatOpenAI
|
||||
"""
|
||||
self.assertCodemod(code, code)
|
||||
|
||||
def test_mixed_imports(self) -> None:
|
||||
before = """
|
||||
from langchain.chat_models import ChatOpenAI, ChatAnthropic, foo
|
||||
"""
|
||||
after = """
|
||||
from langchain.chat_models import foo
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_openai import ChatOpenAI
|
||||
"""
|
||||
self.assertCodemod(before, after)
|
Loading…
Reference in New Issue