From 808248049ddf060732c8bac502d71d5dd04d761c Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 9 Aug 2023 21:17:04 +0100 Subject: [PATCH] Implement a router for openai functions (#8589) --- libs/langchain/langchain/chains/base.py | 18 +++- libs/langchain/langchain/chat_models/base.py | 16 +++- libs/langchain/langchain/llms/base.py | 16 +++- .../langchain/langchain/runnables/__init__.py | 0 .../langchain/runnables/openai_functions.py | 46 +++++++++ libs/langchain/langchain/schema/retriever.py | 16 +++- libs/langchain/langchain/schema/runnable.py | 13 ++- libs/langchain/langchain/tools/base.py | 16 +++- .../__snapshots__/test_openai_functions.ambr | 31 ++++++ .../runnables/test_openai_functions.py | 95 +++++++++++++++++++ 10 files changed, 254 insertions(+), 13 deletions(-) create mode 100644 libs/langchain/langchain/runnables/__init__.py create mode 100644 libs/langchain/langchain/runnables/openai_functions.py create mode 100644 libs/langchain/tests/unit_tests/runnables/__snapshots__/test_openai_functions.ambr create mode 100644 libs/langchain/tests/unit_tests/runnables/test_openai_functions.py diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index 301b0143e7..751dcbd581 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -62,7 +62,14 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Dict[str, Any]: - return self(input, **(config or {}), **kwargs) + config = config or {} + return self( + input, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + **kwargs, + ) async def ainvoke( self, @@ -76,7 +83,14 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): None, partial(self.invoke, input, config, **kwargs) ) - return await self.acall(input, **(config or {}), **kwargs) + config = config or {} + return await self.acall( + input, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + **kwargs, + ) memory: Optional[BaseMemory] = None """Optional memory object. Defaults to None. diff --git a/libs/langchain/langchain/chat_models/base.py b/libs/langchain/langchain/chat_models/base.py index b06b99f99d..0a39dff54a 100644 --- a/libs/langchain/langchain/chat_models/base.py +++ b/libs/langchain/langchain/chat_models/base.py @@ -103,12 +103,18 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): stop: Optional[List[str]] = None, **kwargs: Any, ) -> BaseMessageChunk: + config = config or {} return cast( BaseMessageChunk, cast( ChatGeneration, self.generate_prompt( - [self._convert_input(input)], stop=stop, **(config or {}), **kwargs + [self._convert_input(input)], + stop=stop, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + **kwargs, ).generations[0][0], ).message, ) @@ -127,8 +133,14 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): None, partial(self.invoke, input, config, stop=stop, **kwargs) ) + config = config or {} llm_result = await self.agenerate_prompt( - [self._convert_input(input)], stop=stop, **(config or {}), **kwargs + [self._convert_input(input)], + stop=stop, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + **kwargs, ) return cast( BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message diff --git a/libs/langchain/langchain/llms/base.py b/libs/langchain/langchain/llms/base.py index 7da494de78..3fa006ea72 100644 --- a/libs/langchain/langchain/llms/base.py +++ b/libs/langchain/langchain/llms/base.py @@ -219,9 +219,15 @@ class BaseLLM(BaseLanguageModel[str], ABC): stop: Optional[List[str]] = None, **kwargs: Any, ) -> str: + config = config or {} return ( self.generate_prompt( - [self._convert_input(input)], stop=stop, **(config or {}), **kwargs + [self._convert_input(input)], + stop=stop, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + **kwargs, ) .generations[0][0] .text @@ -241,8 +247,14 @@ class BaseLLM(BaseLanguageModel[str], ABC): None, partial(self.invoke, input, config, stop=stop, **kwargs) ) + config = config or {} llm_result = await self.agenerate_prompt( - [self._convert_input(input)], stop=stop, **(config or {}), **kwargs + [self._convert_input(input)], + stop=stop, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + **kwargs, ) return llm_result.generations[0][0].text diff --git a/libs/langchain/langchain/runnables/__init__.py b/libs/langchain/langchain/runnables/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/libs/langchain/langchain/runnables/openai_functions.py b/libs/langchain/langchain/runnables/openai_functions.py new file mode 100644 index 0000000000..55c9765d20 --- /dev/null +++ b/libs/langchain/langchain/runnables/openai_functions.py @@ -0,0 +1,46 @@ +from operator import itemgetter +from typing import Any, Callable, List, Mapping, Optional, Union + +from typing_extensions import TypedDict + +from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser +from langchain.schema.output import ChatGeneration +from langchain.schema.runnable import RouterRunnable, Runnable, RunnableBinding + + +class OpenAIFunction(TypedDict): + """A function description for ChatOpenAI""" + + name: str + """The name of the function.""" + description: str + """The description of the function.""" + parameters: dict + """The parameters to the function.""" + + +class OpenAIFunctionsRouter(RunnableBinding[ChatGeneration, Any]): + """A runnable that routes to the selected function.""" + + functions: Optional[List[OpenAIFunction]] + + def __init__( + self, + runnables: Mapping[ + str, + Union[ + Runnable[dict, Any], + Callable[[dict], Any], + ], + ], + functions: Optional[List[OpenAIFunction]] = None, + ): + if functions is not None: + assert len(functions) == len(runnables) + assert all(func["name"] in runnables for func in functions) + router = ( + JsonOutputFunctionsParser(args_only=False) + | {"key": itemgetter("name"), "input": itemgetter("arguments")} + | RouterRunnable(runnables) + ) + super().__init__(bound=router, kwargs={}, functions=functions) diff --git a/libs/langchain/langchain/schema/retriever.py b/libs/langchain/langchain/schema/retriever.py index 9df3e7a138..72c5cf6366 100644 --- a/libs/langchain/langchain/schema/retriever.py +++ b/libs/langchain/langchain/schema/retriever.py @@ -107,7 +107,13 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC): def invoke( self, input: str, config: Optional[RunnableConfig] = None ) -> List[Document]: - return self.get_relevant_documents(input, **(config or {})) + config = config or {} + return self.get_relevant_documents( + input, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + ) async def ainvoke( self, input: str, config: Optional[RunnableConfig] = None @@ -116,7 +122,13 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC): # If the retriever doesn't implement async, use default implementation return await super().ainvoke(input, config) - return await self.aget_relevant_documents(input, **(config or {})) + config = config or {} + return await self.aget_relevant_documents( + input, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + ) @abstractmethod def _get_relevant_documents( diff --git a/libs/langchain/langchain/schema/runnable.py b/libs/langchain/langchain/schema/runnable.py index 8edafe4599..84399a2c0b 100644 --- a/libs/langchain/langchain/schema/runnable.py +++ b/libs/langchain/langchain/schema/runnable.py @@ -1254,7 +1254,7 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]): class RunnableBinding(Serializable, Runnable[Input, Output]): """ - A runnable that binds a runnable to a set of kwargs. + A runnable that delegates calls to another runnable with a set of kwargs. """ bound: Runnable[Input, Output] @@ -1339,8 +1339,15 @@ class RouterRunnable( runnables: Mapping[str, Runnable[Input, Output]] - def __init__(self, runnables: Mapping[str, Runnable[Input, Output]]) -> None: - super().__init__(runnables=runnables) + def __init__( + self, + runnables: Mapping[ + str, Union[Runnable[Input, Output], Callable[[Input], Output]] + ], + ) -> None: + super().__init__( + runnables={key: _coerce_to_runnable(r) for key, r in runnables.items()} + ) class Config: arbitrary_types_allowed = True diff --git a/libs/langchain/langchain/tools/base.py b/libs/langchain/langchain/tools/base.py index 651138718b..f8607d5144 100644 --- a/libs/langchain/langchain/tools/base.py +++ b/libs/langchain/langchain/tools/base.py @@ -203,7 +203,13 @@ class BaseTool(BaseModel, Runnable[Union[str, Dict], Any], metaclass=ToolMetacla **kwargs: Any, ) -> Any: config = config or {} - return self.run(input, **config, **kwargs) + return self.run( + input, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + **kwargs, + ) async def ainvoke( self, @@ -216,7 +222,13 @@ class BaseTool(BaseModel, Runnable[Union[str, Dict], Any], metaclass=ToolMetacla return super().ainvoke(input, config, **kwargs) config = config or {} - return await self.arun(input, **config, **kwargs) + return await self.arun( + input, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + **kwargs, + ) # --- Tool --- diff --git a/libs/langchain/tests/unit_tests/runnables/__snapshots__/test_openai_functions.ambr b/libs/langchain/tests/unit_tests/runnables/__snapshots__/test_openai_functions.ambr new file mode 100644 index 0000000000..ed3f36e061 --- /dev/null +++ b/libs/langchain/tests/unit_tests/runnables/__snapshots__/test_openai_functions.ambr @@ -0,0 +1,31 @@ +# serializer version: 1 +# name: test_openai_functions_router + list([ + dict({ + 'description': 'Sends the draft for revision.', + 'name': 'revise', + 'parameters': dict({ + 'properties': dict({ + 'notes': dict({ + 'description': "The editor's notes to guide the revision.", + 'type': 'string', + }), + }), + 'type': 'object', + }), + }), + dict({ + 'description': 'Accepts the draft.', + 'name': 'accept', + 'parameters': dict({ + 'properties': dict({ + 'draft': dict({ + 'description': 'The draft to accept.', + 'type': 'string', + }), + }), + 'type': 'object', + }), + }), + ]) +# --- diff --git a/libs/langchain/tests/unit_tests/runnables/test_openai_functions.py b/libs/langchain/tests/unit_tests/runnables/test_openai_functions.py new file mode 100644 index 0000000000..e4cec167d8 --- /dev/null +++ b/libs/langchain/tests/unit_tests/runnables/test_openai_functions.py @@ -0,0 +1,95 @@ +from typing import Any, List, Optional + +from pytest_mock import MockerFixture +from syrupy import SnapshotAssertion + +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.chat_models.base import BaseChatModel +from langchain.runnables.openai_functions import OpenAIFunctionsRouter +from langchain.schema import ChatResult +from langchain.schema.messages import AIMessage, BaseMessage +from langchain.schema.output import ChatGeneration + + +class FakeChatOpenAI(BaseChatModel): + @property + def _llm_type(self) -> str: + return "fake-openai-chat-model" + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + return ChatResult( + generations=[ + ChatGeneration( + message=AIMessage( + content="", + additional_kwargs={ + "function_call": { + "name": "accept", + "arguments": '{\n "draft": "turtles"\n}', + } + }, + ) + ) + ] + ) + + +def test_openai_functions_router( + snapshot: SnapshotAssertion, mocker: MockerFixture +) -> None: + revise = mocker.Mock( + side_effect=lambda kw: f'Revised draft: no more {kw["notes"]}!' + ) + accept = mocker.Mock(side_effect=lambda kw: f'Accepted draft: {kw["draft"]}!') + + router = OpenAIFunctionsRouter( + { + "revise": revise, + "accept": accept, + }, + functions=[ + { + "name": "revise", + "description": "Sends the draft for revision.", + "parameters": { + "type": "object", + "properties": { + "notes": { + "type": "string", + "description": "The editor's notes to guide the revision.", + }, + }, + }, + }, + { + "name": "accept", + "description": "Accepts the draft.", + "parameters": { + "type": "object", + "properties": { + "draft": { + "type": "string", + "description": "The draft to accept.", + }, + }, + }, + }, + ], + ) + + model = FakeChatOpenAI() + + chain = model.bind(functions=router.functions) | router + + assert router.functions == snapshot + + assert chain.invoke("Something about turtles?") == "Accepted draft: turtles!" + + revise.assert_not_called() + accept.assert_called_once_with({"draft": "turtles"})