diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index e0f27aee9d..7cf580348f 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -2363,7 +2363,12 @@ class RunnableGenerator(Runnable[Input, Output]): return False def __repr__(self) -> str: - return "RunnableGenerator(...)" + if hasattr(self, "_transform"): + return f"RunnableGenerator({self._transform.__name__})" + elif hasattr(self, "_atransform"): + return f"RunnableGenerator({self._atransform.__name__})" + else: + return "RunnableGenerator(...)" def transform( self, diff --git a/libs/core/langchain_core/runnables/passthrough.py b/libs/core/langchain_core/runnables/passthrough.py index 159f981607..21f75b9442 100644 --- a/libs/core/langchain_core/runnables/passthrough.py +++ b/libs/core/langchain_core/runnables/passthrough.py @@ -202,6 +202,21 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]): """ return RunnableAssign(RunnableParallel(kwargs)) + @classmethod + def pick( + cls, + keys: Union[str, List[str]], + ) -> "RunnablePick": + """Pick keys from the Dict input. + + Args: + keys: A string or list of strings representing the keys to pick. + + Returns: + A runnable that picks keys from the Dict input. + """ + return RunnablePick(keys) + def invoke( self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Other: @@ -553,3 +568,124 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): async for chunk in self.atransform(input_aiter(), config, **kwargs): yield chunk + + +class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): + """ + A runnable that picks keys from Dict[str, Any] inputs. + """ + + keys: Union[str, List[str]] + + def __init__(self, keys: Union[str, List[str]], **kwargs: Any) -> None: + super().__init__(keys=keys, **kwargs) + + @classmethod + def is_lc_serializable(cls) -> bool: + return True + + @classmethod + def get_lc_namespace(cls) -> List[str]: + """Get the namespace of the langchain object.""" + return ["langchain", "schema", "runnable"] + + def _pick(self, input: Dict[str, Any]) -> Any: + assert isinstance( + input, dict + ), "The input to RunnablePassthrough.assign() must be a dict." + + if isinstance(self.keys, str): + return input.get(self.keys) + else: + picked = {k: input.get(k) for k in self.keys if k in input} + if picked: + return AddableDict(picked) + else: + return None + + def _invoke( + self, + input: Dict[str, Any], + ) -> Dict[str, Any]: + return self._pick(input) + + def invoke( + self, + input: Dict[str, Any], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + return self._call_with_config(self._invoke, input, config, **kwargs) + + async def _ainvoke( + self, + input: Dict[str, Any], + ) -> Dict[str, Any]: + return self._pick(input) + + async def ainvoke( + self, + input: Dict[str, Any], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + return await self._acall_with_config(self._ainvoke, input, config, **kwargs) + + def _transform( + self, + input: Iterator[Dict[str, Any]], + ) -> Iterator[Dict[str, Any]]: + for chunk in input: + picked = self._pick(chunk) + if picked is not None: + yield picked + + def transform( + self, + input: Iterator[Dict[str, Any]], + config: Optional[RunnableConfig] = None, + **kwargs: Any | None, + ) -> Iterator[Dict[str, Any]]: + yield from self._transform_stream_with_config( + input, self._transform, config, **kwargs + ) + + async def _atransform( + self, + input: AsyncIterator[Dict[str, Any]], + ) -> AsyncIterator[Dict[str, Any]]: + async for chunk in input: + picked = self._pick(chunk) + if picked is not None: + yield picked + + async def atransform( + self, + input: AsyncIterator[Dict[str, Any]], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> AsyncIterator[Dict[str, Any]]: + async for chunk in self._atransform_stream_with_config( + input, self._atransform, config, **kwargs + ): + yield chunk + + def stream( + self, + input: Dict[str, Any], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Iterator[Dict[str, Any]]: + return self.transform(iter([input]), config, **kwargs) + + async def astream( + self, + input: Dict[str, Any], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> AsyncIterator[Dict[str, Any]]: + async def input_aiter() -> AsyncIterator[Dict[str, Any]]: + yield input + + async for chunk in self.atransform(input_aiter(), config, **kwargs): + yield chunk diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 4130ed3190..e0fc13e5fa 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -2764,6 +2764,41 @@ def test_map_stream() -> None: {"question": "What is your name?"} ) + chain_pick_one = chain | RunnablePassthrough.pick("llm") + + stream = chain_pick_one.stream({"question": "What is your name?"}) + + final_value = None + streamed_chunks = [] + for chunk in stream: + streamed_chunks.append(chunk) + if final_value is None: + final_value = chunk + else: + final_value += chunk + + assert streamed_chunks[0] == "i" + assert len(streamed_chunks) == len(llm_res) + + chain_pick_two = chain | RunnablePassthrough.pick(["llm", "chat"]) + + stream = chain_pick_two.stream({"question": "What is your name?"}) + + final_value = None + streamed_chunks = [] + for chunk in stream: + streamed_chunks.append(chunk) + if final_value is None: + final_value = chunk + else: + final_value += chunk + + assert streamed_chunks[0] in [ + {"llm": "i"}, + {"chat": AIMessageChunk(content="i")}, + ] + assert len(streamed_chunks) == len(llm_res) + len(chat_res) + def test_map_stream_iterator_input() -> None: prompt = (