From 0e2e7d8b83215f2e24b85fd5324cd3ae6b2766b9 Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Tue, 23 Jan 2024 18:48:29 -0800 Subject: [PATCH] langchain[patch]: allow passing client with OpenAIAssistantRunnable (#16486) - **Description:** This addresses the issue tagged below where if you try to pass your own client when creating an OpenAI assistant, a pydantic error is raised: Example code: ```python import openai from langchain.agents.openai_assistant import OpenAIAssistantRunnable client = openai.OpenAI() interpreter_assistant = OpenAIAssistantRunnable.create_assistant( name="langchain assistant", instructions="You are a personal math tutor. Write and run code to answer math questions.", tools=[{"type": "code_interpreter"}], model="gpt-4-1106-preview", client=client ) ``` Error: `pydantic.v1.errors.ConfigError: field "client" not yet prepared, so the type is still a ForwardRef. You might need to call OpenAIAssistantRunnable.update_forward_refs()` It additionally updates type hints and docstrings to indicate that an AzureOpenAI client is permissible as well. - **Issue:** https://github.com/langchain-ai/langchain/issues/15948 - **Dependencies:** N/A --- .../langchain/agents/openai_assistant/base.py | 11 +++++----- .../agents/test_openai_assistant.py | 21 +++++++++++++++++++ 2 files changed, 27 insertions(+), 5 deletions(-) create mode 100644 libs/langchain/tests/unit_tests/agents/test_openai_assistant.py diff --git a/libs/langchain/langchain/agents/openai_assistant/base.py b/libs/langchain/langchain/agents/openai_assistant/base.py index 84d99c8f97..67e361b02e 100644 --- a/libs/langchain/langchain/agents/openai_assistant/base.py +++ b/libs/langchain/langchain/agents/openai_assistant/base.py @@ -146,8 +146,8 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]): """ # noqa: E501 - client: openai.OpenAI = Field(default_factory=_get_openai_client) - """OpenAI client.""" + client: Any = Field(default_factory=_get_openai_client) + """OpenAI or AzureOpenAI client.""" assistant_id: str """OpenAI assistant id.""" check_every_ms: float = 1_000.0 @@ -163,7 +163,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]): tools: Sequence[Union[BaseTool, dict]], model: str, *, - client: Optional[openai.OpenAI] = None, + client: Optional[Union[openai.OpenAI, openai.AzureOpenAI]] = None, **kwargs: Any, ) -> OpenAIAssistantRunnable: """Create an OpenAI Assistant and instantiate the Runnable. @@ -173,7 +173,8 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]): instructions: Assistant instructions. tools: Assistant tools. Can be passed in OpenAI format or as BaseTools. model: Assistant model to use. - client: OpenAI client. Will create default client if not specified. + client: OpenAI or AzureOpenAI client. + Will create default OpenAI client if not specified. Returns: OpenAIAssistantRunnable configured to run using the created assistant. @@ -191,7 +192,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]): tools=openai_tools, model=model, ) - return cls(assistant_id=assistant.id, **kwargs) + return cls(assistant_id=assistant.id, client=client, **kwargs) def invoke( self, input: dict, config: Optional[RunnableConfig] = None diff --git a/libs/langchain/tests/unit_tests/agents/test_openai_assistant.py b/libs/langchain/tests/unit_tests/agents/test_openai_assistant.py new file mode 100644 index 0000000000..aaa4ba48d1 --- /dev/null +++ b/libs/langchain/tests/unit_tests/agents/test_openai_assistant.py @@ -0,0 +1,21 @@ +import pytest + +from langchain.agents.openai_assistant import OpenAIAssistantRunnable + + +@pytest.mark.requires("openai") +def test_user_supplied_client() -> None: + import openai + + client = openai.AzureOpenAI( + azure_endpoint="azure_endpoint", + api_key="api_key", + api_version="api_version", + ) + + assistant = OpenAIAssistantRunnable( + assistant_id="assistant_id", + client=client, + ) + + assert assistant.client == client