In stream_event and stream_log handle closed streams (#16661)

if eg. the stream iterator is interrupted then adding more events to the
send_stream will raise an exception that we should catch (and handle
where appropriate)

<!-- Thank you for contributing to LangChain!

Please title your PR "<package>: <description>", where <package> is
whichever of langchain, community, core, experimental, etc. is being
modified.

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes if applicable,
  - **Dependencies:** any dependencies required for this change,
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` from the root
of the package you've modified to check this locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc: https://python.langchain.com/docs/contributing/

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
This commit is contained in:
Nuno Campos 2024-01-27 08:09:29 -08:00 committed by GitHub
parent 0bc397957b
commit e86fd946c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -20,7 +20,7 @@ from typing import (
from uuid import UUID
import jsonpatch # type: ignore[import]
from anyio import create_memory_object_stream
from anyio import BrokenResourceError, ClosedResourceError, create_memory_object_stream
from typing_extensions import NotRequired, TypedDict
from langchain_core.load import dumps
@ -223,6 +223,14 @@ class LogStreamCallbackHandler(BaseTracer):
def __aiter__(self) -> AsyncIterator[RunLogPatch]:
return self.receive_stream.__aiter__()
def send(self, *ops: Dict[str, Any]) -> bool:
"""Send a patch to the stream, return False if the stream is closed."""
try:
self.send_stream.send_nowait(RunLogPatch(*ops))
return True
except (ClosedResourceError, BrokenResourceError):
return False
async def tap_output_aiter(
self, run_id: UUID, output: AsyncIterator[T]
) -> AsyncIterator[T]:
@ -233,15 +241,14 @@ class LogStreamCallbackHandler(BaseTracer):
# if we can't find the run silently ignore
# eg. because this run wasn't included in the log
if key := self._key_map_by_run_id.get(run_id):
self.send_stream.send_nowait(
RunLogPatch(
{
"op": "add",
"path": f"/logs/{key}/streamed_output/-",
"value": chunk,
}
)
)
if not self.send(
{
"op": "add",
"path": f"/logs/{key}/streamed_output/-",
"value": chunk,
}
):
break
yield chunk
@ -285,22 +292,21 @@ class LogStreamCallbackHandler(BaseTracer):
"""Start a run."""
if self.root_id is None:
self.root_id = run.id
self.send_stream.send_nowait(
RunLogPatch(
{
"op": "replace",
"path": "",
"value": RunState(
id=str(run.id),
streamed_output=[],
final_output=None,
logs={},
name=run.name,
type=run.run_type,
),
}
)
)
if not self.send(
{
"op": "replace",
"path": "",
"value": RunState(
id=str(run.id),
streamed_output=[],
final_output=None,
logs={},
name=run.name,
type=run.run_type,
),
}
):
return
if not self.include_run(run):
return
@ -331,14 +337,12 @@ class LogStreamCallbackHandler(BaseTracer):
entry["inputs"] = _get_standardized_inputs(run, self._schema_format)
# Add the run to the stream
self.send_stream.send_nowait(
RunLogPatch(
{
"op": "add",
"path": f"/logs/{self._key_map_by_run_id[run.id]}",
"value": entry,
}
)
self.send(
{
"op": "add",
"path": f"/logs/{self._key_map_by_run_id[run.id]}",
"value": entry,
}
)
def _on_run_update(self, run: Run) -> None:
@ -382,7 +386,7 @@ class LogStreamCallbackHandler(BaseTracer):
]
)
self.send_stream.send_nowait(RunLogPatch(*ops))
self.send(*ops)
finally:
if run.id == self.root_id:
if self.auto_close:
@ -400,21 +404,19 @@ class LogStreamCallbackHandler(BaseTracer):
if index is None:
return
self.send_stream.send_nowait(
RunLogPatch(
{
"op": "add",
"path": f"/logs/{index}/streamed_output_str/-",
"value": token,
},
{
"op": "add",
"path": f"/logs/{index}/streamed_output/-",
"value": chunk.message
if isinstance(chunk, ChatGenerationChunk)
else token,
},
)
self.send(
{
"op": "add",
"path": f"/logs/{index}/streamed_output_str/-",
"value": token,
},
{
"op": "add",
"path": f"/logs/{index}/streamed_output/-",
"value": chunk.message
if isinstance(chunk, ChatGenerationChunk)
else token,
},
)