only add handlers if they are new (#7504)

When using callbacks, there are times when callbacks can be added
redundantly: for instance sometimes you might need to create an llm with
specific callbacks, but then also create and agent that uses a chain
that has those callbacks already set. This means that "callbacks" might
get passed down again to the llm at predict() time, resulting in
duplicate calls to the `on_llm_start` callback.

For the sake of simplicity, I made it so that langchain never adds an
exact handler/callbacks object in `add_handler`, thus avoiding the
duplicate handler issue.

Tagging @hwchase17 for callback review

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Alec Flett 2023-07-12 00:48:29 -07:00 committed by GitHub
parent 50316f6477
commit 6cdd4b5edc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 11 deletions

View File

@ -498,8 +498,9 @@ class BaseCallbackManager(CallbackManagerMixin):
def add_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
"""Add a handler to the callback manager."""
self.handlers.append(handler)
if inherit:
if handler not in self.handlers:
self.handlers.append(handler)
if inherit and handler not in self.inheritable_handlers:
self.inheritable_handlers.append(handler)
def remove_handler(self, handler: BaseCallbackHandler) -> None:

View File

@ -1,6 +1,6 @@
"""A fake callback handler for testing purposes."""
from itertools import chain
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
from uuid import UUID
from pydantic import BaseModel
@ -22,6 +22,9 @@ class BaseFakeCallbackHandler(BaseModel):
ignore_retriever_: bool = False
ignore_chat_model_: bool = False
# to allow for similar callback handlers that are not technicall equal
fake_id: Union[str, None] = None
# add finer-grained counters for easier debugging of failing tests
chain_starts: int = 0
chain_ends: int = 0

View File

@ -178,10 +178,10 @@ async def test_async_callback_manager_sync_handler() -> None:
def test_callback_manager_inheritance() -> None:
handler1, handler2, handler3, handler4 = (
FakeCallbackHandler(),
FakeCallbackHandler(),
FakeCallbackHandler(),
FakeCallbackHandler(),
FakeCallbackHandler(fake_id="handler1"),
FakeCallbackHandler(fake_id="handler2"),
FakeCallbackHandler(fake_id="handler3"),
FakeCallbackHandler(fake_id="handler4"),
)
callback_manager1 = CallbackManager(handlers=[handler1, handler2])
@ -222,15 +222,22 @@ def test_callback_manager_inheritance() -> None:
assert child_manager2.inheritable_handlers == [handler1]
def test_duplicate_callbacks() -> None:
handler = FakeCallbackHandler()
manager = CallbackManager(handlers=[handler])
manager.add_handler(handler)
assert manager.handlers == [handler]
def test_callback_manager_configure(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test callback manager configuration."""
monkeypatch.setenv("LANGCHAIN_TRACING_V2", "false")
monkeypatch.setenv("LANGCHAIN_TRACING", "false")
handler1, handler2, handler3, handler4 = (
FakeCallbackHandler(),
FakeCallbackHandler(),
FakeCallbackHandler(),
FakeCallbackHandler(),
FakeCallbackHandler(fake_id="handler1"),
FakeCallbackHandler(fake_id="handler2"),
FakeCallbackHandler(fake_id="handler3"),
FakeCallbackHandler(fake_id="handler4"),
)
inheritable_callbacks: List[BaseCallbackHandler] = [handler1, handler2]