mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
langchain[minor]: Generic configurable model (#23419)
alternative to [23244](https://github.com/langchain-ai/langchain/pull/23244). allows you to use chat model declarative methods ![Screenshot 2024-06-25 at 1 07 10 PM](https://github.com/langchain-ai/langchain/assets/22008038/910d1694-9b7b-46bc-bc2e-3792df9321d6)
This commit is contained in:
parent
d0728b0ba0
commit
0da5078cad
@ -25,7 +25,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install -qU langchain langchain-openai langchain-anthropic langchain-google-vertexai"
|
||||
"%pip install -qU langchain>=0.2.7 langchain-openai langchain-anthropic langchain-google-vertexai"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -76,32 +76,6 @@
|
||||
"print(\"Gemini 1.5: \" + gemini_15.invoke(\"what's your name\").content + \"\\n\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fff9a4c8-b6ee-4a1a-8d3d-0ecaa312d4ed",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Simple config example"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "75c25d39-bf47-4b51-a6c6-64d9c572bfd6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"user_config = {\n",
|
||||
" \"model\": \"...user-specified...\",\n",
|
||||
" \"model_provider\": \"...user-specified...\",\n",
|
||||
" \"temperature\": 0,\n",
|
||||
" \"max_tokens\": 1000,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"llm = init_chat_model(**user_config)\n",
|
||||
"llm.invoke(\"what's your name\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f811f219-5e78-4b62-b495-915d52a22532",
|
||||
@ -125,12 +99,215 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "da07b5c0-d2e6-42e4-bfcd-2efcfaae6221",
|
||||
"cell_type": "markdown",
|
||||
"id": "476a44db-c50d-4846-951d-0f1c9ba8bbaa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"## Creating a configurable model\n",
|
||||
"\n",
|
||||
"You can also create a runtime-configurable model by specifying `configurable_fields`. If you don't specify a `model` value, then \"model\" and \"model_provider\" be configurable by default."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "6c037f27-12d7-4e83-811e-4245c0e3ba58",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=\"I'm an AI language model created by OpenAI, and I don't have a personal name. You can call me Assistant or any other name you prefer! How can I assist you today?\", response_metadata={'token_usage': {'completion_tokens': 37, 'prompt_tokens': 11, 'total_tokens': 48}, 'model_name': 'gpt-4o-2024-05-13', 'system_fingerprint': 'fp_d576307f90', 'finish_reason': 'stop', 'logprobs': None}, id='run-5428ab5c-b5c0-46de-9946-5d4ca40dbdc8-0', usage_metadata={'input_tokens': 11, 'output_tokens': 37, 'total_tokens': 48})"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"configurable_model = init_chat_model(temperature=0)\n",
|
||||
"\n",
|
||||
"configurable_model.invoke(\n",
|
||||
" \"what's your name\", config={\"configurable\": {\"model\": \"gpt-4o\"}}\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "321e3036-abd2-4e1f-bcc6-606efd036954",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=\"My name is Claude. It's nice to meet you!\", response_metadata={'id': 'msg_012XvotUJ3kGLXJUWKBVxJUi', 'model': 'claude-3-5-sonnet-20240620', 'stop_reason': 'end_turn', 'stop_sequence': None, 'usage': {'input_tokens': 11, 'output_tokens': 15}}, id='run-1ad1eefe-f1c6-4244-8bc6-90e2cb7ee554-0', usage_metadata={'input_tokens': 11, 'output_tokens': 15, 'total_tokens': 26})"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"configurable_model.invoke(\n",
|
||||
" \"what's your name\", config={\"configurable\": {\"model\": \"claude-3-5-sonnet-20240620\"}}\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7f3b3d4a-4066-45e4-8297-ea81ac8e70b7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Configurable model with default values\n",
|
||||
"\n",
|
||||
"We can create a configurable model with default model values, specify which parameters are configurable, and add prefixes to configurable params:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "814a2289-d0db-401e-b555-d5116112b413",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=\"I'm an AI language model created by OpenAI, and I don't have a personal name. You can call me Assistant or any other name you prefer! How can I assist you today?\", response_metadata={'token_usage': {'completion_tokens': 37, 'prompt_tokens': 11, 'total_tokens': 48}, 'model_name': 'gpt-4o-2024-05-13', 'system_fingerprint': 'fp_ce0793330f', 'finish_reason': 'stop', 'logprobs': None}, id='run-3923e328-7715-4cd6-b215-98e4b6bf7c9d-0', usage_metadata={'input_tokens': 11, 'output_tokens': 37, 'total_tokens': 48})"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"first_llm = init_chat_model(\n",
|
||||
" model=\"gpt-4o\",\n",
|
||||
" temperature=0,\n",
|
||||
" configurable_fields=(\"model\", \"model_provider\", \"temperature\", \"max_tokens\"),\n",
|
||||
" config_prefix=\"first\", # useful when you have a chain with multiple models\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"first_llm.invoke(\"what's your name\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "6c8755ba-c001-4f5a-a497-be3f1db83244",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=\"My name is Claude. It's nice to meet you!\", response_metadata={'id': 'msg_01RyYR64DoMPNCfHeNnroMXm', 'model': 'claude-3-5-sonnet-20240620', 'stop_reason': 'end_turn', 'stop_sequence': None, 'usage': {'input_tokens': 11, 'output_tokens': 15}}, id='run-22446159-3723-43e6-88df-b84797e7751d-0', usage_metadata={'input_tokens': 11, 'output_tokens': 15, 'total_tokens': 26})"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"first_llm.invoke(\n",
|
||||
" \"what's your name\",\n",
|
||||
" config={\n",
|
||||
" \"configurable\": {\n",
|
||||
" \"first_model\": \"claude-3-5-sonnet-20240620\",\n",
|
||||
" \"first_temperature\": 0.5,\n",
|
||||
" \"first_max_tokens\": 100,\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0072b1a3-7e44-4b4e-8b07-efe1ba91a689",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Using a configurable model declaratively\n",
|
||||
"\n",
|
||||
"We can call declarative operations like `bind_tools`, `with_structured_output`, `with_configurable`, etc. on a configurable model and chain a configurable model in the same way that we would a regularly instantiated chat model object."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "067dabee-1050-4110-ae24-c48eba01e13b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[{'name': 'GetPopulation',\n",
|
||||
" 'args': {'location': 'Los Angeles, CA'},\n",
|
||||
" 'id': 'call_sYT3PFMufHGWJD32Hi2CTNUP'},\n",
|
||||
" {'name': 'GetPopulation',\n",
|
||||
" 'args': {'location': 'New York, NY'},\n",
|
||||
" 'id': 'call_j1qjhxRnD3ffQmRyqjlI1Lnk'}]"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class GetWeather(BaseModel):\n",
|
||||
" \"\"\"Get the current weather in a given location\"\"\"\n",
|
||||
"\n",
|
||||
" location: str = Field(..., description=\"The city and state, e.g. San Francisco, CA\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class GetPopulation(BaseModel):\n",
|
||||
" \"\"\"Get the current population in a given location\"\"\"\n",
|
||||
"\n",
|
||||
" location: str = Field(..., description=\"The city and state, e.g. San Francisco, CA\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"llm = init_chat_model(temperature=0)\n",
|
||||
"llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])\n",
|
||||
"\n",
|
||||
"llm_with_tools.invoke(\n",
|
||||
" \"what's bigger in 2024 LA or NYC\", config={\"configurable\": {\"model\": \"gpt-4o\"}}\n",
|
||||
").tool_calls"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "e57dfe9f-cd24-4e37-9ce9-ccf8daf78f89",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[{'name': 'GetPopulation',\n",
|
||||
" 'args': {'location': 'Los Angeles, CA'},\n",
|
||||
" 'id': 'toolu_01CxEHxKtVbLBrvzFS7GQ5xR'},\n",
|
||||
" {'name': 'GetPopulation',\n",
|
||||
" 'args': {'location': 'New York City, NY'},\n",
|
||||
" 'id': 'toolu_013A79qt5toWSsKunFBDZd5S'}]"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"llm_with_tools.invoke(\n",
|
||||
" \"what's bigger in 2024 LA or NYC\",\n",
|
||||
" config={\"configurable\": {\"model\": \"claude-3-5-sonnet-20240620\"}},\n",
|
||||
").tool_calls"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@ -149,7 +326,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
"version": "3.11.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -1327,7 +1327,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
def with_config(
|
||||
self,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
# Sadly Unpack is not well supported by mypy so this will have to be untyped
|
||||
# Sadly Unpack is not well-supported by mypy so this will have to be untyped
|
||||
**kwargs: Any,
|
||||
) -> Runnable[Input, Output]:
|
||||
"""
|
||||
|
@ -1,30 +1,99 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from importlib import util
|
||||
from typing import Any, Optional
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.language_models.chat_models import (
|
||||
from langchain_core.language_models import (
|
||||
BaseChatModel,
|
||||
LanguageModelInput,
|
||||
SimpleChatModel,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import (
|
||||
agenerate_from_stream,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.messages import AnyMessage, BaseMessage
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables import Runnable, RunnableConfig
|
||||
from langchain_core.runnables.schema import StreamEvent
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.tracers import RunLog, RunLogPatch
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
__all__ = [
|
||||
"init_chat_model",
|
||||
# For backwards compatibility
|
||||
"BaseChatModel",
|
||||
"SimpleChatModel",
|
||||
"generate_from_stream",
|
||||
"agenerate_from_stream",
|
||||
"init_chat_model",
|
||||
]
|
||||
|
||||
|
||||
@overload
|
||||
def init_chat_model( # type: ignore[overload-overlap]
|
||||
model: str,
|
||||
*,
|
||||
model_provider: Optional[str] = None,
|
||||
configurable_fields: Literal[None] = None,
|
||||
config_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseChatModel: ...
|
||||
|
||||
|
||||
@overload
|
||||
def init_chat_model(
|
||||
model: Literal[None] = None,
|
||||
*,
|
||||
model_provider: Optional[str] = None,
|
||||
configurable_fields: Literal[None] = None,
|
||||
config_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> _ConfigurableModel: ...
|
||||
|
||||
|
||||
@overload
|
||||
def init_chat_model(
|
||||
model: Optional[str] = None,
|
||||
*,
|
||||
model_provider: Optional[str] = None,
|
||||
configurable_fields: Union[Literal["any"], List[str], Tuple[str, ...]] = ...,
|
||||
config_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> _ConfigurableModel: ...
|
||||
|
||||
|
||||
# FOR CONTRIBUTORS: If adding support for a new provider, please append the provider
|
||||
# name to the supported list in the docstring below. Do *not* change the order of the
|
||||
# existing providers.
|
||||
@beta()
|
||||
def init_chat_model(
|
||||
model: str, *, model_provider: Optional[str] = None, **kwargs: Any
|
||||
) -> BaseChatModel:
|
||||
model: Optional[str] = None,
|
||||
*,
|
||||
model_provider: Optional[str] = None,
|
||||
configurable_fields: Optional[
|
||||
Union[Literal["any"], List[str], Tuple[str, ...]]
|
||||
] = None,
|
||||
config_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[BaseChatModel, _ConfigurableModel]:
|
||||
"""Initialize a ChatModel from the model name and provider.
|
||||
|
||||
Must have the integration package corresponding to the model provider installed.
|
||||
@ -55,19 +124,43 @@ def init_chat_model(
|
||||
- gemini... -> google_vertexai
|
||||
- command... -> cohere
|
||||
- accounts/fireworks... -> fireworks
|
||||
configurable_fields: Which model parameters are
|
||||
configurable:
|
||||
- None: No configurable fields.
|
||||
- "any": All fields are configurable. *See Security Note below.*
|
||||
- Union[List[str], Tuple[str, ...]]: Specified fields are configurable.
|
||||
|
||||
Fields are assumed to have config_prefix stripped if there is a
|
||||
config_prefix. If model is specified, then defaults to None. If model is
|
||||
not specified, then defaults to ``("model", "model_provider")``.
|
||||
|
||||
***Security Note***: Setting ``configurable_fields="any"`` means fields like
|
||||
api_key, base_url, etc. can be altered at runtime, potentially redirecting
|
||||
model requests to a different service/user. Make sure that if you're
|
||||
accepting untrusted configurations that you enumerate the
|
||||
``configurable_fields=(...)`` explicitly.
|
||||
|
||||
config_prefix: If config_prefix is a non-empty string then model will be
|
||||
configurable at runtime via the
|
||||
``config["configurable"]["{config_prefix}_{param}"]`` keys. If
|
||||
config_prefix is an empty string then model will be configurable via
|
||||
``config["configurable"]["{param}"]``.
|
||||
kwargs: Additional keyword args to pass to
|
||||
``<<selected ChatModel>>.__init__(model=model_name, **kwargs)``.
|
||||
|
||||
Returns:
|
||||
The BaseChatModel corresponding to the model_name and model_provider specified.
|
||||
A BaseChatModel corresponding to the model_name and model_provider specified if
|
||||
configurability is inferred to be False. If configurable, a chat model emulator
|
||||
that initializes the underlying model at runtime once a config is passed in.
|
||||
|
||||
Raises:
|
||||
ValueError: If model_provider cannot be inferred or isn't supported.
|
||||
ImportError: If the model provider integration package is not installed.
|
||||
|
||||
Example:
|
||||
Initialize non-configurable models:
|
||||
.. code-block:: python
|
||||
|
||||
# pip install langchain langchain-openai langchain-anthropic langchain-google-vertexai
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
gpt_4o = init_chat_model("gpt-4o", model_provider="openai", temperature=0)
|
||||
@ -77,7 +170,125 @@ def init_chat_model(
|
||||
gpt_4o.invoke("what's your name")
|
||||
claude_opus.invoke("what's your name")
|
||||
gemini_15.invoke("what's your name")
|
||||
|
||||
|
||||
Create a partially configurable model with no default model:
|
||||
.. code-block:: python
|
||||
|
||||
# pip install langchain langchain-openai langchain-anthropic
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
# We don't need to specify configurable=True if a model isn't specified.
|
||||
configurable_model = init_chat_model(temperature=0)
|
||||
|
||||
configurable_model.invoke(
|
||||
"what's your name",
|
||||
config={"configurable": {"model": "gpt-4o"}}
|
||||
)
|
||||
# GPT-4o response
|
||||
|
||||
configurable_model.invoke(
|
||||
"what's your name",
|
||||
config={"configurable": {"model": "claude-3-5-sonnet-20240620"}}
|
||||
)
|
||||
# claude-3.5 sonnet response
|
||||
|
||||
Create a fully configurable model with a default model and a config prefix:
|
||||
.. code-block:: python
|
||||
|
||||
# pip install langchain langchain-openai langchain-anthropic
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
configurable_model_with_default = init_chat_model(
|
||||
"gpt-4o",
|
||||
model_provider="openai",
|
||||
configurable_fields="any", # this allows us to configure other params like temperature, max_tokens, etc at runtime.
|
||||
config_prefix="foo",
|
||||
temperature=0
|
||||
)
|
||||
|
||||
configurable_model_with_default.invoke("what's your name")
|
||||
# GPT-4o response with temperature 0
|
||||
|
||||
configurable_model_with_default.invoke(
|
||||
"what's your name",
|
||||
config={
|
||||
"configurable": {
|
||||
"foo_model": "claude-3-5-sonnet-20240620",
|
||||
"foo_model_provider": "anthropic",
|
||||
"foo_temperature": 0.6
|
||||
}
|
||||
}
|
||||
)
|
||||
# Claude-3.5 sonnet response with temperature 0.6
|
||||
|
||||
Bind tools to a configurable model:
|
||||
You can call any ChatModel declarative methods on a configurable model in the
|
||||
same way that you would with a normal model.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# pip install langchain langchain-openai langchain-anthropic
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
|
||||
class GetWeather(BaseModel):
|
||||
'''Get the current weather in a given location'''
|
||||
|
||||
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
|
||||
|
||||
class GetPopulation(BaseModel):
|
||||
'''Get the current population in a given location'''
|
||||
|
||||
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
|
||||
|
||||
configurable_model = init_chat_model(
|
||||
"gpt-4o",
|
||||
configurable_fields=("model", "model_provider"),
|
||||
temperature=0
|
||||
)
|
||||
|
||||
configurable_model_with_tools = configurable_model.bind_tools([GetWeather, GetPopulation])
|
||||
configurable_model_with_tools.invoke(
|
||||
"Which city is hotter today and which is bigger: LA or NY?"
|
||||
)
|
||||
# GPT-4o response with tool calls
|
||||
|
||||
configurable_model_with_tools.invoke(
|
||||
"Which city is hotter today and which is bigger: LA or NY?",
|
||||
config={"configurable": {"model": "claude-3-5-sonnet-20240620"}}
|
||||
)
|
||||
# Claude-3.5 sonnet response with tools
|
||||
""" # noqa: E501
|
||||
if not model and not configurable_fields:
|
||||
configurable_fields = ("model", "model_provider")
|
||||
config_prefix = config_prefix or ""
|
||||
if config_prefix and not configurable_fields:
|
||||
warnings.warn(
|
||||
f"{config_prefix=} has been set but no fields are configurable. Set "
|
||||
f"`configurable_fields=(...)` to specify the model params that are "
|
||||
f"configurable."
|
||||
)
|
||||
|
||||
if not configurable_fields:
|
||||
return _init_chat_model_helper(
|
||||
cast(str, model), model_provider=model_provider, **kwargs
|
||||
)
|
||||
else:
|
||||
if model:
|
||||
kwargs["model"] = model
|
||||
if model_provider:
|
||||
kwargs["model_provider"] = model_provider
|
||||
return _ConfigurableModel(
|
||||
default_config=kwargs,
|
||||
config_prefix=config_prefix,
|
||||
configurable_fields=configurable_fields,
|
||||
)
|
||||
|
||||
|
||||
def _init_chat_model_helper(
|
||||
model: str, *, model_provider: Optional[str] = None, **kwargs: Any
|
||||
) -> BaseChatModel:
|
||||
model_provider = model_provider or _attempt_infer_model_provider(model)
|
||||
if not model_provider:
|
||||
raise ValueError(
|
||||
@ -200,3 +411,386 @@ def _check_pkg(pkg: str) -> None:
|
||||
f"Unable to import {pkg_kebab}. Please install with "
|
||||
f"`pip install -U {pkg_kebab}`"
|
||||
)
|
||||
|
||||
|
||||
def _remove_prefix(s: str, prefix: str) -> str:
|
||||
if s.startswith(prefix):
|
||||
s = s[len(prefix) :]
|
||||
return s
|
||||
|
||||
|
||||
_DECLARATIVE_METHODS = ("bind_tools", "with_structured_output")
|
||||
|
||||
|
||||
class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
default_config: Optional[dict] = None,
|
||||
configurable_fields: Union[Literal["any"], List[str], Tuple[str, ...]] = "any",
|
||||
config_prefix: str = "",
|
||||
queued_declarative_operations: Sequence[Tuple[str, Tuple, Dict]] = (),
|
||||
) -> None:
|
||||
self._default_config: dict = default_config or {}
|
||||
self._configurable_fields: Union[Literal["any"], List[str]] = (
|
||||
configurable_fields
|
||||
if configurable_fields == "any"
|
||||
else list(configurable_fields)
|
||||
)
|
||||
self._config_prefix = (
|
||||
config_prefix + "_"
|
||||
if config_prefix and not config_prefix.endswith("_")
|
||||
else config_prefix
|
||||
)
|
||||
self._queued_declarative_operations: List[Tuple[str, Tuple, Dict]] = list(
|
||||
queued_declarative_operations
|
||||
)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if name in _DECLARATIVE_METHODS:
|
||||
# Declarative operations that cannot be applied until after an actual model
|
||||
# object is instantiated. So instead of returning the actual operation,
|
||||
# we record the operation and its arguments in a queue. This queue is
|
||||
# then applied in order whenever we actually instantiate the model (in
|
||||
# self._model()).
|
||||
def queue(*args: Any, **kwargs: Any) -> _ConfigurableModel:
|
||||
queued_declarative_operations = list(
|
||||
self._queued_declarative_operations
|
||||
)
|
||||
queued_declarative_operations.append((name, args, kwargs))
|
||||
return _ConfigurableModel(
|
||||
default_config=dict(self._default_config),
|
||||
configurable_fields=list(self._configurable_fields)
|
||||
if isinstance(self._configurable_fields, list)
|
||||
else self._configurable_fields,
|
||||
config_prefix=self._config_prefix,
|
||||
queued_declarative_operations=queued_declarative_operations,
|
||||
)
|
||||
|
||||
return queue
|
||||
elif self._default_config and (model := self._model()) and hasattr(model, name):
|
||||
return getattr(model, name)
|
||||
else:
|
||||
msg = f"{name} is not a BaseChatModel attribute"
|
||||
if self._default_config:
|
||||
msg += " and is not implemented on the default model"
|
||||
msg += "."
|
||||
raise AttributeError(msg)
|
||||
|
||||
def _model(self, config: Optional[RunnableConfig] = None) -> Runnable:
|
||||
params = {**self._default_config, **self._model_params(config)}
|
||||
model = _init_chat_model_helper(**params)
|
||||
for name, args, kwargs in self._queued_declarative_operations:
|
||||
model = getattr(model, name)(*args, **kwargs)
|
||||
return model
|
||||
|
||||
def _model_params(self, config: Optional[RunnableConfig]) -> dict:
|
||||
config = config or {}
|
||||
model_params = {
|
||||
_remove_prefix(k, self._config_prefix): v
|
||||
for k, v in config.get("configurable", {}).items()
|
||||
if k.startswith(self._config_prefix)
|
||||
}
|
||||
if self._configurable_fields != "any":
|
||||
model_params = {
|
||||
k: v for k, v in model_params.items() if k in self._configurable_fields
|
||||
}
|
||||
return model_params
|
||||
|
||||
def with_config(
|
||||
self,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> _ConfigurableModel:
|
||||
"""Bind config to a Runnable, returning a new Runnable."""
|
||||
config = RunnableConfig(**(config or {}), **cast(RunnableConfig, kwargs))
|
||||
model_params = self._model_params(config)
|
||||
remaining_config = {k: v for k, v in config.items() if k != "configurable"}
|
||||
remaining_config["configurable"] = {
|
||||
k: v
|
||||
for k, v in config.get("configurable", {}).items()
|
||||
if _remove_prefix(k, self._config_prefix) not in model_params
|
||||
}
|
||||
queued_declarative_operations = list(self._queued_declarative_operations)
|
||||
if remaining_config:
|
||||
queued_declarative_operations.append(
|
||||
("with_config", (), {"config": remaining_config})
|
||||
)
|
||||
return _ConfigurableModel(
|
||||
default_config={**self._default_config, **model_params},
|
||||
configurable_fields=list(self._configurable_fields)
|
||||
if isinstance(self._configurable_fields, list)
|
||||
else self._configurable_fields,
|
||||
config_prefix=self._config_prefix,
|
||||
queued_declarative_operations=queued_declarative_operations,
|
||||
)
|
||||
|
||||
@property
|
||||
def InputType(self) -> TypeAlias:
|
||||
"""Get the input type for this runnable."""
|
||||
from langchain_core.prompt_values import (
|
||||
ChatPromptValueConcrete,
|
||||
StringPromptValue,
|
||||
)
|
||||
|
||||
# This is a version of LanguageModelInput which replaces the abstract
|
||||
# base class BaseMessage with a union of its subclasses, which makes
|
||||
# for a much better schema.
|
||||
return Union[
|
||||
str,
|
||||
Union[StringPromptValue, ChatPromptValueConcrete],
|
||||
List[AnyMessage],
|
||||
]
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return self._model(config).invoke(input, config=config, **kwargs)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return await self._model(config).ainvoke(input, config=config, **kwargs)
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Any]:
|
||||
yield from self._model(config).stream(input, config=config, **kwargs)
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Any]:
|
||||
async for x in self._model(config).astream(input, config=config, **kwargs):
|
||||
yield x
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Any]:
|
||||
config = config or None
|
||||
# If <= 1 config use the underlying models batch implementation.
|
||||
if config is None or isinstance(config, dict) or len(config) <= 1:
|
||||
if isinstance(config, list):
|
||||
config = config[0]
|
||||
return self._model(config).batch(
|
||||
inputs, config=config, return_exceptions=return_exceptions, **kwargs
|
||||
)
|
||||
# If multiple configs default to Runnable.batch which uses executor to invoke
|
||||
# in parallel.
|
||||
else:
|
||||
return super().batch(
|
||||
inputs, config=config, return_exceptions=return_exceptions, **kwargs
|
||||
)
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Any]:
|
||||
config = config or None
|
||||
# If <= 1 config use the underlying models batch implementation.
|
||||
if config is None or isinstance(config, dict) or len(config) <= 1:
|
||||
if isinstance(config, list):
|
||||
config = config[0]
|
||||
return await self._model(config).abatch(
|
||||
inputs, config=config, return_exceptions=return_exceptions, **kwargs
|
||||
)
|
||||
# If multiple configs default to Runnable.batch which uses executor to invoke
|
||||
# in parallel.
|
||||
else:
|
||||
return await super().abatch(
|
||||
inputs, config=config, return_exceptions=return_exceptions, **kwargs
|
||||
)
|
||||
|
||||
def batch_as_completed(
|
||||
self,
|
||||
inputs: Sequence[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Tuple[int, Union[Any, Exception]]]:
|
||||
config = config or None
|
||||
# If <= 1 config use the underlying models batch implementation.
|
||||
if config is None or isinstance(config, dict) or len(config) <= 1:
|
||||
if isinstance(config, list):
|
||||
config = config[0]
|
||||
yield from self._model(cast(RunnableConfig, config)).batch_as_completed( # type: ignore[call-overload]
|
||||
inputs, config=config, return_exceptions=return_exceptions, **kwargs
|
||||
)
|
||||
# If multiple configs default to Runnable.batch which uses executor to invoke
|
||||
# in parallel.
|
||||
else:
|
||||
yield from super().batch_as_completed( # type: ignore[call-overload]
|
||||
inputs, config=config, return_exceptions=return_exceptions, **kwargs
|
||||
)
|
||||
|
||||
async def abatch_as_completed(
|
||||
self,
|
||||
inputs: Sequence[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Tuple[int, Any]]:
|
||||
config = config or None
|
||||
# If <= 1 config use the underlying models batch implementation.
|
||||
if config is None or isinstance(config, dict) or len(config) <= 1:
|
||||
if isinstance(config, list):
|
||||
config = config[0]
|
||||
async for x in self._model(
|
||||
cast(RunnableConfig, config)
|
||||
).abatch_as_completed( # type: ignore[call-overload]
|
||||
inputs, config=config, return_exceptions=return_exceptions, **kwargs
|
||||
):
|
||||
yield x
|
||||
# If multiple configs default to Runnable.batch which uses executor to invoke
|
||||
# in parallel.
|
||||
else:
|
||||
async for x in super().abatch_as_completed( # type: ignore[call-overload]
|
||||
inputs, config=config, return_exceptions=return_exceptions, **kwargs
|
||||
):
|
||||
yield x
|
||||
|
||||
def transform(
|
||||
self,
|
||||
input: Iterator[LanguageModelInput],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Any]:
|
||||
for x in self._model(config).transform(input, config=config, **kwargs):
|
||||
yield x
|
||||
|
||||
async def atransform(
|
||||
self,
|
||||
input: AsyncIterator[LanguageModelInput],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Any]:
|
||||
async for x in self._model(config).atransform(input, config=config, **kwargs):
|
||||
yield x
|
||||
|
||||
@overload
|
||||
def astream_log(
|
||||
self,
|
||||
input: Any,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
diff: Literal[True] = True,
|
||||
with_streamed_output_list: bool = True,
|
||||
include_names: Optional[Sequence[str]] = None,
|
||||
include_types: Optional[Sequence[str]] = None,
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
exclude_names: Optional[Sequence[str]] = None,
|
||||
exclude_types: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[RunLogPatch]: ...
|
||||
|
||||
@overload
|
||||
def astream_log(
|
||||
self,
|
||||
input: Any,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
diff: Literal[False],
|
||||
with_streamed_output_list: bool = True,
|
||||
include_names: Optional[Sequence[str]] = None,
|
||||
include_types: Optional[Sequence[str]] = None,
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
exclude_names: Optional[Sequence[str]] = None,
|
||||
exclude_types: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[RunLog]: ...
|
||||
|
||||
async def astream_log(
|
||||
self,
|
||||
input: Any,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
diff: bool = True,
|
||||
with_streamed_output_list: bool = True,
|
||||
include_names: Optional[Sequence[str]] = None,
|
||||
include_types: Optional[Sequence[str]] = None,
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
exclude_names: Optional[Sequence[str]] = None,
|
||||
exclude_types: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AsyncIterator[RunLogPatch], AsyncIterator[RunLog]]:
|
||||
async for x in self._model(config).astream_log( # type: ignore[call-overload, misc]
|
||||
input,
|
||||
config=config,
|
||||
diff=diff,
|
||||
with_streamed_output_list=with_streamed_output_list,
|
||||
include_names=include_names,
|
||||
include_types=include_types,
|
||||
include_tags=include_tags,
|
||||
exclude_tags=exclude_tags,
|
||||
exclude_types=exclude_types,
|
||||
exclude_names=exclude_names,
|
||||
**kwargs,
|
||||
):
|
||||
yield x
|
||||
|
||||
async def astream_events(
|
||||
self,
|
||||
input: Any,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
version: Literal["v1", "v2"],
|
||||
include_names: Optional[Sequence[str]] = None,
|
||||
include_types: Optional[Sequence[str]] = None,
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
exclude_names: Optional[Sequence[str]] = None,
|
||||
exclude_types: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
async for x in self._model(config).astream_events(
|
||||
input,
|
||||
config=config,
|
||||
version=version,
|
||||
include_names=include_names,
|
||||
include_types=include_types,
|
||||
include_tags=include_tags,
|
||||
exclude_tags=exclude_tags,
|
||||
exclude_types=exclude_types,
|
||||
exclude_names=exclude_names,
|
||||
**kwargs,
|
||||
):
|
||||
yield x
|
||||
|
||||
# Explicitly added to satisfy downstream linters.
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
return self.__getattr__("bind_tools")(tools, **kwargs)
|
||||
|
||||
# Explicitly added to satisfy downstream linters.
|
||||
def with_structured_output(
|
||||
self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any
|
||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||
return self.__getattr__("with_structured_output")(schema, **kwargs)
|
||||
|
87
libs/langchain/poetry.lock
generated
87
libs/langchain/poetry.lock
generated
@ -1760,7 +1760,7 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "0.2.13"
|
||||
version = "0.2.18"
|
||||
description = "Building applications with LLMs through composability"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
@ -1784,7 +1784,7 @@ url = "../core"
|
||||
|
||||
[[package]]
|
||||
name = "langchain-openai"
|
||||
version = "0.1.15"
|
||||
version = "0.1.16"
|
||||
description = "An integration package connecting OpenAI and LangChain"
|
||||
optional = true
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
@ -1792,7 +1792,7 @@ files = []
|
||||
develop = true
|
||||
|
||||
[package.dependencies]
|
||||
langchain-core = "^0.2.13"
|
||||
langchain-core = "^0.2.17"
|
||||
openai = "^1.32.0"
|
||||
tiktoken = ">=0.7,<1"
|
||||
|
||||
@ -1800,6 +1800,24 @@ tiktoken = ">=0.7,<1"
|
||||
type = "directory"
|
||||
url = "../partners/openai"
|
||||
|
||||
[[package]]
|
||||
name = "langchain-standard-tests"
|
||||
version = "0.1.1"
|
||||
description = "Standard tests for LangChain implementations"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
files = []
|
||||
develop = true
|
||||
|
||||
[package.dependencies]
|
||||
httpx = "^0.27.0"
|
||||
langchain-core = ">=0.1.40,<0.3"
|
||||
pytest = ">=7,<9"
|
||||
|
||||
[package.source]
|
||||
type = "directory"
|
||||
url = "../standard-tests"
|
||||
|
||||
[[package]]
|
||||
name = "langchain-text-splitters"
|
||||
version = "0.2.2"
|
||||
@ -2490,8 +2508,8 @@ files = [
|
||||
[package.dependencies]
|
||||
numpy = [
|
||||
{version = ">=1.20.3", markers = "python_version < \"3.10\""},
|
||||
{version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""},
|
||||
{version = ">=1.23.2", markers = "python_version >= \"3.11\""},
|
||||
{version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""},
|
||||
]
|
||||
python-dateutil = ">=2.8.2"
|
||||
pytz = ">=2020.1"
|
||||
@ -4111,20 +4129,6 @@ files = [
|
||||
cryptography = ">=35.0.0"
|
||||
types-pyOpenSSL = "*"
|
||||
|
||||
[[package]]
|
||||
name = "types-requests"
|
||||
version = "2.31.0.6"
|
||||
description = "Typing stubs for requests"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "types-requests-2.31.0.6.tar.gz", hash = "sha256:cd74ce3b53c461f1228a9b783929ac73a666658f223e28ed29753771477b3bd0"},
|
||||
{file = "types_requests-2.31.0.6-py3-none-any.whl", hash = "sha256:a2db9cb228a81da8348b49ad6db3f5519452dd20a9c1e1a868c83c5fe88fd1a9"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
types-urllib3 = "*"
|
||||
|
||||
[[package]]
|
||||
name = "types-requests"
|
||||
version = "2.32.0.20240622"
|
||||
@ -4161,17 +4165,6 @@ files = [
|
||||
{file = "types_toml-0.10.8.20240310-py3-none-any.whl", hash = "sha256:627b47775d25fa29977d9c70dc0cbab3f314f32c8d8d0c012f2ef5de7aaec05d"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "types-urllib3"
|
||||
version = "1.26.25.14"
|
||||
description = "Typing stubs for urllib3"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "types-urllib3-1.26.25.14.tar.gz", hash = "sha256:229b7f577c951b8c1b92c1bc2b2fdb0b49847bd2af6d1cc2a2e3dd340f3bda8f"},
|
||||
{file = "types_urllib3-1.26.25.14-py3-none-any.whl", hash = "sha256:9683bbb7fb72e32bfe9d2be6e04875fbe1b3eeec3cbb4ea231435aa7fd6b4f0e"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typing-extensions"
|
||||
version = "4.12.2"
|
||||
@ -4208,22 +4201,6 @@ files = [
|
||||
[package.extras]
|
||||
dev = ["flake8", "flake8-annotations", "flake8-bandit", "flake8-bugbear", "flake8-commas", "flake8-comprehensions", "flake8-continuation", "flake8-datetimez", "flake8-docstrings", "flake8-import-order", "flake8-literal", "flake8-modern-annotations", "flake8-noqa", "flake8-pyproject", "flake8-requirements", "flake8-typechecking-import", "flake8-use-fstring", "mypy", "pep8-naming", "types-PyYAML"]
|
||||
|
||||
[[package]]
|
||||
name = "urllib3"
|
||||
version = "1.26.19"
|
||||
description = "HTTP library with thread-safe connection pooling, file post, and more."
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7"
|
||||
files = [
|
||||
{file = "urllib3-1.26.19-py2.py3-none-any.whl", hash = "sha256:37a0344459b199fce0e80b0d3569837ec6b6937435c5244e7fd73fa6006830f3"},
|
||||
{file = "urllib3-1.26.19.tar.gz", hash = "sha256:3e3d753a8618b86d7de333b4223005f68720bcd6a7d2bcb9fbd2229ec7c1e429"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
brotli = ["brotli (==1.0.9)", "brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"]
|
||||
secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"]
|
||||
socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "urllib3"
|
||||
version = "2.2.2"
|
||||
@ -4241,6 +4218,23 @@ h2 = ["h2 (>=4,<5)"]
|
||||
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
|
||||
zstd = ["zstandard (>=0.18.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "vcrpy"
|
||||
version = "4.3.0"
|
||||
description = "Automatically mock your HTTP interactions to simplify and speed up testing"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "vcrpy-4.3.0-py2.py3-none-any.whl", hash = "sha256:8fbd4be412e8a7f35f623dd61034e6380a1c8dbd0edf6e87277a3289f6e98093"},
|
||||
{file = "vcrpy-4.3.0.tar.gz", hash = "sha256:49c270ce67e826dba027d83e20d25b67a5885487697e97bca6dbdf53d750a0ac"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
PyYAML = "*"
|
||||
six = ">=1.5"
|
||||
wrapt = "*"
|
||||
yarl = "*"
|
||||
|
||||
[[package]]
|
||||
name = "vcrpy"
|
||||
version = "6.0.1"
|
||||
@ -4253,7 +4247,6 @@ files = [
|
||||
|
||||
[package.dependencies]
|
||||
PyYAML = "*"
|
||||
urllib3 = {version = "<2", markers = "platform_python_implementation == \"PyPy\" or python_version < \"3.10\""}
|
||||
wrapt = "*"
|
||||
yarl = "*"
|
||||
|
||||
@ -4568,4 +4561,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools",
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "30237e9280ade99d7c7741aec1b3d38a8e1ccb24a3d0c4380d48ae80ab86a136"
|
||||
content-hash = "14ebfabffa095e7619e9646bf56bc166d18c1c975b65e301bb6163c4e8eecaac"
|
||||
|
@ -95,6 +95,10 @@ pytest-socket = "^0.6.0"
|
||||
syrupy = "^4.0.2"
|
||||
requests-mock = "^1.11.0"
|
||||
|
||||
[tool.poetry.group.test.dependencies.langchain-standard-tests]
|
||||
path = "../standard-tests"
|
||||
develop = true
|
||||
|
||||
[tool.poetry.group.codespell.dependencies]
|
||||
codespell = "^2.2.0"
|
||||
|
||||
|
@ -0,0 +1,59 @@
|
||||
from typing import Type, cast
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
|
||||
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
|
||||
class multiply(BaseModel):
|
||||
"""Product of two ints."""
|
||||
|
||||
x: int
|
||||
y: int
|
||||
|
||||
|
||||
@pytest.mark.requires("langchain_openai", "langchain_anthropic")
|
||||
async def test_init_chat_model_chain() -> None:
|
||||
model = init_chat_model("gpt-4o", configurable_fields="any", config_prefix="bar")
|
||||
model_with_tools = model.bind_tools([multiply])
|
||||
|
||||
model_with_config = model_with_tools.with_config(
|
||||
RunnableConfig(tags=["foo"]),
|
||||
configurable={"bar_model": "claude-3-sonnet-20240229"},
|
||||
)
|
||||
prompt = ChatPromptTemplate.from_messages([("system", "foo"), ("human", "{input}")])
|
||||
chain = prompt | model_with_config
|
||||
output = chain.invoke({"input": "bar"})
|
||||
assert isinstance(output, AIMessage)
|
||||
events = []
|
||||
async for event in chain.astream_events({"input": "bar"}, version="v2"):
|
||||
events.append(event)
|
||||
assert events
|
||||
|
||||
|
||||
class TestStandard(ChatModelIntegrationTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return cast(Type[BaseChatModel], init_chat_model)
|
||||
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {"model": "gpt-4o", "configurable_fields": "any"}
|
||||
|
||||
@property
|
||||
def supports_image_inputs(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def has_tool_calling(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def has_structured_output(self) -> bool:
|
||||
return True
|
@ -1,4 +1,11 @@
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnableConfig, RunnableSequence
|
||||
|
||||
from langchain.chat_models.base import __all__, init_chat_model
|
||||
|
||||
@ -34,14 +41,156 @@ def test_all_imports() -> None:
|
||||
],
|
||||
)
|
||||
def test_init_chat_model(model_name: str, model_provider: str) -> None:
|
||||
init_chat_model(model_name, model_provider=model_provider, api_key="foo")
|
||||
_: BaseChatModel = init_chat_model(
|
||||
model_name, model_provider=model_provider, api_key="foo"
|
||||
)
|
||||
|
||||
|
||||
def test_init_missing_dep() -> None:
|
||||
with pytest.raises(ImportError):
|
||||
init_chat_model("gpt-4o", model_provider="openai")
|
||||
init_chat_model("mixtral-8x7b-32768", model_provider="groq")
|
||||
|
||||
|
||||
def test_init_unknown_provider() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
init_chat_model("foo", model_provider="bar")
|
||||
|
||||
|
||||
@pytest.mark.requires("langchain_openai")
|
||||
@mock.patch.dict(
|
||||
os.environ, {"OPENAI_API_KEY": "foo", "ANTHROPIC_API_KEY": "foo"}, clear=True
|
||||
)
|
||||
def test_configurable() -> None:
|
||||
model = init_chat_model()
|
||||
|
||||
for method in (
|
||||
"invoke",
|
||||
"ainvoke",
|
||||
"batch",
|
||||
"abatch",
|
||||
"stream",
|
||||
"astream",
|
||||
"batch_as_completed",
|
||||
"abatch_as_completed",
|
||||
):
|
||||
assert hasattr(model, method)
|
||||
|
||||
# Doesn't have access non-configurable, non-declarative methods until a config is
|
||||
# provided.
|
||||
for method in ("get_num_tokens", "get_num_tokens_from_messages"):
|
||||
with pytest.raises(AttributeError):
|
||||
getattr(model, method)
|
||||
|
||||
# Can call declarative methods even without a default model.
|
||||
model_with_tools = model.bind_tools(
|
||||
[{"name": "foo", "description": "foo", "parameters": {}}]
|
||||
)
|
||||
|
||||
# Check that original model wasn't mutated by declarative operation.
|
||||
assert model._queued_declarative_operations == []
|
||||
|
||||
# Can iteratively call declarative methods.
|
||||
model_with_config = model_with_tools.with_config(
|
||||
RunnableConfig(tags=["foo"]), configurable={"model": "gpt-4o"}
|
||||
)
|
||||
assert model_with_config.model_name == "gpt-4o" # type: ignore[attr-defined]
|
||||
|
||||
for method in ("get_num_tokens", "get_num_tokens_from_messages"):
|
||||
assert hasattr(model_with_config, method)
|
||||
|
||||
assert model_with_config.dict() == { # type: ignore[attr-defined]
|
||||
"name": None,
|
||||
"bound": {
|
||||
"model_name": "gpt-4o",
|
||||
"model": "gpt-4o",
|
||||
"stream": False,
|
||||
"n": 1,
|
||||
"temperature": 0.7,
|
||||
"presence_penalty": None,
|
||||
"frequency_penalty": None,
|
||||
"seed": None,
|
||||
"top_p": None,
|
||||
"logprobs": False,
|
||||
"top_logprobs": None,
|
||||
"logit_bias": None,
|
||||
"_type": "openai-chat",
|
||||
},
|
||||
"kwargs": {
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {"name": "foo", "description": "foo", "parameters": {}},
|
||||
}
|
||||
]
|
||||
},
|
||||
"config": {"tags": ["foo"], "configurable": {}},
|
||||
"config_factories": [],
|
||||
"custom_input_type": None,
|
||||
"custom_output_type": None,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.requires("langchain_openai", "langchain_anthropic")
|
||||
@mock.patch.dict(
|
||||
os.environ, {"OPENAI_API_KEY": "foo", "ANTHROPIC_API_KEY": "foo"}, clear=True
|
||||
)
|
||||
def test_configurable_with_default() -> None:
|
||||
model = init_chat_model("gpt-4o", configurable_fields="any", config_prefix="bar")
|
||||
for method in (
|
||||
"invoke",
|
||||
"ainvoke",
|
||||
"batch",
|
||||
"abatch",
|
||||
"stream",
|
||||
"astream",
|
||||
"batch_as_completed",
|
||||
"abatch_as_completed",
|
||||
):
|
||||
assert hasattr(model, method)
|
||||
|
||||
# Does have access non-configurable, non-declarative methods since default params
|
||||
# are provided.
|
||||
for method in ("get_num_tokens", "get_num_tokens_from_messages", "dict"):
|
||||
assert hasattr(model, method)
|
||||
|
||||
assert model.model_name == "gpt-4o" # type: ignore[attr-defined]
|
||||
|
||||
model_with_tools = model.bind_tools(
|
||||
[{"name": "foo", "description": "foo", "parameters": {}}]
|
||||
)
|
||||
|
||||
model_with_config = model_with_tools.with_config(
|
||||
RunnableConfig(tags=["foo"]),
|
||||
configurable={"bar_model": "claude-3-sonnet-20240229"},
|
||||
)
|
||||
|
||||
assert model_with_config.model == "claude-3-sonnet-20240229" # type: ignore[attr-defined]
|
||||
# Anthropic defaults to using `transformers` for token counting.
|
||||
with pytest.raises(ImportError):
|
||||
model_with_config.get_num_tokens_from_messages([(HumanMessage("foo"))]) # type: ignore[attr-defined]
|
||||
|
||||
assert model_with_config.dict() == { # type: ignore[attr-defined]
|
||||
"name": None,
|
||||
"bound": {
|
||||
"model": "claude-3-sonnet-20240229",
|
||||
"max_tokens": 1024,
|
||||
"temperature": None,
|
||||
"top_k": None,
|
||||
"top_p": None,
|
||||
"model_kwargs": {},
|
||||
"streaming": False,
|
||||
"max_retries": 2,
|
||||
"default_request_timeout": None,
|
||||
"_type": "anthropic-chat",
|
||||
},
|
||||
"kwargs": {
|
||||
"tools": [{"name": "foo", "description": "foo", "input_schema": {}}]
|
||||
},
|
||||
"config": {"tags": ["foo"], "configurable": {}},
|
||||
"config_factories": [],
|
||||
"custom_input_type": None,
|
||||
"custom_output_type": None,
|
||||
}
|
||||
prompt = ChatPromptTemplate.from_messages([("system", "foo")])
|
||||
chain = prompt | model_with_config
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
|
@ -79,6 +79,7 @@ def test_test_group_dependencies(poetry_conf: Mapping[str, Any]) -> None:
|
||||
"duckdb-engine",
|
||||
"freezegun",
|
||||
"langchain-core",
|
||||
"langchain-standard-tests",
|
||||
"langchain-text-splitters",
|
||||
"langchain-openai",
|
||||
"lark",
|
||||
|
Loading…
Reference in New Issue
Block a user