From 5e3b96807837178de76473178011c26764e3008f Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 31 Jul 2023 11:07:10 -0700 Subject: [PATCH] router runnable (#8496) Co-authored-by: Nuno Campos --- .../guides/expression_language/cookbook.ipynb | 107 ++++++++- libs/langchain/langchain/schema/runnable.py | 138 +++++++++++ .../schema/__snapshots__/test_runnable.ambr | 221 ++++++++++++++++++ .../tests/unit_tests/schema/test_runnable.py | 49 ++++ 4 files changed, 512 insertions(+), 3 deletions(-) diff --git a/docs/extras/guides/expression_language/cookbook.ipynb b/docs/extras/guides/expression_language/cookbook.ipynb index ed122f260f..19dd9b5554 100644 --- a/docs/extras/guides/expression_language/cookbook.ipynb +++ b/docs/extras/guides/expression_language/cookbook.ipynb @@ -22,10 +22,19 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 1, "id": "466b65b3", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.14) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n", + " warnings.warn(\n" + ] + } + ], "source": [ "from langchain.prompts import ChatPromptTemplate\n", "from langchain.chat_models import ChatOpenAI" @@ -33,7 +42,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 2, "id": "3c634ef0", "metadata": {}, "outputs": [], @@ -583,6 +592,98 @@ "chain2.invoke({})" ] }, + { + "cell_type": "markdown", + "id": "d094d637", + "metadata": {}, + "source": [ + "## Router\n", + "\n", + "You can also use the router runnable to conditionally route inputs to different runnables." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "252625fd", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.chains import create_tagging_chain_pydantic\n", + "from pydantic import BaseModel, Field\n", + "\n", + "class PromptToUse(BaseModel):\n", + " \"\"\"Used to determine which prompt to use to answer the user's input.\"\"\"\n", + " \n", + " name: str = Field(description=\"Should be one of `math` or `english`\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "57886e84", + "metadata": {}, + "outputs": [], + "source": [ + "tagger = create_tagging_chain_pydantic(PromptToUse, ChatOpenAI(temperature=0))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "a303b089", + "metadata": {}, + "outputs": [], + "source": [ + "chain1 = ChatPromptTemplate.from_template(\"You are a math genius. Answer the question: {question}\") | ChatOpenAI()\n", + "chain2 = ChatPromptTemplate.from_template(\"You are an english major. Answer the question: {question}\") | ChatOpenAI()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "7aa9ea06", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.schema.runnable import RouterRunnable\n", + "router = RouterRunnable({\"math\": chain1, \"english\": chain2})" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "6a3d3f5d", + "metadata": {}, + "outputs": [], + "source": [ + "chain = {\n", + " \"key\": {\"input\": lambda x: x[\"question\"]} | tagger | (lambda x: x['text'].name),\n", + " \"input\": {\"question\": lambda x: x[\"question\"]}\n", + "} | router" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "8aeda930", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content='Thank you for the compliment! The sum of 2 + 2 is equal to 4.', additional_kwargs={}, example=False)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain.invoke({\"question\": \"whats 2 + 2\"})" + ] + }, { "cell_type": "markdown", "id": "29781123", diff --git a/libs/langchain/langchain/schema/runnable.py b/libs/langchain/langchain/schema/runnable.py index 09f76de492..f13786565c 100644 --- a/libs/langchain/langchain/schema/runnable.py +++ b/libs/langchain/langchain/schema/runnable.py @@ -108,6 +108,10 @@ class Runnable(Generic[Input, Output], ABC): ) -> List[Output]: configs = self._get_config_list(config, len(inputs)) + # If there's only one input, don't bother with the executor + if len(inputs) == 1: + return [self.invoke(inputs[0], configs[0])] + with ThreadPoolExecutor(max_workers=max_concurrency) as executor: return list(executor.map(self.invoke, inputs, configs)) @@ -759,6 +763,140 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): yield item +class RouterInput(TypedDict): + key: str + input: Any + + +class RouterRunnable( + Serializable, Generic[Input, Output], Runnable[RouterInput, Output] +): + runnables: Mapping[str, Runnable[Input, Output]] + + def __init__(self, runnables: Mapping[str, Runnable[Input, Output]]) -> None: + super().__init__(runnables=runnables) + + class Config: + arbitrary_types_allowed = True + + @property + def lc_serializable(self) -> bool: + return True + + def __or__( + self, + other: Union[ + Runnable[Any, Other], + Callable[[Any], Other], + Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]], + Mapping[str, Any], + ], + ) -> RunnableSequence[RouterInput, Other]: + return RunnableSequence(first=self, last=_coerce_to_runnable(other)) + + def __ror__( + self, + other: Union[ + Runnable[Other, Any], + Callable[[Any], Other], + Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]], + Mapping[str, Any], + ], + ) -> RunnableSequence[Other, Output]: + return RunnableSequence(first=_coerce_to_runnable(other), last=self) + + def invoke( + self, input: RouterInput, config: Optional[RunnableConfig] = None + ) -> Output: + key = input["key"] + actual_input = input["input"] + if key not in self.runnables: + raise ValueError(f"No runnable associated with key '{key}'") + + runnable = self.runnables[key] + return runnable.invoke(actual_input, config) + + async def ainvoke( + self, input: RouterInput, config: Optional[RunnableConfig] = None + ) -> Output: + key = input["key"] + actual_input = input["input"] + if key not in self.runnables: + raise ValueError(f"No runnable associated with key '{key}'") + + runnable = self.runnables[key] + return await runnable.ainvoke(actual_input, config) + + def batch( + self, + inputs: List[RouterInput], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + max_concurrency: Optional[int] = None, + ) -> List[Output]: + keys = [input["key"] for input in inputs] + actual_inputs = [input["input"] for input in inputs] + if any(key not in self.runnables for key in keys): + raise ValueError("One or more keys do not have a corresponding runnable") + + runnables = [self.runnables[key] for key in keys] + configs = self._get_config_list(config, len(inputs)) + with ThreadPoolExecutor(max_workers=max_concurrency) as executor: + return list( + executor.map( + lambda runnable, input, config: runnable.invoke(input, config), + runnables, + actual_inputs, + configs, + ) + ) + + async def abatch( + self, + inputs: List[RouterInput], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + max_concurrency: Optional[int] = None, + ) -> List[Output]: + keys = [input["key"] for input in inputs] + actual_inputs = [input["input"] for input in inputs] + if any(key not in self.runnables for key in keys): + raise ValueError("One or more keys do not have a corresponding runnable") + + runnables = [self.runnables[key] for key in keys] + configs = self._get_config_list(config, len(inputs)) + return await _gather_with_concurrency( + max_concurrency, + *( + runnable.ainvoke(input, config) + for runnable, input, config in zip(runnables, actual_inputs, configs) + ), + ) + + def stream( + self, input: RouterInput, config: Optional[RunnableConfig] = None + ) -> Iterator[Output]: + key = input["key"] + actual_input = input["input"] + if key not in self.runnables: + raise ValueError(f"No runnable associated with key '{key}'") + + runnable = self.runnables[key] + yield from runnable.stream(actual_input, config) + + async def astream( + self, input: RouterInput, config: Optional[RunnableConfig] = None + ) -> AsyncIterator[Output]: + key = input["key"] + actual_input = input["input"] + if key not in self.runnables: + raise ValueError(f"No runnable associated with key '{key}'") + + runnable = self.runnables[key] + async for output in runnable.astream(actual_input, config): + yield output + + def _patch_config( config: RunnableConfig, callback_manager: BaseCallbackManager ) -> RunnableConfig: diff --git a/libs/langchain/tests/unit_tests/schema/__snapshots__/test_runnable.ambr b/libs/langchain/tests/unit_tests/schema/__snapshots__/test_runnable.ambr index 4d45f38918..5f53c2146e 100644 --- a/libs/langchain/tests/unit_tests/schema/__snapshots__/test_runnable.ambr +++ b/libs/langchain/tests/unit_tests/schema/__snapshots__/test_runnable.ambr @@ -327,6 +327,227 @@ Run(id=UUID('00000000-0000-4000-8000-000000000003'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}]}}, 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your favorite color?'}, outputs={'output': 'foo'}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=1, child_execution_order=3, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000004'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}]}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your favorite color?'}, outputs={'output': ChatPromptValue(messages=[SystemMessage(content='You are a nice assistant.', additional_kwargs={}), HumanMessage(content='What is your favorite color?', additional_kwargs={}, example=False)])}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000003'), tags=[], execution_order=2, child_execution_order=2, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000005'), name='FakeListLLM', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo', 'bar'], '_type': 'fake-list', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your favorite color?']}, outputs={'generations': [[{'text': 'foo', 'generation_info': None}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000003'), tags=[], execution_order=3, child_execution_order=3, child_runs=[])]), ]) # --- +# name: test_router_runnable + ''' + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableSequence" + ], + "kwargs": { + "first": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableMap" + ], + "kwargs": { + "steps": { + "key": { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableLambda" + ] + }, + "input": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableMap" + ], + "kwargs": { + "steps": { + "question": { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableLambda" + ] + } + } + } + } + } + } + }, + "last": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RouterRunnable" + ], + "kwargs": { + "runnables": { + "math": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableSequence" + ], + "kwargs": { + "first": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "ChatPromptTemplate" + ], + "kwargs": { + "input_variables": [ + "question" + ], + "messages": [ + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "HumanMessagePromptTemplate" + ], + "kwargs": { + "prompt": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "prompt", + "PromptTemplate" + ], + "kwargs": { + "input_variables": [ + "question" + ], + "template": "You are a math genius. Answer the question: {question}", + "template_format": "f-string", + "partial_variables": {} + } + } + } + } + ] + } + }, + "last": { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain", + "llms", + "fake", + "FakeListLLM" + ] + } + } + }, + "english": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableSequence" + ], + "kwargs": { + "first": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "ChatPromptTemplate" + ], + "kwargs": { + "input_variables": [ + "question" + ], + "messages": [ + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "HumanMessagePromptTemplate" + ], + "kwargs": { + "prompt": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "prompt", + "PromptTemplate" + ], + "kwargs": { + "input_variables": [ + "question" + ], + "template": "You are an english major. Answer the question: {question}", + "template_format": "f-string", + "partial_variables": {} + } + } + } + } + ] + } + }, + "last": { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain", + "llms", + "fake", + "FakeListLLM" + ] + } + } + } + } + } + } + } + } + ''' +# --- +# name: test_router_runnable.1 + list([ + Run(id=UUID('00000000-0000-4000-8000-000000000000'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableMap'], 'kwargs': {'steps': {'key': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'RunnableLambda']}, 'input': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableMap'], 'kwargs': {'steps': {'question': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'RunnableLambda']}}}}}}}, 'last': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RouterRunnable'], 'kwargs': {'runnables': {'math': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'input_variables': ['question'], 'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': 'You are a math genius. Answer the question: {question}', 'template_format': 'f-string', 'partial_variables': {}}}}}]}}, 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}}}, 'english': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'input_variables': ['question'], 'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': 'You are an english major. Answer the question: {question}', 'template_format': 'f-string', 'partial_variables': {}}}}}]}}, 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}}}}}}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'key': 'math', 'question': '2 + 2'}, outputs={'output': '4'}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=1, child_execution_order=8, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000001'), name='RunnableMap', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableMap'], 'kwargs': {'steps': {'key': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'RunnableLambda']}, 'input': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableMap'], 'kwargs': {'steps': {'question': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'RunnableLambda']}}}}}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'input': {'key': 'math', 'question': '2 + 2'}}, outputs={'key': 'math', 'input': {'question': '2 + 2'}}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=2, child_execution_order=5, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000002'), name='RunnableLambda', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'RunnableLambda']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'key': 'math', 'question': '2 + 2'}, outputs={'output': 'math'}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000001'), tags=[], execution_order=3, child_execution_order=3, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000003'), name='RunnableMap', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableMap'], 'kwargs': {'steps': {'question': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'RunnableLambda']}}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'input': {'key': 'math', 'question': '2 + 2'}}, outputs={'question': '2 + 2'}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000001'), tags=[], execution_order=4, child_execution_order=5, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000004'), name='RunnableLambda', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'RunnableLambda']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'key': 'math', 'question': '2 + 2'}, outputs={'output': '2 + 2'}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000003'), tags=[], execution_order=5, child_execution_order=5, child_runs=[])])]), Run(id=UUID('00000000-0000-4000-8000-000000000005'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'input_variables': ['question'], 'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': 'You are a math genius. Answer the question: {question}', 'template_format': 'f-string', 'partial_variables': {}}}}}]}}, 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': '2 + 2'}, outputs={'output': '4'}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=6, child_execution_order=8, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000006'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'input_variables': ['question'], 'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': 'You are a math genius. Answer the question: {question}', 'template_format': 'f-string', 'partial_variables': {}}}}}]}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': '2 + 2'}, outputs={'output': ChatPromptValue(messages=[HumanMessage(content='You are a math genius. Answer the question: 2 + 2', additional_kwargs={}, example=False)])}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000005'), tags=[], execution_order=7, child_execution_order=7, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000007'), name='FakeListLLM', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type=, end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['4'], '_type': 'fake-list', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['Human: You are a math genius. Answer the question: 2 + 2']}, outputs={'generations': [[{'text': '4', 'generation_info': None}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000005'), tags=[], execution_order=8, child_execution_order=8, child_runs=[])])]), + ]) +# --- # name: test_seq_dict_prompt_llm ''' { diff --git a/libs/langchain/tests/unit_tests/schema/test_runnable.py b/libs/langchain/tests/unit_tests/schema/test_runnable.py index a3489cbdc8..d912e18982 100644 --- a/libs/langchain/tests/unit_tests/schema/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/test_runnable.py @@ -23,6 +23,7 @@ from langchain.schema.document import Document from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage from langchain.schema.retriever import BaseRetriever from langchain.schema.runnable import ( + RouterRunnable, Runnable, RunnableConfig, RunnableLambda, @@ -572,6 +573,54 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) -> assert len(map_run.child_runs) == 2 +@pytest.mark.asyncio +@freeze_time("2023-01-01") +async def test_router_runnable( + mocker: MockerFixture, snapshot: SnapshotAssertion +) -> None: + chain1 = ChatPromptTemplate.from_template( + "You are a math genius. Answer the question: {question}" + ) | FakeListLLM(responses=["4"]) + chain2 = ChatPromptTemplate.from_template( + "You are an english major. Answer the question: {question}" + ) | FakeListLLM(responses=["2"]) + router = RouterRunnable({"math": chain1, "english": chain2}) + chain: Runnable = { + "key": lambda x: x["key"], + "input": {"question": lambda x: x["question"]}, + } | router + assert dumps(chain, pretty=True) == snapshot + + result = chain.invoke({"key": "math", "question": "2 + 2"}) + assert result == "4" + + result2 = chain.batch( + [{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}] + ) + assert result2 == ["4", "2"] + + result = await chain.ainvoke({"key": "math", "question": "2 + 2"}) + assert result == "4" + + result2 = await chain.abatch( + [{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}] + ) + assert result2 == ["4", "2"] + + # Test invoke + router_spy = mocker.spy(router.__class__, "invoke") + tracer = FakeTracer() + assert ( + chain.invoke({"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer])) + == "4" + ) + assert router_spy.call_args.args[1] == { + "key": "math", + "input": {"question": "2 + 2"}, + } + assert tracer.runs == snapshot + + @freeze_time("2023-01-01") def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None: passthrough = mocker.Mock(side_effect=lambda x: x)