langchain[patch]: allow passing client with OpenAIAssistantRunnable (#16486)

- **Description:** This addresses the issue tagged below where if you
try to pass your own client when creating an OpenAI assistant, a
pydantic error is raised:

Example code:

```python
import openai
from langchain.agents.openai_assistant import OpenAIAssistantRunnable

client = openai.OpenAI()
interpreter_assistant = OpenAIAssistantRunnable.create_assistant(
    name="langchain assistant",
    instructions="You are a personal math tutor. Write and run code to answer math questions.",
    tools=[{"type": "code_interpreter"}],
    model="gpt-4-1106-preview",
    client=client
)

```

Error:
`pydantic.v1.errors.ConfigError: field "client" not yet prepared, so the
type is still a ForwardRef. You might need to call
OpenAIAssistantRunnable.update_forward_refs()`

It additionally updates type hints and docstrings to indicate that an
AzureOpenAI client is permissible as well.

  - **Issue:** https://github.com/langchain-ai/langchain/issues/15948
  - **Dependencies:** N/A
This commit is contained in:
Krista Pratico 2024-01-23 18:48:29 -08:00 committed by GitHub
parent d898d2f07b
commit 0e2e7d8b83
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 27 additions and 5 deletions

View File

@ -146,8 +146,8 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
""" # noqa: E501
client: openai.OpenAI = Field(default_factory=_get_openai_client)
"""OpenAI client."""
client: Any = Field(default_factory=_get_openai_client)
"""OpenAI or AzureOpenAI client."""
assistant_id: str
"""OpenAI assistant id."""
check_every_ms: float = 1_000.0
@ -163,7 +163,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
tools: Sequence[Union[BaseTool, dict]],
model: str,
*,
client: Optional[openai.OpenAI] = None,
client: Optional[Union[openai.OpenAI, openai.AzureOpenAI]] = None,
**kwargs: Any,
) -> OpenAIAssistantRunnable:
"""Create an OpenAI Assistant and instantiate the Runnable.
@ -173,7 +173,8 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
instructions: Assistant instructions.
tools: Assistant tools. Can be passed in OpenAI format or as BaseTools.
model: Assistant model to use.
client: OpenAI client. Will create default client if not specified.
client: OpenAI or AzureOpenAI client.
Will create default OpenAI client if not specified.
Returns:
OpenAIAssistantRunnable configured to run using the created assistant.
@ -191,7 +192,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
tools=openai_tools,
model=model,
)
return cls(assistant_id=assistant.id, **kwargs)
return cls(assistant_id=assistant.id, client=client, **kwargs)
def invoke(
self, input: dict, config: Optional[RunnableConfig] = None

View File

@ -0,0 +1,21 @@
import pytest
from langchain.agents.openai_assistant import OpenAIAssistantRunnable
@pytest.mark.requires("openai")
def test_user_supplied_client() -> None:
import openai
client = openai.AzureOpenAI(
azure_endpoint="azure_endpoint",
api_key="api_key",
api_version="api_version",
)
assistant = OpenAIAssistantRunnable(
assistant_id="assistant_id",
client=client,
)
assert assistant.client == client