mirror of
https://github.com/hwchase17/langchain
synced 2024-11-02 09:40:22 +00:00
langchain[patch]: allow passing client with OpenAIAssistantRunnable (#16486)
- **Description:** This addresses the issue tagged below where if you try to pass your own client when creating an OpenAI assistant, a pydantic error is raised: Example code: ```python import openai from langchain.agents.openai_assistant import OpenAIAssistantRunnable client = openai.OpenAI() interpreter_assistant = OpenAIAssistantRunnable.create_assistant( name="langchain assistant", instructions="You are a personal math tutor. Write and run code to answer math questions.", tools=[{"type": "code_interpreter"}], model="gpt-4-1106-preview", client=client ) ``` Error: `pydantic.v1.errors.ConfigError: field "client" not yet prepared, so the type is still a ForwardRef. You might need to call OpenAIAssistantRunnable.update_forward_refs()` It additionally updates type hints and docstrings to indicate that an AzureOpenAI client is permissible as well. - **Issue:** https://github.com/langchain-ai/langchain/issues/15948 - **Dependencies:** N/A
This commit is contained in:
parent
d898d2f07b
commit
0e2e7d8b83
@ -146,8 +146,8 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
client: openai.OpenAI = Field(default_factory=_get_openai_client)
|
||||
"""OpenAI client."""
|
||||
client: Any = Field(default_factory=_get_openai_client)
|
||||
"""OpenAI or AzureOpenAI client."""
|
||||
assistant_id: str
|
||||
"""OpenAI assistant id."""
|
||||
check_every_ms: float = 1_000.0
|
||||
@ -163,7 +163,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
||||
tools: Sequence[Union[BaseTool, dict]],
|
||||
model: str,
|
||||
*,
|
||||
client: Optional[openai.OpenAI] = None,
|
||||
client: Optional[Union[openai.OpenAI, openai.AzureOpenAI]] = None,
|
||||
**kwargs: Any,
|
||||
) -> OpenAIAssistantRunnable:
|
||||
"""Create an OpenAI Assistant and instantiate the Runnable.
|
||||
@ -173,7 +173,8 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
||||
instructions: Assistant instructions.
|
||||
tools: Assistant tools. Can be passed in OpenAI format or as BaseTools.
|
||||
model: Assistant model to use.
|
||||
client: OpenAI client. Will create default client if not specified.
|
||||
client: OpenAI or AzureOpenAI client.
|
||||
Will create default OpenAI client if not specified.
|
||||
|
||||
Returns:
|
||||
OpenAIAssistantRunnable configured to run using the created assistant.
|
||||
@ -191,7 +192,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
||||
tools=openai_tools,
|
||||
model=model,
|
||||
)
|
||||
return cls(assistant_id=assistant.id, **kwargs)
|
||||
return cls(assistant_id=assistant.id, client=client, **kwargs)
|
||||
|
||||
def invoke(
|
||||
self, input: dict, config: Optional[RunnableConfig] = None
|
||||
|
@ -0,0 +1,21 @@
|
||||
import pytest
|
||||
|
||||
from langchain.agents.openai_assistant import OpenAIAssistantRunnable
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_user_supplied_client() -> None:
|
||||
import openai
|
||||
|
||||
client = openai.AzureOpenAI(
|
||||
azure_endpoint="azure_endpoint",
|
||||
api_key="api_key",
|
||||
api_version="api_version",
|
||||
)
|
||||
|
||||
assistant = OpenAIAssistantRunnable(
|
||||
assistant_id="assistant_id",
|
||||
client=client,
|
||||
)
|
||||
|
||||
assert assistant.client == client
|
Loading…
Reference in New Issue
Block a user