mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Implement better reprs for Runnables (#11175)
``` ChatPromptTemplate(messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], template='You are a nice assistant.')), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question'], template='{question}'))]) | RunnableLambda(lambda x: x) | { chat: FakeListChatModel(responses=["i'm a chatbot"]), llm: FakeListLLM(responses=["i'm a textbot"]) } ``` <!-- Thank you for contributing to LangChain! Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes (if applicable), - **Dependencies:** any dependencies required for this change, - **Tag maintainer:** for a quicker response, tag the relevant maintainer (see below), - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/extras` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
commit
61b5942adf
@ -77,6 +77,13 @@ class Serializable(BaseModel, ABC):
|
|||||||
class Config:
|
class Config:
|
||||||
extra = "ignore"
|
extra = "ignore"
|
||||||
|
|
||||||
|
def __repr_args__(self) -> Any:
|
||||||
|
return [
|
||||||
|
(k, v)
|
||||||
|
for k, v in super().__repr_args__()
|
||||||
|
if (k not in self.__fields__ or self.__fields__[k].get_default() != v)
|
||||||
|
]
|
||||||
|
|
||||||
_lc_kwargs = PrivateAttr(default_factory=dict)
|
_lc_kwargs = PrivateAttr(default_factory=dict)
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
|
@ -59,6 +59,8 @@ from langchain.schema.runnable.utils import (
|
|||||||
accepts_run_manager,
|
accepts_run_manager,
|
||||||
gather_with_concurrency,
|
gather_with_concurrency,
|
||||||
get_function_first_arg_dict_keys,
|
get_function_first_arg_dict_keys,
|
||||||
|
get_lambda_source,
|
||||||
|
indent_lines_after_first,
|
||||||
)
|
)
|
||||||
from langchain.utils.aiter import atee, py_anext
|
from langchain.utils.aiter import atee, py_anext
|
||||||
from langchain.utils.iter import safetee
|
from langchain.utils.iter import safetee
|
||||||
@ -1298,6 +1300,12 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
def output_schema(self) -> Type[BaseModel]:
|
def output_schema(self) -> Type[BaseModel]:
|
||||||
return self.last.output_schema
|
return self.last.output_schema
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return "\n| ".join(
|
||||||
|
repr(s) if i == 0 else indent_lines_after_first(repr(s), "| ")
|
||||||
|
for i, s in enumerate(self.steps)
|
||||||
|
)
|
||||||
|
|
||||||
def __or__(
|
def __or__(
|
||||||
self,
|
self,
|
||||||
other: Union[
|
other: Union[
|
||||||
@ -1819,6 +1827,13 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|||||||
**{k: (v.OutputType, None) for k, v in self.steps.items()},
|
**{k: (v.OutputType, None) for k, v in self.steps.items()},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
map_for_repr = ",\n ".join(
|
||||||
|
f"{k}: {indent_lines_after_first(repr(v), ' ' + k + ': ')}"
|
||||||
|
for k, v in self.steps.items()
|
||||||
|
)
|
||||||
|
return "{\n " + map_for_repr + "\n}"
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
@ -2134,7 +2149,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return "RunnableLambda(...)"
|
return f"RunnableLambda({get_lambda_source(self.func) or '...'})"
|
||||||
|
|
||||||
def _invoke(
|
def _invoke(
|
||||||
self,
|
self,
|
||||||
|
@ -87,6 +87,17 @@ class IsFunctionArgDict(ast.NodeVisitor):
|
|||||||
IsLocalDict(input_arg_name, self.keys).visit(node)
|
IsLocalDict(input_arg_name, self.keys).visit(node)
|
||||||
|
|
||||||
|
|
||||||
|
class GetLambdaSource(ast.NodeVisitor):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.source: Optional[str] = None
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
||||||
|
self.count += 1
|
||||||
|
if hasattr(ast, "unparse"):
|
||||||
|
self.source = ast.unparse(node)
|
||||||
|
|
||||||
|
|
||||||
def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
|
def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
|
||||||
try:
|
try:
|
||||||
code = inspect.getsource(func)
|
code = inspect.getsource(func)
|
||||||
@ -94,5 +105,40 @@ def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
|
|||||||
visitor = IsFunctionArgDict()
|
visitor = IsFunctionArgDict()
|
||||||
visitor.visit(tree)
|
visitor.visit(tree)
|
||||||
return list(visitor.keys) if visitor.keys else None
|
return list(visitor.keys) if visitor.keys else None
|
||||||
except (TypeError, OSError):
|
except (SyntaxError, TypeError, OSError):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_lambda_source(func: Callable) -> Optional[str]:
|
||||||
|
"""Get the source code of a lambda function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: a callable that can be a lambda function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: the source code of the lambda function
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
code = inspect.getsource(func)
|
||||||
|
tree = ast.parse(textwrap.dedent(code))
|
||||||
|
visitor = GetLambdaSource()
|
||||||
|
visitor.visit(tree)
|
||||||
|
return visitor.source if visitor.count == 1 else None
|
||||||
|
except (SyntaxError, TypeError, OSError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def indent_lines_after_first(text: str, prefix: str) -> str:
|
||||||
|
"""Indent all lines of text after the first line.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to indent
|
||||||
|
prefix: Used to determine the number of spaces to indent
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The indented text
|
||||||
|
"""
|
||||||
|
n_spaces = len(prefix)
|
||||||
|
spaces = " " * n_spaces
|
||||||
|
lines = text.splitlines()
|
||||||
|
return "\n".join([lines[0]] + [spaces + line for line in lines[1:]])
|
||||||
|
File diff suppressed because one or more lines are too long
@ -867,6 +867,7 @@ async def test_prompt_with_chat_model(
|
|||||||
|
|
||||||
chain = prompt | chat
|
chain = prompt | chat
|
||||||
|
|
||||||
|
assert repr(chain) == snapshot
|
||||||
assert isinstance(chain, RunnableSequence)
|
assert isinstance(chain, RunnableSequence)
|
||||||
assert chain.first == prompt
|
assert chain.first == prompt
|
||||||
assert chain.middle == []
|
assert chain.middle == []
|
||||||
@ -1276,6 +1277,7 @@ def test_combining_sequences(
|
|||||||
assert chain.first == prompt
|
assert chain.first == prompt
|
||||||
assert chain.middle == [chat]
|
assert chain.middle == [chat]
|
||||||
assert chain.last == parser
|
assert chain.last == parser
|
||||||
|
if sys.version_info >= (3, 9):
|
||||||
assert dumps(chain, pretty=True) == snapshot
|
assert dumps(chain, pretty=True) == snapshot
|
||||||
|
|
||||||
prompt2 = (
|
prompt2 = (
|
||||||
@ -1294,6 +1296,7 @@ def test_combining_sequences(
|
|||||||
assert chain2.first == input_formatter
|
assert chain2.first == input_formatter
|
||||||
assert chain2.middle == [prompt2, chat2]
|
assert chain2.middle == [prompt2, chat2]
|
||||||
assert chain2.last == parser2
|
assert chain2.last == parser2
|
||||||
|
if sys.version_info >= (3, 9):
|
||||||
assert dumps(chain2, pretty=True) == snapshot
|
assert dumps(chain2, pretty=True) == snapshot
|
||||||
|
|
||||||
combined_chain = chain | chain2
|
combined_chain = chain | chain2
|
||||||
@ -1307,6 +1310,7 @@ def test_combining_sequences(
|
|||||||
chat2,
|
chat2,
|
||||||
]
|
]
|
||||||
assert combined_chain.last == parser2
|
assert combined_chain.last == parser2
|
||||||
|
if sys.version_info >= (3, 9):
|
||||||
assert dumps(combined_chain, pretty=True) == snapshot
|
assert dumps(combined_chain, pretty=True) == snapshot
|
||||||
|
|
||||||
# Test invoke
|
# Test invoke
|
||||||
@ -1315,6 +1319,7 @@ def test_combining_sequences(
|
|||||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||||
) == ["baz", "qux"]
|
) == ["baz", "qux"]
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 9):
|
||||||
assert tracer.runs == snapshot
|
assert tracer.runs == snapshot
|
||||||
|
|
||||||
|
|
||||||
@ -1350,6 +1355,7 @@ Question:
|
|||||||
| parser
|
| parser
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert repr(chain) == snapshot
|
||||||
assert isinstance(chain, RunnableSequence)
|
assert isinstance(chain, RunnableSequence)
|
||||||
assert isinstance(chain.first, RunnableMap)
|
assert isinstance(chain.first, RunnableMap)
|
||||||
assert chain.middle == [prompt, chat]
|
assert chain.middle == [prompt, chat]
|
||||||
@ -1375,7 +1381,7 @@ Question:
|
|||||||
SystemMessage(content="You are a nice assistant."),
|
SystemMessage(content="You are a nice assistant."),
|
||||||
HumanMessage(
|
HumanMessage(
|
||||||
content="""Context:
|
content="""Context:
|
||||||
[Document(page_content='foo', metadata={}), Document(page_content='bar', metadata={})]
|
[Document(page_content='foo'), Document(page_content='bar')]
|
||||||
|
|
||||||
Question:
|
Question:
|
||||||
What is your name?"""
|
What is your name?"""
|
||||||
@ -1413,6 +1419,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert repr(chain) == snapshot
|
||||||
assert isinstance(chain, RunnableSequence)
|
assert isinstance(chain, RunnableSequence)
|
||||||
assert chain.first == prompt
|
assert chain.first == prompt
|
||||||
assert chain.middle == [RunnableLambda(passthrough)]
|
assert chain.middle == [RunnableLambda(passthrough)]
|
||||||
@ -2098,6 +2105,7 @@ async def test_llm_with_fallbacks(
|
|||||||
assert await runnable.ainvoke("hello") == "bar"
|
assert await runnable.ainvoke("hello") == "bar"
|
||||||
assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3
|
assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3
|
||||||
assert list(await runnable.ainvoke("hello")) == list("bar")
|
assert list(await runnable.ainvoke("hello")) == list("bar")
|
||||||
|
if sys.version_info >= (3, 9):
|
||||||
assert dumps(runnable, pretty=True) == snapshot
|
assert dumps(runnable, pretty=True) == snapshot
|
||||||
|
|
||||||
|
|
||||||
@ -2196,6 +2204,7 @@ def test_retrying(mocker: MockerFixture) -> None:
|
|||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
runnable.with_retry(
|
runnable.with_retry(
|
||||||
stop_after_attempt=2,
|
stop_after_attempt=2,
|
||||||
|
wait_exponential_jitter=False,
|
||||||
retry_if_exception_type=(ValueError,),
|
retry_if_exception_type=(ValueError,),
|
||||||
).invoke(2)
|
).invoke(2)
|
||||||
|
|
||||||
@ -2205,6 +2214,7 @@ def test_retrying(mocker: MockerFixture) -> None:
|
|||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
runnable.with_retry(
|
runnable.with_retry(
|
||||||
stop_after_attempt=2,
|
stop_after_attempt=2,
|
||||||
|
wait_exponential_jitter=False,
|
||||||
retry_if_exception_type=(ValueError,),
|
retry_if_exception_type=(ValueError,),
|
||||||
).batch([1, 2, 0])
|
).batch([1, 2, 0])
|
||||||
|
|
||||||
@ -2214,6 +2224,7 @@ def test_retrying(mocker: MockerFixture) -> None:
|
|||||||
|
|
||||||
output = runnable.with_retry(
|
output = runnable.with_retry(
|
||||||
stop_after_attempt=2,
|
stop_after_attempt=2,
|
||||||
|
wait_exponential_jitter=False,
|
||||||
retry_if_exception_type=(ValueError,),
|
retry_if_exception_type=(ValueError,),
|
||||||
).batch([1, 2, 0], return_exceptions=True)
|
).batch([1, 2, 0], return_exceptions=True)
|
||||||
|
|
||||||
@ -2248,6 +2259,7 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
|
|||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await runnable.with_retry(
|
await runnable.with_retry(
|
||||||
stop_after_attempt=2,
|
stop_after_attempt=2,
|
||||||
|
wait_exponential_jitter=False,
|
||||||
retry_if_exception_type=(ValueError, KeyError),
|
retry_if_exception_type=(ValueError, KeyError),
|
||||||
).ainvoke(1)
|
).ainvoke(1)
|
||||||
|
|
||||||
@ -2257,6 +2269,7 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
|
|||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
await runnable.with_retry(
|
await runnable.with_retry(
|
||||||
stop_after_attempt=2,
|
stop_after_attempt=2,
|
||||||
|
wait_exponential_jitter=False,
|
||||||
retry_if_exception_type=(ValueError,),
|
retry_if_exception_type=(ValueError,),
|
||||||
).ainvoke(2)
|
).ainvoke(2)
|
||||||
|
|
||||||
@ -2266,6 +2279,7 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
|
|||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await runnable.with_retry(
|
await runnable.with_retry(
|
||||||
stop_after_attempt=2,
|
stop_after_attempt=2,
|
||||||
|
wait_exponential_jitter=False,
|
||||||
retry_if_exception_type=(ValueError,),
|
retry_if_exception_type=(ValueError,),
|
||||||
).abatch([1, 2, 0])
|
).abatch([1, 2, 0])
|
||||||
|
|
||||||
@ -2275,6 +2289,7 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
|
|||||||
|
|
||||||
output = await runnable.with_retry(
|
output = await runnable.with_retry(
|
||||||
stop_after_attempt=2,
|
stop_after_attempt=2,
|
||||||
|
wait_exponential_jitter=False,
|
||||||
retry_if_exception_type=(ValueError,),
|
retry_if_exception_type=(ValueError,),
|
||||||
).abatch([1, 2, 0], return_exceptions=True)
|
).abatch([1, 2, 0], return_exceptions=True)
|
||||||
|
|
||||||
@ -2729,3 +2744,38 @@ async def test_runnable_branch_abatch() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert await branch.abatch([1, 10, 0]) == [2, 100, -1]
|
assert await branch.abatch([1, 10, 0]) == [2, 100, -1]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||||
|
)
|
||||||
|
def test_representation_of_runnables() -> None:
|
||||||
|
"""Test representation of runnables."""
|
||||||
|
runnable = RunnableLambda(lambda x: x * 2)
|
||||||
|
assert repr(runnable) == "RunnableLambda(lambda x: x * 2)"
|
||||||
|
|
||||||
|
def f(x: int) -> int:
|
||||||
|
"""Return 2."""
|
||||||
|
return 2
|
||||||
|
|
||||||
|
assert repr(RunnableLambda(func=f)) == "RunnableLambda(...)"
|
||||||
|
|
||||||
|
async def af(x: int) -> int:
|
||||||
|
"""Return 2."""
|
||||||
|
return 2
|
||||||
|
|
||||||
|
assert repr(RunnableLambda(func=f, afunc=af)) == "RunnableLambda(...)"
|
||||||
|
|
||||||
|
assert repr(
|
||||||
|
RunnableLambda(lambda x: x + 2)
|
||||||
|
| {
|
||||||
|
"a": RunnableLambda(lambda x: x * 2),
|
||||||
|
"b": RunnableLambda(lambda x: x * 3),
|
||||||
|
}
|
||||||
|
) == (
|
||||||
|
"RunnableLambda(...)\n"
|
||||||
|
"| {\n"
|
||||||
|
" a: RunnableLambda(...),\n"
|
||||||
|
" b: RunnableLambda(...)\n"
|
||||||
|
" }"
|
||||||
|
), "repr where code string contains multiple lambdas gives up"
|
||||||
|
@ -0,0 +1,39 @@
|
|||||||
|
import sys
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.schema.runnable.utils import (
|
||||||
|
get_lambda_source,
|
||||||
|
indent_lines_after_first,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"func, expected_source",
|
||||||
|
[
|
||||||
|
(lambda x: x * 2, "lambda x: x * 2"),
|
||||||
|
(lambda a, b: a + b, "lambda a, b: a + b"),
|
||||||
|
(lambda x: x if x > 0 else 0, "lambda x: x if x > 0 else 0"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_get_lambda_source(func: Callable, expected_source: str) -> None:
|
||||||
|
"""Test get_lambda_source function"""
|
||||||
|
source = get_lambda_source(func)
|
||||||
|
assert source == expected_source
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"text,prefix,expected_output",
|
||||||
|
[
|
||||||
|
("line 1\nline 2\nline 3", "1", "line 1\n line 2\n line 3"),
|
||||||
|
("line 1\nline 2\nline 3", "ax", "line 1\n line 2\n line 3"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_indent_lines_after_first(text: str, prefix: str, expected_output: str) -> None:
|
||||||
|
"""Test indent_lines_after_first function"""
|
||||||
|
indented_text = indent_lines_after_first(text, prefix)
|
||||||
|
assert indented_text == expected_output
|
Loading…
Reference in New Issue
Block a user