astream_events: Add version parameter while method is in beta (#16290)

Add a version parameter while the method is in beta phase.

The idea is to make it possible to minimize making breaking changes for users while we're iterating on schema.

Once the API is stable we can assign a default version requirement.
This commit is contained in:
Eugene Yurtsev 2024-01-19 13:20:02 -05:00 committed by GitHub
parent 91230ef5d1
commit 4ef0ed4ddc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 44 additions and 17 deletions

View File

@ -699,6 +699,7 @@ class Runnable(Generic[Input, Output], ABC):
input: Any,
config: Optional[RunnableConfig] = None,
*,
version: Literal["v1"],
include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None,
@ -793,7 +794,9 @@ class Runnable(Generic[Input, Output], ABC):
chain = RunnableLambda(func=reverse)
events = [event async for event in chain.astream_events("hello")]
events = [
event async for event in chain.astream_events("hello", version="v1")
]
# will produce the following events (run_id has been omitted for brevity):
[
@ -823,6 +826,9 @@ class Runnable(Generic[Input, Output], ABC):
Args:
input: The input to the runnable.
config: The config to use for the runnable.
version: The version of the schema to use.
Currently only version 1 is available.
No default will be assigned until the API is stabilized.
include_names: Only include events from runnables with matching names.
include_types: Only include events from runnables with matching types.
include_tags: Only include events from runnables with matching tags.
@ -836,6 +842,11 @@ class Runnable(Generic[Input, Output], ABC):
Returns:
An async stream of StreamEvents.
""" # noqa: E501
if version != "v1":
raise NotImplementedError(
'Only version "v1" of the schema is currently supported.'
)
from langchain_core.runnables.utils import (
_RootEventFilter,
)

View File

@ -53,7 +53,7 @@ async def test_event_stream_with_single_lambda() -> None:
chain = RunnableLambda(func=reverse)
events = await _collect_events(chain.astream_events("hello"))
events = await _collect_events(chain.astream_events("hello", version="v1"))
assert events == [
{
"data": {"input": "hello"},
@ -94,7 +94,7 @@ async def test_event_stream_with_triple_lambda() -> None:
| r.with_config({"run_name": "2"})
| r.with_config({"run_name": "3"})
)
events = await _collect_events(chain.astream_events("hello"))
events = await _collect_events(chain.astream_events("hello", version="v1"))
assert events == [
{
"data": {"input": "hello"},
@ -209,7 +209,9 @@ async def test_event_stream_with_triple_lambda_test_filtering() -> None:
| r.with_config({"run_name": "2", "tags": ["my_tag"]})
| r.with_config({"run_name": "3", "tags": ["my_tag"]})
)
events = await _collect_events(chain.astream_events("hello", include_names=["1"]))
events = await _collect_events(
chain.astream_events("hello", include_names=["1"], version="v1")
)
assert events == [
{
"data": {},
@ -238,7 +240,9 @@ async def test_event_stream_with_triple_lambda_test_filtering() -> None:
]
events = await _collect_events(
chain.astream_events("hello", include_tags=["my_tag"], exclude_names=["2"])
chain.astream_events(
"hello", include_tags=["my_tag"], exclude_names=["2"], version="v1"
)
)
assert events == [
{
@ -272,7 +276,9 @@ async def test_event_stream_with_lambdas_from_lambda() -> None:
as_lambdas = RunnableLambda(lambda x: {"answer": "goodbye"}).with_config(
{"run_name": "my_lambda"}
)
events = await _collect_events(as_lambdas.astream_events({"question": "hello"}))
events = await _collect_events(
as_lambdas.astream_events({"question": "hello"}, version="v1")
)
assert events == [
{
"data": {"input": {"question": "hello"}},
@ -331,7 +337,9 @@ async def test_event_stream_with_simple_chain() -> None:
}
)
events = await _collect_events(chain.astream_events({"question": "hello"}))
events = await _collect_events(
chain.astream_events({"question": "hello"}, version="v1")
)
assert events == [
{
"data": {"input": {"question": "hello"}},
@ -497,7 +505,7 @@ async def test_event_streaming_with_tools() -> None:
# type ignores below because the tools don't appear to be runnables to type checkers
# we can remove as soon as that's fixed
events = await _collect_events(parameterless.astream_events({})) # type: ignore
events = await _collect_events(parameterless.astream_events({}, version="v1")) # type: ignore
assert events == [
{
"data": {"input": {}},
@ -525,7 +533,7 @@ async def test_event_streaming_with_tools() -> None:
},
]
events = await _collect_events(with_callbacks.astream_events({})) # type: ignore
events = await _collect_events(with_callbacks.astream_events({}, version="v1")) # type: ignore
assert events == [
{
"data": {"input": {}},
@ -552,7 +560,9 @@ async def test_event_streaming_with_tools() -> None:
"tags": [],
},
]
events = await _collect_events(with_parameters.astream_events({"x": 1, "y": "2"})) # type: ignore
events = await _collect_events(
with_parameters.astream_events({"x": 1, "y": "2"}, version="v1") # type: ignore
)
assert events == [
{
"data": {"input": {"x": 1, "y": "2"}},
@ -581,7 +591,7 @@ async def test_event_streaming_with_tools() -> None:
]
events = await _collect_events(
with_parameters_and_callbacks.astream_events({"x": 1, "y": "2"}) # type: ignore
with_parameters_and_callbacks.astream_events({"x": 1, "y": "2"}, version="v1") # type: ignore
)
assert events == [
{
@ -634,7 +644,9 @@ async def test_event_stream_with_retriever() -> None:
),
]
)
events = await _collect_events(retriever.astream_events({"query": "hello"}))
events = await _collect_events(
retriever.astream_events({"query": "hello"}, version="v1")
)
assert events == [
{
"data": {
@ -695,7 +707,7 @@ async def test_event_stream_with_retriever_and_formatter() -> None:
return ", ".join([doc.page_content for doc in docs])
chain = retriever | format_docs
events = await _collect_events(chain.astream_events("hello"))
events = await _collect_events(chain.astream_events("hello", version="v1"))
assert events == [
{
"data": {"input": "hello"},
@ -796,7 +808,9 @@ async def test_event_stream_on_chain_with_tool() -> None:
# does not appear to be a runnable
chain = concat | reverse # type: ignore
events = await _collect_events(chain.astream_events({"a": "hello", "b": "world"}))
events = await _collect_events(
chain.astream_events({"a": "hello", "b": "world"}, version="v1")
)
assert events == [
{
"data": {"input": {"a": "hello", "b": "world"}},
@ -878,7 +892,7 @@ async def test_event_stream_with_retry() -> None:
chain = RunnableLambda(success) | RunnableLambda(fail).with_retry(
stop_after_attempt=1,
)
iterable = chain.astream_events("q")
iterable = chain.astream_events("q", version="v1")
events = []
@ -953,7 +967,9 @@ async def test_with_llm() -> None:
llm = FakeStreamingListLLM(responses=["abc"])
chain = prompt | llm
events = await _collect_events(chain.astream_events({"question": "hello"}))
events = await _collect_events(
chain.astream_events({"question": "hello"}, version="v1")
)
assert events == [
{
"data": {"input": {"question": "hello"}},
@ -1061,5 +1077,5 @@ async def test_runnable_each() -> None:
assert await add_one_map.ainvoke([1, 2, 3]) == [2, 3, 4]
with pytest.raises(NotImplementedError):
async for _ in add_one_map.astream_events([1, 2, 3]):
async for _ in add_one_map.astream_events([1, 2, 3], version="v1"):
pass