mirror of https://github.com/hwchase17/langchain
Compare commits
23 Commits
c262cef1fb
...
5fb7d3b4ba
Author | SHA1 | Date |
---|---|---|
Chester Curme | 5fb7d3b4ba | 3 weeks ago |
Chester Curme | 402298e376 | 3 weeks ago |
ccurme | 989e4a92c2 | 3 weeks ago |
Eugene Yurtsev | 2fa0ff1a2d | 3 weeks ago |
Chester Curme | 267ee9db4c | 3 weeks ago |
Erick Friis | 078c5d9bc6 | 3 weeks ago |
Leonid Kuligin | d4aec8fc8f | 3 weeks ago |
Chester Curme | bc7af5fd7e | 3 weeks ago |
Chester Curme | abf1f4c124 | 3 weeks ago |
ccurme | bf16cefd18 | 3 weeks ago |
Chester Curme | c9fc0447ec | 3 weeks ago |
Erick Friis | 38eccab3ae | 3 weeks ago |
Sean | e1c2e2fdfa | 3 weeks ago |
ccurme | 84b8e67c9c | 3 weeks ago |
ccurme | 465fbaa30b | 3 weeks ago |
Eugene Yurtsev | 12c906f6ce | 3 weeks ago |
Eugene Yurtsev | 5653f36adc | 3 weeks ago |
ccurme | fe1304afc4 | 3 weeks ago |
Eugene Yurtsev | 6598757037 | 3 weeks ago |
Pengcheng Liu | d95e9fb67f | 3 weeks ago |
Lei Zhang | 9281841cfe | 3 weeks ago |
ccurme | 7d8d0229fa | 3 weeks ago |
William FH | 4c437ebb9c | 3 weeks ago |
@ -0,0 +1,25 @@
|
||||
from enum import Enum
|
||||
from typing import List, Type
|
||||
|
||||
from libcst.codemod import ContextAwareTransformer
|
||||
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor
|
||||
|
||||
from langchain_cli.namespaces.migrate.codemods.replace_imports import (
|
||||
ReplaceImportsCodemod,
|
||||
)
|
||||
|
||||
|
||||
class Rule(str, Enum):
|
||||
R001 = "R001"
|
||||
"""Replace imports that have been moved."""
|
||||
|
||||
|
||||
def gather_codemods(disabled: List[Rule]) -> List[Type[ContextAwareTransformer]]:
|
||||
codemods: List[Type[ContextAwareTransformer]] = []
|
||||
|
||||
if Rule.R001 not in disabled:
|
||||
codemods.append(ReplaceImportsCodemod)
|
||||
|
||||
# Those codemods need to be the last ones.
|
||||
codemods.extend([RemoveImportsVisitor, AddImportsVisitor])
|
||||
return codemods
|
@ -0,0 +1,18 @@
|
||||
[
|
||||
[
|
||||
"langchain_community.llms.anthropic.Anthropic",
|
||||
"langchain_anthropic.Anthropic"
|
||||
],
|
||||
[
|
||||
"langchain_community.chat_models.anthropic.ChatAnthropic",
|
||||
"langchain_anthropic.ChatAnthropic"
|
||||
],
|
||||
[
|
||||
"langchain_community.llms.Anthropic",
|
||||
"langchain_anthropic.Anthropic"
|
||||
],
|
||||
[
|
||||
"langchain_community.chat_models.ChatAnthropic",
|
||||
"langchain_anthropic.ChatAnthropic"
|
||||
]
|
||||
]
|
@ -0,0 +1,18 @@
|
||||
[
|
||||
[
|
||||
"langchain_community.llms.fireworks.Fireworks",
|
||||
"langchain_fireworks.Fireworks"
|
||||
],
|
||||
[
|
||||
"langchain_community.chat_models.fireworks.ChatFireworks",
|
||||
"langchain_fireworks.ChatFireworks"
|
||||
],
|
||||
[
|
||||
"langchain_community.llms.Fireworks",
|
||||
"langchain_fireworks.Fireworks"
|
||||
],
|
||||
[
|
||||
"langchain_community.chat_models.ChatFireworks",
|
||||
"langchain_fireworks.ChatFireworks"
|
||||
]
|
||||
]
|
@ -0,0 +1,10 @@
|
||||
[
|
||||
[
|
||||
"langchain_community.llms.watsonxllm.WatsonxLLM",
|
||||
"langchain_ibm.WatsonxLLM"
|
||||
],
|
||||
[
|
||||
"langchain_community.llms.WatsonxLLM",
|
||||
"langchain_ibm.WatsonxLLM"
|
||||
]
|
||||
]
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,50 @@
|
||||
[
|
||||
[
|
||||
"langchain_community.llms.openai.OpenAI",
|
||||
"langchain_openai.OpenAI"
|
||||
],
|
||||
[
|
||||
"langchain_community.llms.openai.AzureOpenAI",
|
||||
"langchain_openai.AzureOpenAI"
|
||||
],
|
||||
[
|
||||
"langchain_community.embeddings.openai.OpenAIEmbeddings",
|
||||
"langchain_openai.OpenAIEmbeddings"
|
||||
],
|
||||
[
|
||||
"langchain_community.embeddings.azure_openai.AzureOpenAIEmbeddings",
|
||||
"langchain_openai.AzureOpenAIEmbeddings"
|
||||
],
|
||||
[
|
||||
"langchain_community.chat_models.openai.ChatOpenAI",
|
||||
"langchain_openai.ChatOpenAI"
|
||||
],
|
||||
[
|
||||
"langchain_community.chat_models.azure_openai.AzureChatOpenAI",
|
||||
"langchain_openai.AzureChatOpenAI"
|
||||
],
|
||||
[
|
||||
"langchain_community.llms.AzureOpenAI",
|
||||
"langchain_openai.AzureOpenAI"
|
||||
],
|
||||
[
|
||||
"langchain_community.llms.OpenAI",
|
||||
"langchain_openai.OpenAI"
|
||||
],
|
||||
[
|
||||
"langchain_community.embeddings.AzureOpenAIEmbeddings",
|
||||
"langchain_openai.AzureOpenAIEmbeddings"
|
||||
],
|
||||
[
|
||||
"langchain_community.embeddings.OpenAIEmbeddings",
|
||||
"langchain_openai.OpenAIEmbeddings"
|
||||
],
|
||||
[
|
||||
"langchain_community.chat_models.AzureChatOpenAI",
|
||||
"langchain_openai.AzureChatOpenAI"
|
||||
],
|
||||
[
|
||||
"langchain_community.chat_models.ChatOpenAI",
|
||||
"langchain_openai.ChatOpenAI"
|
||||
]
|
||||
]
|
@ -0,0 +1,10 @@
|
||||
[
|
||||
[
|
||||
"langchain_community.vectorstores.pinecone.Pinecone",
|
||||
"langchain_pinecone.Pinecone"
|
||||
],
|
||||
[
|
||||
"langchain_community.vectorstores.Pinecone",
|
||||
"langchain_pinecone.Pinecone"
|
||||
]
|
||||
]
|
@ -0,0 +1,214 @@
|
||||
"""
|
||||
# Adapted from bump-pydantic
|
||||
# https://github.com/pydantic/bump-pydantic
|
||||
|
||||
This codemod deals with the following cases:
|
||||
|
||||
1. `from pydantic import BaseSettings`
|
||||
2. `from pydantic.settings import BaseSettings`
|
||||
3. `from pydantic import BaseSettings as <name>`
|
||||
4. `from pydantic.settings import BaseSettings as <name>` # TODO: This is not working.
|
||||
5. `import pydantic` -> `pydantic.BaseSettings`
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, Iterable, List, Sequence, Tuple, TypeVar
|
||||
|
||||
import libcst as cst
|
||||
import libcst.matchers as m
|
||||
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
|
||||
from libcst.codemod.visitors import AddImportsVisitor
|
||||
|
||||
HERE = os.path.dirname(__file__)
|
||||
|
||||
|
||||
def _load_migrations_by_file(path: str):
|
||||
migrations_path = os.path.join(HERE, "migrations", path)
|
||||
with open(migrations_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return data
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _deduplicate_in_order(
|
||||
seq: Iterable[T], key: Callable[[T], str] = lambda x: x
|
||||
) -> List[T]:
|
||||
seen = set()
|
||||
seen_add = seen.add
|
||||
return [x for x in seq if not (key(x) in seen or seen_add(key(x)))]
|
||||
|
||||
|
||||
PARTNERS = [
|
||||
"anthropic.json",
|
||||
"ibm.json",
|
||||
"openai.json",
|
||||
"pinecone.json",
|
||||
"fireworks.json",
|
||||
]
|
||||
|
||||
|
||||
def _load_migrations_from_fixtures() -> List[Tuple[str, str]]:
|
||||
"""Load migrations from fixtures."""
|
||||
paths: List[str] = PARTNERS + ["langchain.json"]
|
||||
data = []
|
||||
for path in paths:
|
||||
data.extend(_load_migrations_by_file(path))
|
||||
data = _deduplicate_in_order(data, key=lambda x: x[0])
|
||||
return data
|
||||
|
||||
|
||||
def _load_migrations():
|
||||
"""Load the migrations from the JSON file."""
|
||||
# Later earlier ones have higher precedence.
|
||||
imports: Dict[str, Tuple[str, str]] = {}
|
||||
data = _load_migrations_from_fixtures()
|
||||
|
||||
for old_path, new_path in data:
|
||||
# Parse the old parse which is of the format 'langchain.chat_models.ChatOpenAI'
|
||||
# into the module and class name.
|
||||
old_parts = old_path.split(".")
|
||||
old_module = ".".join(old_parts[:-1])
|
||||
old_class = old_parts[-1]
|
||||
old_path_str = f"{old_module}:{old_class}"
|
||||
|
||||
# Parse the new parse which is of the format 'langchain.chat_models.ChatOpenAI'
|
||||
# Into a 2-tuple of the module and class name.
|
||||
new_parts = new_path.split(".")
|
||||
new_module = ".".join(new_parts[:-1])
|
||||
new_class = new_parts[-1]
|
||||
new_path_str = (new_module, new_class)
|
||||
|
||||
imports[old_path_str] = new_path_str
|
||||
|
||||
return imports
|
||||
|
||||
|
||||
IMPORTS = _load_migrations()
|
||||
|
||||
|
||||
def resolve_module_parts(module_parts: list[str]) -> m.Attribute | m.Name:
|
||||
"""Converts a list of module parts to a `Name` or `Attribute` node."""
|
||||
if len(module_parts) == 1:
|
||||
return m.Name(module_parts[0])
|
||||
if len(module_parts) == 2:
|
||||
first, last = module_parts
|
||||
return m.Attribute(value=m.Name(first), attr=m.Name(last))
|
||||
last_name = module_parts.pop()
|
||||
attr = resolve_module_parts(module_parts)
|
||||
return m.Attribute(value=attr, attr=m.Name(last_name))
|
||||
|
||||
|
||||
def get_import_from_from_str(import_str: str) -> m.ImportFrom:
|
||||
"""Converts a string like `pydantic:BaseSettings` to Examples:
|
||||
>>> get_import_from_from_str("pydantic:BaseSettings")
|
||||
ImportFrom(
|
||||
module=Name("pydantic"),
|
||||
names=[ImportAlias(name=Name("BaseSettings"))],
|
||||
)
|
||||
>>> get_import_from_from_str("pydantic.settings:BaseSettings")
|
||||
ImportFrom(
|
||||
module=Attribute(value=Name("pydantic"), attr=Name("settings")),
|
||||
names=[ImportAlias(name=Name("BaseSettings"))],
|
||||
)
|
||||
>>> get_import_from_from_str("a.b.c:d")
|
||||
ImportFrom(
|
||||
module=Attribute(
|
||||
value=Attribute(value=Name("a"), attr=Name("b")), attr=Name("c")
|
||||
),
|
||||
names=[ImportAlias(name=Name("d"))],
|
||||
)
|
||||
"""
|
||||
module, name = import_str.split(":")
|
||||
module_parts = module.split(".")
|
||||
module_node = resolve_module_parts(module_parts)
|
||||
return m.ImportFrom(
|
||||
module=module_node,
|
||||
names=[m.ZeroOrMore(), m.ImportAlias(name=m.Name(value=name)), m.ZeroOrMore()],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImportInfo:
|
||||
import_from: m.ImportFrom
|
||||
import_str: str
|
||||
to_import_str: tuple[str, str]
|
||||
|
||||
|
||||
IMPORT_INFOS = [
|
||||
ImportInfo(
|
||||
import_from=get_import_from_from_str(import_str),
|
||||
import_str=import_str,
|
||||
to_import_str=to_import_str,
|
||||
)
|
||||
for import_str, to_import_str in IMPORTS.items()
|
||||
]
|
||||
IMPORT_MATCH = m.OneOf(*[info.import_from for info in IMPORT_INFOS])
|
||||
|
||||
|
||||
class ReplaceImportsCodemod(VisitorBasedCodemodCommand):
|
||||
@m.leave(IMPORT_MATCH)
|
||||
def leave_replace_import(
|
||||
self, _: cst.ImportFrom, updated_node: cst.ImportFrom
|
||||
) -> cst.ImportFrom:
|
||||
for import_info in IMPORT_INFOS:
|
||||
if m.matches(updated_node, import_info.import_from):
|
||||
aliases: Sequence[cst.ImportAlias] = updated_node.names # type: ignore
|
||||
# If multiple objects are imported in a single import statement,
|
||||
# we need to remove only the one we're replacing.
|
||||
AddImportsVisitor.add_needed_import(
|
||||
self.context, *import_info.to_import_str
|
||||
)
|
||||
if len(updated_node.names) > 1: # type: ignore
|
||||
names = [
|
||||
alias
|
||||
for alias in aliases
|
||||
if alias.name.value != import_info.to_import_str[-1]
|
||||
]
|
||||
names[-1] = names[-1].with_changes(comma=cst.MaybeSentinel.DEFAULT)
|
||||
updated_node = updated_node.with_changes(names=names)
|
||||
else:
|
||||
return cst.RemoveFromParent() # type: ignore[return-value]
|
||||
return updated_node
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import textwrap
|
||||
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
|
||||
source = textwrap.dedent(
|
||||
"""
|
||||
from pydantic.settings import BaseSettings
|
||||
from pydantic.color import Color
|
||||
from pydantic.payment import PaymentCardNumber, PaymentCardBrand
|
||||
from pydantic import Color
|
||||
from pydantic import Color as Potato
|
||||
|
||||
|
||||
class Potato(BaseSettings):
|
||||
color: Color
|
||||
payment: PaymentCardNumber
|
||||
brand: PaymentCardBrand
|
||||
potato: Potato
|
||||
"""
|
||||
)
|
||||
console.print(source)
|
||||
console.print("=" * 80)
|
||||
|
||||
mod = cst.parse_module(source)
|
||||
context = CodemodContext(filename="main.py")
|
||||
wrapper = cst.MetadataWrapper(mod)
|
||||
command = ReplaceImportsCodemod(context=context)
|
||||
|
||||
mod = wrapper.visit(command)
|
||||
wrapper = cst.MetadataWrapper(mod)
|
||||
command = AddImportsVisitor(context=context) # type: ignore[assignment]
|
||||
mod = wrapper.visit(command)
|
||||
console.print(mod.code)
|
@ -0,0 +1,129 @@
|
||||
"""Generate migrations from langchain to langchain-community or core packages."""
|
||||
import importlib
|
||||
import inspect
|
||||
import pkgutil
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
def generate_raw_migrations_to_community() -> List[Tuple[str, str]]:
|
||||
"""Scan the `langchain` package and generate migrations for all modules."""
|
||||
import langchain as package
|
||||
|
||||
to_package = "langchain_community"
|
||||
|
||||
items = []
|
||||
for importer, modname, ispkg in pkgutil.walk_packages(
|
||||
package.__path__, package.__name__ + "."
|
||||
):
|
||||
try:
|
||||
module = importlib.import_module(modname)
|
||||
except ModuleNotFoundError:
|
||||
continue
|
||||
|
||||
# Check if the module is an __init__ file and evaluate __all__
|
||||
try:
|
||||
has_all = hasattr(module, "__all__")
|
||||
except ImportError:
|
||||
has_all = False
|
||||
|
||||
if has_all:
|
||||
all_objects = module.__all__
|
||||
for name in all_objects:
|
||||
# Attempt to fetch each object declared in __all__
|
||||
try:
|
||||
obj = getattr(module, name, None)
|
||||
except ImportError:
|
||||
continue
|
||||
if obj and (inspect.isclass(obj) or inspect.isfunction(obj)):
|
||||
if obj.__module__.startswith(to_package):
|
||||
items.append((f"{modname}.{name}", f"{obj.__module__}.{name}"))
|
||||
|
||||
# Iterate over all members of the module
|
||||
for name, obj in inspect.getmembers(module):
|
||||
# Check if it's a class or function
|
||||
if inspect.isclass(obj) or inspect.isfunction(obj):
|
||||
# Check if the module name of the obj starts with 'langchain_community'
|
||||
if obj.__module__.startswith(to_package):
|
||||
items.append((f"{modname}.{name}", f"{obj.__module__}.{name}"))
|
||||
|
||||
return items
|
||||
|
||||
|
||||
def generate_top_level_imports_community() -> List[Tuple[str, str]]:
|
||||
"""This code will look at all the top level modules in langchain_community.
|
||||
|
||||
It'll attempt to import everything from each __init__ file
|
||||
|
||||
for example,
|
||||
|
||||
langchain_community/
|
||||
chat_models/
|
||||
__init__.py # <-- import everything from here
|
||||
llm/
|
||||
__init__.py # <-- import everything from here
|
||||
|
||||
|
||||
It'll collect all the imports, import the classes / functions it can find
|
||||
there. It'll return a list of 2-tuples
|
||||
|
||||
Each tuple will contain the fully qualified path of the class / function to where
|
||||
its logic is defined
|
||||
(e.g., langchain_community.chat_models.xyz_implementation.ver2.XYZ)
|
||||
and the second tuple will contain the path
|
||||
to importing it from the top level namespaces
|
||||
(e.g., langchain_community.chat_models.XYZ)
|
||||
"""
|
||||
import langchain_community as package
|
||||
|
||||
items = []
|
||||
# Only iterate through top-level modules/packages
|
||||
for finder, modname, ispkg in pkgutil.iter_modules(
|
||||
package.__path__, package.__name__ + "."
|
||||
):
|
||||
if ispkg:
|
||||
try:
|
||||
module = importlib.import_module(modname)
|
||||
except ModuleNotFoundError:
|
||||
continue
|
||||
|
||||
if hasattr(module, "__all__"):
|
||||
all_objects = getattr(module, "__all__")
|
||||
for name in all_objects:
|
||||
# Attempt to fetch each object declared in __all__
|
||||
obj = getattr(module, name, None)
|
||||
if obj and (inspect.isclass(obj) or inspect.isfunction(obj)):
|
||||
# Capture the fully qualified name of the object
|
||||
original_module = obj.__module__
|
||||
original_name = obj.__name__
|
||||
# Form the new import path from the top-level namespace
|
||||
top_level_import = f"{modname}.{name}"
|
||||
# Append the tuple with original and top-level paths
|
||||
items.append(
|
||||
(f"{original_module}.{original_name}", top_level_import)
|
||||
)
|
||||
|
||||
return items
|
||||
|
||||
|
||||
def generate_simplified_migrations() -> List[Tuple[str, str]]:
|
||||
"""Get all the raw migrations, then simplify them if possible."""
|
||||
raw_migrations = generate_raw_migrations_to_community()
|
||||
top_level_simplifications = generate_top_level_imports_community()
|
||||
top_level_dict = {full: top_level for full, top_level in top_level_simplifications}
|
||||
simple_migrations = []
|
||||
for migration in raw_migrations:
|
||||
original, new = migration
|
||||
replacement = top_level_dict.get(new, new)
|
||||
simple_migrations.append((original, replacement))
|
||||
|
||||
# Now let's deduplicate the list based on the original path (which is
|
||||
# the 1st element of the tuple)
|
||||
deduped_migrations = []
|
||||
seen = set()
|
||||
for migration in simple_migrations:
|
||||
original = migration[0]
|
||||
if original not in seen:
|
||||
deduped_migrations.append(migration)
|
||||
seen.add(original)
|
||||
|
||||
return deduped_migrations
|
@ -0,0 +1,54 @@
|
||||
"""Generate migrations for partner packages."""
|
||||
import importlib
|
||||
from typing import List, Tuple
|
||||
|
||||
from langchain_core.documents import BaseDocumentCompressor, BaseDocumentTransformer
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langchain_cli.namespaces.migrate.generate.utils import (
|
||||
COMMUNITY_PKG,
|
||||
find_subclasses_in_module,
|
||||
list_classes_by_package,
|
||||
list_init_imports_by_package,
|
||||
)
|
||||
|
||||
# PUBLIC API
|
||||
|
||||
|
||||
def get_migrations_for_partner_package(pkg_name: str) -> List[Tuple[str, str]]:
|
||||
"""Generate migrations from community package to partner package.
|
||||
|
||||
This code works
|
||||
|
||||
Args:
|
||||
pkg_name (str): The name of the partner package.
|
||||
|
||||
Returns:
|
||||
List of 2-tuples containing old and new import paths.
|
||||
"""
|
||||
package = importlib.import_module(pkg_name)
|
||||
classes_ = find_subclasses_in_module(
|
||||
package,
|
||||
[
|
||||
BaseLanguageModel,
|
||||
Embeddings,
|
||||
BaseRetriever,
|
||||
VectorStore,
|
||||
BaseDocumentTransformer,
|
||||
BaseDocumentCompressor,
|
||||
],
|
||||
)
|
||||
community_classes = list_classes_by_package(str(COMMUNITY_PKG))
|
||||
imports_for_pkg = list_init_imports_by_package(str(COMMUNITY_PKG))
|
||||
|
||||
old_paths = community_classes + imports_for_pkg
|
||||
|
||||
migrations = [
|
||||
(f"{module}.{item}", f"{pkg_name}.{item}")
|
||||
for module, item in old_paths
|
||||
if item in classes_
|
||||
]
|
||||
return migrations
|
@ -0,0 +1,158 @@
|
||||
import ast
|
||||
import inspect
|
||||
import os
|
||||
import pathlib
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional, Tuple, Type
|
||||
|
||||
HERE = Path(__file__).parent
|
||||
# Should bring us to [root]/src
|
||||
PKGS_ROOT = HERE.parent.parent.parent.parent.parent
|
||||
|
||||
LANGCHAIN_PKG = PKGS_ROOT / "langchain"
|
||||
COMMUNITY_PKG = PKGS_ROOT / "community"
|
||||
PARTNER_PKGS = PKGS_ROOT / "partners"
|
||||
|
||||
|
||||
class ImportExtractor(ast.NodeVisitor):
|
||||
def __init__(self, *, from_package: Optional[str] = None) -> None:
|
||||
"""Extract all imports from the given code, optionally filtering by package."""
|
||||
self.imports = []
|
||||
self.package = from_package
|
||||
|
||||
def visit_ImportFrom(self, node):
|
||||
if node.module and (
|
||||
self.package is None or str(node.module).startswith(self.package)
|
||||
):
|
||||
for alias in node.names:
|
||||
self.imports.append((node.module, alias.name))
|
||||
self.generic_visit(node)
|
||||
|
||||
|
||||
def _get_class_names(code: str) -> List[str]:
|
||||
"""Extract class names from a code string."""
|
||||
# Parse the content of the file into an AST
|
||||
tree = ast.parse(code)
|
||||
|
||||
# Initialize a list to hold all class names
|
||||
class_names = []
|
||||
|
||||
# Define a node visitor class to collect class names
|
||||
class ClassVisitor(ast.NodeVisitor):
|
||||
def visit_ClassDef(self, node):
|
||||
class_names.append(node.name)
|
||||
self.generic_visit(node)
|
||||
|
||||
# Create an instance of the visitor and visit the AST
|
||||
visitor = ClassVisitor()
|
||||
visitor.visit(tree)
|
||||
return class_names
|
||||
|
||||
|
||||
def is_subclass(class_obj: Any, classes_: List[Type]) -> bool:
|
||||
"""Check if the given class object is a subclass of any class in list classes."""
|
||||
return any(
|
||||
issubclass(class_obj, kls)
|
||||
for kls in classes_
|
||||
if inspect.isclass(class_obj) and inspect.isclass(kls)
|
||||
)
|
||||
|
||||
|
||||
def find_subclasses_in_module(module, classes_: List[Type]) -> List[str]:
|
||||
"""Find all classes in the module that inherit from one of the classes."""
|
||||
subclasses = []
|
||||
# Iterate over all attributes of the module that are classes
|
||||
for name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if is_subclass(obj, classes_):
|
||||
subclasses.append(obj.__name__)
|
||||
return subclasses
|
||||
|
||||
|
||||
def _get_all_classnames_from_file(file: str, pkg: str) -> List[Tuple[str, str]]:
|
||||
"""Extract all class names from a file."""
|
||||
with open(file, encoding="utf-8") as f:
|
||||
code = f.read()
|
||||
module_name = _get_current_module(file, pkg)
|
||||
class_names = _get_class_names(code)
|
||||
|
||||
return [(module_name, class_name) for class_name in class_names]
|
||||
|
||||
|
||||
def identify_all_imports_in_file(
|
||||
file: str, *, from_package: Optional[str] = None
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""Let's also identify all the imports in the given file."""
|
||||
with open(file, encoding="utf-8") as f:
|
||||
code = f.read()
|
||||
return find_imports_from_package(code, from_package=from_package)
|
||||
|
||||
|
||||
def identify_pkg_source(pkg_root: str) -> pathlib.Path:
|
||||
"""Identify the source of the package.
|
||||
|
||||
Args:
|
||||
pkg_root: the root of the package. This contains source + tests, and other
|
||||
things like pyproject.toml, lock files etc
|
||||
|
||||
Returns:
|
||||
Returns the path to the source code for the package.
|
||||
"""
|
||||
dirs = [d for d in Path(pkg_root).iterdir() if d.is_dir()]
|
||||
matching_dirs = [d for d in dirs if d.name.startswith("langchain_")]
|
||||
assert len(matching_dirs) == 1, "There should be only one langchain package."
|
||||
return matching_dirs[0]
|
||||
|
||||
|
||||
def list_classes_by_package(pkg_root: str) -> List[Tuple[str, str]]:
|
||||
"""List all classes in a package."""
|
||||
module_classes = []
|
||||
pkg_source = identify_pkg_source(pkg_root)
|
||||
files = list(pkg_source.rglob("*.py"))
|
||||
|
||||
for file in files:
|
||||
rel_path = os.path.relpath(file, pkg_root)
|
||||
if rel_path.startswith("tests"):
|
||||
continue
|
||||
module_classes.extend(_get_all_classnames_from_file(file, pkg_root))
|
||||
return module_classes
|
||||
|
||||
|
||||
def list_init_imports_by_package(pkg_root: str) -> List[Tuple[str, str]]:
|
||||
"""List all the things that are being imported in a package by module."""
|
||||
imports = []
|
||||
pkg_source = identify_pkg_source(pkg_root)
|
||||
# Scan all the files in the package
|
||||
files = list(Path(pkg_source).rglob("*.py"))
|
||||
|
||||
for file in files:
|
||||
if not file.name == "__init__.py":
|
||||
continue
|
||||
import_in_file = identify_all_imports_in_file(str(file))
|
||||
module_name = _get_current_module(file, pkg_root)
|
||||
imports.extend([(module_name, item) for _, item in import_in_file])
|
||||
return imports
|
||||
|
||||
|
||||
def find_imports_from_package(
|
||||
code: str, *, from_package: Optional[str] = None
|
||||
) -> List[Tuple[str, str]]:
|
||||
# Parse the code into an AST
|
||||
tree = ast.parse(code)
|
||||
# Create an instance of the visitor
|
||||
extractor = ImportExtractor(from_package=from_package)
|
||||
# Use the visitor to update the imports list
|
||||
extractor.visit(tree)
|
||||
return extractor.imports
|
||||
|
||||
|
||||
def _get_current_module(path: str, pkg_root: str) -> str:
|
||||
"""Convert a path to a module name."""
|
||||
path_as_pathlib = pathlib.Path(os.path.abspath(path))
|
||||
relative_path = path_as_pathlib.relative_to(pkg_root).with_suffix("")
|
||||
posix_path = relative_path.as_posix()
|
||||
norm_path = os.path.normpath(str(posix_path))
|
||||
fully_qualified_module = norm_path.replace("/", ".")
|
||||
# Strip __init__ if present
|
||||
if fully_qualified_module.endswith(".__init__"):
|
||||
return fully_qualified_module[:-9]
|
||||
return fully_qualified_module
|
@ -0,0 +1,52 @@
|
||||
# Adapted from bump-pydantic
|
||||
# https://github.com/pydantic/bump-pydantic
|
||||
import fnmatch
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
MATCH_SEP = r"(?:/|\\)"
|
||||
MATCH_SEP_OR_END = r"(?:/|\\|\Z)"
|
||||
MATCH_NON_RECURSIVE = r"[^/\\]*"
|
||||
MATCH_RECURSIVE = r"(?:.*)"
|
||||
|
||||
|
||||
def glob_to_re(pattern: str) -> str:
|
||||
"""Translate a glob pattern to a regular expression for matching."""
|
||||
fragments: List[str] = []
|
||||
for segment in re.split(r"/|\\", pattern):
|
||||
if segment == "":
|
||||
continue
|
||||
if segment == "**":
|
||||
# Remove previous separator match, so the recursive match c
|
||||
# can match zero or more segments.
|
||||
if fragments and fragments[-1] == MATCH_SEP:
|
||||
fragments.pop()
|
||||
fragments.append(MATCH_RECURSIVE)
|
||||
elif "**" in segment:
|
||||
raise ValueError(
|
||||
"invalid pattern: '**' can only be an entire path component"
|
||||
)
|
||||
else:
|
||||
fragment = fnmatch.translate(segment)
|
||||
fragment = fragment.replace(r"(?s:", r"(?:")
|
||||
fragment = fragment.replace(r".*", MATCH_NON_RECURSIVE)
|
||||
fragment = fragment.replace(r"\Z", r"")
|
||||
fragments.append(fragment)
|
||||
fragments.append(MATCH_SEP)
|
||||
# Remove trailing MATCH_SEP, so it can be replaced with MATCH_SEP_OR_END.
|
||||
if fragments and fragments[-1] == MATCH_SEP:
|
||||
fragments.pop()
|
||||
fragments.append(MATCH_SEP_OR_END)
|
||||
return rf"(?s:{''.join(fragments)})"
|
||||
|
||||
|
||||
def match_glob(path: Path, pattern: str) -> bool:
|
||||
"""Check if a path matches a glob pattern.
|
||||
|
||||
If the pattern ends with a directory separator, the path must be a directory.
|
||||
"""
|
||||
match = bool(re.fullmatch(glob_to_re(pattern), str(path)))
|
||||
if pattern.endswith("/") or pattern.endswith("\\"):
|
||||
return match and path.is_dir()
|
||||
return match
|
@ -0,0 +1,194 @@
|
||||
"""Migrate LangChain to the most recent version."""
|
||||
# Adapted from bump-pydantic
|
||||
# https://github.com/pydantic/bump-pydantic
|
||||
import difflib
|
||||
import functools
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Tuple, Type, TypeVar, Union
|
||||
|
||||
import libcst as cst
|
||||
import rich
|
||||
import typer
|
||||
from libcst.codemod import CodemodContext, ContextAwareTransformer
|
||||
from libcst.helpers import calculate_module_and_package
|
||||
from libcst.metadata import FullRepoManager, FullyQualifiedNameProvider, ScopeProvider
|
||||
from rich.console import Console
|
||||
from rich.progress import Progress
|
||||
from typer import Argument, Exit, Option, Typer
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from langchain_cli.namespaces.migrate.codemods import Rule, gather_codemods
|
||||
from langchain_cli.namespaces.migrate.glob_helpers import match_glob
|
||||
|
||||
app = Typer(invoke_without_command=True, add_completion=False)
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
DEFAULT_IGNORES = [".venv/**"]
|
||||
|
||||
|
||||
@app.callback()
|
||||
def main(
|
||||
path: Path = Argument(..., exists=True, dir_okay=True, allow_dash=False),
|
||||
disable: List[Rule] = Option(default=[], help="Disable a rule."),
|
||||
diff: bool = Option(False, help="Show diff instead of applying changes."),
|
||||
ignore: List[str] = Option(
|
||||
default=DEFAULT_IGNORES, help="Ignore a path glob pattern."
|
||||
),
|
||||
log_file: Path = Option("log.txt", help="Log errors to this file."),
|
||||
):
|
||||
"""Migrate langchain to the most recent version."""
|
||||
if not diff:
|
||||
rich.print("[bold red]Alert![/ bold red] langchain-cli migrate", end=": ")
|
||||
if not typer.confirm(
|
||||
"The migration process will modify your files. "
|
||||
"The migration is a `best-effort` process and is not expected to "
|
||||
"be perfect. "
|
||||
"Do you want to continue?"
|
||||
):
|
||||
raise Exit()
|
||||
console = Console(log_time=True)
|
||||
console.log("Start langchain-cli migrate")
|
||||
# NOTE: LIBCST_PARSER_TYPE=native is required according to https://github.com/Instagram/LibCST/issues/487.
|
||||
os.environ["LIBCST_PARSER_TYPE"] = "native"
|
||||
|
||||
if os.path.isfile(path):
|
||||
package = path.parent
|
||||
all_files = [path]
|
||||
else:
|
||||
package = path
|
||||
all_files = sorted(package.glob("**/*.py"))
|
||||
|
||||
filtered_files = [
|
||||
file
|
||||
for file in all_files
|
||||
if not any(match_glob(file, pattern) for pattern in ignore)
|
||||
]
|
||||
files = [str(file.relative_to(".")) for file in filtered_files]
|
||||
|
||||
if len(files) == 1:
|
||||
console.log("Found 1 file to process.")
|
||||
elif len(files) > 1:
|
||||
console.log(f"Found {len(files)} files to process.")
|
||||
else:
|
||||
console.log("No files to process.")
|
||||
raise Exit()
|
||||
|
||||
providers = {FullyQualifiedNameProvider, ScopeProvider}
|
||||
metadata_manager = FullRepoManager(".", files, providers=providers) # type: ignore[arg-type]
|
||||
metadata_manager.resolve_cache()
|
||||
|
||||
scratch: dict[str, Any] = {}
|
||||
start_time = time.time()
|
||||
|
||||
codemods = gather_codemods(disabled=disable)
|
||||
|
||||
log_fp = log_file.open("a+", encoding="utf8")
|
||||
partial_run_codemods = functools.partial(
|
||||
run_codemods, codemods, metadata_manager, scratch, package, diff
|
||||
)
|
||||
with Progress(*Progress.get_default_columns(), transient=True) as progress:
|
||||
task = progress.add_task(description="Executing codemods...", total=len(files))
|
||||
count_errors = 0
|
||||
difflines: List[List[str]] = []
|
||||
with multiprocessing.Pool() as pool:
|
||||
for error, _difflines in pool.imap_unordered(partial_run_codemods, files):
|
||||
progress.advance(task)
|
||||
|
||||
if _difflines is not None:
|
||||
difflines.append(_difflines)
|
||||
|
||||
if error is not None:
|
||||
count_errors += 1
|
||||
log_fp.writelines(error)
|
||||
|
||||
modified = [Path(f) for f in files if os.stat(f).st_mtime > start_time]
|
||||
|
||||
if not diff:
|
||||
if modified:
|
||||
console.log(f"Refactored {len(modified)} files.")
|
||||
else:
|
||||
console.log("No files were modified.")
|
||||
|
||||
for _difflines in difflines:
|
||||
color_diff(console, _difflines)
|
||||
|
||||
if count_errors > 0:
|
||||
console.log(f"Found {count_errors} errors. Please check the {log_file} file.")
|
||||
else:
|
||||
console.log("Run successfully!")
|
||||
|
||||
if difflines:
|
||||
raise Exit(1)
|
||||
|
||||
|
||||
def run_codemods(
|
||||
codemods: List[Type[ContextAwareTransformer]],
|
||||
metadata_manager: FullRepoManager,
|
||||
scratch: Dict[str, Any],
|
||||
package: Path,
|
||||
diff: bool,
|
||||
filename: str,
|
||||
) -> Tuple[Union[str, None], Union[List[str], None]]:
|
||||
try:
|
||||
module_and_package = calculate_module_and_package(str(package), filename)
|
||||
context = CodemodContext(
|
||||
metadata_manager=metadata_manager,
|
||||
filename=filename,
|
||||
full_module_name=module_and_package.name,
|
||||
full_package_name=module_and_package.package,
|
||||
)
|
||||
context.scratch.update(scratch)
|
||||
|
||||
file_path = Path(filename)
|
||||
with file_path.open("r+", encoding="utf-8") as fp:
|
||||
code = fp.read()
|
||||
fp.seek(0)
|
||||
|
||||
input_tree = cst.parse_module(code)
|
||||
|
||||
for codemod in codemods:
|
||||
transformer = codemod(context=context)
|
||||
output_tree = transformer.transform_module(input_tree)
|
||||
input_tree = output_tree
|
||||
|
||||
output_code = input_tree.code
|
||||
if code != output_code:
|
||||
if diff:
|
||||
lines = difflib.unified_diff(
|
||||
code.splitlines(keepends=True),
|
||||
output_code.splitlines(keepends=True),
|
||||
fromfile=filename,
|
||||
tofile=filename,
|
||||
)
|
||||
return None, list(lines)
|
||||
else:
|
||||
fp.write(output_code)
|
||||
fp.truncate()
|
||||
return None, None
|
||||
except cst.ParserSyntaxError as exc:
|
||||
return (
|
||||
f"A syntax error happened on {filename}. This file cannot be "
|
||||
f"formatted.\n"
|
||||
f"{exc}"
|
||||
), None
|
||||
except Exception:
|
||||
return f"An error happened on {filename}.\n{traceback.format_exc()}", None
|
||||
|
||||
|
||||
def color_diff(console: Console, lines: Iterable[str]) -> None:
|
||||
for line in lines:
|
||||
line = line.rstrip("\n")
|
||||
if line.startswith("+"):
|
||||
console.print(line, style="green")
|
||||
elif line.startswith("-"):
|
||||
console.print(line, style="red")
|
||||
elif line.startswith("^"):
|
||||
console.print(line, style="blue")
|
||||
else:
|
||||
console.print(line, style="white")
|
@ -0,0 +1,78 @@
|
||||
"""Script to generate migrations for the migration script."""
|
||||
import json
|
||||
import pkgutil
|
||||
|
||||
import click
|
||||
|
||||
from langchain_cli.namespaces.migrate.generate.langchain import (
|
||||
generate_simplified_migrations,
|
||||
)
|
||||
from langchain_cli.namespaces.migrate.generate.partner import (
|
||||
get_migrations_for_partner_package,
|
||||
)
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
"""Migration scripts management."""
|
||||
pass
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option(
|
||||
"--output",
|
||||
default="langchain_migrations.json",
|
||||
help="Output file for the migration script.",
|
||||
)
|
||||
def langchain(output: str) -> None:
|
||||
"""Generate a migration script."""
|
||||
click.echo("Migration script generated.")
|
||||
migrations = generate_simplified_migrations()
|
||||
with open(output, "w") as f:
|
||||
f.write(json.dumps(migrations, indent=2, sort_keys=True))
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("pkg")
|
||||
@click.option("--output", default=None, help="Output file for the migration script.")
|
||||
def partner(pkg: str, output: str) -> None:
|
||||
"""Generate migration scripts specifically for LangChain modules."""
|
||||
click.echo("Migration script for LangChain generated.")
|
||||
migrations = get_migrations_for_partner_package(pkg)
|
||||
# Run with python 3.9+
|
||||
output_name = f"{pkg.removeprefix('langchain_')}.json" if output is None else output
|
||||
if migrations:
|
||||
with open(output_name, "w") as f:
|
||||
f.write(json.dumps(migrations, indent=2, sort_keys=True))
|
||||
click.secho(f"LangChain migration script saved to {output_name}")
|
||||
else:
|
||||
click.secho(f"No migrations found for {pkg}", fg="yellow")
|
||||
|
||||
|
||||
@cli.command()
|
||||
def all_installed_partner_pkgs() -> None:
|
||||
"""Generate migration scripts for all LangChain modules."""
|
||||
# Will generate migrations for all pather packages.
|
||||
# Define as "langchain_<partner_name>".
|
||||
# First let's determine which packages are installed in the environment
|
||||
# and then generate migrations for them.
|
||||
langchain_pkgs = [
|
||||
name
|
||||
for _, name, _ in pkgutil.iter_modules()
|
||||
if name.startswith("langchain_")
|
||||
and name not in {"langchain_core", "langchain_cli", "langchain_community"}
|
||||
]
|
||||
for pkg in langchain_pkgs:
|
||||
migrations = get_migrations_for_partner_package(pkg)
|
||||
# Run with python 3.9+
|
||||
output_name = f"{pkg.removeprefix('langchain_')}.json"
|
||||
if migrations:
|
||||
with open(output_name, "w") as f:
|
||||
f.write(json.dumps(migrations, indent=2, sort_keys=True))
|
||||
click.secho(f"LangChain migration script saved to {output_name}")
|
||||
else:
|
||||
click.secho(f"No migrations found for {pkg}", fg="yellow")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
@ -0,0 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .file import File
|
||||
from .folder import Folder
|
||||
|
||||
|
||||
@dataclass
|
||||
class Case:
|
||||
source: Folder | File
|
||||
expected: Folder | File
|
||||
name: str
|
@ -0,0 +1,15 @@
|
||||
from tests.unit_tests.migrate.cli_runner.case import Case
|
||||
from tests.unit_tests.migrate.cli_runner.cases import imports
|
||||
from tests.unit_tests.migrate.cli_runner.file import File
|
||||
from tests.unit_tests.migrate.cli_runner.folder import Folder
|
||||
|
||||
cases = [
|
||||
Case(
|
||||
name="empty",
|
||||
source=File("__init__.py", content=[]),
|
||||
expected=File("__init__.py", content=[]),
|
||||
),
|
||||
*imports.cases,
|
||||
]
|
||||
before = Folder("project", *[case.source for case in cases])
|
||||
expected = Folder("project", *[case.expected for case in cases])
|
@ -0,0 +1,32 @@
|
||||
from tests.unit_tests.migrate.cli_runner.case import Case
|
||||
from tests.unit_tests.migrate.cli_runner.file import File
|
||||
|
||||
cases = [
|
||||
Case(
|
||||
name="Imports",
|
||||
source=File(
|
||||
"app.py",
|
||||
content=[
|
||||
"from langchain_community.chat_models import ChatOpenAI",
|
||||
"",
|
||||
"",
|
||||
"class foo:",
|
||||
" a: int",
|
||||
"",
|
||||
"chain = ChatOpenAI()",
|
||||
],
|
||||
),
|
||||
expected=File(
|
||||
"app.py",
|
||||
content=[
|
||||
"from langchain_openai import ChatOpenAI",
|
||||
"",
|
||||
"",
|
||||
"class foo:",
|
||||
" a: int",
|
||||
"",
|
||||
"chain = ChatOpenAI()",
|
||||
],
|
||||
),
|
||||
),
|
||||
]
|
@ -0,0 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class File:
|
||||
def __init__(self, name: str, content: list[str] | None = None) -> None:
|
||||
self.name = name
|
||||
self.content = "\n".join(content or [])
|
||||
|
||||
def __eq__(self, __value: object) -> bool:
|
||||
if not isinstance(__value, File):
|
||||
return NotImplemented
|
||||
|
||||
if self.name != __value.name:
|
||||
return False
|
||||
|
||||
return self.content == __value.content
|
@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from .file import File
|
||||
|
||||
|
||||
class Folder:
|
||||
def __init__(self, name: str, *files: Folder | File) -> None:
|
||||
self.name = name
|
||||
self._files = files
|
||||
|
||||
@property
|
||||
def files(self) -> list[Folder | File]:
|
||||
return sorted(self._files, key=lambda f: f.name)
|
||||
|
||||
def create_structure(self, root: Path) -> None:
|
||||
path = root / self.name
|
||||
path.mkdir()
|
||||
|
||||
for file in self.files:
|
||||
if isinstance(file, Folder):
|
||||
file.create_structure(path)
|
||||
else:
|
||||
(path / file.name).write_text(file.content, encoding="utf-8")
|
||||
|
||||
@classmethod
|
||||
def from_structure(cls, root: Path) -> Folder:
|
||||
name = root.name
|
||||
files: list[File | Folder] = []
|
||||
|
||||
for path in root.iterdir():
|
||||
if path.is_dir():
|
||||
files.append(cls.from_structure(path))
|
||||
else:
|
||||
files.append(
|
||||
File(path.name, path.read_text(encoding="utf-8").splitlines())
|
||||
)
|
||||
|
||||
return Folder(name, *files)
|
||||
|
||||
def __eq__(self, __value: object) -> bool:
|
||||
if isinstance(__value, File):
|
||||
return False
|
||||
|
||||
if not isinstance(__value, Folder):
|
||||
return NotImplemented
|
||||
|
||||
if self.name != __value.name:
|
||||
return False
|
||||
|
||||
if len(self.files) != len(__value.files):
|
||||
return False
|
||||
|
||||
for self_file, other_file in zip(self.files, __value.files):
|
||||
if self_file != other_file:
|
||||
return False
|
||||
|
||||
return True
|
@ -0,0 +1,55 @@
|
||||
# ruff: noqa: E402
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("libcst")
|
||||
|
||||
import difflib
|
||||
from pathlib import Path
|
||||
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from langchain_cli.namespaces.migrate.main import app
|
||||
from tests.unit_tests.migrate.cli_runner.cases import before, expected
|
||||
from tests.unit_tests.migrate.cli_runner.folder import Folder
|
||||
|
||||
|
||||
def find_issue(current: Folder, expected: Folder) -> str:
|
||||
for current_file, expected_file in zip(current.files, expected.files):
|
||||
if current_file != expected_file:
|
||||
if current_file.name != expected_file.name:
|
||||
return (
|
||||
f"Files have "
|
||||
f"different names: {current_file.name} != {expected_file.name}"
|
||||
)
|
||||
if isinstance(current_file, Folder) and isinstance(expected_file, Folder):
|
||||
return find_issue(current_file, expected_file)
|
||||
elif isinstance(current_file, Folder) or isinstance(expected_file, Folder):
|
||||
return (
|
||||
f"One of the files is a "
|
||||
f"folder: {current_file.name} != {expected_file.name}"
|
||||
)
|
||||
return "\n".join(
|
||||
difflib.unified_diff(
|
||||
current_file.content.splitlines(),
|
||||
expected_file.content.splitlines(),
|
||||
fromfile=current_file.name,
|
||||
tofile=expected_file.name,
|
||||
)
|
||||
)
|
||||
return "Unknown"
|
||||
|
||||
|
||||
def test_command_line(tmp_path: Path) -> None:
|
||||
runner = CliRunner()
|
||||
|
||||
with runner.isolated_filesystem(temp_dir=tmp_path) as td:
|
||||
before.create_structure(root=Path(td))
|
||||
# The input is used to force through the confirmation.
|
||||
result = runner.invoke(app, [before.name], input="y\n")
|
||||
assert result.exit_code == 0, result.output
|
||||
|
||||
after = Folder.from_structure(Path(td) / before.name)
|
||||
|
||||
assert after == expected, find_issue(after, expected)
|
@ -0,0 +1,25 @@
|
||||
from langchain_cli.namespaces.migrate.generate.langchain import (
|
||||
generate_simplified_migrations,
|
||||
)
|
||||
|
||||
|
||||
def test_create_json_agent_migration() -> None:
|
||||
"""Test the migration of create_json_agent from langchain to langchain_community."""
|
||||
raw_migrations = generate_simplified_migrations()
|
||||
json_agent_migrations = [
|
||||
migration for migration in raw_migrations if "create_json_agent" in migration[0]
|
||||
]
|
||||
assert json_agent_migrations == [
|
||||
(
|
||||
"langchain.agents.create_json_agent",
|
||||
"langchain_community.agent_toolkits.create_json_agent",
|
||||
),
|
||||
(
|
||||
"langchain.agents.agent_toolkits.create_json_agent",
|
||||
"langchain_community.agent_toolkits.create_json_agent",
|
||||
),
|
||||
(
|
||||
"langchain.agents.agent_toolkits.json.base.create_json_agent",
|
||||
"langchain_community.agent_toolkits.create_json_agent",
|
||||
),
|
||||
]
|
@ -0,0 +1,46 @@
|
||||
import pytest
|
||||
|
||||
from langchain_cli.namespaces.migrate.generate.partner import (
|
||||
get_migrations_for_partner_package,
|
||||
)
|
||||
|
||||
pytest.importorskip(modname="langchain_openai")
|
||||
|
||||
|
||||
def test_generate_migrations() -> None:
|
||||
migrations = get_migrations_for_partner_package("langchain_openai")
|
||||
assert migrations == [
|
||||
("langchain_community.llms.openai.OpenAI", "langchain_openai.OpenAI"),
|
||||
("langchain_community.llms.openai.AzureOpenAI", "langchain_openai.AzureOpenAI"),
|
||||
(
|
||||
"langchain_community.embeddings.openai.OpenAIEmbeddings",
|
||||
"langchain_openai.OpenAIEmbeddings",
|
||||
),
|
||||
(
|
||||
"langchain_community.embeddings.azure_openai.AzureOpenAIEmbeddings",
|
||||
"langchain_openai.AzureOpenAIEmbeddings",
|
||||
),
|
||||
(
|
||||
"langchain_community.chat_models.openai.ChatOpenAI",
|
||||
"langchain_openai.ChatOpenAI",
|
||||
),
|
||||
(
|
||||
"langchain_community.chat_models.azure_openai.AzureChatOpenAI",
|
||||
"langchain_openai.AzureChatOpenAI",
|
||||
),
|
||||
("langchain_community.llms.AzureOpenAI", "langchain_openai.AzureOpenAI"),
|
||||
("langchain_community.llms.OpenAI", "langchain_openai.OpenAI"),
|
||||
(
|
||||
"langchain_community.embeddings.AzureOpenAIEmbeddings",
|
||||
"langchain_openai.AzureOpenAIEmbeddings",
|
||||
),
|
||||
(
|
||||
"langchain_community.embeddings.OpenAIEmbeddings",
|
||||
"langchain_openai.OpenAIEmbeddings",
|
||||
),
|
||||
(
|
||||
"langchain_community.chat_models.AzureChatOpenAI",
|
||||
"langchain_openai.AzureChatOpenAI",
|
||||
),
|
||||
("langchain_community.chat_models.ChatOpenAI", "langchain_openai.ChatOpenAI"),
|
||||
]
|
@ -0,0 +1,5 @@
|
||||
from langchain_cli.namespaces.migrate.generate.utils import PKGS_ROOT
|
||||
|
||||
|
||||
def test_root() -> None:
|
||||
assert PKGS_ROOT.name == "libs"
|
@ -0,0 +1,72 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_cli.namespaces.migrate.glob_helpers import glob_to_re, match_glob
|
||||
|
||||
|
||||
class TestGlobHelpers:
|
||||
match_glob_values: list[tuple[str, Path, bool]] = [
|
||||
("foo", Path("foo"), True),
|
||||
("foo", Path("bar"), False),
|
||||
("foo", Path("foo/bar"), False),
|
||||
("*", Path("foo"), True),
|
||||
("*", Path("bar"), True),
|
||||
("*", Path("foo/bar"), False),
|
||||
("**", Path("foo"), True),
|
||||
("**", Path("foo/bar"), True),
|
||||
("**", Path("foo/bar/baz/qux"), True),
|
||||
("foo/bar", Path("foo/bar"), True),
|
||||
("foo/bar", Path("foo"), False),
|
||||
("foo/bar", Path("far"), False),
|
||||
("foo/bar", Path("foo/foo"), False),
|
||||
("foo/*", Path("foo/bar"), True),
|
||||
("foo/*", Path("foo/bar/baz"), False),
|
||||
("foo/*", Path("foo"), False),
|
||||
("foo/*", Path("bar"), False),
|
||||
("foo/**", Path("foo/bar"), True),
|
||||
("foo/**", Path("foo/bar/baz"), True),
|
||||
("foo/**", Path("foo/bar/baz/qux"), True),
|
||||
("foo/**", Path("foo"), True),
|
||||
("foo/**", Path("bar"), False),
|
||||
("foo/**/bar", Path("foo/bar"), True),
|
||||
("foo/**/bar", Path("foo/baz/bar"), True),
|
||||
("foo/**/bar", Path("foo/baz/qux/bar"), True),
|
||||
("foo/**/bar", Path("foo/baz/qux"), False),
|
||||
("foo/**/bar", Path("foo/bar/baz"), False),
|
||||
("foo/**/bar", Path("foo/bar/bar"), True),
|
||||
("foo/**/bar", Path("foo"), False),
|
||||
("foo/**/bar", Path("bar"), False),
|
||||
("foo/**/*/bar", Path("foo/bar"), False),
|
||||
("foo/**/*/bar", Path("foo/baz/bar"), True),
|
||||
("foo/**/*/bar", Path("foo/baz/qux/bar"), True),
|
||||
("foo/**/*/bar", Path("foo/baz/qux"), False),
|
||||
("foo/**/*/bar", Path("foo/bar/baz"), False),
|
||||
("foo/**/*/bar", Path("foo/bar/bar"), True),
|
||||
("foo/**/*/bar", Path("foo"), False),
|
||||
("foo/**/*/bar", Path("bar"), False),
|
||||
("foo/ba*", Path("foo/bar"), True),
|
||||
("foo/ba*", Path("foo/baz"), True),
|
||||
("foo/ba*", Path("foo/qux"), False),
|
||||
("foo/ba*", Path("foo/baz/qux"), False),
|
||||
("foo/ba*", Path("foo/bar/baz"), False),
|
||||
("foo/ba*", Path("foo"), False),
|
||||
("foo/ba*", Path("bar"), False),
|
||||
("foo/**/ba*/*/qux", Path("foo/a/b/c/bar/a/qux"), True),
|
||||
("foo/**/ba*/*/qux", Path("foo/a/b/c/baz/a/qux"), True),
|
||||
("foo/**/ba*/*/qux", Path("foo/a/bar/a/qux"), True),
|
||||
("foo/**/ba*/*/qux", Path("foo/baz/a/qux"), True),
|
||||
("foo/**/ba*/*/qux", Path("foo/baz/qux"), False),
|
||||
("foo/**/ba*/*/qux", Path("foo/a/b/c/qux/a/qux"), False),
|
||||
("foo/**/ba*/*/qux", Path("foo"), False),
|
||||
("foo/**/ba*/*/qux", Path("bar"), False),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize(("pattern", "path", "expected"), match_glob_values)
|
||||
def test_match_glob(self, pattern: str, path: Path, expected: bool):
|
||||
expr = glob_to_re(pattern)
|
||||
assert (
|
||||
match_glob(path, pattern) == expected
|
||||
), f"path: {path}, pattern: {pattern}, expr: {expr}"
|
@ -0,0 +1,51 @@
|
||||
# ruff: noqa: E402
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("libcst")
|
||||
|
||||
|
||||
from libcst.codemod import CodemodTest
|
||||
|
||||
from langchain_cli.namespaces.migrate.codemods.replace_imports import (
|
||||
ReplaceImportsCodemod,
|
||||
)
|
||||
|
||||
|
||||
class TestReplaceImportsCommand(CodemodTest):
|
||||
TRANSFORM = ReplaceImportsCodemod
|
||||
|
||||
def test_single_import(self) -> None:
|
||||
before = """
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
"""
|
||||
after = """
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
"""
|
||||
self.assertCodemod(before, after)
|
||||
|
||||
def test_from_community_to_partner(self) -> None:
|
||||
"""Test that we can replace imports from community to partner."""
|
||||
before = """
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
"""
|
||||
after = """
|
||||
from langchain_openai import ChatOpenAI
|
||||
"""
|
||||
self.assertCodemod(before, after)
|
||||
|
||||
def test_noop_import(self) -> None:
|
||||
code = """
|
||||
from foo import ChatOpenAI
|
||||
"""
|
||||
self.assertCodemod(code, code)
|
||||
|
||||
def test_mixed_imports(self) -> None:
|
||||
before = """
|
||||
from langchain_community.chat_models import ChatOpenAI, ChatAnthropic, foo
|
||||
"""
|
||||
after = """
|
||||
from langchain_community.chat_models import foo
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_openai import ChatOpenAI
|
||||
"""
|
||||
self.assertCodemod(before, after)
|
@ -1,12 +1,12 @@
|
||||
import os
|
||||
|
||||
from langchain_upstage import GroundednessCheck
|
||||
from langchain_upstage import UpstageGroundednessCheck
|
||||
|
||||
os.environ["UPSTAGE_API_KEY"] = "foo"
|
||||
|
||||
|
||||
def test_initialization() -> None:
|
||||
"""Test embedding model initialization."""
|
||||
GroundednessCheck()
|
||||
GroundednessCheck(upstage_api_key="key")
|
||||
GroundednessCheck(api_key="key")
|
||||
UpstageGroundednessCheck()
|
||||
UpstageGroundednessCheck(upstage_api_key="key")
|
||||
UpstageGroundednessCheck(api_key="key")
|
||||
|
Loading…
Reference in New Issue