forked from Archives/langchain
Add BaseCallbackHandler and CallbackManager (#476)
This commit is contained in:
parent
0f1df0dc2c
commit
46b31626b5
@ -9,6 +9,7 @@ from pydantic import BaseModel, root_validator
|
|||||||
|
|
||||||
import langchain
|
import langchain
|
||||||
from langchain.agents.tools import Tool
|
from langchain.agents.tools import Tool
|
||||||
|
from langchain.callbacks import get_callback_manager
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.input import get_color_mapping
|
from langchain.input import get_color_mapping
|
||||||
@ -213,8 +214,12 @@ class AgentExecutor(Chain, BaseModel):
|
|||||||
# And then we lookup the tool
|
# And then we lookup the tool
|
||||||
if output.tool in name_to_tool_map:
|
if output.tool in name_to_tool_map:
|
||||||
chain = name_to_tool_map[output.tool]
|
chain = name_to_tool_map[output.tool]
|
||||||
|
get_callback_manager().on_tool_start(
|
||||||
|
{"name": str(chain)[:60] + "..."}, output.tool, output.tool_input
|
||||||
|
)
|
||||||
# We then call the tool on the tool input to get an observation
|
# We then call the tool on the tool input to get an observation
|
||||||
observation = chain(output.tool_input)
|
observation = chain(output.tool_input)
|
||||||
|
get_callback_manager().on_tool_end(observation)
|
||||||
color = color_mapping[output.tool]
|
color = color_mapping[output.tool]
|
||||||
else:
|
else:
|
||||||
observation = f"{output.tool} is not a valid tool, try another one."
|
observation = f"{output.tool} is not a valid tool, try another one."
|
||||||
|
8
langchain/callbacks/__init__.py
Normal file
8
langchain/callbacks/__init__.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
"""Callback handlers that allow listening to events in LangChain."""
|
||||||
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
|
from langchain.callbacks.shared import SharedCallbackManager
|
||||||
|
|
||||||
|
|
||||||
|
def get_callback_manager() -> BaseCallbackManager:
|
||||||
|
"""Return the shared callback manager."""
|
||||||
|
return SharedCallbackManager()
|
137
langchain/callbacks/base.py
Normal file
137
langchain/callbacks/base.py
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
"""Base callback handler that can be used to handle callbacks from langchain."""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from langchain.schema import LLMResult
|
||||||
|
|
||||||
|
|
||||||
|
class BaseCallbackHandler(ABC):
|
||||||
|
"""Base callback handler that can be used to handle callbacks from langchain."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_llm_start(
|
||||||
|
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM starts running."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_llm_end(
|
||||||
|
self,
|
||||||
|
response: LLMResult,
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM ends running."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_llm_error(self, error: Exception) -> None:
|
||||||
|
"""Run when LLM errors."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_chain_start(
|
||||||
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
|
||||||
|
) -> None:
|
||||||
|
"""Run when chain starts running."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||||
|
"""Run when chain ends running."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_chain_error(self, error: Exception) -> None:
|
||||||
|
"""Run when chain errors."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_tool_start(
|
||||||
|
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
|
||||||
|
) -> None:
|
||||||
|
"""Run when tool starts running."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_tool_end(self, output: str) -> None:
|
||||||
|
"""Run when tool ends running."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_tool_error(self, error: Exception) -> None:
|
||||||
|
"""Run when tool errors."""
|
||||||
|
|
||||||
|
|
||||||
|
class BaseCallbackManager(BaseCallbackHandler, ABC):
|
||||||
|
"""Base callback manager that can be used to handle callbacks from LangChain."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add_handler(self, callback: BaseCallbackHandler) -> None:
|
||||||
|
"""Add a handler to the callback manager."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def remove_handler(self, handler: BaseCallbackHandler) -> None:
|
||||||
|
"""Remove a handler from the callback manager."""
|
||||||
|
|
||||||
|
|
||||||
|
class CallbackManager(BaseCallbackManager):
|
||||||
|
"""Callback manager that can be used to handle callbacks from langchain."""
|
||||||
|
|
||||||
|
def __init__(self, handlers: List[BaseCallbackHandler]) -> None:
|
||||||
|
"""Initialize the callback manager."""
|
||||||
|
self.handlers = handlers
|
||||||
|
|
||||||
|
def on_llm_start(
|
||||||
|
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM starts running."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
handler.on_llm_start(serialized, prompts, **extra)
|
||||||
|
|
||||||
|
def on_llm_end(
|
||||||
|
self,
|
||||||
|
response: LLMResult,
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM ends running."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
handler.on_llm_end(response)
|
||||||
|
|
||||||
|
def on_llm_error(self, error: Exception) -> None:
|
||||||
|
"""Run when LLM errors."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
handler.on_llm_error(error)
|
||||||
|
|
||||||
|
def on_chain_start(
|
||||||
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
|
||||||
|
) -> None:
|
||||||
|
"""Run when chain starts running."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
handler.on_chain_start(serialized, inputs, **extra)
|
||||||
|
|
||||||
|
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||||
|
"""Run when chain ends running."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
handler.on_chain_end(outputs)
|
||||||
|
|
||||||
|
def on_chain_error(self, error: Exception) -> None:
|
||||||
|
"""Run when chain errors."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
handler.on_chain_error(error)
|
||||||
|
|
||||||
|
def on_tool_start(
|
||||||
|
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
|
||||||
|
) -> None:
|
||||||
|
"""Run when tool starts running."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
handler.on_tool_start(serialized, action, tool_input, **extra)
|
||||||
|
|
||||||
|
def on_tool_end(self, output: str) -> None:
|
||||||
|
"""Run when tool ends running."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
handler.on_tool_end(output)
|
||||||
|
|
||||||
|
def on_tool_error(self, error: Exception) -> None:
|
||||||
|
"""Run when tool errors."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
handler.on_tool_error(error)
|
||||||
|
|
||||||
|
def add_handler(self, handler: BaseCallbackHandler) -> None:
|
||||||
|
"""Add a handler to the callback manager."""
|
||||||
|
self.handlers.append(handler)
|
||||||
|
|
||||||
|
def remove_handler(self, handler: BaseCallbackHandler) -> None:
|
||||||
|
"""Remove a handler from the callback manager."""
|
||||||
|
self.handlers.remove(handler)
|
101
langchain/callbacks/shared.py
Normal file
101
langchain/callbacks/shared.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
"""A shared CallbackManager."""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from langchain.callbacks.base import (
|
||||||
|
BaseCallbackHandler,
|
||||||
|
BaseCallbackManager,
|
||||||
|
CallbackManager,
|
||||||
|
)
|
||||||
|
from langchain.schema import LLMResult
|
||||||
|
|
||||||
|
|
||||||
|
class Singleton:
|
||||||
|
"""A thread-safe singleton class that can be inherited from."""
|
||||||
|
|
||||||
|
_instance = None
|
||||||
|
_lock = threading.Lock()
|
||||||
|
|
||||||
|
def __new__(cls) -> Any:
|
||||||
|
"""Create a new shared instance of the class."""
|
||||||
|
if cls._instance is None:
|
||||||
|
with cls._lock:
|
||||||
|
# Another thread could have created the instance
|
||||||
|
# before we acquired the lock. So check that the
|
||||||
|
# instance is still nonexistent.
|
||||||
|
if not cls._instance:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
|
||||||
|
class SharedCallbackManager(Singleton, BaseCallbackManager):
|
||||||
|
"""A thread-safe singleton CallbackManager."""
|
||||||
|
|
||||||
|
_callback_manager: CallbackManager = CallbackManager([])
|
||||||
|
|
||||||
|
def on_llm_start(
|
||||||
|
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM starts running."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_llm_start(serialized, prompts, **extra)
|
||||||
|
|
||||||
|
def on_llm_end(
|
||||||
|
self,
|
||||||
|
response: LLMResult,
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM ends running."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_llm_end(response)
|
||||||
|
|
||||||
|
def on_llm_error(self, error: Exception) -> None:
|
||||||
|
"""Run when LLM errors."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_llm_error(error)
|
||||||
|
|
||||||
|
def on_chain_start(
|
||||||
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
|
||||||
|
) -> None:
|
||||||
|
"""Run when chain starts running."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_chain_start(serialized, inputs, **extra)
|
||||||
|
|
||||||
|
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||||
|
"""Run when chain ends running."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_chain_end(outputs)
|
||||||
|
|
||||||
|
def on_chain_error(self, error: Exception) -> None:
|
||||||
|
"""Run when chain errors."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_chain_error(error)
|
||||||
|
|
||||||
|
def on_tool_start(
|
||||||
|
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
|
||||||
|
) -> None:
|
||||||
|
"""Run when tool starts running."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_tool_start(
|
||||||
|
serialized, action, tool_input, **extra
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_tool_end(self, output: str) -> None:
|
||||||
|
"""Run when tool ends running."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_tool_end(output)
|
||||||
|
|
||||||
|
def on_tool_error(self, error: Exception) -> None:
|
||||||
|
"""Run when tool errors."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_tool_error(error)
|
||||||
|
|
||||||
|
def add_handler(self, callback: BaseCallbackHandler) -> None:
|
||||||
|
"""Add a callback to the callback manager."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.add_handler(callback)
|
||||||
|
|
||||||
|
def remove_handler(self, callback: BaseCallbackHandler) -> None:
|
||||||
|
"""Remove a callback from the callback manager."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.remove_handler(callback)
|
@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Union
|
|||||||
from pydantic import BaseModel, Extra, Field
|
from pydantic import BaseModel, Extra, Field
|
||||||
|
|
||||||
import langchain
|
import langchain
|
||||||
|
from langchain.callbacks import get_callback_manager
|
||||||
|
|
||||||
|
|
||||||
class Memory(BaseModel, ABC):
|
class Memory(BaseModel, ABC):
|
||||||
@ -109,7 +110,9 @@ class Chain(BaseModel, ABC):
|
|||||||
print(
|
print(
|
||||||
f"\n\n\033[1m> Entering new {self.__class__.__name__} chain...\033[0m"
|
f"\n\n\033[1m> Entering new {self.__class__.__name__} chain...\033[0m"
|
||||||
)
|
)
|
||||||
|
get_callback_manager().on_chain_start({"name": self.__class__.__name__}, inputs)
|
||||||
outputs = self._call(inputs)
|
outputs = self._call(inputs)
|
||||||
|
get_callback_manager().on_chain_end(outputs)
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(f"\n\033[1m> Finished {self.__class__.__name__} chain.\033[0m")
|
print(f"\n\033[1m> Finished {self.__class__.__name__} chain.\033[0m")
|
||||||
self._validate_outputs(outputs)
|
self._validate_outputs(outputs)
|
||||||
|
@ -5,8 +5,9 @@ from pydantic import BaseModel, Extra
|
|||||||
|
|
||||||
import langchain
|
import langchain
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.llms.base import BaseLLM, LLMResult
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
|
from langchain.schema import LLMResult
|
||||||
|
|
||||||
|
|
||||||
class LLMChain(Chain, BaseModel):
|
class LLMChain(Chain, BaseModel):
|
||||||
|
@ -2,23 +2,14 @@
|
|||||||
import json
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Union
|
from typing import Any, Dict, List, Mapping, Optional, Union
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from pydantic import BaseModel, Extra
|
from pydantic import BaseModel, Extra
|
||||||
|
|
||||||
import langchain
|
import langchain
|
||||||
from langchain.schema import Generation
|
from langchain.callbacks import get_callback_manager
|
||||||
|
from langchain.schema import Generation, LLMResult
|
||||||
|
|
||||||
class LLMResult(NamedTuple):
|
|
||||||
"""Class that contains all relevant information for an LLM Result."""
|
|
||||||
|
|
||||||
generations: List[List[Generation]]
|
|
||||||
"""List of the things generated. This is List[List[]] because
|
|
||||||
each input could have multiple generations."""
|
|
||||||
llm_output: Optional[dict] = None
|
|
||||||
"""For arbitrary LLM provider specific output."""
|
|
||||||
|
|
||||||
|
|
||||||
class BaseLLM(BaseModel, ABC):
|
class BaseLLM(BaseModel, ABC):
|
||||||
@ -48,7 +39,12 @@ class BaseLLM(BaseModel, ABC):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Asked to cache, but no cache found at `langchain.cache`."
|
"Asked to cache, but no cache found at `langchain.cache`."
|
||||||
)
|
)
|
||||||
return self._generate(prompts, stop=stop)
|
get_callback_manager().on_llm_start(
|
||||||
|
{"name": self.__class__.__name__}, prompts
|
||||||
|
)
|
||||||
|
output = self._generate(prompts, stop=stop)
|
||||||
|
get_callback_manager().on_llm_end(output)
|
||||||
|
return output
|
||||||
params = self._llm_dict()
|
params = self._llm_dict()
|
||||||
params["stop"] = stop
|
params["stop"] = stop
|
||||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
@ -62,7 +58,11 @@ class BaseLLM(BaseModel, ABC):
|
|||||||
else:
|
else:
|
||||||
missing_prompts.append(prompt)
|
missing_prompts.append(prompt)
|
||||||
missing_prompt_idxs.append(i)
|
missing_prompt_idxs.append(i)
|
||||||
|
get_callback_manager().on_llm_start(
|
||||||
|
{"name": self.__class__.__name__}, missing_prompts
|
||||||
|
)
|
||||||
new_results = self._generate(missing_prompts, stop=stop)
|
new_results = self._generate(missing_prompts, stop=stop)
|
||||||
|
get_callback_manager().on_llm_end(new_results)
|
||||||
for i, result in enumerate(new_results.generations):
|
for i, result in enumerate(new_results.generations):
|
||||||
existing_prompts[i] = result
|
existing_prompts[i] = result
|
||||||
prompt = prompts[i]
|
prompt = prompts[i]
|
||||||
|
@ -4,8 +4,8 @@ from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union
|
|||||||
|
|
||||||
from pydantic import BaseModel, Extra, Field, root_validator
|
from pydantic import BaseModel, Extra, Field, root_validator
|
||||||
|
|
||||||
from langchain.llms.base import BaseLLM, LLMResult
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.schema import Generation
|
from langchain.schema import Generation, LLMResult
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Common schema objects."""
|
"""Common schema objects."""
|
||||||
|
|
||||||
from typing import NamedTuple
|
from typing import List, NamedTuple, Optional
|
||||||
|
|
||||||
|
|
||||||
class AgentAction(NamedTuple):
|
class AgentAction(NamedTuple):
|
||||||
@ -24,3 +24,13 @@ class Generation(NamedTuple):
|
|||||||
text: str
|
text: str
|
||||||
"""Generated text output."""
|
"""Generated text output."""
|
||||||
# TODO: add log probs
|
# TODO: add log probs
|
||||||
|
|
||||||
|
|
||||||
|
class LLMResult(NamedTuple):
|
||||||
|
"""Class that contains all relevant information for an LLM Result."""
|
||||||
|
|
||||||
|
generations: List[List[Generation]]
|
||||||
|
"""List of the things generated. This is List[List[]] because
|
||||||
|
each input could have multiple generations."""
|
||||||
|
llm_output: Optional[dict] = None
|
||||||
|
"""For arbitrary LLM provider specific output."""
|
||||||
|
1
tests/unit_tests/callbacks/__init__.py
Normal file
1
tests/unit_tests/callbacks/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""Tests for correct functioning of callbacks."""
|
107
tests/unit_tests/callbacks/test_callback_manager.py
Normal file
107
tests/unit_tests/callbacks/test_callback_manager.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
"""Test CallbackManager."""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from langchain.callbacks.base import (
|
||||||
|
BaseCallbackHandler,
|
||||||
|
BaseCallbackManager,
|
||||||
|
CallbackManager,
|
||||||
|
)
|
||||||
|
from langchain.callbacks.shared import SharedCallbackManager
|
||||||
|
from langchain.schema import LLMResult
|
||||||
|
|
||||||
|
|
||||||
|
class FakeCallbackHandler(BaseCallbackHandler):
|
||||||
|
"""Fake callback handler for testing."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize the mock callback handler."""
|
||||||
|
self.starts = 0
|
||||||
|
self.ends = 0
|
||||||
|
self.errors = 0
|
||||||
|
|
||||||
|
def on_llm_start(
|
||||||
|
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM starts running."""
|
||||||
|
self.starts += 1
|
||||||
|
|
||||||
|
def on_llm_end(
|
||||||
|
self,
|
||||||
|
response: LLMResult,
|
||||||
|
) -> None:
|
||||||
|
"""Run when LLM ends running."""
|
||||||
|
self.ends += 1
|
||||||
|
|
||||||
|
def on_llm_error(self, error: Exception) -> None:
|
||||||
|
"""Run when LLM errors."""
|
||||||
|
self.errors += 1
|
||||||
|
|
||||||
|
def on_chain_start(
|
||||||
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
|
||||||
|
) -> None:
|
||||||
|
"""Run when chain starts running."""
|
||||||
|
self.starts += 1
|
||||||
|
|
||||||
|
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||||
|
"""Run when chain ends running."""
|
||||||
|
self.ends += 1
|
||||||
|
|
||||||
|
def on_chain_error(self, error: Exception) -> None:
|
||||||
|
"""Run when chain errors."""
|
||||||
|
self.errors += 1
|
||||||
|
|
||||||
|
def on_tool_start(
|
||||||
|
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
|
||||||
|
) -> None:
|
||||||
|
"""Run when tool starts running."""
|
||||||
|
self.starts += 1
|
||||||
|
|
||||||
|
def on_tool_end(self, output: str) -> None:
|
||||||
|
"""Run when tool ends running."""
|
||||||
|
self.ends += 1
|
||||||
|
|
||||||
|
def on_tool_error(self, error: Exception) -> None:
|
||||||
|
"""Run when tool errors."""
|
||||||
|
self.errors += 1
|
||||||
|
|
||||||
|
|
||||||
|
def _test_callback_manager(
|
||||||
|
manager: BaseCallbackManager, *handlers: FakeCallbackHandler
|
||||||
|
) -> None:
|
||||||
|
"""Test the CallbackManager."""
|
||||||
|
manager.on_llm_start({}, [])
|
||||||
|
manager.on_llm_end(LLMResult(generations=[]))
|
||||||
|
manager.on_llm_error(Exception())
|
||||||
|
manager.on_chain_start({}, {})
|
||||||
|
manager.on_chain_end({})
|
||||||
|
manager.on_chain_error(Exception())
|
||||||
|
manager.on_tool_start({}, "", "")
|
||||||
|
manager.on_tool_end("")
|
||||||
|
manager.on_tool_error(Exception())
|
||||||
|
for handler in handlers:
|
||||||
|
assert handler.starts == 3
|
||||||
|
assert handler.ends == 3
|
||||||
|
assert handler.errors == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_callback_manager() -> None:
|
||||||
|
"""Test the CallbackManager."""
|
||||||
|
handler1 = FakeCallbackHandler()
|
||||||
|
handler2 = FakeCallbackHandler()
|
||||||
|
manager = CallbackManager([handler1, handler2])
|
||||||
|
_test_callback_manager(manager, handler1, handler2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_shared_callback_manager() -> None:
|
||||||
|
"""Test the SharedCallbackManager."""
|
||||||
|
manager1 = SharedCallbackManager()
|
||||||
|
manager2 = SharedCallbackManager()
|
||||||
|
|
||||||
|
assert manager1 is manager2
|
||||||
|
|
||||||
|
handler1 = FakeCallbackHandler()
|
||||||
|
handler2 = FakeCallbackHandler()
|
||||||
|
manager1.add_handler(handler1)
|
||||||
|
manager2.add_handler(handler2)
|
||||||
|
_test_callback_manager(manager1, handler1, handler2)
|
@ -7,8 +7,8 @@ from pydantic import BaseModel
|
|||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.embeddings.hyde.base import HypotheticalDocumentEmbedder
|
from langchain.embeddings.hyde.base import HypotheticalDocumentEmbedder
|
||||||
from langchain.embeddings.hyde.prompts import PROMPT_MAP
|
from langchain.embeddings.hyde.prompts import PROMPT_MAP
|
||||||
from langchain.llms.base import BaseLLM, LLMResult
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.schema import Generation
|
from langchain.schema import Generation, LLMResult
|
||||||
|
|
||||||
|
|
||||||
class FakeEmbeddings(Embeddings):
|
class FakeEmbeddings(Embeddings):
|
||||||
|
Loading…
Reference in New Issue
Block a user