mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Improve graph repr for runnable passthrough and itemgetter (#15083)
<!-- Thank you for contributing to LangChain! Please title your PR "<package>: <description>", where <package> is whichever of langchain, community, core, experimental, etc. is being modified. Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes if applicable, - **Dependencies:** any dependencies required for this change, - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` from the root of the package you've modified to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://python.langchain.com/docs/contributing/ If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
parent
0d0901ea18
commit
a2d3042823
@ -2007,7 +2007,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
):
|
||||
# This is correct, but pydantic typings/mypy don't think so.
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableParallelInput",
|
||||
"RunnableMapInput",
|
||||
**{
|
||||
k: (v.annotation, v.default)
|
||||
for step in self.steps.values()
|
||||
@ -2024,7 +2024,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
) -> Type[BaseModel]:
|
||||
# This is correct, but pydantic typings/mypy don't think so.
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableParallelOutput",
|
||||
"RunnableMapOutput",
|
||||
**{k: (v.OutputType, None) for k, v in self.steps.items()},
|
||||
__config__=_SchemaConfig,
|
||||
)
|
||||
@ -2650,7 +2650,9 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""A string representation of this runnable."""
|
||||
if hasattr(self, "func"):
|
||||
if hasattr(self, "func") and isinstance(self.func, itemgetter):
|
||||
return f"RunnableLambda({str(self.func)[len('operator.'):]})"
|
||||
elif hasattr(self, "func"):
|
||||
return f"RunnableLambda({get_lambda_source(self.func) or '...'})"
|
||||
elif hasattr(self, "afunc"):
|
||||
return f"RunnableLambda(afunc={get_lambda_source(self.afunc) or '...'})"
|
||||
|
@ -123,13 +123,13 @@ class Graph:
|
||||
or len(data.splitlines()) > 1
|
||||
):
|
||||
data = node.data.__class__.__name__
|
||||
elif len(data) > 36:
|
||||
data = data[:36] + "..."
|
||||
elif len(data) > 42:
|
||||
data = data[:42] + "..."
|
||||
except Exception:
|
||||
data = node.data.__class__.__name__
|
||||
else:
|
||||
data = node.data.__name__
|
||||
return data
|
||||
return data if not data.startswith("Runnable") else data[8:]
|
||||
|
||||
return draw(
|
||||
{node.id: node_data(node) for node in self.nodes.values()},
|
||||
|
@ -34,6 +34,7 @@ from langchain_core.runnables.config import (
|
||||
get_executor_for_config,
|
||||
patch_config,
|
||||
)
|
||||
from langchain_core.runnables.graph import Graph
|
||||
from langchain_core.runnables.utils import AddableDict, ConfigurableFieldSpec
|
||||
from langchain_core.utils.aiter import atee, py_anext
|
||||
from langchain_core.utils.iter import safetee
|
||||
@ -297,6 +298,9 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||
yield chunk
|
||||
|
||||
|
||||
_graph_passthrough: RunnablePassthrough = RunnablePassthrough()
|
||||
|
||||
|
||||
class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
"""
|
||||
A runnable that assigns key-value pairs to Dict[str, Any] inputs.
|
||||
@ -355,6 +359,18 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
return self.mapper.config_specs
|
||||
|
||||
def get_graph(self, config: RunnableConfig | None = None) -> Graph:
|
||||
# get graph from mapper
|
||||
graph = self.mapper.get_graph(config)
|
||||
# add passthrough node and edges
|
||||
input_node = graph.first_node()
|
||||
output_node = graph.last_node()
|
||||
if input_node is not None and output_node is not None:
|
||||
passthrough_node = graph.add_node(_graph_passthrough)
|
||||
graph.add_edge(input_node, passthrough_node)
|
||||
graph.add_edge(passthrough_node, output_node)
|
||||
return graph
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
|
@ -32,51 +32,51 @@
|
||||
# ---
|
||||
# name: test_graph_sequence_map
|
||||
'''
|
||||
+-------------+
|
||||
| PromptInput |
|
||||
+-------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+----------------+
|
||||
| PromptTemplate |
|
||||
+----------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+-------------+
|
||||
| FakeListLLM |
|
||||
+-------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+-----------------------+
|
||||
| RunnableParallelInput |
|
||||
+-----------------------+**
|
||||
**** *******
|
||||
**** *****
|
||||
** *******
|
||||
+---------------------+ ***
|
||||
| RunnableLambdaInput | *
|
||||
+---------------------+ *
|
||||
*** *** *
|
||||
*** *** *
|
||||
** ** *
|
||||
+-----------------+ +-----------------+ *
|
||||
| StrOutputParser | | XMLOutputParser | *
|
||||
+-----------------+ +-----------------+ *
|
||||
*** *** *
|
||||
*** *** *
|
||||
** ** *
|
||||
+----------------------+ +--------------------------------+
|
||||
| RunnableLambdaOutput | | CommaSeparatedListOutputParser |
|
||||
+----------------------+ +--------------------------------+
|
||||
**** *******
|
||||
**** *****
|
||||
** ****
|
||||
+------------------------+
|
||||
| RunnableParallelOutput |
|
||||
+------------------------+
|
||||
+-------------+
|
||||
| PromptInput |
|
||||
+-------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+----------------+
|
||||
| PromptTemplate |
|
||||
+----------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+-------------+
|
||||
| FakeListLLM |
|
||||
+-------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+---------------+
|
||||
| ParallelInput |
|
||||
+---------------+*****
|
||||
*** ******
|
||||
*** *****
|
||||
** *****
|
||||
+-------------+ ***
|
||||
| LambdaInput | *
|
||||
+-------------+ *
|
||||
** ** *
|
||||
*** *** *
|
||||
** ** *
|
||||
+-----------------+ +-----------------+ *
|
||||
| StrOutputParser | | XMLOutputParser | *
|
||||
+-----------------+ +-----------------+ *
|
||||
** ** *
|
||||
*** *** *
|
||||
** ** *
|
||||
+--------------+ +--------------------------------+
|
||||
| LambdaOutput | | CommaSeparatedListOutputParser |
|
||||
+--------------+ +--------------------------------+
|
||||
*** ******
|
||||
*** *****
|
||||
** ***
|
||||
+-----------+
|
||||
| MapOutput |
|
||||
+-----------+
|
||||
'''
|
||||
# ---
|
||||
# name: test_graph_single_runnable
|
||||
|
@ -569,7 +569,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
"properties": {"name": {"title": "Name", "type": "string"}},
|
||||
}
|
||||
assert seq_w_map.output_schema.schema() == {
|
||||
"title": "RunnableParallelOutput",
|
||||
"title": "RunnableMapOutput",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"original": {"title": "Original", "type": "string"},
|
||||
@ -613,7 +613,7 @@ def test_passthrough_assign_schema() -> None:
|
||||
# expected dict input_schema
|
||||
assert invalid_seq_w_assign.input_schema.schema() == {
|
||||
"properties": {"question": {"title": "Question"}},
|
||||
"title": "RunnableParallelInput",
|
||||
"title": "RunnableMapInput",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
@ -768,7 +768,7 @@ def test_schema_complex_seq() -> None:
|
||||
)
|
||||
|
||||
assert chain2.input_schema.schema() == {
|
||||
"title": "RunnableParallelInput",
|
||||
"title": "RunnableMapInput",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"person": {"title": "Person", "type": "string"},
|
||||
@ -2221,7 +2221,6 @@ async def test_stream_log_lists() -> None:
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_prompt_with_llm_and_async_lambda(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion
|
||||
@ -4262,7 +4261,6 @@ def test_with_config_callbacks() -> None:
|
||||
assert isinstance(result, RunnableBinding)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ainvoke_on_returned_runnable() -> None:
|
||||
"""Verify that a runnable returned by a sync runnable in the async path will
|
||||
be runthroughaasync path (issue #13407)"""
|
||||
@ -4301,7 +4299,6 @@ def test_invoke_stream_passthrough_assign_trace() -> None:
|
||||
assert tracer.runs[0].child_runs[0].name == "RunnableParallel"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ainvoke_astream_passthrough_assign_trace() -> None:
|
||||
def idchain_sync(__input: dict) -> bool:
|
||||
return False
|
||||
|
Loading…
Reference in New Issue
Block a user