Merge branch 'master' into cc/retriever_score

pull/20800/head
Chester Curme 1 month ago
commit c9fc0447ec

@ -17,15 +17,14 @@
"from typing import List\n",
"\n",
"from langchain_community.vectorstores import DocArrayInMemorySearch\n",
"from langchain_core.documents.base import Document\n",
"from langchain_core.output_parsers import StrOutputParser\n",
"from langchain_core.prompts import ChatPromptTemplate\n",
"from langchain_core.runnables import RunnablePassthrough\n",
"from langchain_core.runnables.base import RunnableSerializable\n",
"from langchain_upstage import (\n",
" ChatUpstage,\n",
" GroundednessCheck,\n",
" UpstageEmbeddings,\n",
" UpstageGroundednessCheck,\n",
" UpstageLayoutAnalysisLoader,\n",
")\n",
"\n",
@ -50,7 +49,7 @@
"\n",
"retrieved_docs = retriever.get_relevant_documents(\"How many parameters in SOLAR model?\")\n",
"\n",
"groundedness_check = GroundednessCheck()\n",
"groundedness_check = UpstageGroundednessCheck()\n",
"groundedness = \"\"\n",
"while groundedness != \"grounded\":\n",
" chain: RunnableSerializable = RunnablePassthrough() | prompt | model | output_parser\n",
@ -62,14 +61,10 @@
" }\n",
" )\n",
"\n",
" # convert all Documents to string\n",
" def formatDocumentsAsString(docs: List[Document]) -> str:\n",
" return \"\\n\".join([doc.page_content for doc in docs])\n",
"\n",
" groundedness = groundedness_check.run(\n",
" groundedness = groundedness_check.invoke(\n",
" {\n",
" \"context\": formatDocumentsAsString(retrieved_docs),\n",
" \"query\": result,\n",
" \"context\": retrieved_docs,\n",
" \"answer\": result,\n",
" }\n",
" )"
]

@ -141,12 +141,71 @@
"chatLLM(messages)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tool Calling\n",
"ChatTongyi supports tool calling API that lets you describe tools and their arguments, and have the model return a JSON object with a tool to invoke and the inputs to that tool."
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": []
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='', additional_kwargs={'tool_calls': [{'function': {'name': 'get_current_weather', 'arguments': '{\"location\": \"San Francisco\"}'}, 'id': '', 'type': 'function'}]}, response_metadata={'model_name': 'qwen-turbo', 'finish_reason': 'tool_calls', 'request_id': 'dae79197-8780-9b7e-8c15-6a83e2a53534', 'token_usage': {'input_tokens': 229, 'output_tokens': 19, 'total_tokens': 248}}, id='run-9e06f837-582b-473b-bb1f-5e99a68ecc10-0', tool_calls=[{'name': 'get_current_weather', 'args': {'location': 'San Francisco'}, 'id': ''}])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain_community.chat_models.tongyi import ChatTongyi\n",
"from langchain_core.messages import HumanMessage, SystemMessage\n",
"\n",
"tools = [\n",
" {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_time\",\n",
" \"description\": \"当你想知道现在的时间时非常有用。\",\n",
" \"parameters\": {},\n",
" },\n",
" },\n",
" {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_weather\",\n",
" \"description\": \"当你想查询指定城市的天气时非常有用。\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"城市或县区,比如北京市、杭州市、余杭区等。\",\n",
" }\n",
" },\n",
" },\n",
" \"required\": [\"location\"],\n",
" },\n",
" },\n",
"]\n",
"\n",
"messages = [\n",
" SystemMessage(content=\"You are a helpful assistant.\"),\n",
" HumanMessage(content=\"What is the weather like in San Francisco?\"),\n",
"]\n",
"chatLLM = ChatTongyi()\n",
"llm_kwargs = {\"tools\": tools, \"result_format\": \"message\"}\n",
"ai_message = chatLLM.bind(**llm_kwargs).invoke(messages)\n",
"ai_message"
]
}
],
"metadata": {

@ -52,7 +52,7 @@
"| --- | --- | --- | --- |\n",
"| Chat | Build assistants using Solar Mini Chat | `from langchain_upstage import ChatUpstage` | [Go](../../chat/upstage) |\n",
"| Text Embedding | Embed strings to vectors | `from langchain_upstage import UpstageEmbeddings` | [Go](../../text_embedding/upstage) |\n",
"| Groundedness Check | Verify groundedness of assistant's response | `from langchain_upstage import GroundednessCheck` | [Go](../../tools/upstage_groundedness_check) |\n",
"| Groundedness Check | Verify groundedness of assistant's response | `from langchain_upstage import UpstageGroundednessCheck` | [Go](../../tools/upstage_groundedness_check) |\n",
"| Layout Analysis | Serialize documents with tables and figures | `from langchain_upstage import UpstageLayoutAnalysisLoader` | [Go](../../document_loaders/upstage) |\n",
"\n",
"See [documentations](https://developers.upstage.ai/) for more details about the features."
@ -145,15 +145,15 @@
},
"outputs": [],
"source": [
"from langchain_upstage import GroundednessCheck\n",
"from langchain_upstage import UpstageGroundednessCheck\n",
"\n",
"groundedness_check = GroundednessCheck()\n",
"groundedness_check = UpstageGroundednessCheck()\n",
"\n",
"request_input = {\n",
" \"context\": \"Mauna Kea is an inactive volcano on the island of Hawaii. Its peak is 4,207.3 m above sea level, making it the highest point in Hawaii and second-highest peak of an island on Earth.\",\n",
" \"query\": \"Mauna Kea is 5,207.3 meters tall.\",\n",
" \"answer\": \"Mauna Kea is 5,207.3 meters tall.\",\n",
"}\n",
"response = groundedness_check.run(request_input)\n",
"response = groundedness_check.invoke(request_input)\n",
"print(response)"
]
},

