Implement a router for openai functions (#8589)

pull/9002/head
Nuno Campos 11 months ago committed by GitHub
parent a6e6e9bb86
commit 808248049d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -62,7 +62,14 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
return self(input, **(config or {}), **kwargs)
config = config or {}
return self(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
**kwargs,
)
async def ainvoke(
self,
@ -76,7 +83,14 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
None, partial(self.invoke, input, config, **kwargs)
)
return await self.acall(input, **(config or {}), **kwargs)
config = config or {}
return await self.acall(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
**kwargs,
)
memory: Optional[BaseMemory] = None
"""Optional memory object. Defaults to None.

@ -103,12 +103,18 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> BaseMessageChunk:
config = config or {}
return cast(
BaseMessageChunk,
cast(
ChatGeneration,
self.generate_prompt(
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
[self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
**kwargs,
).generations[0][0],
).message,
)
@ -127,8 +133,14 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
None, partial(self.invoke, input, config, stop=stop, **kwargs)
)
config = config or {}
llm_result = await self.agenerate_prompt(
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
[self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
**kwargs,
)
return cast(
BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message

@ -219,9 +219,15 @@ class BaseLLM(BaseLanguageModel[str], ABC):
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> str:
config = config or {}
return (
self.generate_prompt(
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
[self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
**kwargs,
)
.generations[0][0]
.text
@ -241,8 +247,14 @@ class BaseLLM(BaseLanguageModel[str], ABC):
None, partial(self.invoke, input, config, stop=stop, **kwargs)
)
config = config or {}
llm_result = await self.agenerate_prompt(
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
[self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
**kwargs,
)
return llm_result.generations[0][0].text

@ -0,0 +1,46 @@
from operator import itemgetter
from typing import Any, Callable, List, Mapping, Optional, Union
from typing_extensions import TypedDict
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain.schema.output import ChatGeneration
from langchain.schema.runnable import RouterRunnable, Runnable, RunnableBinding
class OpenAIFunction(TypedDict):
"""A function description for ChatOpenAI"""
name: str
"""The name of the function."""
description: str
"""The description of the function."""
parameters: dict
"""The parameters to the function."""
class OpenAIFunctionsRouter(RunnableBinding[ChatGeneration, Any]):
"""A runnable that routes to the selected function."""
functions: Optional[List[OpenAIFunction]]
def __init__(
self,
runnables: Mapping[
str,
Union[
Runnable[dict, Any],
Callable[[dict], Any],
],
],
functions: Optional[List[OpenAIFunction]] = None,
):
if functions is not None:
assert len(functions) == len(runnables)
assert all(func["name"] in runnables for func in functions)
router = (
JsonOutputFunctionsParser(args_only=False)
| {"key": itemgetter("name"), "input": itemgetter("arguments")}
| RouterRunnable(runnables)
)
super().__init__(bound=router, kwargs={}, functions=functions)

@ -107,7 +107,13 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
def invoke(
self, input: str, config: Optional[RunnableConfig] = None
) -> List[Document]:
return self.get_relevant_documents(input, **(config or {}))
config = config or {}
return self.get_relevant_documents(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
)
async def ainvoke(
self, input: str, config: Optional[RunnableConfig] = None
@ -116,7 +122,13 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
# If the retriever doesn't implement async, use default implementation
return await super().ainvoke(input, config)
return await self.aget_relevant_documents(input, **(config or {}))
config = config or {}
return await self.aget_relevant_documents(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
)
@abstractmethod
def _get_relevant_documents(

@ -1254,7 +1254,7 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
class RunnableBinding(Serializable, Runnable[Input, Output]):
"""
A runnable that binds a runnable to a set of kwargs.
A runnable that delegates calls to another runnable with a set of kwargs.
"""
bound: Runnable[Input, Output]
@ -1339,8 +1339,15 @@ class RouterRunnable(
runnables: Mapping[str, Runnable[Input, Output]]
def __init__(self, runnables: Mapping[str, Runnable[Input, Output]]) -> None:
super().__init__(runnables=runnables)
def __init__(
self,
runnables: Mapping[
str, Union[Runnable[Input, Output], Callable[[Input], Output]]
],
) -> None:
super().__init__(
runnables={key: _coerce_to_runnable(r) for key, r in runnables.items()}
)
class Config:
arbitrary_types_allowed = True

@ -203,7 +203,13 @@ class BaseTool(BaseModel, Runnable[Union[str, Dict], Any], metaclass=ToolMetacla
**kwargs: Any,
) -> Any:
config = config or {}
return self.run(input, **config, **kwargs)
return self.run(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
**kwargs,
)
async def ainvoke(
self,
@ -216,7 +222,13 @@ class BaseTool(BaseModel, Runnable[Union[str, Dict], Any], metaclass=ToolMetacla
return super().ainvoke(input, config, **kwargs)
config = config or {}
return await self.arun(input, **config, **kwargs)
return await self.arun(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
**kwargs,
)
# --- Tool ---

@ -0,0 +1,31 @@
# serializer version: 1
# name: test_openai_functions_router
list([
dict({
'description': 'Sends the draft for revision.',
'name': 'revise',
'parameters': dict({
'properties': dict({
'notes': dict({
'description': "The editor's notes to guide the revision.",
'type': 'string',
}),
}),
'type': 'object',
}),
}),
dict({
'description': 'Accepts the draft.',
'name': 'accept',
'parameters': dict({
'properties': dict({
'draft': dict({
'description': 'The draft to accept.',
'type': 'string',
}),
}),
'type': 'object',
}),
}),
])
# ---

@ -0,0 +1,95 @@
from typing import Any, List, Optional
from pytest_mock import MockerFixture
from syrupy import SnapshotAssertion
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import BaseChatModel
from langchain.runnables.openai_functions import OpenAIFunctionsRouter
from langchain.schema import ChatResult
from langchain.schema.messages import AIMessage, BaseMessage
from langchain.schema.output import ChatGeneration
class FakeChatOpenAI(BaseChatModel):
@property
def _llm_type(self) -> str:
return "fake-openai-chat-model"
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
return ChatResult(
generations=[
ChatGeneration(
message=AIMessage(
content="",
additional_kwargs={
"function_call": {
"name": "accept",
"arguments": '{\n "draft": "turtles"\n}',
}
},
)
)
]
)
def test_openai_functions_router(
snapshot: SnapshotAssertion, mocker: MockerFixture
) -> None:
revise = mocker.Mock(
side_effect=lambda kw: f'Revised draft: no more {kw["notes"]}!'
)
accept = mocker.Mock(side_effect=lambda kw: f'Accepted draft: {kw["draft"]}!')
router = OpenAIFunctionsRouter(
{
"revise": revise,
"accept": accept,
},
functions=[
{
"name": "revise",
"description": "Sends the draft for revision.",
"parameters": {
"type": "object",
"properties": {
"notes": {
"type": "string",
"description": "The editor's notes to guide the revision.",
},
},
},
},
{
"name": "accept",
"description": "Accepts the draft.",
"parameters": {
"type": "object",
"properties": {
"draft": {
"type": "string",
"description": "The draft to accept.",
},
},
},
},
],
)
model = FakeChatOpenAI()
chain = model.bind(functions=router.functions) | router
assert router.functions == snapshot
assert chain.invoke("Something about turtles?") == "Accepted draft: turtles!"
revise.assert_not_called()
accept.assert_called_once_with({"draft": "turtles"})
Loading…
Cancel
Save