From 9cbf14dec2a871d69b8d0bf3ccb076005950f565 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 2 Jan 2024 12:16:39 -0800 Subject: [PATCH] Fetch runnable config from context var inside runnable lambda and runnable generator (#15334) - easier to write custom logic/loops with automatic tracing - if you don't want to streaming support write a regular function and pass to RunnableLambda - if you do want streaming write a generator and pass it to RunnableGenerator ```py import json from typing import AsyncIterator from langchain_core.messages import BaseMessage, FunctionMessage, HumanMessage from langchain_core.agents import AgentAction, AgentFinish from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables import Runnable, RunnableGenerator, RunnablePassthrough from langchain_core.tools import BaseTool from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser from langchain.chat_models import ChatOpenAI from langchain.tools.render import format_tool_to_openai_function def _get_tavily(): from langchain.tools.tavily_search import TavilySearchResults from langchain.utilities.tavily_search import TavilySearchAPIWrapper tavily_search = TavilySearchAPIWrapper() return TavilySearchResults(api_wrapper=tavily_search) async def _agent_executor_generator( input: AsyncIterator[list[BaseMessage]], *, max_iterations: int = 10, tools: dict[str, BaseTool], agent: Runnable[list[BaseMessage], BaseMessage], parser: Runnable[BaseMessage, AgentAction | AgentFinish], ) -> AsyncIterator[BaseMessage]: messages = [m async for mm in input for m in mm] for _ in range(max_iterations): next_message = await agent.ainvoke(messages) yield next_message messages.append(next_message) parsed = await parser.ainvoke(next_message) if isinstance(parsed, AgentAction): result = await tools[parsed.tool].ainvoke(parsed.tool_input) next_message = FunctionMessage(name=parsed.tool, content=json.dumps(result)) yield next_message messages.append(next_message) elif isinstance(parsed, AgentFinish): return def get_agent_executor(tools: list[BaseTool], system_message: str): llm = ChatOpenAI(model="gpt-4-1106-preview", temperature=0, streaming=True) prompt = ChatPromptTemplate.from_messages( [ ("system", system_message), MessagesPlaceholder(variable_name="messages"), ] ) llm_with_tools = llm.bind( functions=[format_tool_to_openai_function(t) for t in tools] ) agent = {"messages": RunnablePassthrough()} | prompt | llm_with_tools parser = OpenAIFunctionsAgentOutputParser() executor = RunnableGenerator(_agent_executor_generator) return executor.bind( tools={tool.name for tool in tools}, agent=agent, parser=parser ) agent = get_agent_executor([_get_tavily()], "You are a very nice agent!") async def main(): async for message in agent.astream( [HumanMessage(content="whats the weather in sf tomorrow?")] ): print(message) if __name__ == "__main__": import asyncio asyncio.run(main()) ``` results in this trace https://smith.langchain.com/public/fa17f05d-9724-4d08-8fa1-750f8fcd051b/r --- .../core/langchain_core/runnables/__init__.py | 2 + libs/core/langchain_core/runnables/base.py | 363 +++++++++++++++--- libs/core/langchain_core/runnables/config.py | 6 +- libs/core/langchain_core/runnables/utils.py | 8 + .../unit_tests/runnables/test_context.py | 48 +-- .../unit_tests/runnables/test_imports.py | 1 + .../unit_tests/runnables/test_runnable.py | 322 +++++++++++++++- 7 files changed, 656 insertions(+), 94 deletions(-) diff --git a/libs/core/langchain_core/runnables/__init__.py b/libs/core/langchain_core/runnables/__init__.py index 2d23a78dc1..903eefd079 100644 --- a/libs/core/langchain_core/runnables/__init__.py +++ b/libs/core/langchain_core/runnables/__init__.py @@ -23,6 +23,7 @@ from langchain_core.runnables.base import ( RunnableParallel, RunnableSequence, RunnableSerializable, + chain, ) from langchain_core.runnables.branch import RunnableBranch from langchain_core.runnables.config import ( @@ -50,6 +51,7 @@ from langchain_core.runnables.utils import ( ) __all__ = [ + "chain", "AddableDict", "ConfigurableField", "ConfigurableFieldSingleOption", diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 5bf6e8d98c..ac98aecf6b 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -1,10 +1,12 @@ from __future__ import annotations import asyncio +import collections import inspect import threading from abc import ABC, abstractmethod from concurrent.futures import FIRST_COMPLETED, wait +from contextvars import copy_context from copy import deepcopy from functools import wraps from itertools import groupby, tee @@ -15,6 +17,7 @@ from typing import ( AsyncIterator, Awaitable, Callable, + Coroutine, Dict, Generic, Iterator, @@ -48,6 +51,7 @@ from langchain_core.runnables.config import ( merge_configs, patch_config, run_in_executor, + var_child_runnable_config, ) from langchain_core.runnables.graph import Graph from langchain_core.runnables.utils import ( @@ -58,6 +62,7 @@ from langchain_core.runnables.utils import ( Input, Output, accepts_config, + accepts_context, accepts_run_manager, gather_with_concurrency, get_function_first_arg_dict_keys, @@ -950,8 +955,19 @@ class Runnable(Generic[Input, Output], ABC): name=config.get("run_name") or self.get_name(), ) try: - output = call_func_with_variable_args( - func, input, config, run_manager, **kwargs + child_config = patch_config(config, callbacks=run_manager.get_child()) + context = copy_context() + context.run(var_child_runnable_config.set, child_config) + output = cast( + Output, + context.run( + call_func_with_variable_args, + func, # type: ignore[arg-type] + input, # type: ignore[arg-type] + config, + run_manager, + **kwargs, + ), ) except BaseException as e: run_manager.on_chain_error(e) @@ -986,9 +1002,16 @@ class Runnable(Generic[Input, Output], ABC): name=config.get("run_name") or self.get_name(), ) try: - output = await acall_func_with_variable_args( + child_config = patch_config(config, callbacks=run_manager.get_child()) + context = copy_context() + context.run(var_child_runnable_config.set, child_config) + coro = acall_func_with_variable_args( func, input, config, run_manager, **kwargs ) + if accepts_context(asyncio.create_task): + output: Output = await asyncio.create_task(coro, context=context) # type: ignore + else: + output = await coro except BaseException as e: await run_manager.on_chain_error(e) raise @@ -1178,24 +1201,29 @@ class Runnable(Generic[Input, Output], ABC): name=config.get("run_name") or self.get_name(), ) try: + child_config = patch_config(config, callbacks=run_manager.get_child()) if accepts_config(transformer): - kwargs["config"] = patch_config( - config, callbacks=run_manager.get_child() - ) + kwargs["config"] = child_config if accepts_run_manager(transformer): kwargs["run_manager"] = run_manager - iterator = transformer(input_for_transform, **kwargs) # type: ignore[call-arg] - for chunk in iterator: - yield chunk - if final_output_supported: - if final_output is None: - final_output = chunk - else: - try: - final_output = final_output + chunk # type: ignore - except TypeError: - final_output = None - final_output_supported = False + context = copy_context() + context.run(var_child_runnable_config.set, child_config) + iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type] + try: + while True: + chunk: Output = context.run(next, iterator) # type: ignore + yield chunk + if final_output_supported: + if final_output is None: + final_output = chunk + else: + try: + final_output = final_output + chunk # type: ignore + except TypeError: + final_output = None + final_output_supported = False + except StopIteration: + pass for ichunk in input_for_tracing: if final_input_supported: if final_input is None: @@ -1254,24 +1282,35 @@ class Runnable(Generic[Input, Output], ABC): name=config.get("run_name") or self.get_name(), ) try: + child_config = patch_config(config, callbacks=run_manager.get_child()) if accepts_config(transformer): - kwargs["config"] = patch_config( - config, callbacks=run_manager.get_child() - ) + kwargs["config"] = child_config if accepts_run_manager(transformer): kwargs["run_manager"] = run_manager - iterator = transformer(input_for_transform, **kwargs) # type: ignore[call-arg] - async for chunk in iterator: - yield chunk - if final_output_supported: - if final_output is None: - final_output = chunk + context = copy_context() + context.run(var_child_runnable_config.set, child_config) + iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type] + try: + while True: + if accepts_context(asyncio.create_task): + chunk: Output = await asyncio.create_task( # type: ignore[call-arg] + py_anext(iterator), # type: ignore[arg-type] + context=context, + ) else: - try: - final_output = final_output + chunk # type: ignore - except TypeError: - final_output = None - final_output_supported = False + chunk = cast(Output, await py_anext(iterator)) + yield chunk + if final_output_supported: + if final_output is None: + final_output = chunk + else: + try: + final_output = final_output + chunk # type: ignore + except TypeError: + final_output = None + final_output_supported = False + except StopAsyncIteration: + pass async for ichunk in input_for_tracing: if final_input_supported: if final_input is None: @@ -1472,7 +1511,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): .. code-block:: python from langchain_core.output_parsers.json import SimpleJsonOutputParser - from langchain_core.chat_models.openai import ChatOpenAI + from langchain.chat_models.openai import ChatOpenAI prompt = PromptTemplate.from_template( 'In JSON format, give me a list of {topic} and their ' @@ -2482,17 +2521,25 @@ class RunnableGenerator(Runnable[Input, Output]): ) -> None: if atransform is not None: self._atransform = atransform + func_for_name: Callable = atransform if inspect.isasyncgenfunction(transform): self._atransform = transform + func_for_name = transform elif inspect.isgeneratorfunction(transform): self._transform = transform + func_for_name = transform else: raise TypeError( "Expected a generator function type for `transform`." f"Instead got an unsupported type: {type(transform)}" ) + try: + self.name = func_for_name.__name__ + except AttributeError: + pass + @property def InputType(self) -> Any: func = getattr(self, "_transform", None) or getattr(self, "_atransform") @@ -2646,12 +2693,14 @@ class RunnableLambda(Runnable[Input, Output]): func: Union[ Union[ Callable[[Input], Output], + Callable[[Input], Iterator[Output]], Callable[[Input, RunnableConfig], Output], Callable[[Input, CallbackManagerForChainRun], Output], Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output], ], Union[ Callable[[Input], Awaitable[Output]], + Callable[[Input], AsyncIterator[Output]], Callable[[Input, RunnableConfig], Awaitable[Output]], Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]], Callable[ @@ -2663,6 +2712,7 @@ class RunnableLambda(Runnable[Input, Output]): afunc: Optional[ Union[ Callable[[Input], Awaitable[Output]], + Callable[[Input], AsyncIterator[Output]], Callable[[Input, RunnableConfig], Awaitable[Output]], Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]], Callable[ @@ -2685,7 +2735,7 @@ class RunnableLambda(Runnable[Input, Output]): self.afunc = afunc func_for_name: Callable = afunc - if inspect.iscoroutinefunction(func): + if inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func): if afunc is not None: raise TypeError( "Func was provided as a coroutine function, but afunc was " @@ -2767,11 +2817,16 @@ class RunnableLambda(Runnable[Input, Output]): func = getattr(self, "func", None) or getattr(self, "afunc") try: sig = inspect.signature(func) - return ( - sig.return_annotation - if sig.return_annotation != inspect.Signature.empty - else Any - ) + if sig.return_annotation != inspect.Signature.empty: + # unwrap iterator types + if getattr(sig.return_annotation, "__origin__", None) in ( + collections.abc.Iterator, + collections.abc.AsyncIterator, + ): + return getattr(sig.return_annotation, "__args__", (Any,))[0] + return sig.return_annotation + else: + return Any except ValueError: return Any @@ -2848,9 +2903,26 @@ class RunnableLambda(Runnable[Input, Output]): config: RunnableConfig, **kwargs: Any, ) -> Output: - output = call_func_with_variable_args( - self.func, input, config, run_manager, **kwargs - ) + if inspect.isgeneratorfunction(self.func): + output: Optional[Output] = None + for chunk in call_func_with_variable_args( + cast(Callable[[Input], Iterator[Output]], self.func), + input, + config, + run_manager, + **kwargs, + ): + if output is None: + output = chunk + else: + try: + output = output + chunk # type: ignore[operator] + except TypeError: + output = chunk + else: + output = call_func_with_variable_args( + self.func, input, config, run_manager, **kwargs + ) # If the output is a runnable, invoke it if isinstance(output, Runnable): recursion_limit = config["recursion_limit"] @@ -2866,7 +2938,7 @@ class RunnableLambda(Runnable[Input, Output]): recursion_limit=recursion_limit - 1, ), ) - return output + return cast(Output, output) async def _ainvoke( self, @@ -2878,16 +2950,69 @@ class RunnableLambda(Runnable[Input, Output]): if hasattr(self, "afunc"): afunc = self.afunc else: + if inspect.isgeneratorfunction(self.func): - @wraps(self.func) + def func( + input: Input, + run_manager: AsyncCallbackManagerForChainRun, + config: RunnableConfig, + ) -> Output: + output: Optional[Output] = None + for chunk in call_func_with_variable_args( + cast(Callable[[Input], Iterator[Output]], self.func), + input, + config, + run_manager.get_sync(), + **kwargs, + ): + if output is None: + output = chunk + else: + try: + output = output + chunk # type: ignore[operator] + except TypeError: + output = chunk + return cast(Output, output) + else: + + def func( + input: Input, + run_manager: AsyncCallbackManagerForChainRun, + config: RunnableConfig, + ) -> Output: + return call_func_with_variable_args( + self.func, input, config, run_manager.get_sync(), **kwargs + ) + + @wraps(func) async def f(*args, **kwargs): # type: ignore[no-untyped-def] - return await run_in_executor(config, self.func, *args, **kwargs) + return await run_in_executor(config, func, *args, **kwargs) afunc = f - output = await acall_func_with_variable_args( - afunc, input, config, run_manager, **kwargs - ) + if inspect.isasyncgenfunction(afunc): + output: Optional[Output] = None + async for chunk in cast( + AsyncIterator[Output], + acall_func_with_variable_args( + cast(Callable, afunc), + input, + config, + run_manager, + **kwargs, + ), + ): + if output is None: + output = chunk + else: + try: + output = output + chunk # type: ignore[operator] + except TypeError: + output = chunk + else: + output = await acall_func_with_variable_args( + cast(Callable, afunc), input, config, run_manager, **kwargs + ) # If the output is a runnable, invoke it if isinstance(output, Runnable): recursion_limit = config["recursion_limit"] @@ -2903,7 +3028,7 @@ class RunnableLambda(Runnable[Input, Output]): recursion_limit=recursion_limit - 1, ), ) - return output + return cast(Output, output) def _config( self, config: Optional[RunnableConfig], callable: Callable[..., Any] @@ -2972,9 +3097,23 @@ class RunnableLambda(Runnable[Input, Output]): except TypeError: final = ichunk - output = call_func_with_variable_args( - self.func, cast(Input, final), config, run_manager, **kwargs - ) + if inspect.isgeneratorfunction(self.func): + output: Optional[Output] = None + for chunk in call_func_with_variable_args( + self.func, cast(Input, final), config, run_manager, **kwargs + ): + yield chunk + if output is None: + output = chunk + else: + try: + output = output + chunk + except TypeError: + output = chunk + else: + output = call_func_with_variable_args( + self.func, cast(Input, final), config, run_manager, **kwargs + ) # If the output is a runnable, use its stream output if isinstance(output, Runnable): @@ -2993,9 +3132,9 @@ class RunnableLambda(Runnable[Input, Output]): ), ): yield chunk - else: + elif not inspect.isgeneratorfunction(self.func): # Otherwise, just yield it - yield output + yield cast(Output, output) def transform( self, @@ -3030,6 +3169,7 @@ class RunnableLambda(Runnable[Input, Output]): input: AsyncIterator[Input], run_manager: AsyncCallbackManagerForChainRun, config: RunnableConfig, + **kwargs: Any, ) -> AsyncIterator[Output]: final: Optional[Input] = None async for ichunk in input: @@ -3044,16 +3184,51 @@ class RunnableLambda(Runnable[Input, Output]): if hasattr(self, "afunc"): afunc = self.afunc else: + if inspect.isgeneratorfunction(self.func): + raise TypeError( + "Cannot stream from a generator function asynchronously." + "Use .stream() instead." + ) - @wraps(self.func) + def func( + input: Input, + run_manager: AsyncCallbackManagerForChainRun, + config: RunnableConfig, + ) -> Output: + return call_func_with_variable_args( + self.func, input, config, run_manager.get_sync(), **kwargs + ) + + @wraps(func) async def f(*args, **kwargs): # type: ignore[no-untyped-def] - return await run_in_executor(config, self.func, *args, **kwargs) + return await run_in_executor(config, func, *args, **kwargs) afunc = f - output = await acall_func_with_variable_args( - afunc, cast(Input, final), config, run_manager - ) + if inspect.isasyncgenfunction(afunc): + output: Optional[Output] = None + async for chunk in cast( + AsyncIterator[Output], + acall_func_with_variable_args( + cast(Callable, afunc), + cast(Input, final), + config, + run_manager, + **kwargs, + ), + ): + yield chunk + if output is None: + output = chunk + else: + try: + output = output + chunk # type: ignore[operator] + except TypeError: + output = chunk + else: + output = await acall_func_with_variable_args( + cast(Callable, afunc), cast(Input, final), config, run_manager, **kwargs + ) # If the output is a runnable, use its astream output if isinstance(output, Runnable): @@ -3072,9 +3247,9 @@ class RunnableLambda(Runnable[Input, Output]): ), ): yield chunk - else: + elif not inspect.isasyncgenfunction(afunc): # Otherwise, just yield it - yield output + yield cast(Output, output) async def atransform( self, @@ -3699,3 +3874,69 @@ def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]: f"Expected a Runnable, callable or dict." f"Instead got an unsupported type: {type(thing)}" ) + + +@overload +def chain( + func: Callable[[Input], Coroutine[Any, Any, Output]], +) -> Runnable[Input, Output]: + ... + + +@overload +def chain( + func: Callable[[Input], Iterator[Output]], +) -> Runnable[Input, Output]: + ... + + +@overload +def chain( + func: Callable[[Input], AsyncIterator[Output]], +) -> Runnable[Input, Output]: + ... + + +@overload +def chain( + func: Callable[[Input], Output], +) -> Runnable[Input, Output]: + ... + + +def chain( + func: Union[ + Callable[[Input], Output], + Callable[[Input], Iterator[Output]], + Callable[[Input], Coroutine[Any, Any, Output]], + Callable[[Input], AsyncIterator[Output]], + ], +) -> Runnable[Input, Output]: + """Decorate a function to make it a Runnable. + Sets the name of the runnable to the name of the function. + Any runnables called by the function will be traced as dependencies. + + Args: + func: A callable. + + Returns: + A Runnable. + + Example: + + .. code-block:: python + + from langchain_core.runnables import chain + from langchain_core.prompts import PromptTemplate + from langchain.llms import OpenAI + + @chain + def my_func(fields): + prompt = PromptTemplate("Hello, {name}!") + llm = OpenAI() + formatted = prompt.invoke(**fields) + + for chunk in llm.stream(formatted): + yield chunk + """ + return RunnableLambda(func) diff --git a/libs/core/langchain_core/runnables/config.py b/libs/core/langchain_core/runnables/config.py index 080dfa9cdb..bd9330a1fc 100644 --- a/libs/core/langchain_core/runnables/config.py +++ b/libs/core/langchain_core/runnables/config.py @@ -323,7 +323,7 @@ def call_func_with_variable_args( return func(input, **kwargs) # type: ignore[call-arg] -async def acall_func_with_variable_args( +def acall_func_with_variable_args( func: Union[ Callable[[Input], Awaitable[Output]], Callable[[Input, RunnableConfig], Awaitable[Output]], @@ -337,7 +337,7 @@ async def acall_func_with_variable_args( config: RunnableConfig, run_manager: Optional[AsyncCallbackManagerForChainRun] = None, **kwargs: Any, -) -> Output: +) -> Awaitable[Output]: """Call function that may optionally accept a run_manager and/or config. Args: @@ -361,7 +361,7 @@ async def acall_func_with_variable_args( kwargs["config"] = config if run_manager is not None and accepts_run_manager(func): kwargs["run_manager"] = run_manager - return await func(input, **kwargs) # type: ignore[call-arg] + return func(input, **kwargs) # type: ignore[call-arg] def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager: diff --git a/libs/core/langchain_core/runnables/utils.py b/libs/core/langchain_core/runnables/utils.py index 0da6b97b26..bd629194b6 100644 --- a/libs/core/langchain_core/runnables/utils.py +++ b/libs/core/langchain_core/runnables/utils.py @@ -68,6 +68,14 @@ def accepts_config(callable: Callable[..., Any]) -> bool: return False +def accepts_context(callable: Callable[..., Any]) -> bool: + """Check if a callable accepts a context argument.""" + try: + return signature(callable).parameters.get("context") is not None + except ValueError: + return False + + class IsLocalDict(ast.NodeVisitor): """Check if a name is a local dict.""" diff --git a/libs/core/tests/unit_tests/runnables/test_context.py b/libs/core/tests/unit_tests/runnables/test_context.py index bf7901d037..7a61eaec86 100644 --- a/libs/core/tests/unit_tests/runnables/test_context.py +++ b/libs/core/tests/unit_tests/runnables/test_context.py @@ -12,7 +12,7 @@ from langchain_core.runnables.utils import aadd, add from tests.unit_tests.fake.llm import FakeListLLM, FakeStreamingListLLM -class TestCase(NamedTuple): +class _TestCase(NamedTuple): input: Any output: Any @@ -102,22 +102,22 @@ test_cases = [ ( Context.setter("foo") | Context.getter("foo"), ( - TestCase("foo", "foo"), - TestCase("bar", "bar"), + _TestCase("foo", "foo"), + _TestCase("bar", "bar"), ), ), ( Context.setter("input") | {"bar": Context.getter("input")}, ( - TestCase("foo", {"bar": "foo"}), - TestCase("bar", {"bar": "bar"}), + _TestCase("foo", {"bar": "foo"}), + _TestCase("bar", {"bar": "bar"}), ), ), ( {"bar": Context.setter("input")} | Context.getter("input"), ( - TestCase("foo", "foo"), - TestCase("bar", "bar"), + _TestCase("foo", "foo"), + _TestCase("bar", "bar"), ), ), ( @@ -132,11 +132,11 @@ test_cases = [ } ), ( - TestCase( + _TestCase( {"foo": "foo", "bar": "bar"}, {"response": "hello", "prompt": StringPromptValue(text="foo bar")}, ), - TestCase( + _TestCase( {"foo": "bar", "bar": "foo"}, {"response": "hello", "prompt": StringPromptValue(text="bar foo")}, ), @@ -155,7 +155,7 @@ test_cases = [ } ), ( - TestCase( + _TestCase( {"foo": "foo", "bar": "bar"}, { "response": "hello", @@ -163,7 +163,7 @@ test_cases = [ "prompt_str": "foo bar", }, ), - TestCase( + _TestCase( {"foo": "bar", "bar": "foo"}, { "response": "hello", @@ -185,11 +185,11 @@ test_cases = [ } ), ( - TestCase( + _TestCase( {"foo": "foo", "bar": "bar"}, {"response": "hello", "prompt_str": "foo bar"}, ), - TestCase( + _TestCase( {"foo": "bar", "bar": "foo"}, {"response": "hello", "prompt_str": "bar foo"}, ), @@ -207,11 +207,11 @@ test_cases = [ } ), ( - TestCase( + _TestCase( {"foo": "foo", "bar": "bar"}, {"response": "hello", "prompt_str": "foo bar"}, ), - TestCase( + _TestCase( {"foo": "bar", "bar": "foo"}, {"response": "hello", "prompt_str": "bar foo"}, ), @@ -229,11 +229,11 @@ test_cases = [ } ), ( - TestCase( + _TestCase( {"foo": "foo", "bar": "bar"}, {"response": "hello", "prompt": StringPromptValue(text="foo bar")}, ), - TestCase( + _TestCase( {"foo": "bar", "bar": "foo"}, {"response": "hello", "prompt": StringPromptValue(text="bar foo")}, ), @@ -242,7 +242,7 @@ test_cases = [ ( seq_naive_rag, ( - TestCase( + _TestCase( "What up", { "result": "hello", @@ -254,7 +254,7 @@ test_cases = [ "input": "What up", }, ), - TestCase( + _TestCase( "Howdy", { "result": "hello", @@ -271,7 +271,7 @@ test_cases = [ ( seq_naive_rag_alt, ( - TestCase( + _TestCase( "What up", { "result": "hello", @@ -283,7 +283,7 @@ test_cases = [ "input": "What up", }, ), - TestCase( + _TestCase( "Howdy", { "result": "hello", @@ -300,7 +300,7 @@ test_cases = [ ( seq_naive_rag_scoped, ( - TestCase( + _TestCase( "What up", { "result": "hello", @@ -312,7 +312,7 @@ test_cases = [ "input": "What up", }, ), - TestCase( + _TestCase( "Howdy", { "result": "hello", @@ -331,7 +331,7 @@ test_cases = [ @pytest.mark.parametrize("runnable, cases", test_cases) async def test_context_runnables( - runnable: Union[Runnable, Callable[[], Runnable]], cases: List[TestCase] + runnable: Union[Runnable, Callable[[], Runnable]], cases: List[_TestCase] ) -> None: runnable = runnable if isinstance(runnable, Runnable) else runnable() assert runnable.invoke(cases[0].input) == cases[0].output diff --git a/libs/core/tests/unit_tests/runnables/test_imports.py b/libs/core/tests/unit_tests/runnables/test_imports.py index 8300292af1..48098aa6db 100644 --- a/libs/core/tests/unit_tests/runnables/test_imports.py +++ b/libs/core/tests/unit_tests/runnables/test_imports.py @@ -1,6 +1,7 @@ from langchain_core.runnables import __all__ EXPECTED_ALL = [ + "chain", "AddableDict", "ConfigurableField", "ConfigurableFieldSingleOption", diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 90e29aed73..23458f70b0 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -68,6 +68,7 @@ from langchain_core.runnables import ( RunnableSequence, RunnableWithFallbacks, add, + chain, ) from langchain_core.tools import BaseTool, tool from langchain_core.tracers import ( @@ -4388,9 +4389,9 @@ async def test_runnable_gen() -> None: runnable = RunnableGenerator(gen) - assert runnable.input_schema.schema() == {"title": "RunnableGeneratorInput"} + assert runnable.input_schema.schema() == {"title": "gen_input"} assert runnable.output_schema.schema() == { - "title": "RunnableGeneratorOutput", + "title": "gen_output", "type": "integer", } @@ -4410,6 +4411,315 @@ async def test_runnable_gen() -> None: assert await arunnable.abatch([None, None]) == [6, 6] +async def test_runnable_gen_context_config() -> None: + """Test that a generator can call other runnables with config + propagated from the context.""" + + fake = RunnableLambda(len) + + def gen(input: Iterator[Any]) -> Iterator[int]: + yield fake.invoke("a") + yield fake.invoke("aa") + yield fake.invoke("aaa") + + runnable = RunnableGenerator(gen) + + assert runnable.input_schema.schema() == {"title": "gen_input"} + assert runnable.output_schema.schema() == { + "title": "gen_output", + "type": "integer", + } + + tracer = FakeTracer() + assert runnable.invoke(None, {"callbacks": [tracer]}) == 6 + assert len(tracer.runs) == 1 + assert tracer.runs[0].outputs == {"output": 6} + assert len(tracer.runs[0].child_runs) == 3 + assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"] + assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3] + tracer.runs.clear() + + assert list(runnable.stream(None)) == [1, 2, 3] + assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call" + + tracer = FakeTracer() + assert list(runnable.stream(None, {"callbacks": [tracer]})) == [1, 2, 3] + assert len(tracer.runs) == 1 + assert tracer.runs[0].outputs == {"output": 6} + assert len(tracer.runs[0].child_runs) == 3 + assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"] + assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3] + + tracer = FakeTracer() + assert runnable.batch([None, None], {"callbacks": [tracer]}) == [6, 6] + assert len(tracer.runs) == 2 + assert tracer.runs[0].outputs == {"output": 6} + assert tracer.runs[1].outputs == {"output": 6} + assert len(tracer.runs[0].child_runs) == 3 + assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"] + assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3] + assert len(tracer.runs[1].child_runs) == 3 + assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"] + assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3] + + if sys.version_info < (3, 11): + # Python 3.10 and below don't support running async tasks in a specific context + return + + async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]: + yield await fake.ainvoke("a") + yield await fake.ainvoke("aa") + yield await fake.ainvoke("aaa") + + arunnable = RunnableGenerator(agen) + + tracer = FakeTracer() + assert await arunnable.ainvoke(None, {"callbacks": [tracer]}) == 6 + assert len(tracer.runs) == 1 + assert tracer.runs[0].outputs == {"output": 6} + assert len(tracer.runs[0].child_runs) == 3 + assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"] + assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3] + tracer.runs.clear() + + assert [p async for p in arunnable.astream(None)] == [1, 2, 3] + assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call" + + tracer = FakeTracer() + assert [p async for p in arunnable.astream(None, {"callbacks": [tracer]})] == [ + 1, + 2, + 3, + ] + assert len(tracer.runs) == 1 + assert tracer.runs[0].outputs == {"output": 6} + assert len(tracer.runs[0].child_runs) == 3 + assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"] + assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3] + + tracer = FakeTracer() + assert await arunnable.abatch([None, None], {"callbacks": [tracer]}) == [6, 6] + assert len(tracer.runs) == 2 + assert tracer.runs[0].outputs == {"output": 6} + assert tracer.runs[1].outputs == {"output": 6} + assert len(tracer.runs[0].child_runs) == 3 + assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"] + assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3] + assert len(tracer.runs[1].child_runs) == 3 + assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"] + assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3] + + +async def test_runnable_iter_context_config() -> None: + """Test that a generator can call other runnables with config + propagated from the context.""" + + fake = RunnableLambda(len) + + @chain + def gen(input: str) -> Iterator[int]: + yield fake.invoke(input) + yield fake.invoke(input * 2) + yield fake.invoke(input * 3) + + assert gen.input_schema.schema() == { + "title": "gen_input", + "type": "string", + } + assert gen.output_schema.schema() == { + "title": "gen_output", + "type": "integer", + } + + tracer = FakeTracer() + assert gen.invoke("a", {"callbacks": [tracer]}) == 6 + assert len(tracer.runs) == 1 + assert tracer.runs[0].outputs == {"output": 6} + assert len(tracer.runs[0].child_runs) == 3 + assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"] + assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3] + tracer.runs.clear() + + assert list(gen.stream("a")) == [1, 2, 3] + assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call" + + tracer = FakeTracer() + assert list(gen.stream("a", {"callbacks": [tracer]})) == [1, 2, 3] + assert len(tracer.runs) == 1 + assert tracer.runs[0].outputs == {"output": 6} + assert len(tracer.runs[0].child_runs) == 3 + assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"] + assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3] + + tracer = FakeTracer() + assert gen.batch(["a", "a"], {"callbacks": [tracer]}) == [6, 6] + assert len(tracer.runs) == 2 + assert tracer.runs[0].outputs == {"output": 6} + assert tracer.runs[1].outputs == {"output": 6} + assert len(tracer.runs[0].child_runs) == 3 + assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"] + assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3] + assert len(tracer.runs[1].child_runs) == 3 + assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"] + assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3] + + if sys.version_info < (3, 11): + # Python 3.10 and below don't support running async tasks in a specific context + return + + @chain + async def agen(input: str) -> AsyncIterator[int]: + yield await fake.ainvoke(input) + yield await fake.ainvoke(input * 2) + yield await fake.ainvoke(input * 3) + + assert agen.input_schema.schema() == { + "title": "agen_input", + "type": "string", + } + assert agen.output_schema.schema() == { + "title": "agen_output", + "type": "integer", + } + + tracer = FakeTracer() + assert await agen.ainvoke("a", {"callbacks": [tracer]}) == 6 + assert len(tracer.runs) == 1 + assert tracer.runs[0].outputs == {"output": 6} + assert len(tracer.runs[0].child_runs) == 3 + assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"] + assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3] + tracer.runs.clear() + + assert [p async for p in agen.astream("a")] == [1, 2, 3] + assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call" + + tracer = FakeTracer() + assert [p async for p in agen.astream("a", {"callbacks": [tracer]})] == [ + 1, + 2, + 3, + ] + assert len(tracer.runs) == 1 + assert tracer.runs[0].outputs == {"output": 6} + assert len(tracer.runs[0].child_runs) == 3 + assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"] + assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3] + + tracer = FakeTracer() + assert await agen.abatch(["a", "a"], {"callbacks": [tracer]}) == [6, 6] + assert len(tracer.runs) == 2 + assert tracer.runs[0].outputs == {"output": 6} + assert tracer.runs[1].outputs == {"output": 6} + assert len(tracer.runs[0].child_runs) == 3 + assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"] + assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3] + assert len(tracer.runs[1].child_runs) == 3 + assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"] + assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3] + + +async def test_runnable_lambda_context_config() -> None: + """Test that a function can call other runnables with config + propagated from the context.""" + + fake = RunnableLambda(len) + + @chain + def fun(input: str) -> int: + output = fake.invoke(input) + output += fake.invoke(input * 2) + output += fake.invoke(input * 3) + return output + + assert fun.input_schema.schema() == {"title": "fun_input", "type": "string"} + assert fun.output_schema.schema() == { + "title": "fun_output", + "type": "integer", + } + + tracer = FakeTracer() + assert fun.invoke("a", {"callbacks": [tracer]}) == 6 + assert len(tracer.runs) == 1 + assert tracer.runs[0].outputs == {"output": 6} + assert len(tracer.runs[0].child_runs) == 3 + assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"] + assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3] + tracer.runs.clear() + + assert list(fun.stream("a")) == [6] + assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call" + + tracer = FakeTracer() + assert list(fun.stream("a", {"callbacks": [tracer]})) == [6] + assert len(tracer.runs) == 1 + assert tracer.runs[0].outputs == {"output": 6} + assert len(tracer.runs[0].child_runs) == 3 + assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"] + assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3] + + tracer = FakeTracer() + assert fun.batch(["a", "a"], {"callbacks": [tracer]}) == [6, 6] + assert len(tracer.runs) == 2 + assert tracer.runs[0].outputs == {"output": 6} + assert tracer.runs[1].outputs == {"output": 6} + assert len(tracer.runs[0].child_runs) == 3 + assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"] + assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3] + assert len(tracer.runs[1].child_runs) == 3 + assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"] + assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3] + + if sys.version_info < (3, 11): + # Python 3.10 and below don't support running async tasks in a specific context + return + + @chain + async def afun(input: str) -> int: + output = await fake.ainvoke(input) + output += await fake.ainvoke(input * 2) + output += await fake.ainvoke(input * 3) + return output + + assert afun.input_schema.schema() == {"title": "afun_input", "type": "string"} + assert afun.output_schema.schema() == { + "title": "afun_output", + "type": "integer", + } + + tracer = FakeTracer() + assert await afun.ainvoke("a", {"callbacks": [tracer]}) == 6 + assert len(tracer.runs) == 1 + assert tracer.runs[0].outputs == {"output": 6} + assert len(tracer.runs[0].child_runs) == 3 + assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"] + assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3] + tracer.runs.clear() + + assert [p async for p in afun.astream("a")] == [6] + assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call" + + tracer = FakeTracer() + assert [p async for p in afun.astream("a", {"callbacks": [tracer]})] == [6] + assert len(tracer.runs) == 1 + assert tracer.runs[0].outputs == {"output": 6} + assert len(tracer.runs[0].child_runs) == 3 + assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"] + assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3] + + tracer = FakeTracer() + assert await afun.abatch(["a", "a"], {"callbacks": [tracer]}) == [6, 6] + assert len(tracer.runs) == 2 + assert tracer.runs[0].outputs == {"output": 6} + assert tracer.runs[1].outputs == {"output": 6} + assert len(tracer.runs[0].child_runs) == 3 + assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"] + assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3] + assert len(tracer.runs[1].child_runs) == 3 + assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"] + assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3] + + async def test_runnable_gen_transform() -> None: """Test that a generator can be used as a runnable.""" @@ -4434,19 +4744,19 @@ async def test_runnable_gen_transform() -> None: achain = RunnableGenerator(gen_indexes, agen_indexes) | aplus_one assert chain.input_schema.schema() == { - "title": "RunnableGeneratorInput", + "title": "gen_indexes_input", "type": "integer", } assert chain.output_schema.schema() == { - "title": "RunnableGeneratorOutput", + "title": "plus_one_output", "type": "integer", } assert achain.input_schema.schema() == { - "title": "RunnableGeneratorInput", + "title": "gen_indexes_input", "type": "integer", } assert achain.output_schema.schema() == { - "title": "RunnableGeneratorOutput", + "title": "aplus_one_output", "type": "integer", }