@ -48,7 +48,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "a83d4da0",
"metadata": {},
"outputs": [],
@ -65,21 +65,21 @@
"source": [
"## Usage\n",
"\n",
"Initialize `GroundednessCheck` class."
"Initialize `UpstageGroundednessCheck` class."
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"id": "b7373380c01cefbe",
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from langchain_upstage import GroundednessCheck\n",
"from langchain_upstage import UpstageGroundednessCheck\n",
"\n",
"groundedness_check = GroundednessCheck()"
"groundedness_check = UpstageGroundednessCheck()"
]
},
{
@ -92,38 +92,22 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"id": "1e0115e3b511f57",
"metadata": {
"collapsed": false,
"is_executing": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"content='notGrounded' response_metadata={'token_usage': {'completion_tokens': 6, 'prompt_tokens': 198, 'total_tokens': 204}, 'model_name': 'solar-1-mini-answer-verification', 'system_fingerprint': '', 'finish_reason': 'stop', 'logprobs': None} id='run-ce7b5787-2ed0-4a68-9de4-c0e91a824147-0'\n"
]
}
],
"outputs": [],
"source": [
"request_input = {\n",
" \"context\": \"Mauna Kea is an inactive volcano on the island of Hawai'i. Its peak is 4,207.3 m above sea level, making it the highest point in Hawaii and second-highest peak of an island on Earth.\",\n",
" \"query\": \"Mauna Kea is 5,207.3 meters tall.\",\n",
" \"answer\": \"Mauna Kea is 5,207.3 meters tall.\",\n",
"}\n",
"\n",
"response = groundedness_check.run(request_input)\n",
"response = groundedness_check.invoke(request_input)\n",
"print(response)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "054b5031",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {

@ -1,3 +1,4 @@
import importlib
from typing import Optional
import typer
@ -22,6 +23,13 @@ app.add_typer(
)
# If libcst is installed, add the migrate namespace
if importlib.util.find_spec("libcst"):
from langchain_cli.namespaces.migrate import main as migrate_namespace
app.add_typer(migrate_namespace.app, name="migrate", help=migrate_namespace.__doc__)
def version_callback(show_version: bool) -> None:
if show_version:
typer.echo(f"langchain-cli {__version__}")

@ -0,0 +1,25 @@
from enum import Enum
from typing import List, Type
from libcst.codemod import ContextAwareTransformer
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor
from langchain_cli.namespaces.migrate.codemods.replace_imports import (
ReplaceImportsCodemod,
)
class Rule(str, Enum):
R001 = "R001"
"""Replace imports that have been moved."""
def gather_codemods(disabled: List[Rule]) -> List[Type[ContextAwareTransformer]]:
codemods: List[Type[ContextAwareTransformer]] = []
if Rule.R001 not in disabled:
codemods.append(ReplaceImportsCodemod)
# Those codemods need to be the last ones.
codemods.extend([RemoveImportsVisitor, AddImportsVisitor])
return codemods

@ -0,0 +1,18 @@
[
[
"langchain_community.llms.anthropic.Anthropic",
"langchain_anthropic.Anthropic"
],
[
"langchain_community.chat_models.anthropic.ChatAnthropic",
"langchain_anthropic.ChatAnthropic"
],
[
"langchain_community.llms.Anthropic",
"langchain_anthropic.Anthropic"
],
[
"langchain_community.chat_models.ChatAnthropic",
"langchain_anthropic.ChatAnthropic"
]
]

@ -0,0 +1,18 @@
[
[
"langchain_community.llms.fireworks.Fireworks",
"langchain_fireworks.Fireworks"
],
[
"langchain_community.chat_models.fireworks.ChatFireworks",
"langchain_fireworks.ChatFireworks"
],
[
"langchain_community.llms.Fireworks",
"langchain_fireworks.Fireworks"
],
[
"langchain_community.chat_models.ChatFireworks",
"langchain_fireworks.ChatFireworks"
]
]

@ -0,0 +1,10 @@
[
[
"langchain_community.llms.watsonxllm.WatsonxLLM",
"langchain_ibm.WatsonxLLM"
],
[
"langchain_community.llms.WatsonxLLM",
"langchain_ibm.WatsonxLLM"
]
]

@ -0,0 +1,50 @@
[
[
"langchain_community.llms.openai.OpenAI",
"langchain_openai.OpenAI"
],
[
"langchain_community.llms.openai.AzureOpenAI",
"langchain_openai.AzureOpenAI"
],
[
"langchain_community.embeddings.openai.OpenAIEmbeddings",
"langchain_openai.OpenAIEmbeddings"
],
[
"langchain_community.embeddings.azure_openai.AzureOpenAIEmbeddings",
"langchain_openai.AzureOpenAIEmbeddings"
],
[
"langchain_community.chat_models.openai.ChatOpenAI",
"langchain_openai.ChatOpenAI"
],
[
"langchain_community.chat_models.azure_openai.AzureChatOpenAI",
"langchain_openai.AzureChatOpenAI"
],
[
"langchain_community.llms.AzureOpenAI",
"langchain_openai.AzureOpenAI"
],
[
"langchain_community.llms.OpenAI",
"langchain_openai.OpenAI"
],
[
"langchain_community.embeddings.AzureOpenAIEmbeddings",
"langchain_openai.AzureOpenAIEmbeddings"
],
[
"langchain_community.embeddings.OpenAIEmbeddings",
"langchain_openai.OpenAIEmbeddings"
],
[
"langchain_community.chat_models.AzureChatOpenAI",
"langchain_openai.AzureChatOpenAI"
],
[
"langchain_community.chat_models.ChatOpenAI",
"langchain_openai.ChatOpenAI"
]
]

@ -0,0 +1,10 @@
[
[
"langchain_community.vectorstores.pinecone.Pinecone",
"langchain_pinecone.Pinecone"
],
[
"langchain_community.vectorstores.Pinecone",
"langchain_pinecone.Pinecone"
]
]

@ -0,0 +1,214 @@
"""
# Adapted from bump-pydantic
# https://github.com/pydantic/bump-pydantic
This codemod deals with the following cases:
1. `from pydantic import BaseSettings`
2. `from pydantic.settings import BaseSettings`
3. `from pydantic import BaseSettings as <name>`
4. `from pydantic.settings import BaseSettings as <name>` # TODO: This is not working.
5. `import pydantic` -> `pydantic.BaseSettings`
"""
from __future__ import annotations
import json
import os
from dataclasses import dataclass
from typing import Callable, Dict, Iterable, List, Sequence, Tuple, TypeVar
import libcst as cst
import libcst.matchers as m
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
from libcst.codemod.visitors import AddImportsVisitor
HERE = os.path.dirname(__file__)
def _load_migrations_by_file(path: str):
migrations_path = os.path.join(HERE, "migrations", path)
with open(migrations_path, "r", encoding="utf-8") as f:
data = json.load(f)
return data
T = TypeVar("T")
def _deduplicate_in_order(
seq: Iterable[T], key: Callable[[T], str] = lambda x: x
) -> List[T]:
seen = set()
seen_add = seen.add
return [x for x in seq if not (key(x) in seen or seen_add(key(x)))]
PARTNERS = [
"anthropic.json",
"ibm.json",
"openai.json",
"pinecone.json",
"fireworks.json",
]
def _load_migrations_from_fixtures() -> List[Tuple[str, str]]:
"""Load migrations from fixtures."""
paths: List[str] = PARTNERS + ["langchain.json"]
data = []
for path in paths:
data.extend(_load_migrations_by_file(path))
data = _deduplicate_in_order(data, key=lambda x: x[0])
return data
def _load_migrations():
"""Load the migrations from the JSON file."""
# Later earlier ones have higher precedence.
imports: Dict[str, Tuple[str, str]] = {}
data = _load_migrations_from_fixtures()
for old_path, new_path in data:
# Parse the old parse which is of the format 'langchain.chat_models.ChatOpenAI'
# into the module and class name.
old_parts = old_path.split(".")
old_module = ".".join(old_parts[:-1])
old_class = old_parts[-1]
old_path_str = f"{old_module}:{old_class}"
# Parse the new parse which is of the format 'langchain.chat_models.ChatOpenAI'
# Into a 2-tuple of the module and class name.
new_parts = new_path.split(".")
new_module = ".".join(new_parts[:-1])
new_class = new_parts[-1]
new_path_str = (new_module, new_class)
imports[old_path_str] = new_path_str
return imports
IMPORTS = _load_migrations()
def resolve_module_parts(module_parts: list[str]) -> m.Attribute | m.Name:
"""Converts a list of module parts to a `Name` or `Attribute` node."""
if len(module_parts) == 1:
return m.Name(module_parts[0])
if len(module_parts) == 2:
first, last = module_parts
return m.Attribute(value=m.Name(first), attr=m.Name(last))
last_name = module_parts.pop()
attr = resolve_module_parts(module_parts)
return m.Attribute(value=attr, attr=m.Name(last_name))
def get_import_from_from_str(import_str: str) -> m.ImportFrom:
"""Converts a string like `pydantic:BaseSettings` to Examples:
>>> get_import_from_from_str("pydantic:BaseSettings")
ImportFrom(
module=Name("pydantic"),
names=[ImportAlias(name=Name("BaseSettings"))],
)
>>> get_import_from_from_str("pydantic.settings:BaseSettings")
ImportFrom(
module=Attribute(value=Name("pydantic"), attr=Name("settings")),
names=[ImportAlias(name=Name("BaseSettings"))],
)
>>> get_import_from_from_str("a.b.c:d")
ImportFrom(
module=Attribute(
value=Attribute(value=Name("a"), attr=Name("b")), attr=Name("c")
),
names=[ImportAlias(name=Name("d"))],
)
"""
module, name = import_str.split(":")
module_parts = module.split(".")
module_node = resolve_module_parts(module_parts)
return m.ImportFrom(
module=module_node,
names=[m.ZeroOrMore(), m.ImportAlias(name=m.Name(value=name)), m.ZeroOrMore()],
)
@dataclass
class ImportInfo:
import_from: m.ImportFrom
import_str: str
to_import_str: tuple[str, str]
IMPORT_INFOS = [
ImportInfo(
import_from=get_import_from_from_str(import_str),
import_str=import_str,
to_import_str=to_import_str,
)
for import_str, to_import_str in IMPORTS.items()
]
IMPORT_MATCH = m.OneOf(*[info.import_from for info in IMPORT_INFOS])
class ReplaceImportsCodemod(VisitorBasedCodemodCommand):
@m.leave(IMPORT_MATCH)
def leave_replace_import(
self, _: cst.ImportFrom, updated_node: cst.ImportFrom
) -> cst.ImportFrom:
for import_info in IMPORT_INFOS:
if m.matches(updated_node, import_info.import_from):
aliases: Sequence[cst.ImportAlias] = updated_node.names # type: ignore
# If multiple objects are imported in a single import statement,
# we need to remove only the one we're replacing.
AddImportsVisitor.add_needed_import(
self.context, *import_info.to_import_str
)
if len(updated_node.names) > 1: # type: ignore
names = [
alias
for alias in aliases
if alias.name.value != import_info.to_import_str[-1]
]
names[-1] = names[-1].with_changes(comma=cst.MaybeSentinel.DEFAULT)
updated_node = updated_node.with_changes(names=names)
else:
return cst.RemoveFromParent() # type: ignore[return-value]
return updated_node
if __name__ == "__main__":
import textwrap
from rich.console import Console
console = Console()
source = textwrap.dedent(
"""
from pydantic.settings import BaseSettings
from pydantic.color import Color
from pydantic.payment import PaymentCardNumber, PaymentCardBrand
from pydantic import Color
from pydantic import Color as Potato
class Potato(BaseSettings):
color: Color
payment: PaymentCardNumber
brand: PaymentCardBrand
potato: Potato
"""
)
console.print(source)
console.print("=" * 80)
mod = cst.parse_module(source)
context = CodemodContext(filename="main.py")
wrapper = cst.MetadataWrapper(mod)
command = ReplaceImportsCodemod(context=context)
mod = wrapper.visit(command)
wrapper = cst.MetadataWrapper(mod)
command = AddImportsVisitor(context=context) # type: ignore[assignment]
mod = wrapper.visit(command)
console.print(mod.code)

@ -0,0 +1,52 @@
"""Generate migrations from langchain to langchain-community or core packages."""
import glob
from pathlib import Path
from typing import List, Tuple
from langchain_cli.namespaces.migrate.generate.utils import (
_get_current_module,
find_imports_from_package,
)
HERE = Path(__file__).parent
PKGS_ROOT = HERE.parent.parent.parent
LANGCHAIN_PKG = PKGS_ROOT / "langchain"
COMMUNITY_PKG = PKGS_ROOT / "community"
PARTNER_PKGS = PKGS_ROOT / "partners"
def _generate_migrations_from_file(
source_module: str, code: str, *, from_package: str
) -> List[Tuple[str, str]]:
"""Generate migrations"""
imports = find_imports_from_package(code, from_package=from_package)
return [
# Rewrite in a list comprehension
(f"{source_module}.{item}", f"{new_path}.{item}")
for new_path, item in imports
]
def _generate_migrations_from_file_in_pkg(
file: str, root_pkg: str
) -> List[Tuple[str, str]]:
"""Generate migrations for a file that's relative to langchain pkg."""
# Read the file.
with open(file, encoding="utf-8") as f:
code = f.read()
module_name = _get_current_module(file, root_pkg)
return _generate_migrations_from_file(
module_name, code, from_package="langchain_community"
)
def generate_migrations_from_langchain_to_community() -> List[Tuple[str, str]]:
"""Generate migrations from langchain to langchain-community."""
migrations = []
# scanning files in pkg
for file_path in glob.glob(str(LANGCHAIN_PKG) + "**/*.py"):
migrations.extend(
_generate_migrations_from_file_in_pkg(file_path, str(LANGCHAIN_PKG))
)
return migrations

@ -0,0 +1,54 @@
"""Generate migrations for partner packages."""
import importlib
from typing import List, Tuple
from langchain_core.documents import BaseDocumentCompressor, BaseDocumentTransformer
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLanguageModel
from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStore
from langchain_cli.namespaces.migrate.generate.utils import (
COMMUNITY_PKG,
find_subclasses_in_module,
list_classes_by_package,
list_init_imports_by_package,
)
# PUBLIC API
def get_migrations_for_partner_package(pkg_name: str) -> List[Tuple[str, str]]:
"""Generate migrations from community package to partner package.
This code works
Args:
pkg_name (str): The name of the partner package.
Returns:
List of 2-tuples containing old and new import paths.
"""
package = importlib.import_module(pkg_name)
classes_ = find_subclasses_in_module(
package,
[
BaseLanguageModel,
Embeddings,
BaseRetriever,
VectorStore,
BaseDocumentTransformer,
BaseDocumentCompressor,
],
)
community_classes = list_classes_by_package(str(COMMUNITY_PKG))
imports_for_pkg = list_init_imports_by_package(str(COMMUNITY_PKG))
old_paths = community_classes + imports_for_pkg
migrations = [
(f"{module}.{item}", f"{pkg_name}.{item}")
for module, item in old_paths
if item in classes_
]
return migrations

@ -0,0 +1,158 @@
import ast
import inspect
import os
import pathlib
from pathlib import Path
from typing import Any, List, Optional, Tuple, Type
HERE = Path(__file__).parent
# Should bring us to [root]/src
PKGS_ROOT = HERE.parent.parent.parent.parent.parent
LANGCHAIN_PKG = PKGS_ROOT / "langchain"
COMMUNITY_PKG = PKGS_ROOT / "community"
PARTNER_PKGS = PKGS_ROOT / "partners"
class ImportExtractor(ast.NodeVisitor):
def __init__(self, *, from_package: Optional[str] = None) -> None:
"""Extract all imports from the given code, optionally filtering by package."""
self.imports = []
self.package = from_package
def visit_ImportFrom(self, node):
if node.module and (
self.package is None or str(node.module).startswith(self.package)
):
for alias in node.names:
self.imports.append((node.module, alias.name))
self.generic_visit(node)
def _get_class_names(code: str) -> List[str]:
"""Extract class names from a code string."""
# Parse the content of the file into an AST
tree = ast.parse(code)
# Initialize a list to hold all class names
class_names = []
# Define a node visitor class to collect class names
class ClassVisitor(ast.NodeVisitor):
def visit_ClassDef(self, node):
class_names.append(node.name)
self.generic_visit(node)
# Create an instance of the visitor and visit the AST
visitor = ClassVisitor()
visitor.visit(tree)
return class_names
def is_subclass(class_obj: Any, classes_: List[Type]) -> bool:
"""Check if the given class object is a subclass of any class in list classes."""
return any(
issubclass(class_obj, kls)
for kls in classes_
if inspect.isclass(class_obj) and inspect.isclass(kls)
)
def find_subclasses_in_module(module, classes_: List[Type]) -> List[str]:
"""Find all classes in the module that inherit from one of the classes."""
subclasses = []
# Iterate over all attributes of the module that are classes
for name, obj in inspect.getmembers(module, inspect.isclass):
if is_subclass(obj, classes_):
subclasses.append(obj.__name__)
return subclasses
def _get_all_classnames_from_file(file: str, pkg: str) -> List[Tuple[str, str]]:
"""Extract all class names from a file."""
with open(file, encoding="utf-8") as f:
code = f.read()
module_name = _get_current_module(file, pkg)
class_names = _get_class_names(code)
return [(module_name, class_name) for class_name in class_names]
def identify_all_imports_in_file(
file: str, *, from_package: Optional[str] = None
) -> List[Tuple[str, str]]:
"""Let's also identify all the imports in the given file."""
with open(file, encoding="utf-8") as f:
code = f.read()
return find_imports_from_package(code, from_package=from_package)
def identify_pkg_source(pkg_root: str) -> pathlib.Path:
"""Identify the source of the package.
Args:
pkg_root: the root of the package. This contains source + tests, and other
things like pyproject.toml, lock files etc
Returns:
Returns the path to the source code for the package.
"""
dirs = [d for d in Path(pkg_root).iterdir() if d.is_dir()]
matching_dirs = [d for d in dirs if d.name.startswith("langchain_")]
assert len(matching_dirs) == 1, "There should be only one langchain package."
return matching_dirs[0]
def list_classes_by_package(pkg_root: str) -> List[Tuple[str, str]]:
"""List all classes in a package."""
module_classes = []
pkg_source = identify_pkg_source(pkg_root)
files = list(pkg_source.rglob("*.py"))
for file in files:
rel_path = os.path.relpath(file, pkg_root)
if rel_path.startswith("tests"):
continue
module_classes.extend(_get_all_classnames_from_file(file, pkg_root))
return module_classes
def list_init_imports_by_package(pkg_root: str) -> List[Tuple[str, str]]:
"""List all the things that are being imported in a package by module."""
imports = []
pkg_source = identify_pkg_source(pkg_root)
# Scan all the files in the package
files = list(Path(pkg_source).rglob("*.py"))
for file in files:
if not file.name == "__init__.py":
continue
import_in_file = identify_all_imports_in_file(str(file))
module_name = _get_current_module(file, pkg_root)
imports.extend([(module_name, item) for _, item in import_in_file])
return imports
def find_imports_from_package(
code: str, *, from_package: Optional[str] = None
) -> List[Tuple[str, str]]:
# Parse the code into an AST
tree = ast.parse(code)
# Create an instance of the visitor
extractor = ImportExtractor(from_package=from_package)
# Use the visitor to update the imports list
extractor.visit(tree)
return extractor.imports
def _get_current_module(path: str, pkg_root: str) -> str:
"""Convert a path to a module name."""
path_as_pathlib = pathlib.Path(os.path.abspath(path))
relative_path = path_as_pathlib.relative_to(pkg_root).with_suffix("")
posix_path = relative_path.as_posix()
norm_path = os.path.normpath(str(posix_path))
fully_qualified_module = norm_path.replace("/", ".")
# Strip __init__ if present
if fully_qualified_module.endswith(".__init__"):
return fully_qualified_module[:-9]
return fully_qualified_module

@ -0,0 +1,52 @@
# Adapted from bump-pydantic
# https://github.com/pydantic/bump-pydantic
import fnmatch
import re
from pathlib import Path
from typing import List
MATCH_SEP = r"(?:/|\\)"
MATCH_SEP_OR_END = r"(?:/|\\|\Z)"
MATCH_NON_RECURSIVE = r"[^/\\]*"
MATCH_RECURSIVE = r"(?:.*)"
def glob_to_re(pattern: str) -> str:
"""Translate a glob pattern to a regular expression for matching."""
fragments: List[str] = []
for segment in re.split(r"/|\\", pattern):
if segment == "":
continue
if segment == "**":
# Remove previous separator match, so the recursive match c
# can match zero or more segments.
if fragments and fragments[-1] == MATCH_SEP:
fragments.pop()
fragments.append(MATCH_RECURSIVE)
elif "**" in segment:
raise ValueError(
"invalid pattern: '**' can only be an entire path component"
)
else:
fragment = fnmatch.translate(segment)
fragment = fragment.replace(r"(?s:", r"(?:")
fragment = fragment.replace(r".*", MATCH_NON_RECURSIVE)
fragment = fragment.replace(r"\Z", r"")
fragments.append(fragment)
fragments.append(MATCH_SEP)
# Remove trailing MATCH_SEP, so it can be replaced with MATCH_SEP_OR_END.
if fragments and fragments[-1] == MATCH_SEP:
fragments.pop()
fragments.append(MATCH_SEP_OR_END)
return rf"(?s:{''.join(fragments)})"
def match_glob(path: Path, pattern: str) -> bool:
"""Check if a path matches a glob pattern.
If the pattern ends with a directory separator, the path must be a directory.
"""
match = bool(re.fullmatch(glob_to_re(pattern), str(path)))
if pattern.endswith("/") or pattern.endswith("\\"):
return match and path.is_dir()
return match

@ -0,0 +1,194 @@
"""Migrate LangChain to the most recent version."""
# Adapted from bump-pydantic
# https://github.com/pydantic/bump-pydantic
import difflib
import functools
import multiprocessing
import os
import time
import traceback
from pathlib import Path
from typing import Any, Dict, Iterable, List, Tuple, Type, TypeVar, Union
import libcst as cst
import rich
import typer
from libcst.codemod import CodemodContext, ContextAwareTransformer
from libcst.helpers import calculate_module_and_package
from libcst.metadata import FullRepoManager, FullyQualifiedNameProvider, ScopeProvider
from rich.console import Console
from rich.progress import Progress
from typer import Argument, Exit, Option, Typer
from typing_extensions import ParamSpec
from langchain_cli.namespaces.migrate.codemods import Rule, gather_codemods
from langchain_cli.namespaces.migrate.glob_helpers import match_glob
app = Typer(invoke_without_command=True, add_completion=False)
P = ParamSpec("P")
T = TypeVar("T")
DEFAULT_IGNORES = [".venv/**"]
@app.callback()
def main(
path: Path = Argument(..., exists=True, dir_okay=True, allow_dash=False),
disable: List[Rule] = Option(default=[], help="Disable a rule."),
diff: bool = Option(False, help="Show diff instead of applying changes."),
ignore: List[str] = Option(
default=DEFAULT_IGNORES, help="Ignore a path glob pattern."
),
log_file: Path = Option("log.txt", help="Log errors to this file."),
):
"""Migrate langchain to the most recent version."""
if not diff:
rich.print("[bold red]Alert![/ bold red] langchain-cli migrate", end=": ")
if not typer.confirm(
"The migration process will modify your files. "
"The migration is a `best-effort` process and is not expected to "
"be perfect. "
"Do you want to continue?"
):
raise Exit()
console = Console(log_time=True)
console.log("Start langchain-cli migrate")
# NOTE: LIBCST_PARSER_TYPE=native is required according to https://github.com/Instagram/LibCST/issues/487.
os.environ["LIBCST_PARSER_TYPE"] = "native"
if os.path.isfile(path):
package = path.parent
all_files = [path]
else:
package = path
all_files = sorted(package.glob("**/*.py"))
filtered_files = [
file
for file in all_files
if not any(match_glob(file, pattern) for pattern in ignore)
]
files = [str(file.relative_to(".")) for file in filtered_files]
if len(files) == 1:
console.log("Found 1 file to process.")
elif len(files) > 1:
console.log(f"Found {len(files)} files to process.")
else:
console.log("No files to process.")
raise Exit()
providers = {FullyQualifiedNameProvider, ScopeProvider}
metadata_manager = FullRepoManager(".", files, providers=providers) # type: ignore[arg-type]
metadata_manager.resolve_cache()
scratch: dict[str, Any] = {}
start_time = time.time()
codemods = gather_codemods(disabled=disable)
log_fp = log_file.open("a+", encoding="utf8")
partial_run_codemods = functools.partial(
run_codemods, codemods, metadata_manager, scratch, package, diff
)
with Progress(*Progress.get_default_columns(), transient=True) as progress:
task = progress.add_task(description="Executing codemods...", total=len(files))
count_errors = 0
difflines: List[List[str]] = []
with multiprocessing.Pool() as pool:
for error, _difflines in pool.imap_unordered(partial_run_codemods, files):
progress.advance(task)
if _difflines is not None:
difflines.append(_difflines)
if error is not None:
count_errors += 1
log_fp.writelines(error)
modified = [Path(f) for f in files if os.stat(f).st_mtime > start_time]
if not diff:
if modified:
console.log(f"Refactored {len(modified)} files.")
else:
console.log("No files were modified.")
for _difflines in difflines:
color_diff(console, _difflines)
if count_errors > 0:
console.log(f"Found {count_errors} errors. Please check the {log_file} file.")
else:
console.log("Run successfully!")
if difflines:
raise Exit(1)
def run_codemods(
codemods: List[Type[ContextAwareTransformer]],
metadata_manager: FullRepoManager,
scratch: Dict[str, Any],
package: Path,
diff: bool,
filename: str,
) -> Tuple[Union[str, None], Union[List[str], None]]:
try:
module_and_package = calculate_module_and_package(str(package), filename)
context = CodemodContext(
metadata_manager=metadata_manager,
filename=filename,
full_module_name=module_and_package.name,
full_package_name=module_and_package.package,
)
context.scratch.update(scratch)
file_path = Path(filename)
with file_path.open("r+", encoding="utf-8") as fp:
code = fp.read()
fp.seek(0)
input_tree = cst.parse_module(code)
for codemod in codemods:
transformer = codemod(context=context)
output_tree = transformer.transform_module(input_tree)
input_tree = output_tree
output_code = input_tree.code
if code != output_code:
if diff:
lines = difflib.unified_diff(
code.splitlines(keepends=True),
output_code.splitlines(keepends=True),
fromfile=filename,
tofile=filename,
)
return None, list(lines)
else:
fp.write(output_code)
fp.truncate()
return None, None
except cst.ParserSyntaxError as exc:
return (
f"A syntax error happened on {filename}. This file cannot be "
f"formatted.\n"
f"{exc}"
), None
except Exception:
return f"An error happened on {filename}.\n{traceback.format_exc()}", None
def color_diff(console: Console, lines: Iterable[str]) -> None:
for line in lines:
line = line.rstrip("\n")
if line.startswith("+"):
console.print(line, style="green")
elif line.startswith("-"):
console.print(line, style="red")
elif line.startswith("^"):
console.print(line, style="blue")
else:
console.print(line, style="white")

45
libs/cli/poetry.lock generated

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
[[package]]
name = "aiohttp"
@ -819,6 +819,46 @@ orjson = ">=3.9.14,<4.0.0"
pydantic = ">=1,<3"
requests = ">=2,<3"
[[package]]
name = "libcst"
version = "1.3.1"
description = "A concrete syntax tree with AST-like properties for Python 3.0 through 3.12 programs."
optional = false
python-versions = ">=3.9"
files = [
{file = "libcst-1.3.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:de93193cba6d54f2a4419e94ba2de642b111f52e4fa01bb6e2c655914585f65b"},
{file = "libcst-1.3.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a2d64d86dcd6c80a5dac2e243c5ed7a7a193242209ac33bad4b0639b24f6d131"},
{file = "libcst-1.3.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:db084f7bbf825c7bd5ed256290066d0336df6a7dc3a029c9870a64cd2298b87f"},
{file = "libcst-1.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:16880711be03a1f5da7028fe791ba5b482a50d830225a70272dc332dfd927652"},
{file = "libcst-1.3.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:189bb28c19c5dd3c64583f969b72f7732dbdb1dee9eca3acc85099e4cef9148b"},
{file = "libcst-1.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:181372386c986e3de07d7a93f269214cd825adc714f1f9da8252b44f05e181c4"},
{file = "libcst-1.3.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8c2020f7449270e3ff0bdc481ae244d812f2d9a8b7dbff0ea66b830f4b350f54"},
{file = "libcst-1.3.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:be3bf9aaafebda6a21e241e819f0ab67e186e898c3562704e41241cf8738353a"},
{file = "libcst-1.3.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a0d250fb6a2c1d158f30d25ba5e33e3ed3672d2700d480dd47beffd1431a008"},
{file = "libcst-1.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4ad5741b251d901f3da1819ac539192230cc6f8f81aaf04eb4ec0009c1c97285"},
{file = "libcst-1.3.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b740dc0c3d1adbd91442fb878343d5a11e23a0e3ac5484be301fd8d148bcb085"},
{file = "libcst-1.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:9e6bc95fa7dde79cde63a34a0412cd4a3d9fcc27be781a590f8c45c840c83658"},
{file = "libcst-1.3.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:4186076ce12141609ce950d61867b2a73ea199a7a9870dbafa76ad600e075b3c"},
{file = "libcst-1.3.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4ed52a1a2fe4d8603de51649db5e438317b8116ebb9fc09ec68703535fe6c1c8"},
{file = "libcst-1.3.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c0886a9963597367b227345f19b24931b3ed6a4703fff237760745f90f0e6a20"},
{file = "libcst-1.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:904c4cc5c801a5747e64b43e0accc87c67a4c804842d977ee215872c4cf8cf88"},
{file = "libcst-1.3.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7cdb7e8a118b60e064a02f6cbfa4d328212a3a115d125244495190f405709d5f"},
{file = "libcst-1.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:431badf0e544b79c0ac9682dbd291ff63ddbc3c3aca0d13d3cc7a10c3a9db8a2"},
{file = "libcst-1.3.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:701f5335e4fd566871497b9af1e871c98e1ef10c30b3b244f39343d709213401"},
{file = "libcst-1.3.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7c6e709623b68ca9148e8ecbdc145f7b83befb26032e4bf6a8122500ba558b17"},
{file = "libcst-1.3.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ede0f026a82b03b33a559ec566860085ece2e76d8f9bc21cb053eedf9cde8c79"},
{file = "libcst-1.3.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c12b7b01d8745f82dd86a82acd2a9f8e8e7d6c94ddcfda996896e83d1a8d5c42"},
{file = "libcst-1.3.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2995ca687118a9d3d41876f7270bc29305a2d402f4b8c81a3cff0aeee6d4c81"},
{file = "libcst-1.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:2dbac1ac0a9d59ea7bbc3f87cdcca5bfe98835e31c668e95cb6f3d907ffc53fc"},
{file = "libcst-1.3.1.tar.gz", hash = "sha256:03b1df1ae02456f1d465fcd5ead1d0d454bb483caefd8c8e6bde515ffdb53d1b"},
]
[package.dependencies]
pyyaml = ">=5.2"
[package.extras]
dev = ["Sphinx (>=5.1.1)", "black (==23.12.1)", "build (>=0.10.0)", "coverage (>=4.5.4)", "fixit (==2.1.0)", "flake8 (==7.0.0)", "hypothesis (>=4.36.0)", "hypothesmith (>=0.0.4)", "jinja2 (==3.1.3)", "jupyter (>=1.0.0)", "maturin (>=0.8.3,<1.5)", "nbsphinx (>=0.4.2)", "prompt-toolkit (>=2.0.9)", "pyre-check (==0.9.18)", "setuptools-rust (>=1.5.2)", "setuptools-scm (>=6.0.1)", "slotscheck (>=0.7.1)", "sphinx-rtd-theme (>=0.4.3)", "ufmt (==2.5.1)", "usort (==1.0.8.post1)"]
[[package]]
name = "markdown-it-py"
version = "3.0.0"
@ -1322,7 +1362,6 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@ -1850,4 +1889,4 @@ serve = []
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "8232264abd652b61dac3fd47833745c1df6c3418599dc14f9fe09773f3f80f13"
content-hash = "4576fb13ecd9e13bc6c85e4cd6f56520708c7c1468f4b81bc6a346b128c9f695"

@ -17,6 +17,7 @@ gitpython = "^3.1.40"
langserve = { extras = ["all"], version = ">=0.0.51" }
uvicorn = "^0.23.2"
tomlkit = "^0.12.2"
libcst = { version = "^1.3.1", python = "^3.9" }
[tool.poetry.scripts]
langchain = "langchain_cli.cli:app"
@ -49,7 +50,7 @@ select = [
]
[tool.poe.tasks]
test = "poetry run pytest"
test = "poetry run pytest tests"
watch = "poetry run ptw"
version = "poetry version --short"
bump = ["_bump_1", "_bump_2"]

@ -0,0 +1,78 @@
"""Script to generate migrations for the migration script."""
import json
import pkgutil
import click
from langchain_cli.namespaces.migrate.generate.langchain import (
generate_migrations_from_langchain_to_community,
)
from langchain_cli.namespaces.migrate.generate.partner import (
get_migrations_for_partner_package,
)
@click.group()
def cli():
"""Migration scripts management."""
pass
@cli.command()
@click.option(
"--output",
default="langchain_migrations.json",
help="Output file for the migration script.",
)
def langchain(output: str) -> None:
"""Generate a migration script."""
click.echo("Migration script generated.")
migrations = generate_migrations_from_langchain_to_community()
with open(output, "w") as f:
f.write(json.dumps(migrations))
@cli.command()
@click.argument("pkg")
@click.option("--output", default=None, help="Output file for the migration script.")
def partner(pkg: str, output: str) -> None:
"""Generate migration scripts specifically for LangChain modules."""
click.echo("Migration script for LangChain generated.")
migrations = get_migrations_for_partner_package(pkg)
# Run with python 3.9+
output_name = f"{pkg.removeprefix('langchain_')}.json" if output is None else output
if migrations:
with open(output_name, "w") as f:
f.write(json.dumps(migrations, indent=2, sort_keys=True))
click.secho(f"LangChain migration script saved to {output_name}")
else:
click.secho(f"No migrations found for {pkg}", fg="yellow")
@cli.command()
def all_installed_partner_pkgs() -> None:
"""Generate migration scripts for all LangChain modules."""
# Will generate migrations for all pather packages.
# Define as "langchain_<partner_name>".
# First let's determine which packages are installed in the environment
# and then generate migrations for them.
langchain_pkgs = [
name
for _, name, _ in pkgutil.iter_modules()
if name.startswith("langchain_")
and name not in {"langchain_core", "langchain_cli", "langchain_community"}
]
for pkg in langchain_pkgs:
migrations = get_migrations_for_partner_package(pkg)
# Run with python 3.9+
output_name = f"{pkg.removeprefix('langchain_')}.json"
if migrations:
with open(output_name, "w") as f:
f.write(json.dumps(migrations, indent=2, sort_keys=True))
click.secho(f"LangChain migration script saved to {output_name}")
else:
click.secho(f"No migrations found for {pkg}", fg="yellow")
if __name__ == "__main__":
cli()

@ -0,0 +1,13 @@
from __future__ import annotations
from dataclasses import dataclass
from .file import File
from .folder import Folder
@dataclass
class Case:
source: Folder | File
expected: Folder | File
name: str

@ -0,0 +1,15 @@
from tests.unit_tests.migrate.cli_runner.case import Case
from tests.unit_tests.migrate.cli_runner.cases import imports
from tests.unit_tests.migrate.cli_runner.file import File
from tests.unit_tests.migrate.cli_runner.folder import Folder
cases = [
Case(
name="empty",
source=File("__init__.py", content=[]),
expected=File("__init__.py", content=[]),
),
*imports.cases,
]
before = Folder("project", *[case.source for case in cases])
expected = Folder("project", *[case.expected for case in cases])

@ -0,0 +1,32 @@
from tests.unit_tests.migrate.cli_runner.case import Case
from tests.unit_tests.migrate.cli_runner.file import File
cases = [
Case(
name="Imports",
source=File(
"app.py",
content=[
"from langchain_community.chat_models import ChatOpenAI",
"",
"",
"class foo:",
" a: int",
"",
"chain = ChatOpenAI()",
],
),
expected=File(
"app.py",
content=[
"from langchain_openai import ChatOpenAI",
"",
"",
"class foo:",
" a: int",
"",
"chain = ChatOpenAI()",
],
),
),
]

@ -0,0 +1,16 @@
from __future__ import annotations
class File:
def __init__(self, name: str, content: list[str] | None = None) -> None:
self.name = name
self.content = "\n".join(content or [])
def __eq__(self, __value: object) -> bool:
if not isinstance(__value, File):
return NotImplemented
if self.name != __value.name:
return False
return self.content == __value.content

@ -0,0 +1,59 @@
from __future__ import annotations
from pathlib import Path
from .file import File
class Folder:
def __init__(self, name: str, *files: Folder | File) -> None:
self.name = name
self._files = files
@property
def files(self) -> list[Folder | File]:
return sorted(self._files, key=lambda f: f.name)
def create_structure(self, root: Path) -> None:
path = root / self.name
path.mkdir()
for file in self.files:
if isinstance(file, Folder):
file.create_structure(path)
else:
(path / file.name).write_text(file.content, encoding="utf-8")
@classmethod
def from_structure(cls, root: Path) -> Folder:
name = root.name
files: list[File | Folder] = []
for path in root.iterdir():
if path.is_dir():
files.append(cls.from_structure(path))
else:
files.append(
File(path.name, path.read_text(encoding="utf-8").splitlines())
)
return Folder(name, *files)
def __eq__(self, __value: object) -> bool:
if isinstance(__value, File):
return False
if not isinstance(__value, Folder):
return NotImplemented
if self.name != __value.name:
return False
if len(self.files) != len(__value.files):
return False
for self_file, other_file in zip(self.files, __value.files):
if self_file != other_file:
return False
return True

@ -0,0 +1,55 @@
# ruff: noqa: E402
from __future__ import annotations
import pytest
pytest.importorskip("libcst")
import difflib
from pathlib import Path
from typer.testing import CliRunner
from langchain_cli.namespaces.migrate.main import app
from tests.unit_tests.migrate.cli_runner.cases import before, expected
from tests.unit_tests.migrate.cli_runner.folder import Folder
def find_issue(current: Folder, expected: Folder) -> str:
for current_file, expected_file in zip(current.files, expected.files):
if current_file != expected_file:
if current_file.name != expected_file.name:
return (
f"Files have "
f"different names: {current_file.name} != {expected_file.name}"
)
if isinstance(current_file, Folder) and isinstance(expected_file, Folder):
return find_issue(current_file, expected_file)
elif isinstance(current_file, Folder) or isinstance(expected_file, Folder):
return (
f"One of the files is a "
f"folder: {current_file.name} != {expected_file.name}"
)
return "\n".join(
difflib.unified_diff(
current_file.content.splitlines(),
expected_file.content.splitlines(),
fromfile=current_file.name,
tofile=expected_file.name,
)
)
return "Unknown"
def test_command_line(tmp_path: Path) -> None:
runner = CliRunner()
with runner.isolated_filesystem(temp_dir=tmp_path) as td:
before.create_structure(root=Path(td))
# The input is used to force through the confirmation.
result = runner.invoke(app, [before.name], input="y\n")
assert result.exit_code == 0, result.output
after = Folder.from_structure(Path(td) / before.name)
assert after == expected, find_issue(after, expected)

@ -0,0 +1,46 @@
import pytest
from langchain_cli.namespaces.migrate.generate.partner import (
get_migrations_for_partner_package,
)
pytest.importorskip(modname="langchain_openai")
def test_generate_migrations() -> None:
migrations = get_migrations_for_partner_package("langchain_openai")
assert migrations == [
("langchain_community.llms.openai.OpenAI", "langchain_openai.OpenAI"),
("langchain_community.llms.openai.AzureOpenAI", "langchain_openai.AzureOpenAI"),
(
"langchain_community.embeddings.openai.OpenAIEmbeddings",
"langchain_openai.OpenAIEmbeddings",
),
(
"langchain_community.embeddings.azure_openai.AzureOpenAIEmbeddings",
"langchain_openai.AzureOpenAIEmbeddings",
),
(
"langchain_community.chat_models.openai.ChatOpenAI",
"langchain_openai.ChatOpenAI",
),
(
"langchain_community.chat_models.azure_openai.AzureChatOpenAI",
"langchain_openai.AzureChatOpenAI",
),
("langchain_community.llms.AzureOpenAI", "langchain_openai.AzureOpenAI"),
("langchain_community.llms.OpenAI", "langchain_openai.OpenAI"),
(
"langchain_community.embeddings.AzureOpenAIEmbeddings",
"langchain_openai.AzureOpenAIEmbeddings",
),
(
"langchain_community.embeddings.OpenAIEmbeddings",
"langchain_openai.OpenAIEmbeddings",
),
(
"langchain_community.chat_models.AzureChatOpenAI",
"langchain_openai.AzureChatOpenAI",
),
("langchain_community.chat_models.ChatOpenAI", "langchain_openai.ChatOpenAI"),
]

@ -0,0 +1,5 @@
from langchain_cli.namespaces.migrate.generate.utils import PKGS_ROOT
def test_root() -> None:
assert PKGS_ROOT.name == "libs"

@ -0,0 +1,72 @@
from __future__ import annotations
from pathlib import Path
import pytest
from langchain_cli.namespaces.migrate.glob_helpers import glob_to_re, match_glob
class TestGlobHelpers:
match_glob_values: list[tuple[str, Path, bool]] = [
("foo", Path("foo"), True),
("foo", Path("bar"), False),
("foo", Path("foo/bar"), False),
("*", Path("foo"), True),
("*", Path("bar"), True),
("*", Path("foo/bar"), False),
("**", Path("foo"), True),
("**", Path("foo/bar"), True),
("**", Path("foo/bar/baz/qux"), True),
("foo/bar", Path("foo/bar"), True),
("foo/bar", Path("foo"), False),
("foo/bar", Path("far"), False),
("foo/bar", Path("foo/foo"), False),
("foo/*", Path("foo/bar"), True),
("foo/*", Path("foo/bar/baz"), False),
("foo/*", Path("foo"), False),
("foo/*", Path("bar"), False),
("foo/**", Path("foo/bar"), True),
("foo/**", Path("foo/bar/baz"), True),
("foo/**", Path("foo/bar/baz/qux"), True),
("foo/**", Path("foo"), True),
("foo/**", Path("bar"), False),
("foo/**/bar", Path("foo/bar"), True),
("foo/**/bar", Path("foo/baz/bar"), True),
("foo/**/bar", Path("foo/baz/qux/bar"), True),
("foo/**/bar", Path("foo/baz/qux"), False),
("foo/**/bar", Path("foo/bar/baz"), False),
("foo/**/bar", Path("foo/bar/bar"), True),
("foo/**/bar", Path("foo"), False),
("foo/**/bar", Path("bar"), False),
("foo/**/*/bar", Path("foo/bar"), False),
("foo/**/*/bar", Path("foo/baz/bar"), True),
("foo/**/*/bar", Path("foo/baz/qux/bar"), True),
("foo/**/*/bar", Path("foo/baz/qux"), False),
("foo/**/*/bar", Path("foo/bar/baz"), False),
("foo/**/*/bar", Path("foo/bar/bar"), True),
("foo/**/*/bar", Path("foo"), False),
("foo/**/*/bar", Path("bar"), False),
("foo/ba*", Path("foo/bar"), True),
("foo/ba*", Path("foo/baz"), True),
("foo/ba*", Path("foo/qux"), False),
("foo/ba*", Path("foo/baz/qux"), False),
("foo/ba*", Path("foo/bar/baz"), False),
("foo/ba*", Path("foo"), False),
("foo/ba*", Path("bar"), False),
("foo/**/ba*/*/qux", Path("foo/a/b/c/bar/a/qux"), True),
("foo/**/ba*/*/qux", Path("foo/a/b/c/baz/a/qux"), True),
("foo/**/ba*/*/qux", Path("foo/a/bar/a/qux"), True),
("foo/**/ba*/*/qux", Path("foo/baz/a/qux"), True),
("foo/**/ba*/*/qux", Path("foo/baz/qux"), False),
("foo/**/ba*/*/qux", Path("foo/a/b/c/qux/a/qux"), False),
("foo/**/ba*/*/qux", Path("foo"), False),
("foo/**/ba*/*/qux", Path("bar"), False),
]
@pytest.mark.parametrize(("pattern", "path", "expected"), match_glob_values)
def test_match_glob(self, pattern: str, path: Path, expected: bool):
expr = glob_to_re(pattern)
assert (
match_glob(path, pattern) == expected
), f"path: {path}, pattern: {pattern}, expr: {expr}"

@ -0,0 +1,51 @@
# ruff: noqa: E402
import pytest
pytest.importorskip("libcst")
from libcst.codemod import CodemodTest
from langchain_cli.namespaces.migrate.codemods.replace_imports import (
ReplaceImportsCodemod,
)
class TestReplaceImportsCommand(CodemodTest):
TRANSFORM = ReplaceImportsCodemod
def test_single_import(self) -> None:
before = """
from langchain.chat_models import ChatOpenAI
"""
after = """
from langchain_community.chat_models import ChatOpenAI
"""
self.assertCodemod(before, after)
def test_from_community_to_partner(self) -> None:
"""Test that we can replace imports from community to partner."""
before = """
from langchain_community.chat_models import ChatOpenAI
"""
after = """
from langchain_openai import ChatOpenAI
"""
self.assertCodemod(before, after)
def test_noop_import(self) -> None:
code = """
from foo import ChatOpenAI
"""
self.assertCodemod(code, code)
def test_mixed_imports(self) -> None:
before = """
from langchain_community.chat_models import ChatOpenAI, ChatAnthropic, foo
"""
after = """
from langchain_community.chat_models import foo
from langchain_anthropic import ChatAnthropic
from langchain_openai import ChatOpenAI
"""
self.assertCodemod(before, after)

@ -12,7 +12,7 @@ def test_async_recursive_url_loader() -> None:
check_response_status=True,
)
docs = loader.load()
assert len(docs) == 513
assert len(docs) == 512
assert docs[0].page_content == "placeholder"

