Allow astream_log to be used inside atrace_as_chain_group (#12558)

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
This commit is contained in:
Nuno Campos 2023-10-30 15:55:16 +00:00 committed by GitHub
parent 8e88ba16a8
commit 7897483819
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 53 additions and 5 deletions

View File

@ -1832,6 +1832,18 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager):
self.parent_run_manager = parent_run_manager self.parent_run_manager = parent_run_manager
self.ended = False self.ended = False
def copy(self) -> AsyncCallbackManagerForChainGroup:
return self.__class__(
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
metadata=self.metadata,
inheritable_metadata=self.inheritable_metadata,
parent_run_manager=self.parent_run_manager,
)
async def on_chain_end( async def on_chain_end(
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
) -> None: ) -> None:

View File

@ -157,12 +157,13 @@ class LogStreamCallbackHandler(BaseTracer):
self.receive_stream = receive_stream self.receive_stream = receive_stream
self._key_map_by_run_id: Dict[UUID, str] = {} self._key_map_by_run_id: Dict[UUID, str] = {}
self._counter_map_by_name: Dict[str, int] = defaultdict(int) self._counter_map_by_name: Dict[str, int] = defaultdict(int)
self.root_id: Optional[UUID] = None
def __aiter__(self) -> AsyncIterator[RunLogPatch]: def __aiter__(self) -> AsyncIterator[RunLogPatch]:
return self.receive_stream.__aiter__() return self.receive_stream.__aiter__()
def include_run(self, run: Run) -> bool: def include_run(self, run: Run) -> bool:
if run.parent_run_id is None: if run.id == self.root_id:
return False return False
run_tags = run.tags or [] run_tags = run.tags or []
@ -199,7 +200,8 @@ class LogStreamCallbackHandler(BaseTracer):
def _on_run_create(self, run: Run) -> None: def _on_run_create(self, run: Run) -> None:
"""Start a run.""" """Start a run."""
if run.parent_run_id is None: if self.root_id is None:
self.root_id = run.id
self.send_stream.send_nowait( self.send_stream.send_nowait(
RunLogPatch( RunLogPatch(
{ {
@ -273,7 +275,7 @@ class LogStreamCallbackHandler(BaseTracer):
) )
) )
finally: finally:
if run.parent_run_id is None: if run.id == self.root_id:
self.send_stream.send_nowait( self.send_stream.send_nowait(
RunLogPatch( RunLogPatch(
{ {

View File

@ -463,7 +463,7 @@ class Runnable(Generic[Input, Output], ABC):
config["callbacks"] = callbacks + [stream] config["callbacks"] = callbacks + [stream]
elif isinstance(callbacks, BaseCallbackManager): elif isinstance(callbacks, BaseCallbackManager):
callbacks = callbacks.copy() callbacks = callbacks.copy()
callbacks.inheritable_handlers.append(stream) callbacks.add_handler(stream, inherit=True)
config["callbacks"] = callbacks config["callbacks"] = callbacks
else: else:
raise ValueError( raise ValueError(

View File

@ -19,7 +19,7 @@ from freezegun import freeze_time
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from syrupy import SnapshotAssertion from syrupy import SnapshotAssertion
from langchain.callbacks.manager import Callbacks, collect_runs from langchain.callbacks.manager import Callbacks, atrace_as_chain_group, collect_runs
from langchain.callbacks.tracers.base import BaseTracer from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch
from langchain.callbacks.tracers.schemas import Run from langchain.callbacks.tracers.schemas import Run
@ -1442,6 +1442,40 @@ async def test_prompt() -> None:
}, },
) )
# nested inside trace_with_chain_group
async with atrace_as_chain_group("a_group") as manager:
stream_log_nested = [
part
async for part in prompt.astream_log(
{"question": "What is your name?"}, config={"callbacks": manager}
)
]
assert len(stream_log_nested[0].ops) == 1
assert stream_log_nested[0].ops[0]["op"] == "replace"
assert stream_log_nested[0].ops[0]["path"] == ""
assert stream_log_nested[0].ops[0]["value"]["logs"] == {}
assert stream_log_nested[0].ops[0]["value"]["final_output"] is None
assert stream_log_nested[0].ops[0]["value"]["streamed_output"] == []
assert isinstance(stream_log_nested[0].ops[0]["value"]["id"], str)
assert stream_log_nested[1:] == [
RunLogPatch(
{
"op": "replace",
"path": "/final_output",
"value": ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
),
}
),
RunLogPatch({"op": "add", "path": "/streamed_output/-", "value": expected}),
]
def test_prompt_template_params() -> None: def test_prompt_template_params() -> None:
prompt = ChatPromptTemplate.from_template( prompt = ChatPromptTemplate.from_template(