[experimental]: minor fix to open assistants code (#24682)

This commit is contained in:
Isaac Francisco 2024-08-15 10:50:57 -07:00 committed by GitHub
parent 2b4fbcb4b4
commit 5150ec3a04
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -272,7 +272,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
instructions=instructions,
tools=[_get_assistants_tool(tool) for tool in tools], # type: ignore
model=model,
file_ids=kwargs.get("file_ids"),
)
return cls(assistant_id=assistant.id, client=client, **kwargs)
@ -287,7 +286,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
thread_id: Existing thread to use.
run_id: Existing run to use. Should only be supplied when providing
the tool output for a required action after an initial invocation.
file_ids: File ids to include in new run. Used for retrieval.
message_metadata: Metadata to associate with new message.
thread_metadata: Metadata to associate with new thread. Only relevant
when new thread being created.
@ -327,7 +325,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
{
"role": "user",
"content": input["content"],
"file_ids": input.get("file_ids", []),
"metadata": input.get("message_metadata"),
}
],
@ -340,7 +337,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
input["thread_id"],
content=input["content"],
role="user",
file_ids=input.get("file_ids", []),
metadata=input.get("message_metadata"),
)
run = self._create_run(input)
@ -394,7 +390,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
instructions=instructions,
tools=openai_tools, # type: ignore
model=model,
file_ids=kwargs.get("file_ids"),
)
return cls(assistant_id=assistant.id, async_client=async_client, **kwargs)
@ -409,7 +404,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
thread_id: Existing thread to use.
run_id: Existing run to use. Should only be supplied when providing
the tool output for a required action after an initial invocation.
file_ids: File ids to include in new run. Used for retrieval.
message_metadata: Metadata to associate with a new message.
thread_metadata: Metadata to associate with new thread. Only relevant
when a new thread is created.
@ -439,7 +433,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
try:
# Being run within AgentExecutor and there are tool outputs to submit.
if self.as_agent and input.get("intermediate_steps"):
tool_outputs = self._parse_intermediate_steps(
tool_outputs = await self._aparse_intermediate_steps(
input["intermediate_steps"]
)
run = await self.async_client.beta.threads.runs.submit_tool_outputs(
@ -452,7 +446,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
{
"role": "user",
"content": input["content"],
"file_ids": input.get("file_ids", []),
"metadata": input.get("message_metadata"),
}
],
@ -465,7 +458,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
input["thread_id"],
content=input["content"],
role="user",
file_ids=input.get("file_ids", []),
metadata=input.get("message_metadata"),
)
run = await self._acreate_run(input)
@ -493,9 +485,11 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
) -> dict:
last_action, last_output = intermediate_steps[-1]
run = self._wait_for_run(last_action.run_id, last_action.thread_id)
required_tool_call_ids = {
tc.id for tc in run.required_action.submit_tool_outputs.tool_calls
}
required_tool_call_ids = set()
if run.required_action:
required_tool_call_ids = {
tc.id for tc in run.required_action.submit_tool_outputs.tool_calls
}
tool_outputs = [
{"output": str(output), "tool_call_id": action.tool_call_id}
for action, output in intermediate_steps
@ -621,9 +615,11 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
) -> dict:
last_action, last_output = intermediate_steps[-1]
run = await self._wait_for_run(last_action.run_id, last_action.thread_id)
required_tool_call_ids = {
tc.id for tc in run.required_action.submit_tool_outputs.tool_calls
}
required_tool_call_ids = set()
if run.required_action:
required_tool_call_ids = {
tc.id for tc in run.required_action.submit_tool_outputs.tool_calls
}
tool_outputs = [
{"output": str(output), "tool_call_id": action.tool_call_id}
for action, output in intermediate_steps