From 677da6a0fd7ac8e6c3127487b09b6bd3e1d620d8 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 17 Aug 2023 14:25:43 +0100 Subject: [PATCH 1/6] Add support for async funcs in RunnableSequence --- .../langchain/schema/runnable/base.py | 41 ++++++- .../runnable/__snapshots__/test_runnable.ambr | 116 ++++++++++++++++++ .../schema/runnable/test_runnable.py | 46 +++++++ 3 files changed, 199 insertions(+), 4 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index c7dd38865f..722a7c376b 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import inspect import threading from abc import ABC, abstractmethod from concurrent.futures import FIRST_COMPLETED, wait @@ -12,6 +13,7 @@ from typing import ( AsyncIterator, Awaitable, Callable, + Coroutine, Dict, Generic, Iterator, @@ -1343,8 +1345,17 @@ class RunnableLambda(Runnable[Input, Output]): A runnable that runs a callable. """ - def __init__(self, func: Callable[[Input], Output]) -> None: - if callable(func): + def __init__( + self, + func: Union[Callable[[Input], Output], Coroutine[Input, Any, Output]], + afunc: Optional[Coroutine[Input, Any, Output]] = None, + ) -> None: + if afunc is not None: + self.afunc = afunc + + if inspect.iscoroutinefunction(func): + self.afunc = func + elif callable(func): self.func = func else: raise TypeError( @@ -1354,7 +1365,12 @@ class RunnableLambda(Runnable[Input, Output]): def __eq__(self, other: Any) -> bool: if isinstance(other, RunnableLambda): - return self.func == other.func + if hasattr(self, "func") and hasattr(other, "func"): + return self.func == other.func + elif hasattr(self, "afunc") and hasattr(other, "afunc"): + return self.afunc == other.afunc + else: + return False else: return False @@ -1364,7 +1380,24 @@ class RunnableLambda(Runnable[Input, Output]): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Output: - return self._call_with_config(self.func, input, config) + if hasattr(self, "func"): + return self._call_with_config(self.func, input, config) + else: + raise TypeError( + "Cannot invoke a coroutine function synchronously." + "Use `ainvoke` instead." + ) + + async def ainvoke( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Output: + if hasattr(self, "afunc"): + return await self._acall_with_config(self.afunc, input, config) + else: + return await super().ainvoke(input, config) class RunnableEach(Serializable, Runnable[List[Input], List[Output]]): diff --git a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr index 828227fd30..ac8deb3de2 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr +++ b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr @@ -1187,6 +1187,122 @@ Run(id=UUID('00000000-0000-4000-8000-000000000003'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your favorite color?'}, outputs={'output': 'foo'}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=None, child_execution_order=None, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000004'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='prompt', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your favorite color?'}, outputs={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptValue'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'SystemMessage'], 'kwargs': {'content': 'You are a nice assistant.', 'additional_kwargs': {}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'HumanMessage'], 'kwargs': {'content': 'What is your favorite color?', 'additional_kwargs': {}}}]}}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000003'), tags=[], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000005'), name='FakeListLLM', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='llm', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo', 'bar'], '_type': 'fake-list', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your favorite color?']}, outputs={'generations': [[{'text': 'foo', 'generation_info': None}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000003'), tags=[], execution_order=None, child_execution_order=None, child_runs=[])]), ]) # --- +# name: test_prompt_with_llm_and_async_lambda + ''' + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableSequence" + ], + "kwargs": { + "first": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "ChatPromptTemplate" + ], + "kwargs": { + "messages": [ + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "SystemMessagePromptTemplate" + ], + "kwargs": { + "prompt": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "prompt", + "PromptTemplate" + ], + "kwargs": { + "input_variables": [], + "template": "You are a nice assistant.", + "template_format": "f-string", + "partial_variables": {} + } + } + } + }, + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "HumanMessagePromptTemplate" + ], + "kwargs": { + "prompt": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "prompt", + "PromptTemplate" + ], + "kwargs": { + "input_variables": [ + "question" + ], + "template": "{question}", + "template_format": "f-string", + "partial_variables": {} + } + } + } + } + ] + } + }, + "middle": [ + { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain", + "llms", + "fake", + "FakeListLLM" + ] + } + ], + "last": { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain", + "schema", + "runnable", + "base", + "RunnableLambda" + ] + } + } + } + ''' +# --- +# name: test_prompt_with_llm_and_async_lambda.1 + list([ + Run(id=UUID('00000000-0000-4000-8000-000000000000'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}]}}, 'middle': [{'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}], 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'base', 'RunnableLambda']}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': 'foo'}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=None, child_execution_order=None, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000001'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='prompt', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}]}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptValue'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'SystemMessage'], 'kwargs': {'content': 'You are a nice assistant.', 'additional_kwargs': {}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'HumanMessage'], 'kwargs': {'content': 'What is your name?', 'additional_kwargs': {}}}]}}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000002'), name='FakeListLLM', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='llm', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo', 'bar'], '_type': 'fake-list', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your name?']}, outputs={'generations': [[{'text': 'foo', 'generation_info': None}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000003'), name='RunnableLambda', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'base', 'RunnableLambda']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'input': 'foo'}, outputs={'output': 'foo'}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=None, child_execution_order=None, child_runs=[])]), + ]) +# --- # name: test_router_runnable ''' { diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 544b42da69..8c616770ef 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -1,3 +1,4 @@ +from ast import Not from typing import Any, Dict, List, Optional from uuid import UUID @@ -38,6 +39,7 @@ from langchain.schema.runnable import ( RunnablePassthrough, RunnableSequence, RunnableWithFallbacks, + passthrough, ) @@ -438,6 +440,50 @@ async def test_prompt_with_llm( ) +@pytest.mark.asyncio +@freeze_time("2023-01-01") +async def test_prompt_with_llm_and_async_lambda( + mocker: MockerFixture, snapshot: SnapshotAssertion +) -> None: + prompt = ( + SystemMessagePromptTemplate.from_template("You are a nice assistant.") + + "{question}" + ) + llm = FakeListLLM(responses=["foo", "bar"]) + + async def passthrough(input: Any) -> Any: + return input + + chain = prompt | llm | passthrough + + assert isinstance(chain, RunnableSequence) + assert chain.first == prompt + assert chain.middle == [llm] + assert chain.last == RunnableLambda(func=passthrough) + assert dumps(chain, pretty=True) == snapshot + + # Test invoke + prompt_spy = mocker.spy(prompt.__class__, "ainvoke") + llm_spy = mocker.spy(llm.__class__, "ainvoke") + tracer = FakeTracer() + assert ( + await chain.ainvoke( + {"question": "What is your name?"}, dict(callbacks=[tracer]) + ) + == "foo" + ) + assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} + assert llm_spy.call_args.args[1] == ChatPromptValue( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessage(content="What is your name?"), + ] + ) + assert tracer.runs == snapshot + mocker.stop(prompt_spy) + mocker.stop(llm_spy) + + @freeze_time("2023-01-01") def test_prompt_with_chat_model_and_parser( mocker: MockerFixture, snapshot: SnapshotAssertion From 6d19709b65d98e1ff7f37a459357e5470bd1c4bf Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 17 Aug 2023 15:04:36 +0100 Subject: [PATCH 2/6] RunnableLambda, if func returns a Runnable, run it --- .../langchain/schema/runnable/base.py | 59 +++++++++- .../langchain/schema/runnable/config.py | 9 ++ .../runnable/__snapshots__/test_runnable.ambr | 82 +++++++++++++- .../schema/runnable/test_runnable.py | 104 +++++++++++++++++- 4 files changed, 245 insertions(+), 9 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 722a7c376b..3f6148e949 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -13,7 +13,6 @@ from typing import ( AsyncIterator, Awaitable, Callable, - Coroutine, Dict, Generic, Iterator, @@ -1347,8 +1346,8 @@ class RunnableLambda(Runnable[Input, Output]): def __init__( self, - func: Union[Callable[[Input], Output], Coroutine[Input, Any, Output]], - afunc: Optional[Coroutine[Input, Any, Output]] = None, + func: Union[Callable[[Input], Output], Callable[[Input], Awaitable[Output]]], + afunc: Optional[Callable[[Input], Awaitable[Output]]] = None, ) -> None: if afunc is not None: self.afunc = afunc @@ -1356,7 +1355,7 @@ class RunnableLambda(Runnable[Input, Output]): if inspect.iscoroutinefunction(func): self.afunc = func elif callable(func): - self.func = func + self.func = cast(Callable[[Input], Output], func) else: raise TypeError( "Expected a callable type for `func`." @@ -1374,6 +1373,54 @@ class RunnableLambda(Runnable[Input, Output]): else: return False + def _invoke( + self, + input: Input, + run_manager: CallbackManagerForChainRun, + config: RunnableConfig, + ) -> Output: + output = self.func(input) + # If the output is a runnable, invoke it + if isinstance(output, Runnable): + recursion_limit = config["recursion_limit"] + if recursion_limit <= 0: + raise RecursionError( + f"Recursion limit reached when invoking {self} with input {input}." + ) + output = output.invoke( + input, + patch_config( + config, + callbacks=run_manager.get_child(), + recursion_limit=recursion_limit - 1, + ), + ) + return output + + async def _ainvoke( + self, + input: Input, + run_manager: AsyncCallbackManagerForChainRun, + config: RunnableConfig, + ) -> Output: + output = await self.afunc(input) + # If the output is a runnable, invoke it + if isinstance(output, Runnable): + recursion_limit = config["recursion_limit"] + if recursion_limit <= 0: + raise RecursionError( + f"Recursion limit reached when invoking {self} with input {input}." + ) + output = await output.ainvoke( + input, + patch_config( + config, + callbacks=run_manager.get_child(), + recursion_limit=recursion_limit - 1, + ), + ) + return output + def invoke( self, input: Input, @@ -1381,7 +1428,7 @@ class RunnableLambda(Runnable[Input, Output]): **kwargs: Optional[Any], ) -> Output: if hasattr(self, "func"): - return self._call_with_config(self.func, input, config) + return self._call_with_config(self._invoke, input, config) else: raise TypeError( "Cannot invoke a coroutine function synchronously." @@ -1395,7 +1442,7 @@ class RunnableLambda(Runnable[Input, Output]): **kwargs: Optional[Any], ) -> Output: if hasattr(self, "afunc"): - return await self._acall_with_config(self.afunc, input, config) + return await self._acall_with_config(self._ainvoke, input, config) else: return await super().ainvoke(input, config) diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index 4ad3b4fb8d..ce4e11861e 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -47,6 +47,11 @@ class RunnableConfig(TypedDict, total=False): ThreadPoolExecutor will be created. """ + recursion_limit: int + """ + Maximum number of times a call can recurse. If not provided, defaults to 10. + """ + def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig: empty = RunnableConfig( @@ -54,6 +59,7 @@ def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig: metadata={}, callbacks=None, _locals={}, + recursion_limit=10, ) if config is not None: empty.update(config) @@ -66,6 +72,7 @@ def patch_config( deep_copy_locals: bool = False, callbacks: Optional[BaseCallbackManager] = None, executor: Optional[Executor] = None, + recursion_limit: Optional[int] = None, ) -> RunnableConfig: config = ensure_config(config) if deep_copy_locals: @@ -74,6 +81,8 @@ def patch_config( config["callbacks"] = callbacks if executor is not None: config["executor"] = executor + if recursion_limit is not None: + config["recursion_limit"] = recursion_limit return config diff --git a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr index ac8deb3de2..c48d4edbd4 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr +++ b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr @@ -599,6 +599,83 @@ } ''' # --- +# name: test_higher_order_lambda_runnable + ''' + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableSequence" + ], + "kwargs": { + "first": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableMap" + ], + "kwargs": { + "steps": { + "key": { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain", + "schema", + "runnable", + "base", + "RunnableLambda" + ] + }, + "input": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "schema", + "runnable", + "RunnableMap" + ], + "kwargs": { + "steps": { + "question": { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain", + "schema", + "runnable", + "base", + "RunnableLambda" + ] + } + } + } + } + } + } + }, + "last": { + "lc": 1, + "type": "not_implemented", + "id": [ + "langchain", + "schema", + "runnable", + "base", + "RunnableLambda" + ] + } + } + } + ''' +# --- # name: test_llm_with_fallbacks[llm_chain_with_fallbacks] ''' { @@ -1268,6 +1345,9 @@ } } } + ], + "input_variables": [ + "question" ] } }, @@ -1300,7 +1380,7 @@ # --- # name: test_prompt_with_llm_and_async_lambda.1 list([ - Run(id=UUID('00000000-0000-4000-8000-000000000000'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}]}}, 'middle': [{'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}], 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'base', 'RunnableLambda']}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': 'foo'}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=None, child_execution_order=None, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000001'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='prompt', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}]}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptValue'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'SystemMessage'], 'kwargs': {'content': 'You are a nice assistant.', 'additional_kwargs': {}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'HumanMessage'], 'kwargs': {'content': 'What is your name?', 'additional_kwargs': {}}}]}}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000002'), name='FakeListLLM', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='llm', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo', 'bar'], '_type': 'fake-list', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your name?']}, outputs={'generations': [[{'text': 'foo', 'generation_info': None}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000003'), name='RunnableLambda', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'base', 'RunnableLambda']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'input': 'foo'}, outputs={'output': 'foo'}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=None, child_execution_order=None, child_runs=[])]), + Run(id=UUID('00000000-0000-4000-8000-000000000000'), name='RunnableSequence', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'runnable', 'RunnableSequence'], 'kwargs': {'first': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, 'middle': [{'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}], 'last': {'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'base', 'RunnableLambda']}}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'output': 'foo'}, reference_example_id=None, parent_run_id=None, tags=[], execution_order=None, child_execution_order=None, child_runs=[Run(id=UUID('00000000-0000-4000-8000-000000000001'), name='ChatPromptTemplate', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='prompt', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptTemplate'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'SystemMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': [], 'template': 'You are a nice assistant.', 'template_format': 'f-string', 'partial_variables': {}}}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'HumanMessagePromptTemplate'], 'kwargs': {'prompt': {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'prompt', 'PromptTemplate'], 'kwargs': {'input_variables': ['question'], 'template': '{question}', 'template_format': 'f-string', 'partial_variables': {}}}}}], 'input_variables': ['question']}}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'question': 'What is your name?'}, outputs={'lc': 1, 'type': 'constructor', 'id': ['langchain', 'prompts', 'chat', 'ChatPromptValue'], 'kwargs': {'messages': [{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'SystemMessage'], 'kwargs': {'content': 'You are a nice assistant.', 'additional_kwargs': {}}}, {'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'HumanMessage'], 'kwargs': {'content': 'What is your name?', 'additional_kwargs': {}}}]}}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000002'), name='FakeListLLM', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='llm', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={'invocation_params': {'responses': ['foo', 'bar'], '_type': 'fake-list', 'stop': None}, 'options': {'stop': None}}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'llms', 'fake', 'FakeListLLM']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'prompts': ['System: You are a nice assistant.\nHuman: What is your name?']}, outputs={'generations': [[{'text': 'foo', 'generation_info': None}]], 'llm_output': None, 'run': None}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=None, child_execution_order=None, child_runs=[]), Run(id=UUID('00000000-0000-4000-8000-000000000003'), name='RunnableLambda', start_time=FakeDatetime(2023, 1, 1, 0, 0), run_type='chain', end_time=FakeDatetime(2023, 1, 1, 0, 0), extra={}, error=None, serialized={'lc': 1, 'type': 'not_implemented', 'id': ['langchain', 'schema', 'runnable', 'base', 'RunnableLambda']}, events=[{'name': 'start', 'time': FakeDatetime(2023, 1, 1, 0, 0)}, {'name': 'end', 'time': FakeDatetime(2023, 1, 1, 0, 0)}], inputs={'input': 'foo'}, outputs={'output': 'foo'}, reference_example_id=None, parent_run_id=UUID('00000000-0000-4000-8000-000000000000'), tags=[], execution_order=None, child_execution_order=None, child_runs=[])]), ]) # --- # name: test_router_runnable diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 8c616770ef..0a36408d2c 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -1,4 +1,4 @@ -from ast import Not +from operator import itemgetter from typing import Any, Dict, List, Optional from uuid import UUID @@ -39,7 +39,6 @@ from langchain.schema.runnable import ( RunnablePassthrough, RunnableSequence, RunnableWithFallbacks, - passthrough, ) @@ -178,6 +177,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: tags=[], callbacks=None, _locals={}, + recursion_limit=10, ), ), mocker.call( @@ -187,6 +187,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: tags=[], callbacks=None, _locals={}, + recursion_limit=10, ), ), ] @@ -768,6 +769,105 @@ async def test_router_runnable( assert len(router_run.child_runs) == 2 +@pytest.mark.asyncio +@freeze_time("2023-01-01") +async def test_higher_order_lambda_runnable( + mocker: MockerFixture, snapshot: SnapshotAssertion +) -> None: + math_chain = ChatPromptTemplate.from_template( + "You are a math genius. Answer the question: {question}" + ) | FakeListLLM(responses=["4"]) + english_chain = ChatPromptTemplate.from_template( + "You are an english major. Answer the question: {question}" + ) | FakeListLLM(responses=["2"]) + input_map = RunnableMap( + { # type: ignore[arg-type] + "key": lambda x: x["key"], + "input": {"question": lambda x: x["question"]}, + } + ) + + def router(input: Dict[str, Any]) -> Runnable: + if input["key"] == "math": + return itemgetter("input") | math_chain + elif input["key"] == "english": + return itemgetter("input") | english_chain + else: + raise ValueError(f"Unknown key: {input['key']}") + + chain: Runnable = input_map | router + assert dumps(chain, pretty=True) == snapshot + + result = chain.invoke({"key": "math", "question": "2 + 2"}) + assert result == "4" + + result2 = chain.batch( + [{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}] + ) + assert result2 == ["4", "2"] + + result = await chain.ainvoke({"key": "math", "question": "2 + 2"}) + assert result == "4" + + result2 = await chain.abatch( + [{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}] + ) + assert result2 == ["4", "2"] + + # Test invoke + math_spy = mocker.spy(math_chain.__class__, "invoke") + tracer = FakeTracer() + assert ( + chain.invoke({"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer])) + == "4" + ) + assert math_spy.call_args.args[1] == { + "key": "math", + "input": {"question": "2 + 2"}, + } + assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1 + parent_run = next(r for r in tracer.runs if r.parent_run_id is None) + assert len(parent_run.child_runs) == 2 + router_run = parent_run.child_runs[1] + assert router_run.name == "RunnableLambda" + assert len(router_run.child_runs) == 1 + math_run = router_run.child_runs[0] + assert math_run.name == "RunnableSequence" + assert len(math_run.child_runs) == 3 + + # Test ainvoke + async def arouter(input: Dict[str, Any]) -> Runnable: + if input["key"] == "math": + return itemgetter("input") | math_chain + elif input["key"] == "english": + return itemgetter("input") | english_chain + else: + raise ValueError(f"Unknown key: {input['key']}") + + achain: Runnable = input_map | arouter + math_spy = mocker.spy(math_chain.__class__, "ainvoke") + tracer = FakeTracer() + assert ( + await achain.ainvoke( + {"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer]) + ) + == "4" + ) + assert math_spy.call_args.args[1] == { + "key": "math", + "input": {"question": "2 + 2"}, + } + assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1 + parent_run = next(r for r in tracer.runs if r.parent_run_id is None) + assert len(parent_run.child_runs) == 2 + router_run = parent_run.child_runs[1] + assert router_run.name == "RunnableLambda" + assert len(router_run.child_runs) == 1 + math_run = router_run.child_runs[0] + assert math_run.name == "RunnableSequence" + assert len(math_run.child_runs) == 3 + + @freeze_time("2023-01-01") def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None: passthrough = mocker.Mock(side_effect=lambda x: x) From c326751085ad2f89a21a74c86c9a3597904bdebe Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 18 Aug 2023 09:22:35 +0100 Subject: [PATCH 3/6] Lint --- .../langchain/tests/unit_tests/schema/runnable/test_runnable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 0a36408d2c..5214f5b366 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -780,7 +780,7 @@ async def test_higher_order_lambda_runnable( english_chain = ChatPromptTemplate.from_template( "You are an english major. Answer the question: {question}" ) | FakeListLLM(responses=["2"]) - input_map = RunnableMap( + input_map: Runnable = RunnableMap( { # type: ignore[arg-type] "key": lambda x: x["key"], "input": {"question": lambda x: x["question"]}, From da18e177f13023da59c64234f6ad76b845abee3e Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Fri, 18 Aug 2023 09:21:05 -0700 Subject: [PATCH 4/6] Update libs/langchain/langchain/schema/runnable/base.py Co-authored-by: Eugene Yurtsev --- libs/langchain/langchain/schema/runnable/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 3f6148e949..a4c95c12ee 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -1444,6 +1444,8 @@ class RunnableLambda(Runnable[Input, Output]): if hasattr(self, "afunc"): return await self._acall_with_config(self._ainvoke, input, config) else: + # Delegating to super implementation of ainvoke. + # Uses asyncio executor to run the sync version (invoke) return await super().ainvoke(input, config) From 6424b3cde0136c0b5fb8698bc764435692b017bd Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 23 Aug 2023 20:02:35 +0100 Subject: [PATCH 5/6] Add another test --- .../unit_tests/schema/runnable/test_runnable.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 5214f5b366..f244753310 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -1,5 +1,5 @@ from operator import itemgetter -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from uuid import UUID import pytest @@ -1282,3 +1282,17 @@ def test_each(snapshot: SnapshotAssertion) -> None: "test", "this", ] + + +def test_recursive_lambda() -> None: + def _simple_recursion(x: int) -> Union[int, Runnable]: + if x < 10: + return RunnableLambda(lambda *args: _simple_recursion(x + 1)) + else: + return x + + runnable = RunnableLambda(_simple_recursion) + assert runnable.invoke(5) == 10 + + with pytest.raises(RecursionError): + runnable.invoke(0, {"recursion_limit": 9}) From 20ce283fa7ff7efa5447e922ffa64b3889603c4a Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 23 Aug 2023 20:03:35 +0100 Subject: [PATCH 6/6] Format --- libs/langchain/langchain/schema/runnable/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index a4c95c12ee..a130dc62b8 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -1445,7 +1445,7 @@ class RunnableLambda(Runnable[Input, Output]): return await self._acall_with_config(self._ainvoke, input, config) else: # Delegating to super implementation of ainvoke. - # Uses asyncio executor to run the sync version (invoke) + # Uses asyncio executor to run the sync version (invoke) return await super().ainvoke(input, config)