diff --git a/libs/partners/ai21/tests/integration_tests/test_standard.py b/libs/partners/ai21/tests/integration_tests/test_standard.py index 2d74ca59ec..2235dcd321 100644 --- a/libs/partners/ai21/tests/integration_tests/test_standard.py +++ b/libs/partners/ai21/tests/integration_tests/test_standard.py @@ -10,98 +10,38 @@ from langchain_standard_tests.integration_tests import ChatModelIntegrationTests from langchain_ai21 import ChatAI21 -class TestAI21J2(ChatModelIntegrationTests): +class BaseTestAI21(ChatModelIntegrationTests): def teardown(self) -> None: # avoid getting rate limited time.sleep(1) - @pytest.fixture + @property def chat_model_class(self) -> Type[BaseChatModel]: return ChatAI21 @pytest.mark.xfail(reason="Emits AIMessage instead of AIMessageChunk.") - def test_stream( - self, - chat_model_class: Type[BaseChatModel], - chat_model_params: dict, - ) -> None: - super().test_stream( - chat_model_class, - chat_model_params, - ) + def test_stream(self, model: BaseChatModel) -> None: + super().test_stream(model) @pytest.mark.xfail(reason="Emits AIMessage instead of AIMessageChunk.") - async def test_astream( - self, - chat_model_class: Type[BaseChatModel], - chat_model_params: dict, - ) -> None: - await super().test_astream( - chat_model_class, - chat_model_params, - ) + async def test_astream(self, model: BaseChatModel) -> None: + await super().test_astream(model) @pytest.mark.xfail(reason="Not implemented.") - def test_usage_metadata( - self, - chat_model_class: Type[BaseChatModel], - chat_model_params: dict, - ) -> None: - super().test_usage_metadata( - chat_model_class, - chat_model_params, - ) + def test_usage_metadata(self, model: BaseChatModel) -> None: + super().test_usage_metadata(model) - @pytest.fixture + +class TestAI21J2(BaseTestAI21): + @property def chat_model_params(self) -> dict: return { "model": "j2-ultra", } -class TestAI21Jamba(ChatModelIntegrationTests): - def teardown(self) -> None: - # avoid getting rate limited - time.sleep(1) - - @pytest.fixture - def chat_model_class(self) -> Type[BaseChatModel]: - return ChatAI21 - - @pytest.mark.xfail(reason="Emits AIMessage instead of AIMessageChunk.") - def test_stream( - self, - chat_model_class: Type[BaseChatModel], - chat_model_params: dict, - ) -> None: - super().test_stream( - chat_model_class, - chat_model_params, - ) - - @pytest.mark.xfail(reason="Emits AIMessage instead of AIMessageChunk.") - async def test_astream( - self, - chat_model_class: Type[BaseChatModel], - chat_model_params: dict, - ) -> None: - await super().test_astream( - chat_model_class, - chat_model_params, - ) - - @pytest.mark.xfail(reason="Not implemented.") - def test_usage_metadata( - self, - chat_model_class: Type[BaseChatModel], - chat_model_params: dict, - ) -> None: - super().test_usage_metadata( - chat_model_class, - chat_model_params, - ) - - @pytest.fixture +class TestAI21Jamba(BaseTestAI21): + @property def chat_model_params(self) -> dict: return { "model": "jamba-instruct-preview", diff --git a/libs/partners/ai21/tests/unit_tests/test_standard.py b/libs/partners/ai21/tests/unit_tests/test_standard.py index e0b9c5b98c..f5a0e3eb8a 100644 --- a/libs/partners/ai21/tests/unit_tests/test_standard.py +++ b/libs/partners/ai21/tests/unit_tests/test_standard.py @@ -2,7 +2,6 @@ from typing import Type -import pytest from langchain_core.language_models import BaseChatModel from langchain_standard_tests.unit_tests import ChatModelUnitTests @@ -10,11 +9,11 @@ from langchain_ai21 import ChatAI21 class TestAI21J2(ChatModelUnitTests): - @pytest.fixture + @property def chat_model_class(self) -> Type[BaseChatModel]: return ChatAI21 - @pytest.fixture + @property def chat_model_params(self) -> dict: return { "model": "j2-ultra", @@ -23,11 +22,11 @@ class TestAI21J2(ChatModelUnitTests): class TestAI21Jamba(ChatModelUnitTests): - @pytest.fixture + @property def chat_model_class(self) -> Type[BaseChatModel]: return ChatAI21 - @pytest.fixture + @property def chat_model_params(self) -> dict: return { "model": "jamba-instruct", diff --git a/libs/partners/anthropic/tests/integration_tests/test_standard.py b/libs/partners/anthropic/tests/integration_tests/test_standard.py index 464f5f947e..28a194e75a 100644 --- a/libs/partners/anthropic/tests/integration_tests/test_standard.py +++ b/libs/partners/anthropic/tests/integration_tests/test_standard.py @@ -2,7 +2,6 @@ from typing import Type -import pytest from langchain_core.language_models import BaseChatModel from langchain_standard_tests.integration_tests import ChatModelIntegrationTests @@ -10,12 +9,10 @@ from langchain_anthropic import ChatAnthropic class TestAnthropicStandard(ChatModelIntegrationTests): - @pytest.fixture + @property def chat_model_class(self) -> Type[BaseChatModel]: return ChatAnthropic - @pytest.fixture + @property def chat_model_params(self) -> dict: - return { - "model": "claude-3-haiku-20240307", - } + return {"model": "claude-3-haiku-20240307"} diff --git a/libs/partners/anthropic/tests/unit_tests/test_standard.py b/libs/partners/anthropic/tests/unit_tests/test_standard.py index 2650554e79..7976dcb2bc 100644 --- a/libs/partners/anthropic/tests/unit_tests/test_standard.py +++ b/libs/partners/anthropic/tests/unit_tests/test_standard.py @@ -2,7 +2,6 @@ from typing import Type -import pytest from langchain_core.language_models import BaseChatModel from langchain_standard_tests.unit_tests import ChatModelUnitTests @@ -10,12 +9,10 @@ from langchain_anthropic import ChatAnthropic class TestAnthropicStandard(ChatModelUnitTests): - @pytest.fixture + @property def chat_model_class(self) -> Type[BaseChatModel]: return ChatAnthropic - @pytest.fixture + @property def chat_model_params(self) -> dict: - return { - "model": "claude-3-haiku-20240307", - } + return {"model": "claude-3-haiku-20240307"} diff --git a/libs/partners/fireworks/langchain_fireworks/chat_models.py b/libs/partners/fireworks/langchain_fireworks/chat_models.py index 48de39f8c8..fb439a6811 100644 --- a/libs/partners/fireworks/langchain_fireworks/chat_models.py +++ b/libs/partners/fireworks/langchain_fireworks/chat_models.py @@ -296,6 +296,8 @@ class ChatFireworks(BaseChatModel): """Model name to use.""" temperature: float = 0.0 """What sampling temperature to use.""" + stop: Optional[Union[str, List[str]]] = Field(None, alias="stop_sequences") + """Default stop sequences.""" model_kwargs: Dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" fireworks_api_key: SecretStr = Field(default=None, alias="api_key") @@ -314,8 +316,8 @@ class ChatFireworks(BaseChatModel): """Number of chat completions to generate for each prompt.""" max_tokens: Optional[int] = None """Maximum number of tokens to generate.""" - stop: Optional[List[str]] = Field(None, alias="stop_sequences") - """Default stop sequences.""" + max_retries: Optional[int] = None + """Maximum number of retries to make when generating.""" class Config: """Configuration for this pydantic object.""" @@ -360,6 +362,9 @@ class ChatFireworks(BaseChatModel): values["client"] = Fireworks(**client_params).chat.completions if not values.get("async_client"): values["async_client"] = AsyncFireworks(**client_params).chat.completions + if values["max_retries"]: + values["client"]._max_retries = values["max_retries"] + values["async_client"]._max_retries = values["max_retries"] return values @property diff --git a/libs/partners/fireworks/tests/integration_tests/test_standard.py b/libs/partners/fireworks/tests/integration_tests/test_standard.py index bfeeca693d..84e118c2a9 100644 --- a/libs/partners/fireworks/tests/integration_tests/test_standard.py +++ b/libs/partners/fireworks/tests/integration_tests/test_standard.py @@ -10,11 +10,11 @@ from langchain_fireworks import ChatFireworks class TestFireworksStandard(ChatModelIntegrationTests): - @pytest.fixture + @property def chat_model_class(self) -> Type[BaseChatModel]: return ChatFireworks - @pytest.fixture + @property def chat_model_params(self) -> dict: return { "model": "accounts/fireworks/models/firefunction-v1", @@ -22,12 +22,5 @@ class TestFireworksStandard(ChatModelIntegrationTests): } @pytest.mark.xfail(reason="Not yet implemented.") - def test_tool_message_histories_list_content( - self, - chat_model_class: Type[BaseChatModel], - chat_model_params: dict, - chat_model_has_tool_calling: bool, - ) -> None: - super().test_tool_message_histories_list_content( - chat_model_class, chat_model_params, chat_model_has_tool_calling - ) + def test_tool_message_histories_list_content(self, model: BaseChatModel) -> None: + super().test_tool_message_histories_list_content(model) diff --git a/libs/partners/fireworks/tests/unit_tests/test_standard.py b/libs/partners/fireworks/tests/unit_tests/test_standard.py index 455af60288..6e5bdc2f50 100644 --- a/libs/partners/fireworks/tests/unit_tests/test_standard.py +++ b/libs/partners/fireworks/tests/unit_tests/test_standard.py @@ -2,7 +2,6 @@ from typing import Type -import pytest from langchain_core.language_models import BaseChatModel from langchain_standard_tests.unit_tests import ChatModelUnitTests @@ -10,12 +9,10 @@ from langchain_fireworks import ChatFireworks class TestFireworksStandard(ChatModelUnitTests): - @pytest.fixture + @property def chat_model_class(self) -> Type[BaseChatModel]: return ChatFireworks - @pytest.fixture + @property def chat_model_params(self) -> dict: - return { - "api_key": "test_api_key", - } + return {"api_key": "test_api_key"} diff --git a/libs/partners/groq/langchain_groq/chat_models.py b/libs/partners/groq/langchain_groq/chat_models.py index 7c46e266a8..df49602b73 100644 --- a/libs/partners/groq/langchain_groq/chat_models.py +++ b/libs/partners/groq/langchain_groq/chat_models.py @@ -304,6 +304,8 @@ class ChatGroq(BaseChatModel): """Model name to use.""" temperature: float = 0.7 """What sampling temperature to use.""" + stop: Optional[Union[List[str], str]] = Field(None, alias="stop_sequences") + """Default stop sequences.""" model_kwargs: Dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" groq_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") @@ -326,8 +328,6 @@ class ChatGroq(BaseChatModel): """Number of chat completions to generate for each prompt.""" max_tokens: Optional[int] = None """Maximum number of tokens to generate.""" - stop: Optional[List[str]] = Field(None, alias="stop_sequences") - """Default stop sequences.""" default_headers: Union[Mapping[str, str], None] = None default_query: Union[Mapping[str, object], None] = None # Configure a custom httpx client. See the @@ -449,7 +449,7 @@ class ChatGroq(BaseChatModel): if ls_max_tokens := params.get("max_tokens", self.max_tokens): ls_params["ls_max_tokens"] = ls_max_tokens if ls_stop := stop or params.get("stop", None) or self.stop: - ls_params["ls_stop"] = ls_stop + ls_params["ls_stop"] = ls_stop if isinstance(ls_stop, list) else [ls_stop] return ls_params def _generate( @@ -804,10 +804,19 @@ class ChatGroq(BaseChatModel): formatted_tools = [convert_to_openai_tool(tool) for tool in tools] if tool_choice is not None and tool_choice: + if tool_choice == "any": + if len(tools) > 1: + raise ValueError( + f"Groq does not currently support {tool_choice=}. Should " + f"be one of 'auto', 'none', or the name of the tool to call." + ) + else: + tool_choice = convert_to_openai_tool(tools[0])["function"]["name"] if isinstance(tool_choice, str) and ( tool_choice not in ("auto", "any", "none") ): tool_choice = {"type": "function", "function": {"name": tool_choice}} + # TODO: Remove this update once 'any' is supported. if isinstance(tool_choice, dict) and (len(formatted_tools) != 1): raise ValueError( "When specifying `tool_choice`, you must provide exactly one " diff --git a/libs/partners/groq/pyproject.toml b/libs/partners/groq/pyproject.toml index 15f2334a80..daab187a7a 100644 --- a/libs/partners/groq/pyproject.toml +++ b/libs/partners/groq/pyproject.toml @@ -92,5 +92,6 @@ filterwarnings = [ 'ignore:The method `ChatGroq.with_structured_output` is in beta', # Maintain support for pydantic 1.X 'default:The `dict` method is deprecated; use `model_dump` instead:DeprecationWarning', + "ignore:tool_choice='any' is not currently supported. Converting to 'auto'.", ] asyncio_mode = "auto" diff --git a/libs/partners/groq/tests/integration_tests/test_standard.py b/libs/partners/groq/tests/integration_tests/test_standard.py index b458d8adf1..4234ff616c 100644 --- a/libs/partners/groq/tests/integration_tests/test_standard.py +++ b/libs/partners/groq/tests/integration_tests/test_standard.py @@ -10,17 +10,10 @@ from langchain_groq import ChatGroq class TestGroqStandard(ChatModelIntegrationTests): - @pytest.fixture + @property def chat_model_class(self) -> Type[BaseChatModel]: return ChatGroq @pytest.mark.xfail(reason="Not yet implemented.") - def test_tool_message_histories_list_content( - self, - chat_model_class: Type[BaseChatModel], - chat_model_params: dict, - chat_model_has_tool_calling: bool, - ) -> None: - super().test_tool_message_histories_list_content( - chat_model_class, chat_model_params, chat_model_has_tool_calling - ) + def test_tool_message_histories_list_content(self, model: BaseChatModel) -> None: + super().test_tool_message_histories_list_content(model) diff --git a/libs/partners/groq/tests/unit_tests/test_standard.py b/libs/partners/groq/tests/unit_tests/test_standard.py index 38841230a9..c677d6378e 100644 --- a/libs/partners/groq/tests/unit_tests/test_standard.py +++ b/libs/partners/groq/tests/unit_tests/test_standard.py @@ -2,14 +2,26 @@ from typing import Type -import pytest from langchain_core.language_models import BaseChatModel -from langchain_standard_tests.unit_tests import ChatModelUnitTests +from langchain_core.runnables import RunnableBinding +from langchain_standard_tests.unit_tests.chat_models import ( + ChatModelUnitTests, + Person, + my_adder_tool, +) from langchain_groq import ChatGroq class TestGroqStandard(ChatModelUnitTests): - @pytest.fixture + @property def chat_model_class(self) -> Type[BaseChatModel]: return ChatGroq + + def test_bind_tool_pydantic(self, model: BaseChatModel) -> None: + """Does not currently support tool_choice='any'.""" + if not self.has_tool_calling: + return + + tool_model = model.bind_tools([Person, Person.schema(), my_adder_tool]) + assert isinstance(tool_model, RunnableBinding) diff --git a/libs/partners/mistralai/tests/integration_tests/test_standard.py b/libs/partners/mistralai/tests/integration_tests/test_standard.py index d9b8ff1969..6971719930 100644 --- a/libs/partners/mistralai/tests/integration_tests/test_standard.py +++ b/libs/partners/mistralai/tests/integration_tests/test_standard.py @@ -2,7 +2,6 @@ from typing import Type -import pytest from langchain_core.language_models import BaseChatModel from langchain_standard_tests.integration_tests import ChatModelIntegrationTests @@ -10,13 +9,10 @@ from langchain_mistralai import ChatMistralAI class TestMistralStandard(ChatModelIntegrationTests): - @pytest.fixture + @property def chat_model_class(self) -> Type[BaseChatModel]: return ChatMistralAI - @pytest.fixture + @property def chat_model_params(self) -> dict: - return { - "model": "mistral-large-latest", - "temperature": 0, - } + return {"model": "mistral-large-latest", "temperature": 0} diff --git a/libs/partners/mistralai/tests/unit_tests/test_standard.py b/libs/partners/mistralai/tests/unit_tests/test_standard.py index 46ef3ec3a4..100a766e89 100644 --- a/libs/partners/mistralai/tests/unit_tests/test_standard.py +++ b/libs/partners/mistralai/tests/unit_tests/test_standard.py @@ -2,7 +2,6 @@ from typing import Type -import pytest from langchain_core.language_models import BaseChatModel from langchain_standard_tests.unit_tests import ChatModelUnitTests @@ -10,6 +9,6 @@ from langchain_mistralai import ChatMistralAI class TestMistralStandard(ChatModelUnitTests): - @pytest.fixture + @property def chat_model_class(self) -> Type[BaseChatModel]: return ChatMistralAI diff --git a/libs/partners/openai/langchain_openai/chat_models/azure.py b/libs/partners/openai/langchain_openai/chat_models/azure.py index f820228787..30db136384 100644 --- a/libs/partners/openai/langchain_openai/chat_models/azure.py +++ b/libs/partners/openai/langchain_openai/chat_models/azure.py @@ -3,13 +3,18 @@ from __future__ import annotations import logging import os -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Type, Union import openai +from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import LangSmithParams +from langchain_core.messages import BaseMessage from langchain_core.outputs import ChatResult -from langchain_core.pydantic_v1 import Field, SecretStr, root_validator +from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator +from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_openai.chat_models.base import BaseChatOpenAI @@ -210,6 +215,27 @@ class AzureChatOpenAI(BaseChatOpenAI): ).chat.completions return values + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + *, + tool_choice: Optional[ + Union[dict, str, Literal["auto", "none", "required", "any"], bool] + ] = None, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + # As of 05/2024 Azure OpenAI doesn't support tool_choice="required". + # TODO: Update this condition once tool_choice="required" is supported. + if tool_choice in ("any", "required", True): + if len(tools) > 1: + raise ValueError( + f"Azure OpenAI does not currently support {tool_choice=}. Should " + f"be one of 'auto', 'none', or the name of the tool to call." + ) + else: + tool_choice = convert_to_openai_tool(tools[0])["function"]["name"] + return super().bind_tools(tools, tool_choice=tool_choice, **kwargs) + @property def _identifying_params(self) -> Dict[str, Any]: """Get the identifying parameters.""" diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 2df447b70f..91edb1309d 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -345,6 +345,8 @@ class BaseChatOpenAI(BaseChatModel): http_async_client: Union[Any, None] = None """Optional httpx.AsyncClient. Only used for async invocations. Must specify http_client as well if you'd like a custom client for sync invocations.""" + stop: Optional[Union[List[str], str]] = Field(default=None, alias="stop_sequences") + """Default stop sequences.""" class Config: """Configuration for this pydantic object.""" @@ -441,6 +443,7 @@ class BaseChatOpenAI(BaseChatModel): "stream": self.streaming, "n": self.n, "temperature": self.temperature, + "stop": self.stop, **self.model_kwargs, } if self.max_tokens is not None: @@ -548,8 +551,6 @@ class BaseChatOpenAI(BaseChatModel): ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: params = self._default_params if stop is not None: - if "stop" in params: - raise ValueError("`stop` found in both the input and default params.") params["stop"] = stop message_dicts = [_convert_message_to_dict(m) for m in messages] return message_dicts, params @@ -871,15 +872,7 @@ class BaseChatOpenAI(BaseChatModel): if tool_choice == "any": tool_choice = "required" elif isinstance(tool_choice, bool): - if len(tools) > 1: - raise ValueError( - "tool_choice=True can only be used when a single tool is " - f"passed in, received {len(tools)} tools." - ) - tool_choice = { - "type": "function", - "function": {"name": formatted_tools[0]["function"]["name"]}, - } + tool_choice = "required" elif isinstance(tool_choice, dict): tool_names = [ formatted_tool["function"]["name"] @@ -1094,7 +1087,7 @@ class BaseChatOpenAI(BaseChatModel): "schema must be specified when method is 'function_calling'. " "Received None." ) - llm = self.bind_tools([schema], tool_choice=True) + llm = self.bind_tools([schema], tool_choice="any") if is_pydantic_schema: output_parser: OutputParserLike = PydanticToolsParser( tools=[schema], first_tool_only=True diff --git a/libs/partners/openai/langchain_openai/llms/base.py b/libs/partners/openai/langchain_openai/llms/base.py index d156fcc602..07d864ab7f 100644 --- a/libs/partners/openai/langchain_openai/llms/base.py +++ b/libs/partners/openai/langchain_openai/llms/base.py @@ -440,8 +440,6 @@ class BaseOpenAI(BaseLLM): ) -> List[List[str]]: """Get the sub prompts for llm call.""" if stop is not None: - if "stop" in params: - raise ValueError("`stop` found in both the input and default params.") params["stop"] = stop if params["max_tokens"] == -1: if len(prompts) != 1: diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_azure_standard.py b/libs/partners/openai/tests/integration_tests/chat_models/test_azure_standard.py index ad21b06311..8cddaac726 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_azure_standard.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_azure_standard.py @@ -3,7 +3,6 @@ import os from typing import Type -import pytest from langchain_core.language_models import BaseChatModel from langchain_standard_tests.integration_tests import ChatModelIntegrationTests @@ -19,15 +18,15 @@ DEPLOYMENT_NAME = os.environ.get( class TestOpenAIStandard(ChatModelIntegrationTests): - @pytest.fixture + @property def chat_model_class(self) -> Type[BaseChatModel]: return AzureChatOpenAI - @pytest.fixture + @property def chat_model_params(self) -> dict: return { "deployment_name": DEPLOYMENT_NAME, "openai_api_version": OPENAI_API_VERSION, "azure_endpoint": OPENAI_API_BASE, - "openai_api_key": OPENAI_API_KEY, + "api_key": OPENAI_API_KEY, } diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base_standard.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base_standard.py index 48cdb4d8e7..b4b925aa38 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base_standard.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base_standard.py @@ -2,7 +2,6 @@ from typing import Type -import pytest from langchain_core.language_models import BaseChatModel from langchain_standard_tests.integration_tests import ChatModelIntegrationTests @@ -10,6 +9,6 @@ from langchain_openai import ChatOpenAI class TestOpenAIStandard(ChatModelIntegrationTests): - @pytest.fixture + @property def chat_model_class(self) -> Type[BaseChatModel]: return ChatOpenAI diff --git a/libs/partners/openai/tests/integration_tests/llms/test_base.py b/libs/partners/openai/tests/integration_tests/llms/test_base.py index 651f94ef1b..b5e4defb71 100644 --- a/libs/partners/openai/tests/integration_tests/llms/test_base.py +++ b/libs/partners/openai/tests/integration_tests/llms/test_base.py @@ -99,13 +99,6 @@ def test_openai_stop_valid() -> None: assert first_output == second_output -def test_openai_stop_error() -> None: - """Test openai stop logic on bad configuration.""" - llm = OpenAI(stop="3", temperature=0) - with pytest.raises(ValueError): - llm.invoke("write an ordered list of five items", stop=["\n"]) - - @pytest.mark.scheduled def test_openai_streaming() -> None: """Test streaming tokens from OpenAI.""" diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_azure_standard.py b/libs/partners/openai/tests/unit_tests/chat_models/test_azure_standard.py index 40c4ff2d0c..d75cf98bc8 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_azure_standard.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_azure_standard.py @@ -2,23 +2,31 @@ from typing import Type -import pytest from langchain_core.language_models import BaseChatModel +from langchain_core.runnables import RunnableBinding from langchain_standard_tests.unit_tests import ChatModelUnitTests +from langchain_standard_tests.unit_tests.chat_models import Person, my_adder_tool from langchain_openai import AzureChatOpenAI class TestOpenAIStandard(ChatModelUnitTests): - @pytest.fixture + @property def chat_model_class(self) -> Type[BaseChatModel]: return AzureChatOpenAI - @pytest.fixture + @property def chat_model_params(self) -> dict: return { "deployment_name": "test", "openai_api_version": "2021-10-01", "azure_endpoint": "https://test.azure.com", - "openai_api_key": "test", } + + def test_bind_tool_pydantic(self, model: BaseChatModel) -> None: + """Does not currently support tool_choice='any'.""" + if not self.has_tool_calling: + return + + tool_model = model.bind_tools([Person, Person.schema(), my_adder_tool]) + assert isinstance(tool_model, RunnableBinding) diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base_standard.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base_standard.py index 5936989a34..34198d0fb9 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base_standard.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base_standard.py @@ -2,7 +2,6 @@ from typing import Type -import pytest from langchain_core.language_models import BaseChatModel from langchain_standard_tests.unit_tests import ChatModelUnitTests @@ -10,6 +9,6 @@ from langchain_openai import ChatOpenAI class TestOpenAIStandard(ChatModelUnitTests): - @pytest.fixture + @property def chat_model_class(self) -> Type[BaseChatModel]: return ChatOpenAI diff --git a/libs/partners/together/poetry.lock b/libs/partners/together/poetry.lock index 51867e30b1..f43b7a1b7f 100644 --- a/libs/partners/together/poetry.lock +++ b/libs/partners/together/poetry.lock @@ -566,7 +566,7 @@ files = [ [[package]] name = "langchain-core" -version = "0.2.4" +version = "0.2.7" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -575,15 +575,12 @@ develop = true [package.dependencies] jsonpatch = "^1.33" -langsmith = "^0.1.66" -packaging = "^23.2" +langsmith = "^0.1.75" +packaging = ">=23.2,<25" pydantic = ">=1,<3" PyYAML = ">=5.3" tenacity = "^8.1.0" -[package.extras] -extended-testing = ["jinja2 (>=3,<4)"] - [package.source] type = "directory" url = "../../core" @@ -608,7 +605,7 @@ url = "../openai" [[package]] name = "langchain-standard-tests" -version = "0.1.0" +version = "0.1.1" description = "Standard tests for LangChain implementations" optional = false python-versions = ">=3.8.1,<4.0" @@ -625,13 +622,13 @@ url = "../../standard-tests" [[package]] name = "langsmith" -version = "0.1.69" +version = "0.1.77" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langsmith-0.1.69-py3-none-any.whl", hash = "sha256:3d7bd6fadb0852fc4cd2e7cf8a1593306046900052da3970bb2b48ed21cc73d8"}, - {file = "langsmith-0.1.69.tar.gz", hash = "sha256:0146764904a8e479620b7e73efcba1cf172b621799564dd7e7342859c05c264a"}, + {file = "langsmith-0.1.77-py3-none-any.whl", hash = "sha256:2202cc21b1ed7e7b9e5d2af2694be28898afa048c09fdf09f620cbd9301755ae"}, + {file = "langsmith-0.1.77.tar.gz", hash = "sha256:4ace09077a9a4e412afeb4b517ca68e7de7b07f36e4792dc8236ac5207c0c0c7"}, ] [package.dependencies] @@ -871,6 +868,51 @@ files = [ {file = "numpy-1.24.4.tar.gz", hash = "sha256:80f5e3a4e498641401868df4208b74581206afbee7cf7b8329daae82676d9463"}, ] +[[package]] +name = "numpy" +version = "1.26.4" +description = "Fundamental package for array computing in Python" +optional = false +python-versions = ">=3.9" +files = [ + {file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"}, + {file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"}, + {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4"}, + {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f"}, + {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a"}, + {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2"}, + {file = "numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07"}, + {file = "numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5"}, + {file = "numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71"}, + {file = "numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef"}, + {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e"}, + {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5"}, + {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a"}, + {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a"}, + {file = "numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20"}, + {file = "numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2"}, + {file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"}, + {file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"}, + {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b"}, + {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed"}, + {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a"}, + {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0"}, + {file = "numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110"}, + {file = "numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818"}, + {file = "numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c"}, + {file = "numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be"}, + {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764"}, + {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3"}, + {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd"}, + {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c"}, + {file = "numpy-1.26.4-cp39-cp39-win32.whl", hash = "sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6"}, + {file = "numpy-1.26.4-cp39-cp39-win_amd64.whl", hash = "sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0"}, + {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"}, +] + [[package]] name = "openai" version = "1.30.1" @@ -1679,4 +1721,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "e7b21f556475be4c7133b74b6b0e138012bef9d47bc5bdc9709b24e55d9500f0" +content-hash = "8a868382f8f3b693dccc1ce99428cdf9d6f8b6f77b3403c342c2bcc7b8526db9" diff --git a/libs/partners/together/pyproject.toml b/libs/partners/together/pyproject.toml index ae6bd8c4ba..7e1ad59e52 100644 --- a/libs/partners/together/pyproject.toml +++ b/libs/partners/together/pyproject.toml @@ -43,6 +43,12 @@ codespell = "^2.2.0" optional = true [tool.poetry.group.test_integration.dependencies] +# Support Python 3.8 and 3.12+. +numpy = [ + {version = "^1", python = "<3.12"}, + {version = "^1.26.0", python = ">=3.12"} +] + [tool.poetry.group.lint] optional = true diff --git a/libs/partners/together/tests/integration_tests/test_chat_models_standard.py b/libs/partners/together/tests/integration_tests/test_chat_models_standard.py index 5c51ffe8fe..0150f12c1f 100644 --- a/libs/partners/together/tests/integration_tests/test_chat_models_standard.py +++ b/libs/partners/together/tests/integration_tests/test_chat_models_standard.py @@ -2,20 +2,17 @@ from typing import Type -import pytest from langchain_core.language_models import BaseChatModel from langchain_standard_tests.integration_tests import ChatModelIntegrationTests from langchain_together import ChatTogether -class TestTogethertandard(ChatModelIntegrationTests): - @pytest.fixture +class TestTogetherStandard(ChatModelIntegrationTests): + @property def chat_model_class(self) -> Type[BaseChatModel]: return ChatTogether - @pytest.fixture + @property def chat_model_params(self) -> dict: - return { - "model": "mistralai/Mistral-7B-Instruct-v0.1", - } + return {"model": "mistralai/Mistral-7B-Instruct-v0.1"} diff --git a/libs/partners/together/tests/unit_tests/test_chat_models_standard.py b/libs/partners/together/tests/unit_tests/test_chat_models_standard.py index 5ce9fea21e..824cde156c 100644 --- a/libs/partners/together/tests/unit_tests/test_chat_models_standard.py +++ b/libs/partners/together/tests/unit_tests/test_chat_models_standard.py @@ -2,7 +2,6 @@ from typing import Type -import pytest from langchain_core.language_models import BaseChatModel from langchain_standard_tests.unit_tests import ChatModelUnitTests @@ -10,12 +9,10 @@ from langchain_together import ChatTogether class TestTogetherStandard(ChatModelUnitTests): - @pytest.fixture + @property def chat_model_class(self) -> Type[BaseChatModel]: return ChatTogether - @pytest.fixture + @property def chat_model_params(self) -> dict: - return { - "model": "meta-llama/Llama-3-8b-chat-hf", - } + return {"model": "meta-llama/Llama-3-8b-chat-hf"} diff --git a/libs/partners/upstage/poetry.lock b/libs/partners/upstage/poetry.lock index 79b75628aa..7fde6b0df6 100644 --- a/libs/partners/upstage/poetry.lock +++ b/libs/partners/upstage/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "anyio" @@ -340,7 +340,7 @@ files = [ [[package]] name = "langchain-core" -version = "0.2.4" +version = "0.2.7" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -349,15 +349,12 @@ develop = true [package.dependencies] jsonpatch = "^1.33" -langsmith = "^0.1.66" -packaging = "^23.2" +langsmith = "^0.1.75" +packaging = ">=23.2,<25" pydantic = ">=1,<3" PyYAML = ">=5.3" tenacity = "^8.1.0" -[package.extras] -extended-testing = ["jinja2 (>=3,<4)"] - [package.source] type = "directory" url = "../../core" @@ -382,7 +379,7 @@ url = "../openai" [[package]] name = "langchain-standard-tests" -version = "0.1.0" +version = "0.1.1" description = "Standard tests for LangChain implementations" optional = false python-versions = ">=3.8.1,<4.0" @@ -399,13 +396,13 @@ url = "../../standard-tests" [[package]] name = "langsmith" -version = "0.1.69" +version = "0.1.77" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langsmith-0.1.69-py3-none-any.whl", hash = "sha256:3d7bd6fadb0852fc4cd2e7cf8a1593306046900052da3970bb2b48ed21cc73d8"}, - {file = "langsmith-0.1.69.tar.gz", hash = "sha256:0146764904a8e479620b7e73efcba1cf172b621799564dd7e7342859c05c264a"}, + {file = "langsmith-0.1.77-py3-none-any.whl", hash = "sha256:2202cc21b1ed7e7b9e5d2af2694be28898afa048c09fdf09f620cbd9301755ae"}, + {file = "langsmith-0.1.77.tar.gz", hash = "sha256:4ace09077a9a4e412afeb4b517ca68e7de7b07f36e4792dc8236ac5207c0c0c7"}, ] [package.dependencies] @@ -546,6 +543,51 @@ files = [ {file = "numpy-1.24.4.tar.gz", hash = "sha256:80f5e3a4e498641401868df4208b74581206afbee7cf7b8329daae82676d9463"}, ] +[[package]] +name = "numpy" +version = "1.26.4" +description = "Fundamental package for array computing in Python" +optional = false +python-versions = ">=3.9" +files = [ + {file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"}, + {file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"}, + {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4"}, + {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f"}, + {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a"}, + {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2"}, + {file = "numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07"}, + {file = "numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5"}, + {file = "numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71"}, + {file = "numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef"}, + {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e"}, + {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5"}, + {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a"}, + {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a"}, + {file = "numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20"}, + {file = "numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2"}, + {file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"}, + {file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"}, + {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b"}, + {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed"}, + {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a"}, + {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0"}, + {file = "numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110"}, + {file = "numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818"}, + {file = "numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c"}, + {file = "numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be"}, + {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764"}, + {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3"}, + {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd"}, + {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c"}, + {file = "numpy-1.26.4-cp39-cp39-win32.whl", hash = "sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6"}, + {file = "numpy-1.26.4-cp39-cp39-win_amd64.whl", hash = "sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0"}, + {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"}, +] + [[package]] name = "openai" version = "1.30.1" @@ -725,26 +767,31 @@ python-versions = ">=3.8" files = [ {file = "PyMuPDF-1.24.3-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:8d63d850d337c10fa49859697b9517e461b28e6d5d5a80121c72cc518eb0bae0"}, {file = "PyMuPDF-1.24.3-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:5f4a9ffabbcf8f19f6938484702e393ed6d423516f3e52c9d443162e3e42a884"}, + {file = "PyMuPDF-1.24.3-cp310-none-manylinux2014_aarch64.whl", hash = "sha256:c7dfddf19d2a8c734c5439692e87419c86f2621f1f205100355afb3bb43e5675"}, {file = "PyMuPDF-1.24.3-cp310-none-manylinux2014_x86_64.whl", hash = "sha256:cf743d8c7f7261112153525ba7de1d954f9d563b875414814b27da35fb0df2cc"}, {file = "PyMuPDF-1.24.3-cp310-none-win32.whl", hash = "sha256:e30e8dec04c241739e0e9cf89b8a0317e991889dbca781e30abef228009c8cbd"}, {file = "PyMuPDF-1.24.3-cp310-none-win_amd64.whl", hash = "sha256:3ceca02b143efe6b6f159d64a2f0e0aa32d0670894149a7f7144125fe2982da2"}, {file = "PyMuPDF-1.24.3-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:171313ee03e031437687cf856914eb61f66a5a76eddedc63048a63b69b00474b"}, {file = "PyMuPDF-1.24.3-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:a421b332c257e70d9daed350cebefc043817ae0fd6b361734ee27f98288cc8c7"}, + {file = "PyMuPDF-1.24.3-cp311-none-manylinux2014_aarch64.whl", hash = "sha256:cc519230e352024111f065a1d32366eea4f1f1034e01515f10dbed709d9ab5ad"}, {file = "PyMuPDF-1.24.3-cp311-none-manylinux2014_x86_64.whl", hash = "sha256:9df2d5e89810d3708fb8933dbc07505a57bfcb976a72bc559c7f0ccacc054c76"}, {file = "PyMuPDF-1.24.3-cp311-none-win32.whl", hash = "sha256:1de61f186c8367d1647d679bf6a4a77198751b378f9b67958a3b5d59adbc8c95"}, {file = "PyMuPDF-1.24.3-cp311-none-win_amd64.whl", hash = "sha256:28e8c6c29de2951e29f98f17752eff0e80776fca7fe7ed5c7368363dff887c6c"}, {file = "PyMuPDF-1.24.3-cp312-none-macosx_10_9_x86_64.whl", hash = "sha256:34ab87e6d0f79eea9b632ed0401de20aff2622c95aa1a57fd17b49401c22c906"}, {file = "PyMuPDF-1.24.3-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:ef2311861a3173072c489dc365827bb26f2c4487f969501afbbf1746478553ea"}, + {file = "PyMuPDF-1.24.3-cp312-none-manylinux2014_aarch64.whl", hash = "sha256:c4df2c50eba8fb8d8ffe63bd4099c57b562d11ed01dcf6cd922c4ea774212a34"}, {file = "PyMuPDF-1.24.3-cp312-none-manylinux2014_x86_64.whl", hash = "sha256:401f2da8621f19bc302efa2a404c794b17982dea0e552b48ecd2c3f8d10b4707"}, {file = "PyMuPDF-1.24.3-cp312-none-win32.whl", hash = "sha256:ce4c07355b45e95803d1221cece01be58e32d1d9daec0d1ebc075ad03640c177"}, {file = "PyMuPDF-1.24.3-cp312-none-win_amd64.whl", hash = "sha256:4f084f735e2e2d21f2c76de1abdcb44261889ec01a2842b57e69c89502f74b7a"}, {file = "PyMuPDF-1.24.3-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:14b2459e1a7e4dbf9ec6026e6056ccba6868bdfff1ffb346fd910108a61be095"}, {file = "PyMuPDF-1.24.3-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:f3572c2a85a12026637d485d6156b7f279a4aac7f474a341e5e06e8943ab2e0b"}, + {file = "PyMuPDF-1.24.3-cp38-none-manylinux2014_aarch64.whl", hash = "sha256:b0ceed71fa62eebd1bf8b55875cd6da7c2f09bbe2067218b68b5deb0d9feaa6e"}, {file = "PyMuPDF-1.24.3-cp38-none-manylinux2014_x86_64.whl", hash = "sha256:e6b9100fa5194be1240b9998643ba122fcffd76149dccda3607455ccfed5fa2b"}, {file = "PyMuPDF-1.24.3-cp38-none-win32.whl", hash = "sha256:88e52a5c6d0375d27401c08fe7f7894f19db4af31169ba6deb6b3c1453f8b6e0"}, {file = "PyMuPDF-1.24.3-cp38-none-win_amd64.whl", hash = "sha256:45c93944a14b19da3ee9b6d648e609f3ca35b8bca5c1cd16e6addcc59e7816d9"}, {file = "PyMuPDF-1.24.3-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:0d4b6caf5ad25b7bd654ad4d42b8b3a00683b742bc5a81b8aeface79811386d5"}, {file = "PyMuPDF-1.24.3-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:6b52ee0460db88c71a06677353a0c768a8bb17718aa313462e9847ed1bf53f87"}, + {file = "PyMuPDF-1.24.3-cp39-none-manylinux2014_aarch64.whl", hash = "sha256:06a8a3226c9ec97c5e1df8cd16ec29b5df83d04ae88e9e0f5f4e25fcc1b997a1"}, {file = "PyMuPDF-1.24.3-cp39-none-manylinux2014_x86_64.whl", hash = "sha256:ce113303b41adb74ae30ebd98761d9bd53477573e47566f05b3b7ff1c7354675"}, {file = "PyMuPDF-1.24.3-cp39-none-win32.whl", hash = "sha256:e4b4b2d5700c48a67da278476767488005408fac29426467b5bb437012197c0b"}, {file = "PyMuPDF-1.24.3-cp39-none-win_amd64.whl", hash = "sha256:39acbac2854ef5b58f28c71bb19e84840771a771ec09cb33c4e66e2679c3b419"}, @@ -763,6 +810,7 @@ python-versions = ">=3.8" files = [ {file = "PyMuPDFb-1.24.3-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:d2ccca660042896d4af479f979ec10674c5a0b3cd2d9ecb0011f08dc82380cce"}, {file = "PyMuPDFb-1.24.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:ad51d21086a16199684a3eebcb47d9c8460fc27e7bebae77f5fe64e8c34ebf34"}, + {file = "PyMuPDFb-1.24.3-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3e7aab000d707c40e3254cd60152897b90952ed9a3567584d70974292f4912ce"}, {file = "PyMuPDFb-1.24.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f39588fd2b7a63e2456df42cd8925c316202e0eb77d115d9c01ba032b2c9086f"}, {file = "PyMuPDFb-1.24.3-py3-none-win32.whl", hash = "sha256:0d606a10cb828cefc9f864bf67bc9d46e8007af55e643f022b59d378af4151a8"}, {file = "PyMuPDFb-1.24.3-py3-none-win_amd64.whl", hash = "sha256:e88289bd4b4afe5966a028774b302f37d4b51dad5c5e6720dd04524910db6c6e"}, @@ -1304,4 +1352,4 @@ watchmedo = ["PyYAML (>=3.10)"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "0073172ce2312055480e9ff47dc99ce7dfd6809208ad5ea4cee5ecf7f12eef56" +content-hash = "b21648a1fdc08f901c82fb3b4773682f0a4b83b03b97ae1ddbd0834b730ff8c2" diff --git a/libs/partners/upstage/pyproject.toml b/libs/partners/upstage/pyproject.toml index 233584d838..f315da7fdf 100644 --- a/libs/partners/upstage/pyproject.toml +++ b/libs/partners/upstage/pyproject.toml @@ -43,6 +43,12 @@ codespell = "^2.2.0" optional = true [tool.poetry.group.test_integration.dependencies] +# Support Python 3.8 and 3.12+. +numpy = [ + {version = "^1", python = "<3.12"}, + {version = "^1.26.0", python = ">=3.12"} +] + [tool.poetry.group.lint] optional = true diff --git a/libs/partners/upstage/tests/integration_tests/test_chat_models_standard.py b/libs/partners/upstage/tests/integration_tests/test_chat_models_standard.py index ba06a00e34..0d1b29ec82 100644 --- a/libs/partners/upstage/tests/integration_tests/test_chat_models_standard.py +++ b/libs/partners/upstage/tests/integration_tests/test_chat_models_standard.py @@ -2,7 +2,6 @@ from typing import Type -import pytest from langchain_core.language_models import BaseChatModel from langchain_standard_tests.integration_tests import ChatModelIntegrationTests @@ -10,12 +9,10 @@ from langchain_upstage import ChatUpstage class TestUpstageStandard(ChatModelIntegrationTests): - @pytest.fixture + @property def chat_model_class(self) -> Type[BaseChatModel]: return ChatUpstage - @pytest.fixture + @property def chat_model_params(self) -> dict: - return { - "model": "solar-1-mini-chat", - } + return {"model": "solar-1-mini-chat"} diff --git a/libs/partners/upstage/tests/unit_tests/test_chat_models_standard.py b/libs/partners/upstage/tests/unit_tests/test_chat_models_standard.py index b42053b224..91dcb760ad 100644 --- a/libs/partners/upstage/tests/unit_tests/test_chat_models_standard.py +++ b/libs/partners/upstage/tests/unit_tests/test_chat_models_standard.py @@ -2,7 +2,6 @@ from typing import Type -import pytest from langchain_core.language_models import BaseChatModel from langchain_standard_tests.unit_tests import ChatModelUnitTests @@ -10,12 +9,10 @@ from langchain_upstage import ChatUpstage class TestUpstageStandard(ChatModelUnitTests): - @pytest.fixture + @property def chat_model_class(self) -> Type[BaseChatModel]: return ChatUpstage - @pytest.fixture + @property def chat_model_params(self) -> dict: - return { - "model": "solar-1-mini-chat", - } + return {"model": "solar-1-mini-chat"} diff --git a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py index 0dfc9a0481..78041da003 100644 --- a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py @@ -1,74 +1,31 @@ import json -from abc import ABC, abstractmethod -from typing import Type import pytest from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage, ToolMessage -from langchain_core.pydantic_v1 import BaseModel, Field -from langchain_core.tools import tool +from langchain_standard_tests.unit_tests.chat_models import ( + ChatModelTests, + my_adder_tool, +) -class Person(BaseModel): - name: str = Field(..., description="The name of the person.") - age: int = Field(..., description="The age of the person.") - -@tool -def my_adder_tool(a: int, b: int) -> int: - """Takes two integers, a and b, and returns their sum.""" - return a + b - - -class ChatModelIntegrationTests(ABC): - @abstractmethod - @pytest.fixture - def chat_model_class(self) -> Type[BaseChatModel]: - ... - - @pytest.fixture - def chat_model_params(self) -> dict: - return {} - - @pytest.fixture - def chat_model_has_tool_calling( - self, chat_model_class: Type[BaseChatModel] - ) -> bool: - return chat_model_class.bind_tools is not BaseChatModel.bind_tools - - @pytest.fixture - def chat_model_has_structured_output( - self, chat_model_class: Type[BaseChatModel] - ) -> bool: - return ( - chat_model_class.with_structured_output - is not BaseChatModel.with_structured_output - ) - - def test_invoke( - self, chat_model_class: Type[BaseChatModel], chat_model_params: dict - ) -> None: - model = chat_model_class(**chat_model_params) +class ChatModelIntegrationTests(ChatModelTests): + def test_invoke(self, model: BaseChatModel) -> None: result = model.invoke("Hello") assert result is not None assert isinstance(result, AIMessage) assert isinstance(result.content, str) assert len(result.content) > 0 - async def test_ainvoke( - self, chat_model_class: Type[BaseChatModel], chat_model_params: dict - ) -> None: - model = chat_model_class(**chat_model_params) + async def test_ainvoke(self, model: BaseChatModel) -> None: result = await model.ainvoke("Hello") assert result is not None assert isinstance(result, AIMessage) assert isinstance(result.content, str) assert len(result.content) > 0 - def test_stream( - self, chat_model_class: Type[BaseChatModel], chat_model_params: dict - ) -> None: - model = chat_model_class(**chat_model_params) + def test_stream(self, model: BaseChatModel) -> None: num_tokens = 0 for token in model.stream("Hello"): assert token is not None @@ -76,10 +33,7 @@ class ChatModelIntegrationTests(ABC): num_tokens += len(token.content) assert num_tokens > 0 - async def test_astream( - self, chat_model_class: Type[BaseChatModel], chat_model_params: dict - ) -> None: - model = chat_model_class(**chat_model_params) + async def test_astream(self, model: BaseChatModel) -> None: num_tokens = 0 async for token in model.astream("Hello"): assert token is not None @@ -87,10 +41,7 @@ class ChatModelIntegrationTests(ABC): num_tokens += len(token.content) assert num_tokens > 0 - def test_batch( - self, chat_model_class: Type[BaseChatModel], chat_model_params: dict - ) -> None: - model = chat_model_class(**chat_model_params) + def test_batch(self, model: BaseChatModel) -> None: batch_results = model.batch(["Hello", "Hey"]) assert batch_results is not None assert isinstance(batch_results, list) @@ -101,10 +52,7 @@ class ChatModelIntegrationTests(ABC): assert isinstance(result.content, str) assert len(result.content) > 0 - async def test_abatch( - self, chat_model_class: Type[BaseChatModel], chat_model_params: dict - ) -> None: - model = chat_model_class(**chat_model_params) + async def test_abatch(self, model: BaseChatModel) -> None: batch_results = await model.abatch(["Hello", "Hey"]) assert batch_results is not None assert isinstance(batch_results, list) @@ -115,14 +63,11 @@ class ChatModelIntegrationTests(ABC): assert isinstance(result.content, str) assert len(result.content) > 0 - def test_conversation( - self, chat_model_class: Type[BaseChatModel], chat_model_params: dict - ) -> None: - model = chat_model_class(**chat_model_params) + def test_conversation(self, model: BaseChatModel) -> None: messages = [ - HumanMessage(content="hello"), - AIMessage(content="hello"), - HumanMessage(content="how are you"), + HumanMessage("hello"), + AIMessage("hello"), + HumanMessage("how are you"), ] result = model.invoke(messages) assert result is not None @@ -130,10 +75,9 @@ class ChatModelIntegrationTests(ABC): assert isinstance(result.content, str) assert len(result.content) > 0 - def test_usage_metadata( - self, chat_model_class: Type[BaseChatModel], chat_model_params: dict - ) -> None: - model = chat_model_class(**chat_model_params) + def test_usage_metadata(self, model: BaseChatModel) -> None: + if not self.returns_usage_metadata: + pytest.skip("Not implemented.") result = model.invoke("Hello") assert result is not None assert isinstance(result, AIMessage) @@ -142,39 +86,35 @@ class ChatModelIntegrationTests(ABC): assert isinstance(result.usage_metadata["output_tokens"], int) assert isinstance(result.usage_metadata["total_tokens"], int) - def test_stop_sequence( - self, chat_model_class: Type[BaseChatModel], chat_model_params: dict - ) -> None: - model = chat_model_class(**chat_model_params) + def test_stop_sequence(self, model: BaseChatModel) -> None: result = model.invoke("hi", stop=["you"]) assert isinstance(result, AIMessage) - model = chat_model_class(**chat_model_params, stop=["you"]) - result = model.invoke("hi") + custom_model = self.chat_model_class( + **{**self.chat_model_params, "stop": ["you"]} + ) + result = custom_model.invoke("hi") assert isinstance(result, AIMessage) def test_tool_message_histories_string_content( self, - chat_model_class: Type[BaseChatModel], - chat_model_params: dict, - chat_model_has_tool_calling: bool, + model: BaseChatModel, ) -> None: """ Test that message histories are compatible with string tool contents (e.g. OpenAI). """ - if not chat_model_has_tool_calling: + if not self.has_tool_calling: pytest.skip("Test requires tool calling.") - model = chat_model_class(**chat_model_params) model_with_tools = model.bind_tools([my_adder_tool]) function_name = "my_adder_tool" function_args = {"a": "1", "b": "2"} messages_string_content = [ - HumanMessage(content="What is 1 + 2"), + HumanMessage("What is 1 + 2"), # string content (e.g. OpenAI) AIMessage( - content="", + "", tool_calls=[ { "name": function_name, @@ -184,8 +124,8 @@ class ChatModelIntegrationTests(ABC): ], ), ToolMessage( + json.dumps({"result": 3}), name=function_name, - content=json.dumps({"result": 3}), tool_call_id="abc123", ), ] @@ -194,26 +134,23 @@ class ChatModelIntegrationTests(ABC): def test_tool_message_histories_list_content( self, - chat_model_class: Type[BaseChatModel], - chat_model_params: dict, - chat_model_has_tool_calling: bool, + model: BaseChatModel, ) -> None: """ Test that message histories are compatible with list tool contents (e.g. Anthropic). """ - if not chat_model_has_tool_calling: + if not self.has_tool_calling: pytest.skip("Test requires tool calling.") - model = chat_model_class(**chat_model_params) model_with_tools = model.bind_tools([my_adder_tool]) function_name = "my_adder_tool" function_args = {"a": 1, "b": 2} messages_list_content = [ - HumanMessage(content="What is 1 + 2"), + HumanMessage("What is 1 + 2"), # List content (e.g., Anthropic) AIMessage( - content=[ + [ {"type": "text", "text": "some text"}, { "type": "tool_use", @@ -231,8 +168,8 @@ class ChatModelIntegrationTests(ABC): ], ), ToolMessage( + json.dumps({"result": 3}), name=function_name, - content=json.dumps({"result": 3}), tool_call_id="abc123", ), ] @@ -241,25 +178,22 @@ class ChatModelIntegrationTests(ABC): def test_structured_few_shot_examples( self, - chat_model_class: Type[BaseChatModel], - chat_model_params: dict, - chat_model_has_tool_calling: bool, + model: BaseChatModel, ) -> None: """ Test that model can process few-shot examples with tool calls. """ - if not chat_model_has_tool_calling: + if not self.has_tool_calling: pytest.skip("Test requires tool calling.") - model = chat_model_class(**chat_model_params) - model_with_tools = model.bind_tools([my_adder_tool]) + model_with_tools = model.bind_tools([my_adder_tool], tool_choice="any") function_name = "my_adder_tool" function_args = {"a": 1, "b": 2} function_result = json.dumps({"result": 3}) messages_string_content = [ - HumanMessage(content="What is 1 + 2"), + HumanMessage("What is 1 + 2"), AIMessage( - content="", + "", tool_calls=[ { "name": function_name, @@ -269,12 +203,12 @@ class ChatModelIntegrationTests(ABC): ], ), ToolMessage( + function_result, name=function_name, - content=function_result, tool_call_id="abc123", ), - AIMessage(content=function_result), - HumanMessage(content="What is 3 + 4"), + AIMessage(function_result), + HumanMessage("What is 3 + 4"), ] result_string_content = model_with_tools.invoke(messages_string_content) assert isinstance(result_string_content, AIMessage) diff --git a/libs/standard-tests/langchain_standard_tests/unit_tests/chat_models.py b/libs/standard-tests/langchain_standard_tests/unit_tests/chat_models.py index a3e71411b0..89d069c203 100644 --- a/libs/standard-tests/langchain_standard_tests/unit_tests/chat_models.py +++ b/libs/standard-tests/langchain_standard_tests/unit_tests/chat_models.py @@ -1,13 +1,16 @@ from abc import ABC, abstractmethod -from typing import List, Literal, Optional, Type +from typing import Any, List, Literal, Optional, Type import pytest from langchain_core.language_models import BaseChatModel from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError +from langchain_core.runnables import RunnableBinding from langchain_core.tools import tool class Person(BaseModel): + """Record attributes of a person.""" + name: str = Field(..., description="The name of the person.") age: int = Field(..., description="The age of the person.") @@ -18,81 +21,105 @@ def my_adder_tool(a: int, b: int) -> int: return a + b -class ChatModelUnitTests(ABC): +class ChatModelTests(ABC): + @property @abstractmethod - @pytest.fixture def chat_model_class(self) -> Type[BaseChatModel]: ... - @pytest.fixture + @property def chat_model_params(self) -> dict: return {} - @pytest.fixture - def chat_model_has_tool_calling( - self, chat_model_class: Type[BaseChatModel] - ) -> bool: - return chat_model_class.bind_tools is not BaseChatModel.bind_tools + @property + def standard_chat_model_params(self) -> dict: + return { + "temperature": 0, + "max_tokens": 100, + "timeout": 60, + "stop_sequences": [], + "max_retries": 2, + } @pytest.fixture - def chat_model_has_structured_output( - self, chat_model_class: Type[BaseChatModel] - ) -> bool: + def model(self) -> BaseChatModel: + return self.chat_model_class( + **{**self.standard_chat_model_params, **self.chat_model_params} + ) + + @property + def has_tool_calling(self) -> bool: + return self.chat_model_class.bind_tools is not BaseChatModel.bind_tools + + @property + def has_structured_output(self) -> bool: return ( - chat_model_class.with_structured_output + self.chat_model_class.with_structured_output is not BaseChatModel.with_structured_output ) - def test_chat_model_init( - self, chat_model_class: Type[BaseChatModel], chat_model_params: dict - ) -> None: - model = chat_model_class(**chat_model_params) - assert model is not None + @property + def supports_image_inputs(self) -> bool: + return False - def test_chat_model_init_api_key( - self, chat_model_class: Type[BaseChatModel], chat_model_params: dict - ) -> None: - params = {**chat_model_params, "api_key": "test"} - model = chat_model_class(**params) # type: ignore + @property + def supports_video_inputs(self) -> bool: + return False + + @property + def returns_usage_metadata(self) -> bool: + return True + + +class ChatModelUnitTests(ChatModelTests): + @property + def standard_chat_model_params(self) -> dict: + params = super().standard_chat_model_params + params["api_key"] = "test" + return params + + def test_init(self) -> None: + model = self.chat_model_class( + **{**self.standard_chat_model_params, **self.chat_model_params} + ) assert model is not None - def test_chat_model_init_streaming( - self, chat_model_class: Type[BaseChatModel], chat_model_params: dict + def test_init_streaming( + self, ) -> None: - model = chat_model_class(streaming=True, **chat_model_params) # type: ignore + model = self.chat_model_class( + **{ + **self.standard_chat_model_params, + **self.chat_model_params, + "streaming": True, + } + ) assert model is not None - def test_chat_model_bind_tool_pydantic( + def test_bind_tool_pydantic( self, - chat_model_class: Type[BaseChatModel], - chat_model_params: dict, - chat_model_has_tool_calling: bool, + model: BaseChatModel, ) -> None: - if not chat_model_has_tool_calling: + if not self.has_tool_calling: return - model = chat_model_class(**chat_model_params) - - assert hasattr(model, "bind_tools") - tool_model = model.bind_tools([Person]) - assert tool_model is not None + tool_model = model.bind_tools( + [Person, Person.schema(), my_adder_tool], tool_choice="any" + ) + assert isinstance(tool_model, RunnableBinding) - def test_chat_model_with_structured_output( + @pytest.mark.parametrize("schema", [Person, Person.schema()]) + def test_with_structured_output( self, - chat_model_class: Type[BaseChatModel], - chat_model_params: dict, - chat_model_has_structured_output: bool, + model: BaseChatModel, + schema: Any, ) -> None: - if not chat_model_has_structured_output: + if not self.has_structured_output: return - model = chat_model_class(**chat_model_params) - assert model is not None - assert model.with_structured_output(Person) is not None + assert model.with_structured_output(schema) is not None - def test_standard_params( - self, chat_model_class: Type[BaseChatModel], chat_model_params: dict - ) -> None: + def test_standard_params(self, model: BaseChatModel) -> None: class ExpectedParams(BaseModel): ls_provider: str ls_model_name: str @@ -101,7 +128,6 @@ class ChatModelUnitTests(ABC): ls_max_tokens: Optional[int] ls_stop: Optional[List[str]] - model = chat_model_class(**chat_model_params) ls_params = model._get_ls_params() try: ExpectedParams(**ls_params) @@ -109,7 +135,9 @@ class ChatModelUnitTests(ABC): pytest.fail(f"Validation error: {e}") # Test optional params - model = chat_model_class(max_tokens=10, stop=["test"], **chat_model_params) + model = self.chat_model_class( + max_tokens=10, stop=["test"], **self.chat_model_params + ) ls_params = model._get_ls_params() try: ExpectedParams(**ls_params)