mirror of https://github.com/hwchase17/langchain
Implement a router for openai functions (#8589)
parent
a6e6e9bb86
commit
808248049d
@ -0,0 +1,46 @@
|
|||||||
|
from operator import itemgetter
|
||||||
|
from typing import Any, Callable, List, Mapping, Optional, Union
|
||||||
|
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
||||||
|
from langchain.schema.output import ChatGeneration
|
||||||
|
from langchain.schema.runnable import RouterRunnable, Runnable, RunnableBinding
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIFunction(TypedDict):
|
||||||
|
"""A function description for ChatOpenAI"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
"""The name of the function."""
|
||||||
|
description: str
|
||||||
|
"""The description of the function."""
|
||||||
|
parameters: dict
|
||||||
|
"""The parameters to the function."""
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIFunctionsRouter(RunnableBinding[ChatGeneration, Any]):
|
||||||
|
"""A runnable that routes to the selected function."""
|
||||||
|
|
||||||
|
functions: Optional[List[OpenAIFunction]]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
runnables: Mapping[
|
||||||
|
str,
|
||||||
|
Union[
|
||||||
|
Runnable[dict, Any],
|
||||||
|
Callable[[dict], Any],
|
||||||
|
],
|
||||||
|
],
|
||||||
|
functions: Optional[List[OpenAIFunction]] = None,
|
||||||
|
):
|
||||||
|
if functions is not None:
|
||||||
|
assert len(functions) == len(runnables)
|
||||||
|
assert all(func["name"] in runnables for func in functions)
|
||||||
|
router = (
|
||||||
|
JsonOutputFunctionsParser(args_only=False)
|
||||||
|
| {"key": itemgetter("name"), "input": itemgetter("arguments")}
|
||||||
|
| RouterRunnable(runnables)
|
||||||
|
)
|
||||||
|
super().__init__(bound=router, kwargs={}, functions=functions)
|
@ -0,0 +1,31 @@
|
|||||||
|
# serializer version: 1
|
||||||
|
# name: test_openai_functions_router
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'description': 'Sends the draft for revision.',
|
||||||
|
'name': 'revise',
|
||||||
|
'parameters': dict({
|
||||||
|
'properties': dict({
|
||||||
|
'notes': dict({
|
||||||
|
'description': "The editor's notes to guide the revision.",
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': 'object',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'description': 'Accepts the draft.',
|
||||||
|
'name': 'accept',
|
||||||
|
'parameters': dict({
|
||||||
|
'properties': dict({
|
||||||
|
'draft': dict({
|
||||||
|
'description': 'The draft to accept.',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': 'object',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
@ -0,0 +1,95 @@
|
|||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
from syrupy import SnapshotAssertion
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
|
from langchain.chat_models.base import BaseChatModel
|
||||||
|
from langchain.runnables.openai_functions import OpenAIFunctionsRouter
|
||||||
|
from langchain.schema import ChatResult
|
||||||
|
from langchain.schema.messages import AIMessage, BaseMessage
|
||||||
|
from langchain.schema.output import ChatGeneration
|
||||||
|
|
||||||
|
|
||||||
|
class FakeChatOpenAI(BaseChatModel):
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "fake-openai-chat-model"
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
return ChatResult(
|
||||||
|
generations=[
|
||||||
|
ChatGeneration(
|
||||||
|
message=AIMessage(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"function_call": {
|
||||||
|
"name": "accept",
|
||||||
|
"arguments": '{\n "draft": "turtles"\n}',
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_functions_router(
|
||||||
|
snapshot: SnapshotAssertion, mocker: MockerFixture
|
||||||
|
) -> None:
|
||||||
|
revise = mocker.Mock(
|
||||||
|
side_effect=lambda kw: f'Revised draft: no more {kw["notes"]}!'
|
||||||
|
)
|
||||||
|
accept = mocker.Mock(side_effect=lambda kw: f'Accepted draft: {kw["draft"]}!')
|
||||||
|
|
||||||
|
router = OpenAIFunctionsRouter(
|
||||||
|
{
|
||||||
|
"revise": revise,
|
||||||
|
"accept": accept,
|
||||||
|
},
|
||||||
|
functions=[
|
||||||
|
{
|
||||||
|
"name": "revise",
|
||||||
|
"description": "Sends the draft for revision.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"notes": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The editor's notes to guide the revision.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "accept",
|
||||||
|
"description": "Accepts the draft.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"draft": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The draft to accept.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
model = FakeChatOpenAI()
|
||||||
|
|
||||||
|
chain = model.bind(functions=router.functions) | router
|
||||||
|
|
||||||
|
assert router.functions == snapshot
|
||||||
|
|
||||||
|
assert chain.invoke("Something about turtles?") == "Accepted draft: turtles!"
|
||||||
|
|
||||||
|
revise.assert_not_called()
|
||||||
|
accept.assert_called_once_with({"draft": "turtles"})
|
Loading…
Reference in New Issue