mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
only add handlers if they are new (#7504)
When using callbacks, there are times when callbacks can be added redundantly: for instance sometimes you might need to create an llm with specific callbacks, but then also create and agent that uses a chain that has those callbacks already set. This means that "callbacks" might get passed down again to the llm at predict() time, resulting in duplicate calls to the `on_llm_start` callback. For the sake of simplicity, I made it so that langchain never adds an exact handler/callbacks object in `add_handler`, thus avoiding the duplicate handler issue. Tagging @hwchase17 for callback review --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
50316f6477
commit
6cdd4b5edc
@ -498,8 +498,9 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
|
||||
def add_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
|
||||
"""Add a handler to the callback manager."""
|
||||
self.handlers.append(handler)
|
||||
if inherit:
|
||||
if handler not in self.handlers:
|
||||
self.handlers.append(handler)
|
||||
if inherit and handler not in self.inheritable_handlers:
|
||||
self.inheritable_handlers.append(handler)
|
||||
|
||||
def remove_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""A fake callback handler for testing purposes."""
|
||||
from itertools import chain
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -22,6 +22,9 @@ class BaseFakeCallbackHandler(BaseModel):
|
||||
ignore_retriever_: bool = False
|
||||
ignore_chat_model_: bool = False
|
||||
|
||||
# to allow for similar callback handlers that are not technicall equal
|
||||
fake_id: Union[str, None] = None
|
||||
|
||||
# add finer-grained counters for easier debugging of failing tests
|
||||
chain_starts: int = 0
|
||||
chain_ends: int = 0
|
||||
|
@ -178,10 +178,10 @@ async def test_async_callback_manager_sync_handler() -> None:
|
||||
|
||||
def test_callback_manager_inheritance() -> None:
|
||||
handler1, handler2, handler3, handler4 = (
|
||||
FakeCallbackHandler(),
|
||||
FakeCallbackHandler(),
|
||||
FakeCallbackHandler(),
|
||||
FakeCallbackHandler(),
|
||||
FakeCallbackHandler(fake_id="handler1"),
|
||||
FakeCallbackHandler(fake_id="handler2"),
|
||||
FakeCallbackHandler(fake_id="handler3"),
|
||||
FakeCallbackHandler(fake_id="handler4"),
|
||||
)
|
||||
|
||||
callback_manager1 = CallbackManager(handlers=[handler1, handler2])
|
||||
@ -222,15 +222,22 @@ def test_callback_manager_inheritance() -> None:
|
||||
assert child_manager2.inheritable_handlers == [handler1]
|
||||
|
||||
|
||||
def test_duplicate_callbacks() -> None:
|
||||
handler = FakeCallbackHandler()
|
||||
manager = CallbackManager(handlers=[handler])
|
||||
manager.add_handler(handler)
|
||||
assert manager.handlers == [handler]
|
||||
|
||||
|
||||
def test_callback_manager_configure(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test callback manager configuration."""
|
||||
monkeypatch.setenv("LANGCHAIN_TRACING_V2", "false")
|
||||
monkeypatch.setenv("LANGCHAIN_TRACING", "false")
|
||||
handler1, handler2, handler3, handler4 = (
|
||||
FakeCallbackHandler(),
|
||||
FakeCallbackHandler(),
|
||||
FakeCallbackHandler(),
|
||||
FakeCallbackHandler(),
|
||||
FakeCallbackHandler(fake_id="handler1"),
|
||||
FakeCallbackHandler(fake_id="handler2"),
|
||||
FakeCallbackHandler(fake_id="handler3"),
|
||||
FakeCallbackHandler(fake_id="handler4"),
|
||||
)
|
||||
|
||||
inheritable_callbacks: List[BaseCallbackHandler] = [handler1, handler2]
|
||||
|
Loading…
Reference in New Issue
Block a user