diff --git a/libs/langchain/langchain/agents/openai_assistant/base.py b/libs/langchain/langchain/agents/openai_assistant/base.py index 84d99c8f97..67e361b02e 100644 --- a/libs/langchain/langchain/agents/openai_assistant/base.py +++ b/libs/langchain/langchain/agents/openai_assistant/base.py @@ -146,8 +146,8 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]): """ # noqa: E501 - client: openai.OpenAI = Field(default_factory=_get_openai_client) - """OpenAI client.""" + client: Any = Field(default_factory=_get_openai_client) + """OpenAI or AzureOpenAI client.""" assistant_id: str """OpenAI assistant id.""" check_every_ms: float = 1_000.0 @@ -163,7 +163,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]): tools: Sequence[Union[BaseTool, dict]], model: str, *, - client: Optional[openai.OpenAI] = None, + client: Optional[Union[openai.OpenAI, openai.AzureOpenAI]] = None, **kwargs: Any, ) -> OpenAIAssistantRunnable: """Create an OpenAI Assistant and instantiate the Runnable. @@ -173,7 +173,8 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]): instructions: Assistant instructions. tools: Assistant tools. Can be passed in OpenAI format or as BaseTools. model: Assistant model to use. - client: OpenAI client. Will create default client if not specified. + client: OpenAI or AzureOpenAI client. + Will create default OpenAI client if not specified. Returns: OpenAIAssistantRunnable configured to run using the created assistant. @@ -191,7 +192,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]): tools=openai_tools, model=model, ) - return cls(assistant_id=assistant.id, **kwargs) + return cls(assistant_id=assistant.id, client=client, **kwargs) def invoke( self, input: dict, config: Optional[RunnableConfig] = None diff --git a/libs/langchain/tests/unit_tests/agents/test_openai_assistant.py b/libs/langchain/tests/unit_tests/agents/test_openai_assistant.py new file mode 100644 index 0000000000..aaa4ba48d1 --- /dev/null +++ b/libs/langchain/tests/unit_tests/agents/test_openai_assistant.py @@ -0,0 +1,21 @@ +import pytest + +from langchain.agents.openai_assistant import OpenAIAssistantRunnable + + +@pytest.mark.requires("openai") +def test_user_supplied_client() -> None: + import openai + + client = openai.AzureOpenAI( + azure_endpoint="azure_endpoint", + api_key="api_key", + api_version="api_version", + ) + + assistant = OpenAIAssistantRunnable( + assistant_id="assistant_id", + client=client, + ) + + assert assistant.client == client