mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
astream_events: Add version parameter while method is in beta (#16290)
Add a version parameter while the method is in beta phase. The idea is to make it possible to minimize making breaking changes for users while we're iterating on schema. Once the API is stable we can assign a default version requirement.
This commit is contained in:
parent
91230ef5d1
commit
4ef0ed4ddc
@ -699,6 +699,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
input: Any,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
version: Literal["v1"],
|
||||
include_names: Optional[Sequence[str]] = None,
|
||||
include_types: Optional[Sequence[str]] = None,
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
@ -793,7 +794,9 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
chain = RunnableLambda(func=reverse)
|
||||
|
||||
events = [event async for event in chain.astream_events("hello")]
|
||||
events = [
|
||||
event async for event in chain.astream_events("hello", version="v1")
|
||||
]
|
||||
|
||||
# will produce the following events (run_id has been omitted for brevity):
|
||||
[
|
||||
@ -823,6 +826,9 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
Args:
|
||||
input: The input to the runnable.
|
||||
config: The config to use for the runnable.
|
||||
version: The version of the schema to use.
|
||||
Currently only version 1 is available.
|
||||
No default will be assigned until the API is stabilized.
|
||||
include_names: Only include events from runnables with matching names.
|
||||
include_types: Only include events from runnables with matching types.
|
||||
include_tags: Only include events from runnables with matching tags.
|
||||
@ -836,6 +842,11 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
Returns:
|
||||
An async stream of StreamEvents.
|
||||
""" # noqa: E501
|
||||
if version != "v1":
|
||||
raise NotImplementedError(
|
||||
'Only version "v1" of the schema is currently supported.'
|
||||
)
|
||||
|
||||
from langchain_core.runnables.utils import (
|
||||
_RootEventFilter,
|
||||
)
|
||||
|
@ -53,7 +53,7 @@ async def test_event_stream_with_single_lambda() -> None:
|
||||
|
||||
chain = RunnableLambda(func=reverse)
|
||||
|
||||
events = await _collect_events(chain.astream_events("hello"))
|
||||
events = await _collect_events(chain.astream_events("hello", version="v1"))
|
||||
assert events == [
|
||||
{
|
||||
"data": {"input": "hello"},
|
||||
@ -94,7 +94,7 @@ async def test_event_stream_with_triple_lambda() -> None:
|
||||
| r.with_config({"run_name": "2"})
|
||||
| r.with_config({"run_name": "3"})
|
||||
)
|
||||
events = await _collect_events(chain.astream_events("hello"))
|
||||
events = await _collect_events(chain.astream_events("hello", version="v1"))
|
||||
assert events == [
|
||||
{
|
||||
"data": {"input": "hello"},
|
||||
@ -209,7 +209,9 @@ async def test_event_stream_with_triple_lambda_test_filtering() -> None:
|
||||
| r.with_config({"run_name": "2", "tags": ["my_tag"]})
|
||||
| r.with_config({"run_name": "3", "tags": ["my_tag"]})
|
||||
)
|
||||
events = await _collect_events(chain.astream_events("hello", include_names=["1"]))
|
||||
events = await _collect_events(
|
||||
chain.astream_events("hello", include_names=["1"], version="v1")
|
||||
)
|
||||
assert events == [
|
||||
{
|
||||
"data": {},
|
||||
@ -238,7 +240,9 @@ async def test_event_stream_with_triple_lambda_test_filtering() -> None:
|
||||
]
|
||||
|
||||
events = await _collect_events(
|
||||
chain.astream_events("hello", include_tags=["my_tag"], exclude_names=["2"])
|
||||
chain.astream_events(
|
||||
"hello", include_tags=["my_tag"], exclude_names=["2"], version="v1"
|
||||
)
|
||||
)
|
||||
assert events == [
|
||||
{
|
||||
@ -272,7 +276,9 @@ async def test_event_stream_with_lambdas_from_lambda() -> None:
|
||||
as_lambdas = RunnableLambda(lambda x: {"answer": "goodbye"}).with_config(
|
||||
{"run_name": "my_lambda"}
|
||||
)
|
||||
events = await _collect_events(as_lambdas.astream_events({"question": "hello"}))
|
||||
events = await _collect_events(
|
||||
as_lambdas.astream_events({"question": "hello"}, version="v1")
|
||||
)
|
||||
assert events == [
|
||||
{
|
||||
"data": {"input": {"question": "hello"}},
|
||||
@ -331,7 +337,9 @@ async def test_event_stream_with_simple_chain() -> None:
|
||||
}
|
||||
)
|
||||
|
||||
events = await _collect_events(chain.astream_events({"question": "hello"}))
|
||||
events = await _collect_events(
|
||||
chain.astream_events({"question": "hello"}, version="v1")
|
||||
)
|
||||
assert events == [
|
||||
{
|
||||
"data": {"input": {"question": "hello"}},
|
||||
@ -497,7 +505,7 @@ async def test_event_streaming_with_tools() -> None:
|
||||
|
||||
# type ignores below because the tools don't appear to be runnables to type checkers
|
||||
# we can remove as soon as that's fixed
|
||||
events = await _collect_events(parameterless.astream_events({})) # type: ignore
|
||||
events = await _collect_events(parameterless.astream_events({}, version="v1")) # type: ignore
|
||||
assert events == [
|
||||
{
|
||||
"data": {"input": {}},
|
||||
@ -525,7 +533,7 @@ async def test_event_streaming_with_tools() -> None:
|
||||
},
|
||||
]
|
||||
|
||||
events = await _collect_events(with_callbacks.astream_events({})) # type: ignore
|
||||
events = await _collect_events(with_callbacks.astream_events({}, version="v1")) # type: ignore
|
||||
assert events == [
|
||||
{
|
||||
"data": {"input": {}},
|
||||
@ -552,7 +560,9 @@ async def test_event_streaming_with_tools() -> None:
|
||||
"tags": [],
|
||||
},
|
||||
]
|
||||
events = await _collect_events(with_parameters.astream_events({"x": 1, "y": "2"})) # type: ignore
|
||||
events = await _collect_events(
|
||||
with_parameters.astream_events({"x": 1, "y": "2"}, version="v1") # type: ignore
|
||||
)
|
||||
assert events == [
|
||||
{
|
||||
"data": {"input": {"x": 1, "y": "2"}},
|
||||
@ -581,7 +591,7 @@ async def test_event_streaming_with_tools() -> None:
|
||||
]
|
||||
|
||||
events = await _collect_events(
|
||||
with_parameters_and_callbacks.astream_events({"x": 1, "y": "2"}) # type: ignore
|
||||
with_parameters_and_callbacks.astream_events({"x": 1, "y": "2"}, version="v1") # type: ignore
|
||||
)
|
||||
assert events == [
|
||||
{
|
||||
@ -634,7 +644,9 @@ async def test_event_stream_with_retriever() -> None:
|
||||
),
|
||||
]
|
||||
)
|
||||
events = await _collect_events(retriever.astream_events({"query": "hello"}))
|
||||
events = await _collect_events(
|
||||
retriever.astream_events({"query": "hello"}, version="v1")
|
||||
)
|
||||
assert events == [
|
||||
{
|
||||
"data": {
|
||||
@ -695,7 +707,7 @@ async def test_event_stream_with_retriever_and_formatter() -> None:
|
||||
return ", ".join([doc.page_content for doc in docs])
|
||||
|
||||
chain = retriever | format_docs
|
||||
events = await _collect_events(chain.astream_events("hello"))
|
||||
events = await _collect_events(chain.astream_events("hello", version="v1"))
|
||||
assert events == [
|
||||
{
|
||||
"data": {"input": "hello"},
|
||||
@ -796,7 +808,9 @@ async def test_event_stream_on_chain_with_tool() -> None:
|
||||
# does not appear to be a runnable
|
||||
chain = concat | reverse # type: ignore
|
||||
|
||||
events = await _collect_events(chain.astream_events({"a": "hello", "b": "world"}))
|
||||
events = await _collect_events(
|
||||
chain.astream_events({"a": "hello", "b": "world"}, version="v1")
|
||||
)
|
||||
assert events == [
|
||||
{
|
||||
"data": {"input": {"a": "hello", "b": "world"}},
|
||||
@ -878,7 +892,7 @@ async def test_event_stream_with_retry() -> None:
|
||||
chain = RunnableLambda(success) | RunnableLambda(fail).with_retry(
|
||||
stop_after_attempt=1,
|
||||
)
|
||||
iterable = chain.astream_events("q")
|
||||
iterable = chain.astream_events("q", version="v1")
|
||||
|
||||
events = []
|
||||
|
||||
@ -953,7 +967,9 @@ async def test_with_llm() -> None:
|
||||
llm = FakeStreamingListLLM(responses=["abc"])
|
||||
|
||||
chain = prompt | llm
|
||||
events = await _collect_events(chain.astream_events({"question": "hello"}))
|
||||
events = await _collect_events(
|
||||
chain.astream_events({"question": "hello"}, version="v1")
|
||||
)
|
||||
assert events == [
|
||||
{
|
||||
"data": {"input": {"question": "hello"}},
|
||||
@ -1061,5 +1077,5 @@ async def test_runnable_each() -> None:
|
||||
assert await add_one_map.ainvoke([1, 2, 3]) == [2, 3, 4]
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
async for _ in add_one_map.astream_events([1, 2, 3]):
|
||||
async for _ in add_one_map.astream_events([1, 2, 3], version="v1"):
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user