@ -124,7 +124,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
name=chunk["name"],
args=chunk["args"],
id=chunk["id"],
error="Malformed args.",
error=None,
)
)
values["tool_calls"] = tool_calls

@ -149,7 +149,7 @@ def default_tool_parser(
name=function_name,
args=tool_call["function"]["arguments"],
id=tool_call.get("id"),
error="Malformed args.",
error=None,
)
)
return tool_calls, invalid_tool_calls

@ -147,6 +147,8 @@ def _get_trace_callbacks(
def _tracing_v2_is_enabled() -> bool:
return (
env_var_is_set("LANGCHAIN_TRACING_V2")
or env_var_is_set("LANGSMITH_TRACING")
or env_var_is_set("LANGSMITH_TRACING_V2")
or tracing_v2_callback_var.get() is not None
or get_run_tree_context() is not None
)

@ -53,7 +53,7 @@ def test_serdes_message_chunk() -> None:
"name": "foobad",
"args": "blah",
"id": "booz",
"error": "Malformed args.",
"error": None,
}
],
"tool_call_chunks": [

@ -306,8 +306,8 @@ def test_message_chunk_to_message() -> None:
{"name": "tool2", "args": {}, "id": "2"},
],
invalid_tool_calls=[
{"name": "tool3", "args": None, "id": "3", "error": "Malformed args."},
{"name": "tool4", "args": "abc", "id": "4", "error": "Malformed args."},
{"name": "tool3", "args": None, "id": "3", "error": None},
{"name": "tool4", "args": "abc", "id": "4", "error": None},
],
)
assert message_chunk_to_message(chunk) == expected

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]]
name = "annotated-types"
@ -389,7 +389,7 @@ files = [
[[package]]
name = "langchain-core"
version = "0.1.45"
version = "0.1.46"
description = "Building applications with LLMs through composability"
optional = false
python-versions = ">=3.8.1,<4.0"
@ -1062,4 +1062,4 @@ zstd = ["zstandard (>=0.18.0)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "bfac6e5ad2828fe02c95b280d68c737f719dc517fc158b0ab66204b97e7fa591"
content-hash = "567868376ce31e29a3795431cb8b53ce7860a50652f233a1b8bee9827d5c9871"

@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-mistralai"
version = "0.1.3"
version = "0.1.4"
description = "An integration package connecting Mistral and LangChain"
authors = []
readme = "README.md"
@ -12,7 +12,7 @@ license = "MIT"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = "^0.1.42"
langchain-core = "^0.1.46"
tokenizers = "^0.15.1"
httpx = ">=0.25.2,<1"
httpx-sse = ">=0.3.1,<1"

@ -385,7 +385,7 @@ files = [
[[package]]
name = "langchain-core"
version = "0.1.42"
version = "0.1.46"
description = "Building applications with LLMs through composability"
optional = false
python-versions = ">=3.8.1,<4.0"
@ -1286,4 +1286,4 @@ watchmedo = ["PyYAML (>=3.10)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "625e7565d37b9633874f61ee5660220e8e330658715d8b56ef2340f06dc1c625"
content-hash = "f8a406a4ebd93e5c2ef3fcf4a3cebdd588ce09e288dc31b7b9b6b1560285575a"

@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-openai"
version = "0.1.3"
version = "0.1.4"
description = "An integration package connecting OpenAI and LangChain"
authors = []
readme = "README.md"
@ -12,7 +12,7 @@ license = "MIT"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = "^0.1.42"
langchain-core = "^0.1.46"
openai = "^1.10.0"
tiktoken = ">=0.5.2,<1"

@ -19,6 +19,7 @@ from langchain_openai import ChatOpenAI
from langchain_openai.chat_models.base import (
_convert_dict_to_message,
_convert_message_to_dict,
_format_message_content,
)
@ -287,3 +288,36 @@ def test_custom_token_counting() -> None:
llm = ChatOpenAI(custom_get_token_ids=token_encoder)
assert llm.get_token_ids("foo") == [1, 2, 3]
def test_format_message_content() -> None:
content: Any = "hello"
assert content == _format_message_content(content)
content = None
assert content == _format_message_content(content)
content = []
assert content == _format_message_content(content)
content = [
{"type": "text", "text": "What is in this image?"},
{
"type": "image_url",
"image_url": {
"url": "url.com",
},
},
]
assert content == _format_message_content(content)
content = [
{"type": "text", "text": "hello"},
{
"type": "tool_use",
"id": "toolu_01A09q90qw90lq917835lq9",
"name": "get_weather",
"input": {"location": "San Francisco, CA", "unit": "celsius"},
},
]
assert [{"type": "text", "text": "hello"}] == _format_message_content(content)

@ -2,12 +2,16 @@ from langchain_upstage.chat_models import ChatUpstage
from langchain_upstage.embeddings import UpstageEmbeddings
from langchain_upstage.layout_analysis import UpstageLayoutAnalysisLoader
from langchain_upstage.layout_analysis_parsers import UpstageLayoutAnalysisParser
from langchain_upstage.tools.groundedness_check import GroundednessCheck
from langchain_upstage.tools.groundedness_check import (
GroundednessCheck,
UpstageGroundednessCheck,
)
__all__ = [
"ChatUpstage",
"UpstageEmbeddings",
"UpstageLayoutAnalysisLoader",
"UpstageLayoutAnalysisParser",
"UpstageGroundednessCheck",
"GroundednessCheck",
]

@ -1,10 +1,12 @@
import os
from typing import Any, Literal, Optional, Type, Union
from typing import Any, List, Literal, Optional, Type, Union
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.documents import Document
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr
from langchain_core.tools import BaseTool
@ -13,16 +15,18 @@ from langchain_core.utils import convert_to_secret_str
from langchain_upstage import ChatUpstage
class GroundednessCheckInput(BaseModel):
class UpstageGroundednessCheckInput(BaseModel):
"""Input for the Groundedness Check tool."""
context: str = Field(description="context in which the answer should be verified")
query: str = Field(
context: Union[str, List[Document]] = Field(
description="context in which the answer should be verified"
)
answer: str = Field(
description="assistant's reply or a text that is subject to groundedness check"
)
class GroundednessCheck(BaseTool):
class UpstageGroundednessCheck(BaseTool):
"""Tool that checks the groundedness of a context and an assistant message.
To use, you should have the environment variable `UPSTAGE_API_KEY`
@ -31,15 +35,15 @@ class GroundednessCheck(BaseTool):
Example:
.. code-block:: python
from langchain_upstage import GroundednessCheck
from langchain_upstage import UpstageGroundednessCheck
tool = GroundednessCheck()
tool = UpstageGroundednessCheck()
"""
name: str = "groundedness_check"
description: str = (
"A tool that checks the groundedness of an assistant response "
"to user-provided context. GroundednessCheck ensures that "
"to user-provided context. UpstageGroundednessCheck ensures that "
"the assistants response is not only relevant but also "
"precisely aligned with the user's initial context, "
"promoting a more reliable and context-aware interaction. "
@ -50,7 +54,7 @@ class GroundednessCheck(BaseTool):
upstage_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
api_wrapper: ChatUpstage
args_schema: Type[BaseModel] = GroundednessCheckInput
args_schema: Type[BaseModel] = UpstageGroundednessCheckInput
def __init__(self, **kwargs: Any) -> None:
upstage_api_key = kwargs.get("upstage_api_key", None)
@ -73,25 +77,41 @@ class GroundednessCheck(BaseTool):
)
super().__init__(upstage_api_key=upstage_api_key, api_wrapper=api_wrapper)
def formatDocumentsAsString(self, docs: List[Document]) -> str:
return "\n".join([doc.page_content for doc in docs])
def _run(
self,
context: str,
query: str,
context: Union[str, List[Document]],
answer: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Union[str, Literal["grounded", "notGrounded", "notSure"]]:
"""Use the tool."""
if isinstance(context, List):
context = self.formatDocumentsAsString(context)
response = self.api_wrapper.invoke(
[HumanMessage(context), AIMessage(query)], stream=False
[HumanMessage(context), AIMessage(answer)], stream=False
)
return str(response.content)
async def _arun(
self,
context: str,
query: str,
context: Union[str, List[Document]],
answer: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> Union[str, Literal["grounded", "notGrounded", "notSure"]]:
if isinstance(context, List):
context = self.formatDocumentsAsString(context)
response = await self.api_wrapper.ainvoke(
[HumanMessage(context), AIMessage(query)], stream=False
[HumanMessage(context), AIMessage(answer)], stream=False
)
return str(response.content)
@deprecated(
since="0.1.3",
removal="0.2.0",
alternative_import="langchain_upstage.UpstageGroundednessCheck",
)
class GroundednessCheck(UpstageGroundednessCheck):
pass

@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-upstage"
version = "0.1.2"
version = "0.1.3"
description = "An integration package connecting Upstage and LangChain"
authors = []
readme = "README.md"

@ -2,34 +2,62 @@ import os
import openai
import pytest
from langchain_core.documents import Document
from langchain_upstage import GroundednessCheck
from langchain_upstage import GroundednessCheck, UpstageGroundednessCheck
def test_langchain_upstage_groundedness_check() -> None:
def test_langchain_upstage_groundedness_check_deprecated() -> None:
"""Test Upstage Groundedness Check."""
tool = GroundednessCheck()
output = tool.run({"context": "foo bar", "query": "bar foo"})
output = tool.invoke({"context": "foo bar", "answer": "bar foo"})
assert output in ["grounded", "notGrounded", "notSure"]
api_key = os.environ.get("UPSTAGE_API_KEY", None)
tool = GroundednessCheck(upstage_api_key=api_key)
output = tool.run({"context": "foo bar", "query": "bar foo"})
output = tool.invoke({"context": "foo bar", "answer": "bar foo"})
assert output in ["grounded", "notGrounded", "notSure"]
def test_langchain_upstage_groundedness_check() -> None:
"""Test Upstage Groundedness Check."""
tool = UpstageGroundednessCheck()
output = tool.invoke({"context": "foo bar", "answer": "bar foo"})
assert output in ["grounded", "notGrounded", "notSure"]
api_key = os.environ.get("UPSTAGE_API_KEY", None)
tool = UpstageGroundednessCheck(upstage_api_key=api_key)
output = tool.invoke({"context": "foo bar", "answer": "bar foo"})
assert output in ["grounded", "notGrounded", "notSure"]
def test_langchain_upstage_groundedness_check_with_documents_input() -> None:
"""Test Upstage Groundedness Check."""
tool = UpstageGroundednessCheck()
docs = [
Document(page_content="foo bar"),
Document(page_content="bar foo"),
]
output = tool.invoke({"context": docs, "answer": "bar foo"})
assert output in ["grounded", "notGrounded", "notSure"]
def test_langchain_upstage_groundedness_check_fail_with_wrong_api_key() -> None:
tool = GroundednessCheck(api_key="wrong-key")
tool = UpstageGroundednessCheck(api_key="wrong-key")
with pytest.raises(openai.AuthenticationError):
tool.run({"context": "foo bar", "query": "bar foo"})
tool.invoke({"context": "foo bar", "answer": "bar foo"})
async def test_langchain_upstage_groundedness_check_async() -> None:
"""Test Upstage Groundedness Check asynchronous."""
tool = GroundednessCheck()
output = await tool.arun({"context": "foo bar", "query": "bar foo"})
tool = UpstageGroundednessCheck()
output = await tool.ainvoke({"context": "foo bar", "answer": "bar foo"})
assert output in ["grounded", "notGrounded", "notSure"]

@ -1,12 +1,12 @@
import os
from langchain_upstage import GroundednessCheck
from langchain_upstage import UpstageGroundednessCheck
os.environ["UPSTAGE_API_KEY"] = "foo"
def test_initialization() -> None:
"""Test embedding model initialization."""
GroundednessCheck()
GroundednessCheck(upstage_api_key="key")
GroundednessCheck(api_key="key")
UpstageGroundednessCheck()
UpstageGroundednessCheck(upstage_api_key="key")
UpstageGroundednessCheck(api_key="key")

@ -2,10 +2,11 @@ from langchain_upstage import __all__
EXPECTED_ALL = [
"ChatUpstage",
"GroundednessCheck",
"UpstageEmbeddings",
"UpstageLayoutAnalysisLoader",
"UpstageLayoutAnalysisParser",
"GroundednessCheck",
"UpstageGroundednessCheck",
]

Loading…
Cancel
Save