langchain[minor]: Generic configurable model (#23419)

alternative to
[23244](https://github.com/langchain-ai/langchain/pull/23244). allows
you to use chat model declarative methods

![Screenshot 2024-06-25 at 1 07 10
PM](https://github.com/langchain-ai/langchain/assets/22008038/910d1694-9b7b-46bc-bc2e-3792df9321d6)
This commit is contained in:
Bagatur 2024-07-14 18:11:01 -07:00 committed by GitHub
parent d0728b0ba0
commit 0da5078cad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 1067 additions and 90 deletions

View File

@ -25,7 +25,7 @@
"metadata": {},
"outputs": [],
"source": [
"%pip install -qU langchain langchain-openai langchain-anthropic langchain-google-vertexai"
"%pip install -qU langchain>=0.2.7 langchain-openai langchain-anthropic langchain-google-vertexai"
]
},
{
@ -76,32 +76,6 @@
"print(\"Gemini 1.5: \" + gemini_15.invoke(\"what's your name\").content + \"\\n\")"
]
},
{
"cell_type": "markdown",
"id": "fff9a4c8-b6ee-4a1a-8d3d-0ecaa312d4ed",
"metadata": {},
"source": [
"## Simple config example"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "75c25d39-bf47-4b51-a6c6-64d9c572bfd6",
"metadata": {},
"outputs": [],
"source": [
"user_config = {\n",
" \"model\": \"...user-specified...\",\n",
" \"model_provider\": \"...user-specified...\",\n",
" \"temperature\": 0,\n",
" \"max_tokens\": 1000,\n",
"}\n",
"\n",
"llm = init_chat_model(**user_config)\n",
"llm.invoke(\"what's your name\")"
]
},
{
"cell_type": "markdown",
"id": "f811f219-5e78-4b62-b495-915d52a22532",
@ -125,12 +99,215 @@
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "da07b5c0-d2e6-42e4-bfcd-2efcfaae6221",
"cell_type": "markdown",
"id": "476a44db-c50d-4846-951d-0f1c9ba8bbaa",
"metadata": {},
"outputs": [],
"source": []
"source": [
"## Creating a configurable model\n",
"\n",
"You can also create a runtime-configurable model by specifying `configurable_fields`. If you don't specify a `model` value, then \"model\" and \"model_provider\" be configurable by default."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "6c037f27-12d7-4e83-811e-4245c0e3ba58",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=\"I'm an AI language model created by OpenAI, and I don't have a personal name. You can call me Assistant or any other name you prefer! How can I assist you today?\", response_metadata={'token_usage': {'completion_tokens': 37, 'prompt_tokens': 11, 'total_tokens': 48}, 'model_name': 'gpt-4o-2024-05-13', 'system_fingerprint': 'fp_d576307f90', 'finish_reason': 'stop', 'logprobs': None}, id='run-5428ab5c-b5c0-46de-9946-5d4ca40dbdc8-0', usage_metadata={'input_tokens': 11, 'output_tokens': 37, 'total_tokens': 48})"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"configurable_model = init_chat_model(temperature=0)\n",
"\n",
"configurable_model.invoke(\n",
" \"what's your name\", config={\"configurable\": {\"model\": \"gpt-4o\"}}\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "321e3036-abd2-4e1f-bcc6-606efd036954",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=\"My name is Claude. It's nice to meet you!\", response_metadata={'id': 'msg_012XvotUJ3kGLXJUWKBVxJUi', 'model': 'claude-3-5-sonnet-20240620', 'stop_reason': 'end_turn', 'stop_sequence': None, 'usage': {'input_tokens': 11, 'output_tokens': 15}}, id='run-1ad1eefe-f1c6-4244-8bc6-90e2cb7ee554-0', usage_metadata={'input_tokens': 11, 'output_tokens': 15, 'total_tokens': 26})"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"configurable_model.invoke(\n",
" \"what's your name\", config={\"configurable\": {\"model\": \"claude-3-5-sonnet-20240620\"}}\n",
")"
]
},
{
"cell_type": "markdown",
"id": "7f3b3d4a-4066-45e4-8297-ea81ac8e70b7",
"metadata": {},
"source": [
"### Configurable model with default values\n",
"\n",
"We can create a configurable model with default model values, specify which parameters are configurable, and add prefixes to configurable params:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "814a2289-d0db-401e-b555-d5116112b413",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=\"I'm an AI language model created by OpenAI, and I don't have a personal name. You can call me Assistant or any other name you prefer! How can I assist you today?\", response_metadata={'token_usage': {'completion_tokens': 37, 'prompt_tokens': 11, 'total_tokens': 48}, 'model_name': 'gpt-4o-2024-05-13', 'system_fingerprint': 'fp_ce0793330f', 'finish_reason': 'stop', 'logprobs': None}, id='run-3923e328-7715-4cd6-b215-98e4b6bf7c9d-0', usage_metadata={'input_tokens': 11, 'output_tokens': 37, 'total_tokens': 48})"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"first_llm = init_chat_model(\n",
" model=\"gpt-4o\",\n",
" temperature=0,\n",
" configurable_fields=(\"model\", \"model_provider\", \"temperature\", \"max_tokens\"),\n",
" config_prefix=\"first\", # useful when you have a chain with multiple models\n",
")\n",
"\n",
"first_llm.invoke(\"what's your name\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "6c8755ba-c001-4f5a-a497-be3f1db83244",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=\"My name is Claude. It's nice to meet you!\", response_metadata={'id': 'msg_01RyYR64DoMPNCfHeNnroMXm', 'model': 'claude-3-5-sonnet-20240620', 'stop_reason': 'end_turn', 'stop_sequence': None, 'usage': {'input_tokens': 11, 'output_tokens': 15}}, id='run-22446159-3723-43e6-88df-b84797e7751d-0', usage_metadata={'input_tokens': 11, 'output_tokens': 15, 'total_tokens': 26})"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"first_llm.invoke(\n",
" \"what's your name\",\n",
" config={\n",
" \"configurable\": {\n",
" \"first_model\": \"claude-3-5-sonnet-20240620\",\n",
" \"first_temperature\": 0.5,\n",
" \"first_max_tokens\": 100,\n",
" }\n",
" },\n",
")"
]
},
{
"cell_type": "markdown",
"id": "0072b1a3-7e44-4b4e-8b07-efe1ba91a689",
"metadata": {},
"source": [
"### Using a configurable model declaratively\n",
"\n",
"We can call declarative operations like `bind_tools`, `with_structured_output`, `with_configurable`, etc. on a configurable model and chain a configurable model in the same way that we would a regularly instantiated chat model object."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "067dabee-1050-4110-ae24-c48eba01e13b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'name': 'GetPopulation',\n",
" 'args': {'location': 'Los Angeles, CA'},\n",
" 'id': 'call_sYT3PFMufHGWJD32Hi2CTNUP'},\n",
" {'name': 'GetPopulation',\n",
" 'args': {'location': 'New York, NY'},\n",
" 'id': 'call_j1qjhxRnD3ffQmRyqjlI1Lnk'}]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
"\n",
"\n",
"class GetWeather(BaseModel):\n",
" \"\"\"Get the current weather in a given location\"\"\"\n",
"\n",
" location: str = Field(..., description=\"The city and state, e.g. San Francisco, CA\")\n",
"\n",
"\n",
"class GetPopulation(BaseModel):\n",
" \"\"\"Get the current population in a given location\"\"\"\n",
"\n",
" location: str = Field(..., description=\"The city and state, e.g. San Francisco, CA\")\n",
"\n",
"\n",
"llm = init_chat_model(temperature=0)\n",
"llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])\n",
"\n",
"llm_with_tools.invoke(\n",
" \"what's bigger in 2024 LA or NYC\", config={\"configurable\": {\"model\": \"gpt-4o\"}}\n",
").tool_calls"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "e57dfe9f-cd24-4e37-9ce9-ccf8daf78f89",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'name': 'GetPopulation',\n",
" 'args': {'location': 'Los Angeles, CA'},\n",
" 'id': 'toolu_01CxEHxKtVbLBrvzFS7GQ5xR'},\n",
" {'name': 'GetPopulation',\n",
" 'args': {'location': 'New York City, NY'},\n",
" 'id': 'toolu_013A79qt5toWSsKunFBDZd5S'}]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"llm_with_tools.invoke(\n",
" \"what's bigger in 2024 LA or NYC\",\n",
" config={\"configurable\": {\"model\": \"claude-3-5-sonnet-20240620\"}},\n",
").tool_calls"
]
}
],
"metadata": {
@ -149,7 +326,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
"version": "3.11.9"
}
},
"nbformat": 4,

View File

@ -1327,7 +1327,7 @@ class Runnable(Generic[Input, Output], ABC):
def with_config(
self,
config: Optional[RunnableConfig] = None,
# Sadly Unpack is not well supported by mypy so this will have to be untyped
# Sadly Unpack is not well-supported by mypy so this will have to be untyped
**kwargs: Any,
) -> Runnable[Input, Output]:
"""

View File

@ -1,30 +1,99 @@
from __future__ import annotations
import warnings
from importlib import util
from typing import Any, Optional
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
overload,
)
from langchain_core._api import beta
from langchain_core.language_models.chat_models import (
from langchain_core.language_models import (
BaseChatModel,
LanguageModelInput,
SimpleChatModel,
)
from langchain_core.language_models.chat_models import (
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import AnyMessage, BaseMessage
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.runnables.schema import StreamEvent
from langchain_core.tools import BaseTool
from langchain_core.tracers import RunLog, RunLogPatch
from typing_extensions import TypeAlias
__all__ = [
"init_chat_model",
# For backwards compatibility
"BaseChatModel",
"SimpleChatModel",
"generate_from_stream",
"agenerate_from_stream",
"init_chat_model",
]
@overload
def init_chat_model( # type: ignore[overload-overlap]
model: str,
*,
model_provider: Optional[str] = None,
configurable_fields: Literal[None] = None,
config_prefix: Optional[str] = None,
**kwargs: Any,
) -> BaseChatModel: ...
@overload
def init_chat_model(
model: Literal[None] = None,
*,
model_provider: Optional[str] = None,
configurable_fields: Literal[None] = None,
config_prefix: Optional[str] = None,
**kwargs: Any,
) -> _ConfigurableModel: ...
@overload
def init_chat_model(
model: Optional[str] = None,
*,
model_provider: Optional[str] = None,
configurable_fields: Union[Literal["any"], List[str], Tuple[str, ...]] = ...,
config_prefix: Optional[str] = None,
**kwargs: Any,
) -> _ConfigurableModel: ...
# FOR CONTRIBUTORS: If adding support for a new provider, please append the provider
# name to the supported list in the docstring below. Do *not* change the order of the
# existing providers.
@beta()
def init_chat_model(
model: str, *, model_provider: Optional[str] = None, **kwargs: Any
) -> BaseChatModel:
model: Optional[str] = None,
*,
model_provider: Optional[str] = None,
configurable_fields: Optional[
Union[Literal["any"], List[str], Tuple[str, ...]]
] = None,
config_prefix: Optional[str] = None,
**kwargs: Any,
) -> Union[BaseChatModel, _ConfigurableModel]:
"""Initialize a ChatModel from the model name and provider.
Must have the integration package corresponding to the model provider installed.
@ -55,19 +124,43 @@ def init_chat_model(
- gemini... -> google_vertexai
- command... -> cohere
- accounts/fireworks... -> fireworks
configurable_fields: Which model parameters are
configurable:
- None: No configurable fields.
- "any": All fields are configurable. *See Security Note below.*
- Union[List[str], Tuple[str, ...]]: Specified fields are configurable.
Fields are assumed to have config_prefix stripped if there is a
config_prefix. If model is specified, then defaults to None. If model is
not specified, then defaults to ``("model", "model_provider")``.
***Security Note***: Setting ``configurable_fields="any"`` means fields like
api_key, base_url, etc. can be altered at runtime, potentially redirecting
model requests to a different service/user. Make sure that if you're
accepting untrusted configurations that you enumerate the
``configurable_fields=(...)`` explicitly.
config_prefix: If config_prefix is a non-empty string then model will be
configurable at runtime via the
``config["configurable"]["{config_prefix}_{param}"]`` keys. If
config_prefix is an empty string then model will be configurable via
``config["configurable"]["{param}"]``.
kwargs: Additional keyword args to pass to
``<<selected ChatModel>>.__init__(model=model_name, **kwargs)``.
Returns:
The BaseChatModel corresponding to the model_name and model_provider specified.
A BaseChatModel corresponding to the model_name and model_provider specified if
configurability is inferred to be False. If configurable, a chat model emulator
that initializes the underlying model at runtime once a config is passed in.
Raises:
ValueError: If model_provider cannot be inferred or isn't supported.
ImportError: If the model provider integration package is not installed.
Example:
Initialize non-configurable models:
.. code-block:: python
# pip install langchain langchain-openai langchain-anthropic langchain-google-vertexai
from langchain.chat_models import init_chat_model
gpt_4o = init_chat_model("gpt-4o", model_provider="openai", temperature=0)
@ -77,7 +170,125 @@ def init_chat_model(
gpt_4o.invoke("what's your name")
claude_opus.invoke("what's your name")
gemini_15.invoke("what's your name")
Create a partially configurable model with no default model:
.. code-block:: python
# pip install langchain langchain-openai langchain-anthropic
from langchain.chat_models import init_chat_model
# We don't need to specify configurable=True if a model isn't specified.
configurable_model = init_chat_model(temperature=0)
configurable_model.invoke(
"what's your name",
config={"configurable": {"model": "gpt-4o"}}
)
# GPT-4o response
configurable_model.invoke(
"what's your name",
config={"configurable": {"model": "claude-3-5-sonnet-20240620"}}
)
# claude-3.5 sonnet response
Create a fully configurable model with a default model and a config prefix:
.. code-block:: python
# pip install langchain langchain-openai langchain-anthropic
from langchain.chat_models import init_chat_model
configurable_model_with_default = init_chat_model(
"gpt-4o",
model_provider="openai",
configurable_fields="any", # this allows us to configure other params like temperature, max_tokens, etc at runtime.
config_prefix="foo",
temperature=0
)
configurable_model_with_default.invoke("what's your name")
# GPT-4o response with temperature 0
configurable_model_with_default.invoke(
"what's your name",
config={
"configurable": {
"foo_model": "claude-3-5-sonnet-20240620",
"foo_model_provider": "anthropic",
"foo_temperature": 0.6
}
}
)
# Claude-3.5 sonnet response with temperature 0.6
Bind tools to a configurable model:
You can call any ChatModel declarative methods on a configurable model in the
same way that you would with a normal model.
.. code-block:: python
# pip install langchain langchain-openai langchain-anthropic
from langchain.chat_models import init_chat_model
from langchain_core.pydantic_v1 import BaseModel, Field
class GetWeather(BaseModel):
'''Get the current weather in a given location'''
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
class GetPopulation(BaseModel):
'''Get the current population in a given location'''
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
configurable_model = init_chat_model(
"gpt-4o",
configurable_fields=("model", "model_provider"),
temperature=0
)
configurable_model_with_tools = configurable_model.bind_tools([GetWeather, GetPopulation])
configurable_model_with_tools.invoke(
"Which city is hotter today and which is bigger: LA or NY?"
)
# GPT-4o response with tool calls
configurable_model_with_tools.invoke(
"Which city is hotter today and which is bigger: LA or NY?",
config={"configurable": {"model": "claude-3-5-sonnet-20240620"}}
)
# Claude-3.5 sonnet response with tools
""" # noqa: E501
if not model and not configurable_fields:
configurable_fields = ("model", "model_provider")
config_prefix = config_prefix or ""
if config_prefix and not configurable_fields:
warnings.warn(
f"{config_prefix=} has been set but no fields are configurable. Set "
f"`configurable_fields=(...)` to specify the model params that are "
f"configurable."
)
if not configurable_fields:
return _init_chat_model_helper(
cast(str, model), model_provider=model_provider, **kwargs
)
else:
if model:
kwargs["model"] = model
if model_provider:
kwargs["model_provider"] = model_provider
return _ConfigurableModel(
default_config=kwargs,
config_prefix=config_prefix,
configurable_fields=configurable_fields,
)
def _init_chat_model_helper(
model: str, *, model_provider: Optional[str] = None, **kwargs: Any
) -> BaseChatModel:
model_provider = model_provider or _attempt_infer_model_provider(model)
if not model_provider:
raise ValueError(
@ -200,3 +411,386 @@ def _check_pkg(pkg: str) -> None:
f"Unable to import {pkg_kebab}. Please install with "
f"`pip install -U {pkg_kebab}`"
)
def _remove_prefix(s: str, prefix: str) -> str:
if s.startswith(prefix):
s = s[len(prefix) :]
return s
_DECLARATIVE_METHODS = ("bind_tools", "with_structured_output")
class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
def __init__(
self,
*,
default_config: Optional[dict] = None,
configurable_fields: Union[Literal["any"], List[str], Tuple[str, ...]] = "any",
config_prefix: str = "",
queued_declarative_operations: Sequence[Tuple[str, Tuple, Dict]] = (),
) -> None:
self._default_config: dict = default_config or {}
self._configurable_fields: Union[Literal["any"], List[str]] = (
configurable_fields
if configurable_fields == "any"
else list(configurable_fields)
)
self._config_prefix = (
config_prefix + "_"
if config_prefix and not config_prefix.endswith("_")
else config_prefix
)
self._queued_declarative_operations: List[Tuple[str, Tuple, Dict]] = list(
queued_declarative_operations
)
def __getattr__(self, name: str) -> Any:
if name in _DECLARATIVE_METHODS:
# Declarative operations that cannot be applied until after an actual model
# object is instantiated. So instead of returning the actual operation,
# we record the operation and its arguments in a queue. This queue is
# then applied in order whenever we actually instantiate the model (in
# self._model()).
def queue(*args: Any, **kwargs: Any) -> _ConfigurableModel:
queued_declarative_operations = list(
self._queued_declarative_operations
)
queued_declarative_operations.append((name, args, kwargs))
return _ConfigurableModel(
default_config=dict(self._default_config),
configurable_fields=list(self._configurable_fields)
if isinstance(self._configurable_fields, list)
else self._configurable_fields,
config_prefix=self._config_prefix,
queued_declarative_operations=queued_declarative_operations,
)
return queue
elif self._default_config and (model := self._model()) and hasattr(model, name):
return getattr(model, name)
else:
msg = f"{name} is not a BaseChatModel attribute"
if self._default_config:
msg += " and is not implemented on the default model"
msg += "."
raise AttributeError(msg)
def _model(self, config: Optional[RunnableConfig] = None) -> Runnable:
params = {**self._default_config, **self._model_params(config)}
model = _init_chat_model_helper(**params)
for name, args, kwargs in self._queued_declarative_operations:
model = getattr(model, name)(*args, **kwargs)
return model
def _model_params(self, config: Optional[RunnableConfig]) -> dict:
config = config or {}
model_params = {
_remove_prefix(k, self._config_prefix): v
for k, v in config.get("configurable", {}).items()
if k.startswith(self._config_prefix)
}
if self._configurable_fields != "any":
model_params = {
k: v for k, v in model_params.items() if k in self._configurable_fields
}
return model_params
def with_config(
self,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> _ConfigurableModel:
"""Bind config to a Runnable, returning a new Runnable."""
config = RunnableConfig(**(config or {}), **cast(RunnableConfig, kwargs))
model_params = self._model_params(config)
remaining_config = {k: v for k, v in config.items() if k != "configurable"}
remaining_config["configurable"] = {
k: v
for k, v in config.get("configurable", {}).items()
if _remove_prefix(k, self._config_prefix) not in model_params
}
queued_declarative_operations = list(self._queued_declarative_operations)
if remaining_config:
queued_declarative_operations.append(
("with_config", (), {"config": remaining_config})
)
return _ConfigurableModel(
default_config={**self._default_config, **model_params},
configurable_fields=list(self._configurable_fields)
if isinstance(self._configurable_fields, list)
else self._configurable_fields,
config_prefix=self._config_prefix,
queued_declarative_operations=queued_declarative_operations,
)
@property
def InputType(self) -> TypeAlias:
"""Get the input type for this runnable."""
from langchain_core.prompt_values import (
ChatPromptValueConcrete,
StringPromptValue,
)
# This is a version of LanguageModelInput which replaces the abstract
# base class BaseMessage with a union of its subclasses, which makes
# for a much better schema.
return Union[
str,
Union[StringPromptValue, ChatPromptValueConcrete],
List[AnyMessage],
]
def invoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
return self._model(config).invoke(input, config=config, **kwargs)
async def ainvoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
return await self._model(config).ainvoke(input, config=config, **kwargs)
def stream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Any]:
yield from self._model(config).stream(input, config=config, **kwargs)
async def astream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Any]:
async for x in self._model(config).astream(input, config=config, **kwargs):
yield x
def batch(
self,
inputs: List[LanguageModelInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Any]:
config = config or None
# If <= 1 config use the underlying models batch implementation.
if config is None or isinstance(config, dict) or len(config) <= 1:
if isinstance(config, list):
config = config[0]
return self._model(config).batch(
inputs, config=config, return_exceptions=return_exceptions, **kwargs
)
# If multiple configs default to Runnable.batch which uses executor to invoke
# in parallel.
else:
return super().batch(
inputs, config=config, return_exceptions=return_exceptions, **kwargs
)
async def abatch(
self,
inputs: List[LanguageModelInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Any]:
config = config or None
# If <= 1 config use the underlying models batch implementation.
if config is None or isinstance(config, dict) or len(config) <= 1:
if isinstance(config, list):
config = config[0]
return await self._model(config).abatch(
inputs, config=config, return_exceptions=return_exceptions, **kwargs
)
# If multiple configs default to Runnable.batch which uses executor to invoke
# in parallel.
else:
return await super().abatch(
inputs, config=config, return_exceptions=return_exceptions, **kwargs
)
def batch_as_completed(
self,
inputs: Sequence[LanguageModelInput],
config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> Iterator[Tuple[int, Union[Any, Exception]]]:
config = config or None
# If <= 1 config use the underlying models batch implementation.
if config is None or isinstance(config, dict) or len(config) <= 1:
if isinstance(config, list):
config = config[0]
yield from self._model(cast(RunnableConfig, config)).batch_as_completed( # type: ignore[call-overload]
inputs, config=config, return_exceptions=return_exceptions, **kwargs
)
# If multiple configs default to Runnable.batch which uses executor to invoke
# in parallel.
else:
yield from super().batch_as_completed( # type: ignore[call-overload]
inputs, config=config, return_exceptions=return_exceptions, **kwargs
)
async def abatch_as_completed(
self,
inputs: Sequence[LanguageModelInput],
config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> AsyncIterator[Tuple[int, Any]]:
config = config or None
# If <= 1 config use the underlying models batch implementation.
if config is None or isinstance(config, dict) or len(config) <= 1:
if isinstance(config, list):
config = config[0]
async for x in self._model(
cast(RunnableConfig, config)
).abatch_as_completed( # type: ignore[call-overload]
inputs, config=config, return_exceptions=return_exceptions, **kwargs
):
yield x
# If multiple configs default to Runnable.batch which uses executor to invoke
# in parallel.
else:
async for x in super().abatch_as_completed( # type: ignore[call-overload]
inputs, config=config, return_exceptions=return_exceptions, **kwargs
):
yield x
def transform(
self,
input: Iterator[LanguageModelInput],
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Any]:
for x in self._model(config).transform(input, config=config, **kwargs):
yield x
async def atransform(
self,
input: AsyncIterator[LanguageModelInput],
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Any]:
async for x in self._model(config).atransform(input, config=config, **kwargs):
yield x
@overload
def astream_log(
self,
input: Any,
config: Optional[RunnableConfig] = None,
*,
diff: Literal[True] = True,
with_streamed_output_list: bool = True,
include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None,
exclude_names: Optional[Sequence[str]] = None,
exclude_types: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> AsyncIterator[RunLogPatch]: ...
@overload
def astream_log(
self,
input: Any,
config: Optional[RunnableConfig] = None,
*,
diff: Literal[False],
with_streamed_output_list: bool = True,
include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None,
exclude_names: Optional[Sequence[str]] = None,
exclude_types: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> AsyncIterator[RunLog]: ...
async def astream_log(
self,
input: Any,
config: Optional[RunnableConfig] = None,
*,
diff: bool = True,
with_streamed_output_list: bool = True,
include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None,
exclude_names: Optional[Sequence[str]] = None,
exclude_types: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> Union[AsyncIterator[RunLogPatch], AsyncIterator[RunLog]]:
async for x in self._model(config).astream_log( # type: ignore[call-overload, misc]
input,
config=config,
diff=diff,
with_streamed_output_list=with_streamed_output_list,
include_names=include_names,
include_types=include_types,
include_tags=include_tags,
exclude_tags=exclude_tags,
exclude_types=exclude_types,
exclude_names=exclude_names,
**kwargs,
):
yield x
async def astream_events(
self,
input: Any,
config: Optional[RunnableConfig] = None,
*,
version: Literal["v1", "v2"],
include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None,
exclude_names: Optional[Sequence[str]] = None,
exclude_types: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> AsyncIterator[StreamEvent]:
async for x in self._model(config).astream_events(
input,
config=config,
version=version,
include_names=include_names,
include_types=include_types,
include_tags=include_tags,
exclude_tags=exclude_tags,
exclude_types=exclude_types,
exclude_names=exclude_names,
**kwargs,
):
yield x
# Explicitly added to satisfy downstream linters.
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
return self.__getattr__("bind_tools")(tools, **kwargs)
# Explicitly added to satisfy downstream linters.
def with_structured_output(
self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
return self.__getattr__("with_structured_output")(schema, **kwargs)

View File

@ -1760,7 +1760,7 @@ files = [
[[package]]
name = "langchain-core"
version = "0.2.13"
version = "0.2.18"
description = "Building applications with LLMs through composability"
optional = false
python-versions = ">=3.8.1,<4.0"
@ -1784,7 +1784,7 @@ url = "../core"
[[package]]
name = "langchain-openai"
version = "0.1.15"
version = "0.1.16"
description = "An integration package connecting OpenAI and LangChain"
optional = true
python-versions = ">=3.8.1,<4.0"
@ -1792,7 +1792,7 @@ files = []
develop = true
[package.dependencies]
langchain-core = "^0.2.13"
langchain-core = "^0.2.17"
openai = "^1.32.0"
tiktoken = ">=0.7,<1"
@ -1800,6 +1800,24 @@ tiktoken = ">=0.7,<1"
type = "directory"
url = "../partners/openai"
[[package]]
name = "langchain-standard-tests"
version = "0.1.1"
description = "Standard tests for LangChain implementations"
optional = false
python-versions = ">=3.8.1,<4.0"
files = []
develop = true
[package.dependencies]
httpx = "^0.27.0"
langchain-core = ">=0.1.40,<0.3"
pytest = ">=7,<9"
[package.source]
type = "directory"
url = "../standard-tests"
[[package]]
name = "langchain-text-splitters"
version = "0.2.2"
@ -2490,8 +2508,8 @@ files = [
[package.dependencies]
numpy = [
{version = ">=1.20.3", markers = "python_version < \"3.10\""},
{version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""},
{version = ">=1.23.2", markers = "python_version >= \"3.11\""},
{version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""},
]
python-dateutil = ">=2.8.2"
pytz = ">=2020.1"
@ -4111,20 +4129,6 @@ files = [
cryptography = ">=35.0.0"
types-pyOpenSSL = "*"
[[package]]
name = "types-requests"
version = "2.31.0.6"
description = "Typing stubs for requests"
optional = false
python-versions = ">=3.7"
files = [
{file = "types-requests-2.31.0.6.tar.gz", hash = "sha256:cd74ce3b53c461f1228a9b783929ac73a666658f223e28ed29753771477b3bd0"},
{file = "types_requests-2.31.0.6-py3-none-any.whl", hash = "sha256:a2db9cb228a81da8348b49ad6db3f5519452dd20a9c1e1a868c83c5fe88fd1a9"},
]
[package.dependencies]
types-urllib3 = "*"
[[package]]
name = "types-requests"
version = "2.32.0.20240622"
@ -4161,17 +4165,6 @@ files = [
{file = "types_toml-0.10.8.20240310-py3-none-any.whl", hash = "sha256:627b47775d25fa29977d9c70dc0cbab3f314f32c8d8d0c012f2ef5de7aaec05d"},
]
[[package]]
name = "types-urllib3"
version = "1.26.25.14"
description = "Typing stubs for urllib3"
optional = false
python-versions = "*"
files = [
{file = "types-urllib3-1.26.25.14.tar.gz", hash = "sha256:229b7f577c951b8c1b92c1bc2b2fdb0b49847bd2af6d1cc2a2e3dd340f3bda8f"},
{file = "types_urllib3-1.26.25.14-py3-none-any.whl", hash = "sha256:9683bbb7fb72e32bfe9d2be6e04875fbe1b3eeec3cbb4ea231435aa7fd6b4f0e"},
]
[[package]]
name = "typing-extensions"
version = "4.12.2"
@ -4208,22 +4201,6 @@ files = [
[package.extras]
dev = ["flake8", "flake8-annotations", "flake8-bandit", "flake8-bugbear", "flake8-commas", "flake8-comprehensions", "flake8-continuation", "flake8-datetimez", "flake8-docstrings", "flake8-import-order", "flake8-literal", "flake8-modern-annotations", "flake8-noqa", "flake8-pyproject", "flake8-requirements", "flake8-typechecking-import", "flake8-use-fstring", "mypy", "pep8-naming", "types-PyYAML"]
[[package]]
name = "urllib3"
version = "1.26.19"
description = "HTTP library with thread-safe connection pooling, file post, and more."
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7"
files = [
{file = "urllib3-1.26.19-py2.py3-none-any.whl", hash = "sha256:37a0344459b199fce0e80b0d3569837ec6b6937435c5244e7fd73fa6006830f3"},
{file = "urllib3-1.26.19.tar.gz", hash = "sha256:3e3d753a8618b86d7de333b4223005f68720bcd6a7d2bcb9fbd2229ec7c1e429"},
]
[package.extras]
brotli = ["brotli (==1.0.9)", "brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"]
secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"]
socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"]
[[package]]
name = "urllib3"
version = "2.2.2"
@ -4241,6 +4218,23 @@ h2 = ["h2 (>=4,<5)"]
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
zstd = ["zstandard (>=0.18.0)"]
[[package]]
name = "vcrpy"
version = "4.3.0"
description = "Automatically mock your HTTP interactions to simplify and speed up testing"
optional = false
python-versions = ">=3.7"
files = [
{file = "vcrpy-4.3.0-py2.py3-none-any.whl", hash = "sha256:8fbd4be412e8a7f35f623dd61034e6380a1c8dbd0edf6e87277a3289f6e98093"},
{file = "vcrpy-4.3.0.tar.gz", hash = "sha256:49c270ce67e826dba027d83e20d25b67a5885487697e97bca6dbdf53d750a0ac"},
]
[package.dependencies]
PyYAML = "*"
six = ">=1.5"
wrapt = "*"
yarl = "*"
[[package]]
name = "vcrpy"
version = "6.0.1"
@ -4253,7 +4247,6 @@ files = [
[package.dependencies]
PyYAML = "*"
urllib3 = {version = "<2", markers = "platform_python_implementation == \"PyPy\" or python_version < \"3.10\""}
wrapt = "*"
yarl = "*"
@ -4568,4 +4561,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools",
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "30237e9280ade99d7c7741aec1b3d38a8e1ccb24a3d0c4380d48ae80ab86a136"
content-hash = "14ebfabffa095e7619e9646bf56bc166d18c1c975b65e301bb6163c4e8eecaac"

View File

@ -95,6 +95,10 @@ pytest-socket = "^0.6.0"
syrupy = "^4.0.2"
requests-mock = "^1.11.0"
[tool.poetry.group.test.dependencies.langchain-standard-tests]
path = "../standard-tests"
develop = true
[tool.poetry.group.codespell.dependencies]
codespell = "^2.2.0"

View File

@ -0,0 +1,59 @@
from typing import Type, cast
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import RunnableConfig
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
from langchain.chat_models import init_chat_model
class multiply(BaseModel):
"""Product of two ints."""
x: int
y: int
@pytest.mark.requires("langchain_openai", "langchain_anthropic")
async def test_init_chat_model_chain() -> None:
model = init_chat_model("gpt-4o", configurable_fields="any", config_prefix="bar")
model_with_tools = model.bind_tools([multiply])
model_with_config = model_with_tools.with_config(
RunnableConfig(tags=["foo"]),
configurable={"bar_model": "claude-3-sonnet-20240229"},
)
prompt = ChatPromptTemplate.from_messages([("system", "foo"), ("human", "{input}")])
chain = prompt | model_with_config
output = chain.invoke({"input": "bar"})
assert isinstance(output, AIMessage)
events = []
async for event in chain.astream_events({"input": "bar"}, version="v2"):
events.append(event)
assert events
class TestStandard(ChatModelIntegrationTests):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
return cast(Type[BaseChatModel], init_chat_model)
@property
def chat_model_params(self) -> dict:
return {"model": "gpt-4o", "configurable_fields": "any"}
@property
def supports_image_inputs(self) -> bool:
return True
@property
def has_tool_calling(self) -> bool:
return True
@property
def has_structured_output(self) -> bool:
return True

View File

@ -1,4 +1,11 @@
import os
from unittest import mock
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import HumanMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig, RunnableSequence
from langchain.chat_models.base import __all__, init_chat_model
@ -34,14 +41,156 @@ def test_all_imports() -> None:
],
)
def test_init_chat_model(model_name: str, model_provider: str) -> None:
init_chat_model(model_name, model_provider=model_provider, api_key="foo")
_: BaseChatModel = init_chat_model(
model_name, model_provider=model_provider, api_key="foo"
)
def test_init_missing_dep() -> None:
with pytest.raises(ImportError):
init_chat_model("gpt-4o", model_provider="openai")
init_chat_model("mixtral-8x7b-32768", model_provider="groq")
def test_init_unknown_provider() -> None:
with pytest.raises(ValueError):
init_chat_model("foo", model_provider="bar")
@pytest.mark.requires("langchain_openai")
@mock.patch.dict(
os.environ, {"OPENAI_API_KEY": "foo", "ANTHROPIC_API_KEY": "foo"}, clear=True
)
def test_configurable() -> None:
model = init_chat_model()
for method in (
"invoke",
"ainvoke",
"batch",
"abatch",
"stream",
"astream",
"batch_as_completed",
"abatch_as_completed",
):
assert hasattr(model, method)
# Doesn't have access non-configurable, non-declarative methods until a config is
# provided.
for method in ("get_num_tokens", "get_num_tokens_from_messages"):
with pytest.raises(AttributeError):
getattr(model, method)
# Can call declarative methods even without a default model.
model_with_tools = model.bind_tools(
[{"name": "foo", "description": "foo", "parameters": {}}]
)
# Check that original model wasn't mutated by declarative operation.
assert model._queued_declarative_operations == []
# Can iteratively call declarative methods.
model_with_config = model_with_tools.with_config(
RunnableConfig(tags=["foo"]), configurable={"model": "gpt-4o"}
)
assert model_with_config.model_name == "gpt-4o" # type: ignore[attr-defined]
for method in ("get_num_tokens", "get_num_tokens_from_messages"):
assert hasattr(model_with_config, method)
assert model_with_config.dict() == { # type: ignore[attr-defined]
"name": None,
"bound": {
"model_name": "gpt-4o",
"model": "gpt-4o",
"stream": False,
"n": 1,
"temperature": 0.7,
"presence_penalty": None,
"frequency_penalty": None,
"seed": None,
"top_p": None,
"logprobs": False,
"top_logprobs": None,
"logit_bias": None,
"_type": "openai-chat",
},
"kwargs": {
"tools": [
{
"type": "function",
"function": {"name": "foo", "description": "foo", "parameters": {}},
}
]
},
"config": {"tags": ["foo"], "configurable": {}},
"config_factories": [],
"custom_input_type": None,
"custom_output_type": None,
}
@pytest.mark.requires("langchain_openai", "langchain_anthropic")
@mock.patch.dict(
os.environ, {"OPENAI_API_KEY": "foo", "ANTHROPIC_API_KEY": "foo"}, clear=True
)
def test_configurable_with_default() -> None:
model = init_chat_model("gpt-4o", configurable_fields="any", config_prefix="bar")
for method in (
"invoke",
"ainvoke",
"batch",
"abatch",
"stream",
"astream",
"batch_as_completed",
"abatch_as_completed",
):
assert hasattr(model, method)
# Does have access non-configurable, non-declarative methods since default params
# are provided.
for method in ("get_num_tokens", "get_num_tokens_from_messages", "dict"):
assert hasattr(model, method)
assert model.model_name == "gpt-4o" # type: ignore[attr-defined]
model_with_tools = model.bind_tools(
[{"name": "foo", "description": "foo", "parameters": {}}]
)
model_with_config = model_with_tools.with_config(
RunnableConfig(tags=["foo"]),
configurable={"bar_model": "claude-3-sonnet-20240229"},
)
assert model_with_config.model == "claude-3-sonnet-20240229" # type: ignore[attr-defined]
# Anthropic defaults to using `transformers` for token counting.
with pytest.raises(ImportError):
model_with_config.get_num_tokens_from_messages([(HumanMessage("foo"))]) # type: ignore[attr-defined]
assert model_with_config.dict() == { # type: ignore[attr-defined]
"name": None,
"bound": {
"model": "claude-3-sonnet-20240229",
"max_tokens": 1024,
"temperature": None,
"top_k": None,
"top_p": None,
"model_kwargs": {},
"streaming": False,
"max_retries": 2,
"default_request_timeout": None,
"_type": "anthropic-chat",
},
"kwargs": {
"tools": [{"name": "foo", "description": "foo", "input_schema": {}}]
},
"config": {"tags": ["foo"], "configurable": {}},
"config_factories": [],
"custom_input_type": None,
"custom_output_type": None,
}
prompt = ChatPromptTemplate.from_messages([("system", "foo")])
chain = prompt | model_with_config
assert isinstance(chain, RunnableSequence)

View File

@ -79,6 +79,7 @@ def test_test_group_dependencies(poetry_conf: Mapping[str, Any]) -> None:
"duckdb-engine",
"freezegun",
"langchain-core",
"langchain-standard-tests",
"langchain-text-splitters",
"langchain-openai",
"lark",