mirror of
https://github.com/hwchase17/langchain
synced 2024-11-02 09:40:22 +00:00
Allow astream_log to be used inside atrace_as_chain_group (#12558)
<!-- Thank you for contributing to LangChain! Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes (if applicable), - **Dependencies:** any dependencies required for this change, - **Tag maintainer:** for a quicker response, tag the relevant maintainer (see below), - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/extras` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
parent
8e88ba16a8
commit
7897483819
@ -1832,6 +1832,18 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager):
|
|||||||
self.parent_run_manager = parent_run_manager
|
self.parent_run_manager = parent_run_manager
|
||||||
self.ended = False
|
self.ended = False
|
||||||
|
|
||||||
|
def copy(self) -> AsyncCallbackManagerForChainGroup:
|
||||||
|
return self.__class__(
|
||||||
|
handlers=self.handlers,
|
||||||
|
inheritable_handlers=self.inheritable_handlers,
|
||||||
|
parent_run_id=self.parent_run_id,
|
||||||
|
tags=self.tags,
|
||||||
|
inheritable_tags=self.inheritable_tags,
|
||||||
|
metadata=self.metadata,
|
||||||
|
inheritable_metadata=self.inheritable_metadata,
|
||||||
|
parent_run_manager=self.parent_run_manager,
|
||||||
|
)
|
||||||
|
|
||||||
async def on_chain_end(
|
async def on_chain_end(
|
||||||
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
|
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -157,12 +157,13 @@ class LogStreamCallbackHandler(BaseTracer):
|
|||||||
self.receive_stream = receive_stream
|
self.receive_stream = receive_stream
|
||||||
self._key_map_by_run_id: Dict[UUID, str] = {}
|
self._key_map_by_run_id: Dict[UUID, str] = {}
|
||||||
self._counter_map_by_name: Dict[str, int] = defaultdict(int)
|
self._counter_map_by_name: Dict[str, int] = defaultdict(int)
|
||||||
|
self.root_id: Optional[UUID] = None
|
||||||
|
|
||||||
def __aiter__(self) -> AsyncIterator[RunLogPatch]:
|
def __aiter__(self) -> AsyncIterator[RunLogPatch]:
|
||||||
return self.receive_stream.__aiter__()
|
return self.receive_stream.__aiter__()
|
||||||
|
|
||||||
def include_run(self, run: Run) -> bool:
|
def include_run(self, run: Run) -> bool:
|
||||||
if run.parent_run_id is None:
|
if run.id == self.root_id:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
run_tags = run.tags or []
|
run_tags = run.tags or []
|
||||||
@ -199,7 +200,8 @@ class LogStreamCallbackHandler(BaseTracer):
|
|||||||
|
|
||||||
def _on_run_create(self, run: Run) -> None:
|
def _on_run_create(self, run: Run) -> None:
|
||||||
"""Start a run."""
|
"""Start a run."""
|
||||||
if run.parent_run_id is None:
|
if self.root_id is None:
|
||||||
|
self.root_id = run.id
|
||||||
self.send_stream.send_nowait(
|
self.send_stream.send_nowait(
|
||||||
RunLogPatch(
|
RunLogPatch(
|
||||||
{
|
{
|
||||||
@ -273,7 +275,7 @@ class LogStreamCallbackHandler(BaseTracer):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
if run.parent_run_id is None:
|
if run.id == self.root_id:
|
||||||
self.send_stream.send_nowait(
|
self.send_stream.send_nowait(
|
||||||
RunLogPatch(
|
RunLogPatch(
|
||||||
{
|
{
|
||||||
|
@ -463,7 +463,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
config["callbacks"] = callbacks + [stream]
|
config["callbacks"] = callbacks + [stream]
|
||||||
elif isinstance(callbacks, BaseCallbackManager):
|
elif isinstance(callbacks, BaseCallbackManager):
|
||||||
callbacks = callbacks.copy()
|
callbacks = callbacks.copy()
|
||||||
callbacks.inheritable_handlers.append(stream)
|
callbacks.add_handler(stream, inherit=True)
|
||||||
config["callbacks"] = callbacks
|
config["callbacks"] = callbacks
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -19,7 +19,7 @@ from freezegun import freeze_time
|
|||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
from syrupy import SnapshotAssertion
|
from syrupy import SnapshotAssertion
|
||||||
|
|
||||||
from langchain.callbacks.manager import Callbacks, collect_runs
|
from langchain.callbacks.manager import Callbacks, atrace_as_chain_group, collect_runs
|
||||||
from langchain.callbacks.tracers.base import BaseTracer
|
from langchain.callbacks.tracers.base import BaseTracer
|
||||||
from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch
|
from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch
|
||||||
from langchain.callbacks.tracers.schemas import Run
|
from langchain.callbacks.tracers.schemas import Run
|
||||||
@ -1442,6 +1442,40 @@ async def test_prompt() -> None:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# nested inside trace_with_chain_group
|
||||||
|
|
||||||
|
async with atrace_as_chain_group("a_group") as manager:
|
||||||
|
stream_log_nested = [
|
||||||
|
part
|
||||||
|
async for part in prompt.astream_log(
|
||||||
|
{"question": "What is your name?"}, config={"callbacks": manager}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(stream_log_nested[0].ops) == 1
|
||||||
|
assert stream_log_nested[0].ops[0]["op"] == "replace"
|
||||||
|
assert stream_log_nested[0].ops[0]["path"] == ""
|
||||||
|
assert stream_log_nested[0].ops[0]["value"]["logs"] == {}
|
||||||
|
assert stream_log_nested[0].ops[0]["value"]["final_output"] is None
|
||||||
|
assert stream_log_nested[0].ops[0]["value"]["streamed_output"] == []
|
||||||
|
assert isinstance(stream_log_nested[0].ops[0]["value"]["id"], str)
|
||||||
|
|
||||||
|
assert stream_log_nested[1:] == [
|
||||||
|
RunLogPatch(
|
||||||
|
{
|
||||||
|
"op": "replace",
|
||||||
|
"path": "/final_output",
|
||||||
|
"value": ChatPromptValue(
|
||||||
|
messages=[
|
||||||
|
SystemMessage(content="You are a nice assistant."),
|
||||||
|
HumanMessage(content="What is your name?"),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
RunLogPatch({"op": "add", "path": "/streamed_output/-", "value": expected}),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_template_params() -> None:
|
def test_prompt_template_params() -> None:
|
||||||
prompt = ChatPromptTemplate.from_template(
|
prompt = ChatPromptTemplate.from_template(
|
||||||
|
Loading…
Reference in New Issue
Block a user