From a08f9a7ff999d95800a1a66a2992bbe1cb8030f2 Mon Sep 17 00:00:00 2001 From: chyroc Date: Tue, 30 Jan 2024 04:19:47 +0800 Subject: [PATCH] langchain[patch]: support OpenAIAssistantRunnable async (#15302) fix https://github.com/langchain-ai/langchain/issues/15299 --------- Co-authored-by: Bagatur --- .../langchain/agents/openai_assistant/base.py | 272 +++++++++++++++++- 1 file changed, 271 insertions(+), 1 deletion(-) diff --git a/libs/langchain/langchain/agents/openai_assistant/base.py b/libs/langchain/langchain/agents/openai_assistant/base.py index 897fe4078a..cd4e144291 100644 --- a/libs/langchain/langchain/agents/openai_assistant/base.py +++ b/libs/langchain/langchain/agents/openai_assistant/base.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Un from langchain_core.agents import AgentAction, AgentFinish from langchain_core.callbacks import CallbackManager from langchain_core.load import dumpd -from langchain_core.pydantic_v1 import Field +from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.runnables import RunnableConfig, RunnableSerializable, ensure_config from langchain_core.tools import BaseTool from langchain_core.utils.function_calling import convert_to_openai_tool @@ -60,6 +60,22 @@ def _get_openai_client() -> openai.OpenAI: ) from e +def _get_openai_async_client() -> openai.AsyncOpenAI: + try: + import openai + + return openai.AsyncOpenAI() + except ImportError as e: + raise ImportError( + "Unable to import openai, please install with `pip install openai`." + ) from e + except AttributeError as e: + raise AttributeError( + "Please make sure you are using a v1.1-compatible version of openai. You " + 'can install with `pip install "openai>=1.1"`.' + ) from e + + OutputType = Union[ List[OpenAIAssistantAction], OpenAIAssistantFinish, @@ -148,6 +164,8 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]): client: Any = Field(default_factory=_get_openai_client) """OpenAI or AzureOpenAI client.""" + async_client: Any = None + """OpenAI or AzureOpenAI async client.""" assistant_id: str """OpenAI assistant id.""" check_every_ms: float = 1_000.0 @@ -155,6 +173,15 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]): as_agent: bool = False """Use as a LangChain agent, compatible with the AgentExecutor.""" + @root_validator() + def validate_async_client(cls, values: dict) -> dict: + if values["async_client"] is None: + import openai + + api_key = values["client"].api_key + values["async_client"] = openai.AsyncOpenAI(api_key=api_key) + return values + @classmethod def create_assistant( cls, @@ -273,6 +300,131 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]): run_manager.on_chain_end(response) return response + @classmethod + async def acreate_assistant( + cls, + name: str, + instructions: str, + tools: Sequence[Union[BaseTool, dict]], + model: str, + *, + async_client: Optional[ + Union[openai.AsyncOpenAI, openai.AsyncAzureOpenAI] + ] = None, + **kwargs: Any, + ) -> OpenAIAssistantRunnable: + """Create an AsyncOpenAI Assistant and instantiate the Runnable. + + Args: + name: Assistant name. + instructions: Assistant instructions. + tools: Assistant tools. Can be passed in OpenAI format or as BaseTools. + model: Assistant model to use. + async_client: AsyncOpenAI client. + Will create default async_client if not specified. + + Returns: + AsyncOpenAIAssistantRunnable configured to run using the created assistant. + """ + async_client = async_client or _get_openai_async_client() + openai_tools = [convert_to_openai_tool(tool) for tool in tools] + assistant = await async_client.beta.assistants.create( + name=name, + instructions=instructions, + tools=openai_tools, + model=model, + ) + return cls(assistant_id=assistant.id, async_client=async_client, **kwargs) + + async def ainvoke( + self, input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> OutputType: + """Async invoke assistant. + + Args: + input: Runnable input dict that can have: + content: User message when starting a new run. + thread_id: Existing thread to use. + run_id: Existing run to use. Should only be supplied when providing + the tool output for a required action after an initial invocation. + file_ids: File ids to include in new run. Used for retrieval. + message_metadata: Metadata to associate with new message. + thread_metadata: Metadata to associate with new thread. Only relevant + when new thread being created. + instructions: Additional run instructions. + model: Override Assistant model for this run. + tools: Override Assistant tools for this run. + run_metadata: Metadata to associate with new run. + config: Runnable config: + + Return: + If self.as_agent, will return + Union[List[OpenAIAssistantAction], OpenAIAssistantFinish]. Otherwise, + will return OpenAI types + Union[List[ThreadMessage], List[RequiredActionFunctionToolCall]]. + """ + + config = config or {} + callback_manager = CallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + inheritable_tags=config.get("tags"), + inheritable_metadata=config.get("metadata"), + ) + run_manager = callback_manager.on_chain_start( + dumpd(self), input, name=config.get("run_name") + ) + try: + # Being run within AgentExecutor and there are tool outputs to submit. + if self.as_agent and input.get("intermediate_steps"): + tool_outputs = self._parse_intermediate_steps( + input["intermediate_steps"] + ) + run = await self.async_client.beta.threads.runs.submit_tool_outputs( + **tool_outputs + ) + # Starting a new thread and a new run. + elif "thread_id" not in input: + thread = { + "messages": [ + { + "role": "user", + "content": input["content"], + "file_ids": input.get("file_ids", []), + "metadata": input.get("message_metadata"), + } + ], + "metadata": input.get("thread_metadata"), + } + run = await self._create_thread_and_run(input, thread) + # Starting a new run in an existing thread. + elif "run_id" not in input: + _ = await self.async_client.beta.threads.messages.create( + input["thread_id"], + content=input["content"], + role="user", + file_ids=input.get("file_ids", []), + metadata=input.get("message_metadata"), + ) + run = await self._create_run(input) + # Submitting tool outputs to an existing run, outside the AgentExecutor + # framework. + else: + run = await self.async_client.beta.threads.runs.submit_tool_outputs( + **input + ) + run = await self._wait_for_run(run.id, run.thread_id) + except BaseException as e: + run_manager.on_chain_error(e) + raise e + try: + response = self._get_response(run) + except BaseException as e: + run_manager.on_chain_error(e, metadata=run.dict()) + raise e + else: + run_manager.on_chain_end(response) + return response + def _parse_intermediate_steps( self, intermediate_steps: List[Tuple[OpenAIAssistantAction, str]] ) -> dict: @@ -388,3 +540,121 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]): if in_progress: sleep(self.check_every_ms / 1000) return run + + async def _aparse_intermediate_steps( + self, intermediate_steps: List[Tuple[OpenAIAssistantAction, str]] + ) -> dict: + last_action, last_output = intermediate_steps[-1] + run = await self._wait_for_run(last_action.run_id, last_action.thread_id) + required_tool_call_ids = { + tc.id for tc in run.required_action.submit_tool_outputs.tool_calls + } + tool_outputs = [ + {"output": str(output), "tool_call_id": action.tool_call_id} + for action, output in intermediate_steps + if action.tool_call_id in required_tool_call_ids + ] + submit_tool_outputs = { + "tool_outputs": tool_outputs, + "run_id": last_action.run_id, + "thread_id": last_action.thread_id, + } + return submit_tool_outputs + + async def _acreate_run(self, input: dict) -> Any: + params = { + k: v + for k, v in input.items() + if k in ("instructions", "model", "tools", "run_metadata") + } + return await self.async_client.beta.threads.runs.create( + input["thread_id"], + assistant_id=self.assistant_id, + **params, + ) + + async def _acreate_thread_and_run(self, input: dict, thread: dict) -> Any: + params = { + k: v + for k, v in input.items() + if k in ("instructions", "model", "tools", "run_metadata") + } + run = await self.async_client.beta.threads.create_and_run( + assistant_id=self.assistant_id, + thread=thread, + **params, + ) + return run + + async def _aget_response(self, run: Any) -> Any: + # TODO: Pagination + + if run.status == "completed": + import openai + + messages = await self.async_client.beta.threads.messages.list( + run.thread_id, order="asc" + ) + new_messages = [msg for msg in messages if msg.run_id == run.id] + if not self.as_agent: + return new_messages + answer: Any = [ + msg_content for msg in new_messages for msg_content in msg.content + ] + if all( + isinstance(content, openai.types.beta.threads.MessageContentText) + for content in answer + ): + answer = "\n".join(content.text.value for content in answer) + return OpenAIAssistantFinish( + return_values={ + "output": answer, + "thread_id": run.thread_id, + "run_id": run.id, + }, + log="", + run_id=run.id, + thread_id=run.thread_id, + ) + elif run.status == "requires_action": + if not self.as_agent: + return run.required_action.submit_tool_outputs.tool_calls + actions = [] + for tool_call in run.required_action.submit_tool_outputs.tool_calls: + function = tool_call.function + try: + args = json.loads(function.arguments, strict=False) + except JSONDecodeError as e: + raise ValueError( + f"Received invalid JSON function arguments: " + f"{function.arguments} for function {function.name}" + ) from e + if len(args) == 1 and "__arg1" in args: + args = args["__arg1"] + actions.append( + OpenAIAssistantAction( + tool=function.name, + tool_input=args, + tool_call_id=tool_call.id, + log="", + run_id=run.id, + thread_id=run.thread_id, + ) + ) + return actions + else: + run_info = json.dumps(run.dict(), indent=2) + raise ValueError( + f"Unexpected run status: {run.status}. Full run info:\n\n{run_info})" + ) + + async def _await_for_run(self, run_id: str, thread_id: str) -> Any: + in_progress = True + while in_progress: + run = await self.async_client.beta.threads.runs.retrieve( + run_id, thread_id=thread_id + ) + in_progress = run.status in ("in_progress", "queued") + if in_progress: + sleep(self.check_every_ms / 1000) + return run