2024-04-26 14:50:21 +00:00
"""Migrate LangChain to the most recent version."""
# Adapted from bump-pydantic
# https://github.com/pydantic/bump-pydantic
import difflib
import functools
import multiprocessing
import os
import time
import traceback
from pathlib import Path
2024-04-29 14:11:21 +00:00
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union
2024-04-26 14:50:21 +00:00
import libcst as cst
import typer
from libcst.codemod import CodemodContext, ContextAwareTransformer
from libcst.helpers import calculate_module_and_package
from libcst.metadata import FullRepoManager, FullyQualifiedNameProvider, ScopeProvider
from rich.console import Console
from rich.progress import Progress
from typer import Argument, Exit, Option, Typer
from typing_extensions import ParamSpec
from langchain_cli.namespaces.migrate.codemods import Rule, gather_codemods
from langchain_cli.namespaces.migrate.glob_helpers import match_glob
app = Typer(invoke_without_command=True, add_completion=False)
P = ParamSpec("P")
T = TypeVar("T")
DEFAULT_IGNORES = [".venv/**"]
def main(
path: Path = Argument(..., exists=True, dir_okay=True, allow_dash=False),
disable: List[Rule] = Option(default=[], help="Disable a rule."),
diff: bool = Option(False, help="Show diff instead of applying changes."),
ignore: List[str] = Option(
default=DEFAULT_IGNORES, help="Ignore a path glob pattern."
log_file: Path = Option("log.txt", help="Log errors to this file."),
2024-04-29 14:11:21 +00:00
include_ipynb: bool = Option(
False, help="Include Jupyter Notebook files in the migration."
2024-04-26 14:50:21 +00:00
"""Migrate langchain to the most recent version."""
if not diff:
if not typer.confirm(
2024-04-29 16:20:42 +00:00
"✈️ This script will help you migrate to a recent version LangChain. "
"This migration script will attempt to replace old imports in the code "
"with new ones.\n\n"
"🔄 You will need to run the migration script TWICE to migrate (e.g., "
"to update llms import from langchain, the script will first move them to "
"corresponding imports from the community package, and on the second "
"run will migrate from the community package to the partner package "
"when possible). \n\n"
"🔍 You can pre-view the changes by running with the --diff flag. \n\n"
"🚫 You can disable specific import changes by using the --disable "
"flag. \n\n"
"📄 Update your pyproject.toml or requirements.txt file to "
"reflect any imports from new packages. For example, if you see new "
"imports from langchain_openai, langchain_anthropic or "
"langchain_text_splitters you "
"should them to your dependencies! \n\n"
'⚠️ This script is a "best-effort", and is likely to make some '
"🛡️ Backup your code prior to running the migration script -- it will "
"modify your files!\n\n"
"❓ Do you want to continue?"
2024-04-26 14:50:21 +00:00
raise Exit()
console = Console(log_time=True)
console.log("Start langchain-cli migrate")
# NOTE: LIBCST_PARSER_TYPE=native is required according to https://github.com/Instagram/LibCST/issues/487.
os.environ["LIBCST_PARSER_TYPE"] = "native"
if os.path.isfile(path):
package = path.parent
all_files = [path]
package = path
all_files = sorted(package.glob("**/*.py"))
2024-04-29 14:11:21 +00:00
if include_ipynb:
2024-04-26 14:50:21 +00:00
filtered_files = [
for file in all_files
if not any(match_glob(file, pattern) for pattern in ignore)
files = [str(file.relative_to(".")) for file in filtered_files]
if len(files) == 1:
console.log("Found 1 file to process.")
elif len(files) > 1:
console.log(f"Found {len(files)} files to process.")
console.log("No files to process.")
raise Exit()
providers = {FullyQualifiedNameProvider, ScopeProvider}
metadata_manager = FullRepoManager(".", files, providers=providers) # type: ignore[arg-type]
scratch: dict[str, Any] = {}
start_time = time.time()
log_fp = log_file.open("a+", encoding="utf8")
partial_run_codemods = functools.partial(
2024-04-29 14:11:21 +00:00
get_and_run_codemods, disable, metadata_manager, scratch, package, diff
2024-04-26 14:50:21 +00:00
with Progress(*Progress.get_default_columns(), transient=True) as progress:
task = progress.add_task(description="Executing codemods...", total=len(files))
count_errors = 0
difflines: List[List[str]] = []
with multiprocessing.Pool() as pool:
for error, _difflines in pool.imap_unordered(partial_run_codemods, files):
if _difflines is not None:
if error is not None:
count_errors += 1
modified = [Path(f) for f in files if os.stat(f).st_mtime > start_time]
if not diff:
if modified:
console.log(f"Refactored {len(modified)} files.")
console.log("No files were modified.")
for _difflines in difflines:
color_diff(console, _difflines)
if count_errors > 0:
console.log(f"Found {count_errors} errors. Please check the {log_file} file.")
console.log("Run successfully!")
if difflines:
raise Exit(1)
2024-04-29 14:11:21 +00:00
def get_and_run_codemods(
disabled_rules: List[Rule],
2024-04-26 14:50:21 +00:00
metadata_manager: FullRepoManager,
scratch: Dict[str, Any],
package: Path,
diff: bool,
filename: str,
) -> Tuple[Union[str, None], Union[List[str], None]]:
2024-04-29 14:11:21 +00:00
"""Run codemods from rules.
Wrapper around run_codemods to be used with multiprocessing.Pool.
codemods = gather_codemods(disabled=disabled_rules)
return run_codemods(codemods, metadata_manager, scratch, package, diff, filename)
def _rewrite_file(
filename: str,
codemods: List[Type[ContextAwareTransformer]],
diff: bool,
context: CodemodContext,
) -> Tuple[Union[str, None], Union[List[str], None]]:
file_path = Path(filename)
with file_path.open("r+", encoding="utf-8") as fp:
code = fp.read()
input_tree = cst.parse_module(code)
for codemod in codemods:
transformer = codemod(context=context)
output_tree = transformer.transform_module(input_tree)
input_tree = output_tree
output_code = input_tree.code
if code != output_code:
if diff:
lines = difflib.unified_diff(
return None, list(lines)
return None, None
def _rewrite_notebook(
filename: str,
codemods: List[Type[ContextAwareTransformer]],
diff: bool,
context: CodemodContext,
) -> Tuple[Optional[str], Optional[List[str]]]:
"""Try to rewrite a Jupyter Notebook file."""
import nbformat
file_path = Path(filename)
if file_path.suffix != ".ipynb":
raise ValueError("Only Jupyter Notebook files (.ipynb) are supported.")
2024-04-26 14:50:21 +00:00
2024-04-29 14:11:21 +00:00
with file_path.open("r", encoding="utf-8") as fp:
notebook = nbformat.read(fp, as_version=4)
diffs = []
for cell in notebook.cells:
if cell.cell_type == "code":
code = "".join(cell.source)
# Skip code if any of the lines begin with a magic command or
# a ! command.
# We can try to handle later.
if any(
line.startswith("!") or line.startswith("%")
for line in code.splitlines()
2024-04-26 14:50:21 +00:00
input_tree = cst.parse_module(code)
2024-04-29 14:11:21 +00:00
# TODO(Team): Quick hack, need to figure out
# how to handle this correctly.
# This prevents the code from trying to re-insert the imports
# for every cell in the notebook.
local_context = CodemodContext()
2024-04-26 14:50:21 +00:00
for codemod in codemods:
2024-04-29 14:11:21 +00:00
transformer = codemod(context=local_context)
2024-04-26 14:50:21 +00:00
output_tree = transformer.transform_module(input_tree)
input_tree = output_tree
output_code = input_tree.code
if code != output_code:
2024-04-29 14:11:21 +00:00
cell.source = output_code.splitlines(keepends=True)
2024-04-26 14:50:21 +00:00
if diff:
2024-04-29 14:11:21 +00:00
cell_diff = difflib.unified_diff(
2024-04-26 14:50:21 +00:00
2024-04-29 14:11:21 +00:00
if diff:
return None, diffs
with file_path.open("w", encoding="utf-8") as fp:
nbformat.write(notebook, fp)
return None, None
def run_codemods(
codemods: List[Type[ContextAwareTransformer]],
metadata_manager: FullRepoManager,
scratch: Dict[str, Any],
package: Path,
diff: bool,
filename: str,
) -> Tuple[Union[str, None], Union[List[str], None]]:
module_and_package = calculate_module_and_package(str(package), filename)
context = CodemodContext(
if filename.endswith(".ipynb"):
return _rewrite_notebook(filename, codemods, diff, context)
return _rewrite_file(filename, codemods, diff, context)
2024-04-26 14:50:21 +00:00
except cst.ParserSyntaxError as exc:
return (
f"A syntax error happened on {filename}. This file cannot be "
), None
except Exception:
return f"An error happened on {filename}.\n{traceback.format_exc()}", None
def color_diff(console: Console, lines: Iterable[str]) -> None:
for line in lines:
line = line.rstrip("\n")
if line.startswith("+"):
console.print(line, style="green")
elif line.startswith("-"):
console.print(line, style="red")
elif line.startswith("^"):
console.print(line, style="blue")
console.print(line, style="white")