From 9e04c34e2070b7e615e3be276f453232d554527d Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Wed, 4 Jan 2023 07:54:25 -0800 Subject: [PATCH] Add BaseCallbackHandler and CallbackManager (#478) Co-authored-by: Ankush Gola <9536492+agola11@users.noreply.github.com> --- docs/modules/agents/getting_started.ipynb | 19 +- docs/modules/memory/getting_started.ipynb | 4 +- langchain/__init__.py | 5 +- langchain/agents/agent.py | 44 ++++- langchain/agents/loading.py | 17 +- langchain/callbacks/__init__.py | 20 ++ langchain/callbacks/base.py | 177 ++++++++++++++++++ langchain/callbacks/shared.py | 114 +++++++++++ langchain/callbacks/stdout.py | 82 ++++++++ langchain/callbacks/streamlit.py | 77 ++++++++ langchain/chains/api/base.py | 5 +- langchain/chains/base.py | 41 +++- langchain/chains/llm.py | 11 +- langchain/chains/llm_bash/base.py | 9 +- langchain/chains/llm_math/base.py | 9 +- langchain/chains/pal/base.py | 3 +- langchain/chains/qa_with_sources/__init__.py | 11 +- .../chains/question_answering/__init__.py | 11 +- langchain/chains/sequential.py | 6 +- langchain/chains/sql_database/base.py | 13 +- langchain/chains/summarize/__init__.py | 11 +- langchain/input.py | 12 +- langchain/llms/base.py | 56 ++++-- langchain/llms/openai.py | 4 +- langchain/logger.py | 71 ------- langchain/schema.py | 12 +- tests/unit_tests/agents/test_agent.py | 118 +++++++++++- tests/unit_tests/callbacks/__init__.py | 1 + .../callbacks/fake_callback_handler.py | 67 +++++++ .../callbacks/test_callback_manager.py | 97 ++++++++++ tests/unit_tests/chains/test_base.py | 30 +++ tests/unit_tests/llms/test_callbacks.py | 30 +++ tests/unit_tests/test_hyde.py | 4 +- 33 files changed, 1014 insertions(+), 177 deletions(-) create mode 100644 langchain/callbacks/__init__.py create mode 100644 langchain/callbacks/base.py create mode 100644 langchain/callbacks/shared.py create mode 100644 langchain/callbacks/stdout.py create mode 100644 langchain/callbacks/streamlit.py delete mode 100644 langchain/logger.py create mode 100644 tests/unit_tests/callbacks/__init__.py create mode 100644 tests/unit_tests/callbacks/fake_callback_handler.py create mode 100644 tests/unit_tests/callbacks/test_callback_manager.py create mode 100644 tests/unit_tests/llms/test_callbacks.py diff --git a/docs/modules/agents/getting_started.ipynb b/docs/modules/agents/getting_started.ipynb index b8360e18b7..766d144a12 100644 --- a/docs/modules/agents/getting_started.ipynb +++ b/docs/modules/agents/getting_started.ipynb @@ -51,7 +51,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "id": "0728f0d9", "metadata": {}, "outputs": [], @@ -69,7 +69,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "id": "ba4e7618", "metadata": {}, "outputs": [], @@ -87,7 +87,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "id": "03208e2b", "metadata": {}, "outputs": [], @@ -105,7 +105,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "id": "244ee75c", "metadata": {}, "outputs": [ @@ -131,7 +131,7 @@ "\u001b[0m\n", "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", "Final Answer: Harry Styles is Olivia Wilde's boyfriend and his current age raised to the 0.23 power is 2.1520202182226886.\u001b[0m\n", - "\u001b[1m> Finished AgentExecutor chain.\u001b[0m\n" + "\u001b[1m> Finished chain.\u001b[0m\n" ] }, { @@ -140,7 +140,7 @@ "\"Harry Styles is Olivia Wilde's boyfriend and his current age raised to the 0.23 power is 2.1520202182226886.\"" ] }, - "execution_count": 6, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -166,12 +166,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.0" - }, - "vscode": { - "interpreter": { - "hash": "b1677b440931f40d89ef8be7bf03acb108ce003de0ac9b18e8d43753ea2e7103" - } + "version": "3.10.9" } }, "nbformat": 4, diff --git a/docs/modules/memory/getting_started.ipynb b/docs/modules/memory/getting_started.ipynb index f80981671b..5818691df3 100644 --- a/docs/modules/memory/getting_started.ipynb +++ b/docs/modules/memory/getting_started.ipynb @@ -60,7 +60,7 @@ "Human: Hi there!\n", "AI:\u001b[0m\n", "\n", - "\u001b[1m> Finished ConversationChain chain.\u001b[0m\n" + "\u001b[1m> Finished chain.\u001b[0m\n" ] }, { @@ -101,7 +101,7 @@ "Human: I'm doing well! Just having a conversation with an AI.\n", "AI:\u001b[0m\n", "\n", - "\u001b[1m> Finished ConversationChain chain.\u001b[0m\n" + "\u001b[1m> Finished chain.\u001b[0m\n" ] }, { diff --git a/langchain/__init__.py b/langchain/__init__.py index 537d209e6f..071b8a3e28 100644 --- a/langchain/__init__.py +++ b/langchain/__init__.py @@ -4,6 +4,7 @@ from typing import Optional from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain from langchain.cache import BaseCache +from langchain.callbacks import set_default_callback_manager, set_handler from langchain.chains import ( ConversationChain, LLMBashChain, @@ -19,7 +20,6 @@ from langchain.chains import ( from langchain.docstore import InMemoryDocstore, Wikipedia from langchain.llms import Cohere, HuggingFaceHub, OpenAI from langchain.llms.huggingface_pipeline import HuggingFacePipeline -from langchain.logger import BaseLogger, StdOutLogger from langchain.prompts import ( BasePromptTemplate, FewShotPromptTemplate, @@ -31,9 +31,9 @@ from langchain.sql_database import SQLDatabase from langchain.utilities.google_search import GoogleSearchAPIWrapper from langchain.vectorstores import FAISS, ElasticVectorSearch -logger: BaseLogger = StdOutLogger() verbose: bool = False llm_cache: Optional[BaseCache] = None +set_default_callback_manager() __all__ = [ "LLMChain", @@ -65,4 +65,5 @@ __all__ = [ "VectorDBQAWithSourcesChain", "QAWithSourcesChain", "PALChain", + "set_handler", ] diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index f2b92320d1..fd96f64be7 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -7,8 +7,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union from pydantic import BaseModel, root_validator -import langchain from langchain.agents.tools import Tool +from langchain.callbacks.base import BaseCallbackManager from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.input import get_color_mapping @@ -46,7 +46,7 @@ class Agent(BaseModel): def plan( self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any - ) -> Union[AgentFinish, AgentAction]: + ) -> Union[AgentAction, AgentFinish]: """Given input, decided what to do. Args: @@ -132,10 +132,19 @@ class Agent(BaseModel): pass @classmethod - def from_llm_and_tools(cls, llm: BaseLLM, tools: List[Tool]) -> Agent: + def from_llm_and_tools( + cls, + llm: BaseLLM, + tools: List[Tool], + callback_manager: Optional[BaseCallbackManager] = None, + ) -> Agent: """Construct an agent from an LLM and tools.""" cls._validate_tools(tools) - llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools)) + llm_chain = LLMChain( + llm=llm, + prompt=cls.create_prompt(tools), + callback_manager=callback_manager, + ) return cls(llm_chain=llm_chain) def return_stopped_response( @@ -194,10 +203,16 @@ class AgentExecutor(Chain, BaseModel): @classmethod def from_agent_and_tools( - cls, agent: Agent, tools: List[Tool], **kwargs: Any + cls, + agent: Agent, + tools: List[Tool], + callback_manager: Optional[BaseCallbackManager] = None, + **kwargs: Any, ) -> AgentExecutor: """Create from agent and tools.""" - return cls(agent=agent, tools=tools, **kwargs) + return cls( + agent=agent, tools=tools, callback_manager=callback_manager, **kwargs + ) @property def input_keys(self) -> List[str]: @@ -244,24 +259,31 @@ class AgentExecutor(Chain, BaseModel): # If the tool chosen is the finishing tool, then we end and return. if isinstance(output, AgentFinish): if self.verbose: - langchain.logger.log_agent_end(output, color="green") + self.callback_manager.on_agent_finish(output, color="green") final_output = output.return_values if self.return_intermediate_steps: final_output["intermediate_steps"] = intermediate_steps return final_output - if self.verbose: - langchain.logger.log_agent_action(output, color="green") + # And then we lookup the tool if output.tool in name_to_tool_map: chain = name_to_tool_map[output.tool] + if self.verbose: + self.callback_manager.on_tool_start( + {"name": str(chain)[:60] + "..."}, output, color="green" + ) # We then call the tool on the tool input to get an observation observation = chain(output.tool_input) color = color_mapping[output.tool] else: + if self.verbose: + self.callback_manager.on_tool_start( + {"name": "N/A"}, output, color="green" + ) observation = f"{output.tool} is not a valid tool, try another one." color = None if self.verbose: - langchain.logger.log_agent_observation( + self.callback_manager.on_tool_end( observation, color=color, observation_prefix=self.agent.observation_prefix, @@ -272,6 +294,8 @@ class AgentExecutor(Chain, BaseModel): output = self.agent.return_stopped_response( self.early_stopping_method, intermediate_steps, **inputs ) + if self.verbose: + self.callback_manager.on_agent_finish(output, color="green") final_output = output.return_values if self.return_intermediate_steps: final_output["intermediate_steps"] = intermediate_steps diff --git a/langchain/agents/loading.py b/langchain/agents/loading.py index f1823b7458..cb33d6607a 100644 --- a/langchain/agents/loading.py +++ b/langchain/agents/loading.py @@ -1,11 +1,12 @@ """Load agent.""" -from typing import Any, List +from typing import Any, List, Optional from langchain.agents.agent import AgentExecutor from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.react.base import ReActDocstoreAgent from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent from langchain.agents.tools import Tool +from langchain.callbacks.base import BaseCallbackManager from langchain.llms.base import BaseLLM AGENT_TO_CLASS = { @@ -19,6 +20,7 @@ def initialize_agent( tools: List[Tool], llm: BaseLLM, agent: str = "zero-shot-react-description", + callback_manager: Optional[BaseCallbackManager] = None, **kwargs: Any, ) -> AgentExecutor: """Load agent given tools and LLM. @@ -28,6 +30,8 @@ def initialize_agent( llm: Language model to use as the agent. agent: The agent to use. Valid options are: `zero-shot-react-description`, `react-docstore`, `self-ask-with-search`. + callback_manager: CallbackManager to use. Global callback manager is used if + not provided. Defaults to None. **kwargs: Additional key word arguments to pass to the agent. Returns: @@ -39,5 +43,12 @@ def initialize_agent( f"Valid types are: {AGENT_TO_CLASS.keys()}." ) agent_cls = AGENT_TO_CLASS[agent] - agent_obj = agent_cls.from_llm_and_tools(llm, tools) - return AgentExecutor.from_agent_and_tools(agent=agent_obj, tools=tools, **kwargs) + agent_obj = agent_cls.from_llm_and_tools( + llm, tools, callback_manager=callback_manager + ) + return AgentExecutor.from_agent_and_tools( + agent=agent_obj, + tools=tools, + callback_manager=callback_manager, + **kwargs, + ) diff --git a/langchain/callbacks/__init__.py b/langchain/callbacks/__init__.py new file mode 100644 index 0000000000..13a4c30213 --- /dev/null +++ b/langchain/callbacks/__init__.py @@ -0,0 +1,20 @@ +"""Callback handlers that allow listening to events in LangChain.""" +from langchain.callbacks.base import BaseCallbackHandler, BaseCallbackManager +from langchain.callbacks.shared import SharedCallbackManager +from langchain.callbacks.stdout import StdOutCallbackHandler + + +def get_callback_manager() -> BaseCallbackManager: + """Return the shared callback manager.""" + return SharedCallbackManager() + + +def set_handler(handler: BaseCallbackHandler) -> None: + """Set handler.""" + callback = get_callback_manager() + callback.set_handler(handler) + + +def set_default_callback_manager() -> None: + """Set default callback manager.""" + set_handler(StdOutCallbackHandler()) diff --git a/langchain/callbacks/base.py b/langchain/callbacks/base.py new file mode 100644 index 0000000000..6b3b4ca924 --- /dev/null +++ b/langchain/callbacks/base.py @@ -0,0 +1,177 @@ +"""Base callback handler that can be used to handle callbacks from langchain.""" + +from abc import ABC, abstractmethod +from typing import Any, Dict, List + +from pydantic import BaseModel + +from langchain.schema import AgentAction, AgentFinish, LLMResult + + +class BaseCallbackHandler(BaseModel, ABC): + """Base callback handler that can be used to handle callbacks from langchain.""" + + ignore_llm: bool = False + ignore_chain: bool = False + ignore_agent: bool = False + + @abstractmethod + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Run when LLM starts running.""" + + @abstractmethod + def on_llm_end( + self, + response: LLMResult, + ) -> None: + """Run when LLM ends running.""" + + @abstractmethod + def on_llm_error(self, error: Exception) -> None: + """Run when LLM errors.""" + + @abstractmethod + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Run when chain starts running.""" + + @abstractmethod + def on_chain_end(self, outputs: Dict[str, Any]) -> None: + """Run when chain ends running.""" + + @abstractmethod + def on_chain_error(self, error: Exception) -> None: + """Run when chain errors.""" + + @abstractmethod + def on_tool_start( + self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any + ) -> None: + """Run when tool starts running.""" + + @abstractmethod + def on_tool_end(self, output: str, **kwargs: Any) -> None: + """Run when tool ends running.""" + + @abstractmethod + def on_tool_error(self, error: Exception) -> None: + """Run when tool errors.""" + + @abstractmethod + def on_text(self, text: str, **kwargs: Any) -> None: + """Run on arbitrary text.""" + + @abstractmethod + def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: + """Run on agent end.""" + + +class BaseCallbackManager(BaseCallbackHandler, ABC): + """Base callback manager that can be used to handle callbacks from LangChain.""" + + @abstractmethod + def add_handler(self, callback: BaseCallbackHandler) -> None: + """Add a handler to the callback manager.""" + + @abstractmethod + def remove_handler(self, handler: BaseCallbackHandler) -> None: + """Remove a handler from the callback manager.""" + + @abstractmethod + def set_handler(self, handler: BaseCallbackHandler) -> None: + """Set handler as the only handler on the callback manager.""" + + +class CallbackManager(BaseCallbackManager): + """Callback manager that can be used to handle callbacks from langchain.""" + + handlers: List[BaseCallbackHandler] + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Run when LLM starts running.""" + for handler in self.handlers: + if not handler.ignore_llm: + handler.on_llm_start(serialized, prompts, **kwargs) + + def on_llm_end( + self, + response: LLMResult, + ) -> None: + """Run when LLM ends running.""" + for handler in self.handlers: + if not handler.ignore_llm: + handler.on_llm_end(response) + + def on_llm_error(self, error: Exception) -> None: + """Run when LLM errors.""" + for handler in self.handlers: + if not handler.ignore_llm: + handler.on_llm_error(error) + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Run when chain starts running.""" + for handler in self.handlers: + if not handler.ignore_chain: + handler.on_chain_start(serialized, inputs, **kwargs) + + def on_chain_end(self, outputs: Dict[str, Any]) -> None: + """Run when chain ends running.""" + for handler in self.handlers: + if not handler.ignore_chain: + handler.on_chain_end(outputs) + + def on_chain_error(self, error: Exception) -> None: + """Run when chain errors.""" + for handler in self.handlers: + if not handler.ignore_chain: + handler.on_chain_error(error) + + def on_tool_start( + self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any + ) -> None: + """Run when tool starts running.""" + for handler in self.handlers: + if not handler.ignore_agent: + handler.on_tool_start(serialized, action, **kwargs) + + def on_tool_end(self, output: str, **kwargs: Any) -> None: + """Run when tool ends running.""" + for handler in self.handlers: + if not handler.ignore_agent: + handler.on_tool_end(output, **kwargs) + + def on_tool_error(self, error: Exception) -> None: + """Run when tool errors.""" + for handler in self.handlers: + if not handler.ignore_agent: + handler.on_tool_error(error) + + def on_text(self, text: str, **kwargs: Any) -> None: + """Run on additional input from chains and agents.""" + for handler in self.handlers: + handler.on_text(text, **kwargs) + + def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: + """Run on agent end.""" + for handler in self.handlers: + if not handler.ignore_agent: + handler.on_agent_finish(finish, **kwargs) + + def add_handler(self, handler: BaseCallbackHandler) -> None: + """Add a handler to the callback manager.""" + self.handlers.append(handler) + + def remove_handler(self, handler: BaseCallbackHandler) -> None: + """Remove a handler from the callback manager.""" + self.handlers.remove(handler) + + def set_handler(self, handler: BaseCallbackHandler) -> None: + """Set handler as the only handler on the callback manager.""" + self.handlers = [handler] diff --git a/langchain/callbacks/shared.py b/langchain/callbacks/shared.py new file mode 100644 index 0000000000..576f460968 --- /dev/null +++ b/langchain/callbacks/shared.py @@ -0,0 +1,114 @@ +"""A shared CallbackManager.""" + +import threading +from typing import Any, Dict, List + +from langchain.callbacks.base import ( + BaseCallbackHandler, + BaseCallbackManager, + CallbackManager, +) +from langchain.schema import AgentAction, AgentFinish, LLMResult + + +class Singleton: + """A thread-safe singleton class that can be inherited from.""" + + _instance = None + _lock = threading.Lock() + + def __new__(cls) -> Any: + """Create a new shared instance of the class.""" + if cls._instance is None: + with cls._lock: + # Another thread could have created the instance + # before we acquired the lock. So check that the + # instance is still nonexistent. + if not cls._instance: + cls._instance = super().__new__(cls) + return cls._instance + + +class SharedCallbackManager(Singleton, BaseCallbackManager): + """A thread-safe singleton CallbackManager.""" + + _callback_manager: CallbackManager = CallbackManager(handlers=[]) + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Run when LLM starts running.""" + with self._lock: + self._callback_manager.on_llm_start(serialized, prompts, **kwargs) + + def on_llm_end( + self, + response: LLMResult, + ) -> None: + """Run when LLM ends running.""" + with self._lock: + self._callback_manager.on_llm_end(response) + + def on_llm_error(self, error: Exception) -> None: + """Run when LLM errors.""" + with self._lock: + self._callback_manager.on_llm_error(error) + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Run when chain starts running.""" + with self._lock: + self._callback_manager.on_chain_start(serialized, inputs, **kwargs) + + def on_chain_end(self, outputs: Dict[str, Any]) -> None: + """Run when chain ends running.""" + with self._lock: + self._callback_manager.on_chain_end(outputs) + + def on_chain_error(self, error: Exception) -> None: + """Run when chain errors.""" + with self._lock: + self._callback_manager.on_chain_error(error) + + def on_tool_start( + self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any + ) -> None: + """Run when tool starts running.""" + with self._lock: + self._callback_manager.on_tool_start(serialized, action, **kwargs) + + def on_tool_end(self, output: str, **kwargs: Any) -> None: + """Run when tool ends running.""" + with self._lock: + self._callback_manager.on_tool_end(output, **kwargs) + + def on_tool_error(self, error: Exception) -> None: + """Run when tool errors.""" + with self._lock: + self._callback_manager.on_tool_error(error) + + def on_text(self, text: str, **kwargs: Any) -> None: + """Run on arbitrary text.""" + with self._lock: + self._callback_manager.on_text(text, **kwargs) + + def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: + """Run on agent end.""" + with self._lock: + self._callback_manager.on_agent_finish(finish, **kwargs) + + def add_handler(self, callback: BaseCallbackHandler) -> None: + """Add a callback to the callback manager.""" + with self._lock: + self._callback_manager.add_handler(callback) + + def remove_handler(self, callback: BaseCallbackHandler) -> None: + """Remove a callback from the callback manager.""" + with self._lock: + self._callback_manager.remove_handler(callback) + + def set_handler(self, handler: BaseCallbackHandler) -> None: + """Set handler as the only handler on the callback manager.""" + with self._lock: + self._callback_manager.handlers = [handler] diff --git a/langchain/callbacks/stdout.py b/langchain/callbacks/stdout.py new file mode 100644 index 0000000000..3e6cd28156 --- /dev/null +++ b/langchain/callbacks/stdout.py @@ -0,0 +1,82 @@ +"""Callback Handler that prints to std out.""" +from typing import Any, Dict, List, Optional + +from langchain.callbacks.base import BaseCallbackHandler +from langchain.input import print_text +from langchain.schema import AgentAction, AgentFinish, LLMResult + + +class StdOutCallbackHandler(BaseCallbackHandler): + """Callback Handler that prints to std out.""" + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Print out the prompts.""" + pass + + def on_llm_end(self, response: LLMResult) -> None: + """Do nothing.""" + pass + + def on_llm_error(self, error: Exception) -> None: + """Do nothing.""" + pass + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Print out that we are entering a chain.""" + class_name = serialized["name"] + print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m") + + def on_chain_end(self, outputs: Dict[str, Any]) -> None: + """Print out that we finished a chain.""" + print("\n\033[1m> Finished chain.\033[0m") + + def on_chain_error(self, error: Exception) -> None: + """Do nothing.""" + pass + + def on_tool_start( + self, + serialized: Dict[str, Any], + action: AgentAction, + color: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Print out the log in specified color.""" + print_text(action.log, color=color) + + def on_tool_end( + self, + output: str, + color: Optional[str] = None, + observation_prefix: Optional[str] = None, + llm_prefix: Optional[str] = None, + **kwargs: Any, + ) -> None: + """If not the final action, print out observation.""" + print_text(f"\n{observation_prefix}") + print_text(output, color=color) + print_text(f"\n{llm_prefix}") + + def on_tool_error(self, error: Exception) -> None: + """Do nothing.""" + pass + + def on_text( + self, + text: str, + color: Optional[str] = None, + end: str = "", + **kwargs: Optional[str], + ) -> None: + """Run when agent ends.""" + print_text(text, color=color, end=end) + + def on_agent_finish( + self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any + ) -> None: + """Run on agent end.""" + print_text(finish.log, color=color, end="\n") diff --git a/langchain/callbacks/streamlit.py b/langchain/callbacks/streamlit.py new file mode 100644 index 0000000000..2e781f611e --- /dev/null +++ b/langchain/callbacks/streamlit.py @@ -0,0 +1,77 @@ +"""Callback Handler that logs to streamlit.""" +from typing import Any, Dict, List, Optional + +import streamlit as st + +from langchain.callbacks.base import BaseCallbackHandler +from langchain.schema import AgentAction, AgentFinish, LLMResult + + +class StreamlitCallbackHandler(BaseCallbackHandler): + """Callback Handler that logs to streamlit.""" + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Print out the prompts.""" + st.write("Prompts after formatting:") + for prompt in prompts: + st.write(prompt) + + def on_llm_end(self, response: LLMResult) -> None: + """Do nothing.""" + pass + + def on_llm_error(self, error: Exception) -> None: + """Do nothing.""" + pass + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Print out that we are entering a chain.""" + class_name = serialized["name"] + st.write(f"Entering new {class_name} chain...") + + def on_chain_end(self, outputs: Dict[str, Any]) -> None: + """Print out that we finished a chain.""" + st.write("Finished chain.") + + def on_chain_error(self, error: Exception) -> None: + """Do nothing.""" + pass + + def on_tool_start( + self, + serialized: Dict[str, Any], + action: AgentAction, + **kwargs: Any, + ) -> None: + """Print out the log in specified color.""" + # st.write requires two spaces before a newline to render it + st.markdown(action.log.replace("\n", " \n")) + + def on_tool_end( + self, + output: str, + observation_prefix: Optional[str] = None, + llm_prefix: Optional[str] = None, + **kwargs: Any, + ) -> None: + """If not the final action, print out observation.""" + st.write(f"{observation_prefix}{output}") + st.write(llm_prefix) + + def on_tool_error(self, error: Exception) -> None: + """Do nothing.""" + pass + + def on_text(self, text: str, **kwargs: Any) -> None: + """Run on text.""" + # st.write requires two spaces before a newline to render it + st.write(text.replace("\n", " \n")) + + def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: + """Run on agent end.""" + # st.write requires two spaces before a newline to render it + st.write(finish.log.replace("\n", " \n")) diff --git a/langchain/chains/api/base.py b/langchain/chains/api/base.py index 07591bd425..4bb81879a1 100644 --- a/langchain/chains/api/base.py +++ b/langchain/chains/api/base.py @@ -8,7 +8,6 @@ from pydantic import BaseModel, root_validator from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT from langchain.chains.base import Chain from langchain.chains.llm import LLMChain -from langchain.input import print_text from langchain.llms.base import BaseLLM from langchain.requests import RequestsWrapper @@ -67,10 +66,10 @@ class APIChain(Chain, BaseModel): question=question, api_docs=self.api_docs ) if self.verbose: - print_text(api_url, color="green", end="\n") + self.callback_manager.on_text(api_url, color="green", end="\n") api_response = self.requests_wrapper.run(api_url) if self.verbose: - print_text(api_response, color="yellow", end="\n") + self.callback_manager.on_text(api_response, color="yellow", end="\n") answer = self.api_answer_chain.predict( question=question, api_docs=self.api_docs, diff --git a/langchain/chains/base.py b/langchain/chains/base.py index 24f90ea0bc..af2e6cc125 100644 --- a/langchain/chains/base.py +++ b/langchain/chains/base.py @@ -2,9 +2,11 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Union -from pydantic import BaseModel, Extra, Field +from pydantic import BaseModel, Extra, Field, validator import langchain +from langchain.callbacks import get_callback_manager +from langchain.callbacks.base import BaseCallbackManager class Memory(BaseModel, ABC): @@ -42,9 +44,36 @@ class Chain(BaseModel, ABC): """Base interface that all chains should implement.""" memory: Optional[Memory] = None + callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager) + verbose: bool = Field( + default_factory=_get_verbosity + ) # Whether to print the response text - verbose: bool = Field(default_factory=_get_verbosity) - """Whether to print out response text.""" + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + @validator("callback_manager", pre=True, always=True) + def set_callback_manager( + cls, callback_manager: Optional[BaseCallbackManager] + ) -> BaseCallbackManager: + """If callback manager is None, set it. + + This allows users to pass in None as callback manager, which is a nice UX. + """ + return callback_manager or get_callback_manager() + + @validator("verbose", pre=True, always=True) + def set_verbose(cls, verbose: Optional[bool]) -> bool: + """If verbose is None, set it. + + This allows users to pass in None as verbose to access the global setting. + """ + if verbose is None: + return _get_verbosity() + else: + return verbose @property @abstractmethod @@ -106,12 +135,12 @@ class Chain(BaseModel, ABC): inputs = dict(inputs, **external_context) self._validate_inputs(inputs) if self.verbose: - print( - f"\n\n\033[1m> Entering new {self.__class__.__name__} chain...\033[0m" + self.callback_manager.on_chain_start( + {"name": self.__class__.__name__}, inputs ) outputs = self._call(inputs) if self.verbose: - print(f"\n\033[1m> Finished {self.__class__.__name__} chain.\033[0m") + self.callback_manager.on_chain_end(outputs) self._validate_outputs(outputs) if self.memory is not None: self.memory.save_context(inputs, outputs) diff --git a/langchain/chains/llm.py b/langchain/chains/llm.py index cbe01bd9a8..11c316380d 100644 --- a/langchain/chains/llm.py +++ b/langchain/chains/llm.py @@ -3,10 +3,11 @@ from typing import Any, Dict, List, Sequence, Union from pydantic import BaseModel, Extra -import langchain from langchain.chains.base import Chain -from langchain.llms.base import BaseLLM, LLMResult +from langchain.input import get_colored_text +from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate +from langchain.schema import LLMResult class LLMChain(Chain, BaseModel): @@ -61,7 +62,9 @@ class LLMChain(Chain, BaseModel): selected_inputs = {k: inputs[k] for k in self.prompt.input_variables} prompt = self.prompt.format(**selected_inputs) if self.verbose: - langchain.logger.log_llm_inputs(selected_inputs, prompt) + _colored_text = get_colored_text(prompt, "green") + _text = "Prompt after formatting:\n" + _colored_text + self.callback_manager.on_text(_text, end="\n") if "stop" in inputs and inputs["stop"] != stop: raise ValueError( "If `stop` is present in any inputs, should be present in all." @@ -77,8 +80,6 @@ class LLMChain(Chain, BaseModel): for generation in response.generations: # Get the text of the top generated string. response_str = generation[0].text - if self.verbose: - langchain.logger.log_llm_response(response_str) outputs.append({self.output_key: response_str}) return outputs diff --git a/langchain/chains/llm_bash/base.py b/langchain/chains/llm_bash/base.py index 7696371302..5d9d8b610c 100644 --- a/langchain/chains/llm_bash/base.py +++ b/langchain/chains/llm_bash/base.py @@ -6,7 +6,6 @@ from pydantic import BaseModel, Extra from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.llm_bash.prompt import PROMPT -from langchain.input import print_text from langchain.llms.base import BaseLLM from langchain.utilities.bash import BashProcess @@ -52,11 +51,11 @@ class LLMBashChain(Chain, BaseModel): llm_executor = LLMChain(prompt=PROMPT, llm=self.llm) bash_executor = BashProcess() if self.verbose: - print_text(inputs[self.input_key]) + self.callback_manager.on_text(inputs[self.input_key]) t = llm_executor.predict(question=inputs[self.input_key]) if self.verbose: - print_text(t, color="green") + self.callback_manager.on_text(t, color="green") t = t.strip() if t.startswith("```bash"): @@ -69,8 +68,8 @@ class LLMBashChain(Chain, BaseModel): output = bash_executor.run(command_list) if self.verbose: - print_text("\nAnswer: ") - print_text(output, color="yellow") + self.callback_manager.on_text("\nAnswer: ") + self.callback_manager.on_text(output, color="yellow") else: raise ValueError(f"unknown format from LLM: {t}") diff --git a/langchain/chains/llm_math/base.py b/langchain/chains/llm_math/base.py index a0485e9e3a..383bcac96f 100644 --- a/langchain/chains/llm_math/base.py +++ b/langchain/chains/llm_math/base.py @@ -6,7 +6,6 @@ from pydantic import BaseModel, Extra from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.llm_math.prompt import PROMPT -from langchain.input import print_text from langchain.llms.base import BaseLLM from langchain.python import PythonREPL @@ -52,17 +51,17 @@ class LLMMathChain(Chain, BaseModel): llm_executor = LLMChain(prompt=PROMPT, llm=self.llm) python_executor = PythonREPL() if self.verbose: - print_text(inputs[self.input_key]) + self.callback_manager.on_text(inputs[self.input_key]) t = llm_executor.predict(question=inputs[self.input_key], stop=["```output"]) if self.verbose: - print_text(t, color="green") + self.callback_manager.on_text(t, color="green") t = t.strip() if t.startswith("```python"): code = t[9:-4] output = python_executor.run(code) if self.verbose: - print_text("\nAnswer: ") - print_text(output, color="yellow") + self.callback_manager.on_text("\nAnswer: ") + self.callback_manager.on_text(output, color="yellow") answer = "Answer: " + output elif t.startswith("Answer:"): answer = t diff --git a/langchain/chains/pal/base.py b/langchain/chains/pal/base.py index 08335e2fb2..858f6a1f29 100644 --- a/langchain/chains/pal/base.py +++ b/langchain/chains/pal/base.py @@ -12,7 +12,6 @@ from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT from langchain.chains.pal.math_prompt import MATH_PROMPT -from langchain.input import print_text from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate from langchain.python import PythonREPL @@ -53,7 +52,7 @@ class PALChain(Chain, BaseModel): llm_chain = LLMChain(llm=self.llm, prompt=self.prompt) code = llm_chain.predict(stop=[self.stop], **inputs) if self.verbose: - print_text(code, color="green", end="\n") + self.callback_manager.on_text(code, color="green", end="\n") repl = PythonREPL() res = repl.run(code + f"\n{self.get_answer_expr}") return {self.output_key: res.strip()} diff --git a/langchain/chains/qa_with_sources/__init__.py b/langchain/chains/qa_with_sources/__init__.py index 88b60aecfe..56bb9bcd2e 100644 --- a/langchain/chains/qa_with_sources/__init__.py +++ b/langchain/chains/qa_with_sources/__init__.py @@ -26,7 +26,7 @@ def _load_stuff_chain( llm: BaseLLM, prompt: BasePromptTemplate = stuff_prompt.PROMPT, document_variable_name: str = "summaries", - verbose: bool = False, + verbose: Optional[bool] = None, **kwargs: Any, ) -> StuffDocumentsChain: llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) @@ -49,7 +49,7 @@ def _load_map_reduce_chain( collapse_prompt: Optional[BasePromptTemplate] = None, reduce_llm: Optional[BaseLLM] = None, collapse_llm: Optional[BaseLLM] = None, - verbose: bool = False, + verbose: Optional[bool] = None, **kwargs: Any, ) -> MapReduceDocumentsChain: map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) @@ -97,7 +97,7 @@ def _load_refine_chain( document_variable_name: str = "context_str", initial_response_name: str = "existing_answer", refine_llm: Optional[BaseLLM] = None, - verbose: bool = False, + verbose: Optional[bool] = None, **kwargs: Any, ) -> RefineDocumentsChain: initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) @@ -115,7 +115,10 @@ def _load_refine_chain( def load_qa_with_sources_chain( - llm: BaseLLM, chain_type: str = "stuff", verbose: bool = False, **kwargs: Any + llm: BaseLLM, + chain_type: str = "stuff", + verbose: Optional[bool] = None, + **kwargs: Any, ) -> BaseCombineDocumentsChain: """Load question answering with sources chain. diff --git a/langchain/chains/question_answering/__init__.py b/langchain/chains/question_answering/__init__.py index 7685d7ab18..1e9bc7cdb8 100644 --- a/langchain/chains/question_answering/__init__.py +++ b/langchain/chains/question_answering/__init__.py @@ -26,7 +26,7 @@ def _load_stuff_chain( llm: BaseLLM, prompt: BasePromptTemplate = stuff_prompt.PROMPT, document_variable_name: str = "context", - verbose: bool = False, + verbose: Optional[bool] = None, **kwargs: Any, ) -> StuffDocumentsChain: llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) @@ -48,7 +48,7 @@ def _load_map_reduce_chain( collapse_prompt: Optional[BasePromptTemplate] = None, reduce_llm: Optional[BaseLLM] = None, collapse_llm: Optional[BaseLLM] = None, - verbose: bool = False, + verbose: Optional[bool] = None, **kwargs: Any, ) -> MapReduceDocumentsChain: map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) @@ -94,7 +94,7 @@ def _load_refine_chain( document_variable_name: str = "context_str", initial_response_name: str = "existing_answer", refine_llm: Optional[BaseLLM] = None, - verbose: bool = False, + verbose: Optional[bool] = None, **kwargs: Any, ) -> RefineDocumentsChain: initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) @@ -111,7 +111,10 @@ def _load_refine_chain( def load_qa_chain( - llm: BaseLLM, chain_type: str = "stuff", verbose: bool = False, **kwargs: Any + llm: BaseLLM, + chain_type: str = "stuff", + verbose: Optional[bool] = None, + **kwargs: Any, ) -> BaseCombineDocumentsChain: """Load question answering chain. diff --git a/langchain/chains/sequential.py b/langchain/chains/sequential.py index 1bcc723d40..9db4be411e 100644 --- a/langchain/chains/sequential.py +++ b/langchain/chains/sequential.py @@ -5,7 +5,7 @@ from typing import Dict, List from pydantic import BaseModel, Extra, root_validator from langchain.chains.base import Chain -from langchain.input import get_color_mapping, print_text +from langchain.input import get_color_mapping class SequentialChain(Chain, BaseModel): @@ -133,5 +133,7 @@ class SimpleSequentialChain(Chain, BaseModel): if self.strip_outputs: _input = _input.strip() if self.verbose: - print_text(_input, color=color_mapping[str(i)], end="\n") + self.callback_manager.on_text( + _input, color=color_mapping[str(i)], end="\n" + ) return {self.output_key: _input} diff --git a/langchain/chains/sql_database/base.py b/langchain/chains/sql_database/base.py index 103993f35d..3c56cac7d6 100644 --- a/langchain/chains/sql_database/base.py +++ b/langchain/chains/sql_database/base.py @@ -6,7 +6,6 @@ from pydantic import BaseModel, Extra from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.sql_database.prompt import PROMPT -from langchain.input import print_text from langchain.llms.base import BaseLLM from langchain.sql_database import SQLDatabase @@ -55,7 +54,7 @@ class SQLDatabaseChain(Chain, BaseModel): llm_chain = LLMChain(llm=self.llm, prompt=PROMPT) input_text = f"{inputs[self.input_key]} \nSQLQuery:" if self.verbose: - print_text(input_text) + self.callback_manager.on_text(input_text) llm_inputs = { "input": input_text, "dialect": self.database.dialect, @@ -64,15 +63,15 @@ class SQLDatabaseChain(Chain, BaseModel): } sql_cmd = llm_chain.predict(**llm_inputs) if self.verbose: - print_text(sql_cmd, color="green") + self.callback_manager.on_text(sql_cmd, color="green") result = self.database.run(sql_cmd) if self.verbose: - print_text("\nSQLResult: ") - print_text(result, color="yellow") - print_text("\nAnswer:") + self.callback_manager.on_text("\nSQLResult: ") + self.callback_manager.on_text(result, color="yellow") + self.callback_manager.on_text("\nAnswer:") input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:" llm_inputs["input"] = input_text final_result = llm_chain.predict(**llm_inputs) if self.verbose: - print_text(final_result, color="green") + self.callback_manager.on_text(final_result, color="green") return {self.output_key: final_result} diff --git a/langchain/chains/summarize/__init__.py b/langchain/chains/summarize/__init__.py index 8605ed2c2b..d0c9f35b31 100644 --- a/langchain/chains/summarize/__init__.py +++ b/langchain/chains/summarize/__init__.py @@ -22,7 +22,7 @@ def _load_stuff_chain( llm: BaseLLM, prompt: BasePromptTemplate = stuff_prompt.PROMPT, document_variable_name: str = "text", - verbose: bool = False, + verbose: Optional[bool] = None, **kwargs: Any, ) -> StuffDocumentsChain: llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) @@ -44,7 +44,7 @@ def _load_map_reduce_chain( collapse_prompt: Optional[BasePromptTemplate] = None, reduce_llm: Optional[BaseLLM] = None, collapse_llm: Optional[BaseLLM] = None, - verbose: bool = False, + verbose: Optional[bool] = None, **kwargs: Any, ) -> MapReduceDocumentsChain: map_chain = LLMChain(llm=llm, prompt=map_prompt, verbose=verbose) @@ -90,7 +90,7 @@ def _load_refine_chain( document_variable_name: str = "text", initial_response_name: str = "existing_answer", refine_llm: Optional[BaseLLM] = None, - verbose: bool = False, + verbose: Optional[bool] = None, **kwargs: Any, ) -> RefineDocumentsChain: @@ -108,7 +108,10 @@ def _load_refine_chain( def load_summarize_chain( - llm: BaseLLM, chain_type: str = "stuff", verbose: bool = False, **kwargs: Any + llm: BaseLLM, + chain_type: str = "stuff", + verbose: Optional[bool] = None, + **kwargs: Any, ) -> BaseCombineDocumentsChain: """Load summarizing chain. diff --git a/langchain/input.py b/langchain/input.py index 680685fff5..197c60bdd6 100644 --- a/langchain/input.py +++ b/langchain/input.py @@ -20,10 +20,16 @@ def get_color_mapping( return color_mapping +def get_colored_text(text: str, color: str) -> str: + """Get colored text.""" + color_str = _TEXT_COLOR_MAPPING[color] + return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" + + def print_text(text: str, color: Optional[str] = None, end: str = "") -> None: """Print text with highlighting and no end characters.""" if color is None: - print(text, end=end) + text_to_print = text else: - color_str = _TEXT_COLOR_MAPPING[color] - print(f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m", end=end) + text_to_print = get_colored_text(text, color) + print(text_to_print, end=end) diff --git a/langchain/llms/base.py b/langchain/llms/base.py index 1941385afc..22e63b4908 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -2,34 +2,55 @@ import json from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Union +from typing import Any, Dict, List, Mapping, Optional, Union import yaml -from pydantic import BaseModel, Extra +from pydantic import BaseModel, Extra, Field, validator import langchain -from langchain.schema import Generation +from langchain.callbacks import get_callback_manager +from langchain.callbacks.base import BaseCallbackManager +from langchain.schema import Generation, LLMResult -class LLMResult(NamedTuple): - """Class that contains all relevant information for an LLM Result.""" - - generations: List[List[Generation]] - """List of the things generated. This is List[List[]] because - each input could have multiple generations.""" - llm_output: Optional[dict] = None - """For arbitrary LLM provider specific output.""" +def _get_verbosity() -> bool: + return langchain.verbose class BaseLLM(BaseModel, ABC): """LLM wrapper should take in a prompt and return a string.""" cache: Optional[bool] = None + verbose: bool = Field(default_factory=_get_verbosity) + """Whether to print out response text.""" + callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager) class Config: """Configuration for this pydantic object.""" extra = Extra.forbid + arbitrary_types_allowed = True + + @validator("callback_manager", pre=True, always=True) + def set_callback_manager( + cls, callback_manager: Optional[BaseCallbackManager] + ) -> BaseCallbackManager: + """If callback manager is None, set it. + + This allows users to pass in None as callback manager, which is a nice UX. + """ + return callback_manager or get_callback_manager() + + @validator("verbose", pre=True, always=True) + def set_verbose(cls, verbose: Optional[bool]) -> bool: + """If verbose is None, set it. + + This allows users to pass in None as verbose to access the global setting. + """ + if verbose is None: + return _get_verbosity() + else: + return verbose @abstractmethod def _generate( @@ -48,7 +69,14 @@ class BaseLLM(BaseModel, ABC): raise ValueError( "Asked to cache, but no cache found at `langchain.cache`." ) - return self._generate(prompts, stop=stop) + if self.verbose: + self.callback_manager.on_llm_start( + {"name": self.__class__.__name__}, prompts + ) + output = self._generate(prompts, stop=stop) + if self.verbose: + self.callback_manager.on_llm_end(output) + return output params = self._llm_dict() params["stop"] = stop llm_string = str(sorted([(k, v) for k, v in params.items()])) @@ -62,7 +90,11 @@ class BaseLLM(BaseModel, ABC): else: missing_prompts.append(prompt) missing_prompt_idxs.append(i) + self.callback_manager.on_llm_start( + {"name": self.__class__.__name__}, missing_prompts + ) new_results = self._generate(missing_prompts, stop=stop) + self.callback_manager.on_llm_end(new_results) for i, result in enumerate(new_results.generations): existing_prompts[i] = result prompt = prompts[i] diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index 4ea229849b..ab15a39655 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -4,8 +4,8 @@ from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union from pydantic import BaseModel, Extra, Field, root_validator -from langchain.llms.base import BaseLLM, LLMResult -from langchain.schema import Generation +from langchain.llms.base import BaseLLM +from langchain.schema import Generation, LLMResult from langchain.utils import get_from_dict_or_env diff --git a/langchain/logger.py b/langchain/logger.py deleted file mode 100644 index d264da9048..0000000000 --- a/langchain/logger.py +++ /dev/null @@ -1,71 +0,0 @@ -"""BETA: everything in here is highly experimental, do not rely on.""" -from typing import Any, Optional - -from langchain.input import print_text -from langchain.schema import AgentAction, AgentFinish - - -class BaseLogger: - """Base logging interface.""" - - def log_agent_start(self, text: str, **kwargs: Any) -> None: - """Log the start of an agent interaction.""" - pass - - def log_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None: - """Log the end of an agent interaction.""" - pass - - def log_agent_action(self, action: AgentAction, **kwargs: Any) -> None: - """Log agent action decision.""" - pass - - def log_agent_observation(self, observation: str, **kwargs: Any) -> None: - """Log agent observation.""" - pass - - def log_llm_inputs(self, inputs: dict, prompt: str, **kwargs: Any) -> None: - """Log LLM inputs.""" - pass - - def log_llm_response(self, output: str, **kwargs: Any) -> None: - """Log LLM response.""" - pass - - -class StdOutLogger(BaseLogger): - """Interface for printing things to stdout.""" - - def log_agent_start(self, text: str, **kwargs: Any) -> None: - """Print the text to start the agent.""" - print_text(text) - - def log_agent_action( - self, action: AgentAction, color: Optional[str] = None, **kwargs: Any - ) -> None: - """Print the log of the action in a certain color.""" - print_text(action.log, color=color) - - def log_agent_observation( - self, - observation: str, - color: Optional[str] = None, - observation_prefix: Optional[str] = None, - llm_prefix: Optional[str] = None, - **kwargs: Any, - ) -> None: - """Print the observation in a special color.""" - print_text(f"\n{observation_prefix}") - print_text(observation, color=color) - print_text(f"\n{llm_prefix}") - - def log_llm_inputs(self, inputs: dict, prompt: str, **kwargs: Any) -> None: - """Print the prompt in green.""" - print("Prompt after formatting:") - print_text(prompt, color="green", end="\n") - - def log_agent_end( - self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any - ) -> None: - """Log the end of an agent interaction.""" - print_text(finish.log, color=color) diff --git a/langchain/schema.py b/langchain/schema.py index 31cf1cdc82..a4b4e6267c 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -1,6 +1,6 @@ """Common schema objects.""" -from typing import NamedTuple +from typing import List, NamedTuple, Optional class AgentAction(NamedTuple): @@ -24,3 +24,13 @@ class Generation(NamedTuple): text: str """Generated text output.""" # TODO: add log probs + + +class LLMResult(NamedTuple): + """Class that contains all relevant information for an LLM Result.""" + + generations: List[List[Generation]] + """List of the things generated. This is List[List[]] because + each input could have multiple generations.""" + llm_output: Optional[dict] = None + """For arbitrary LLM provider specific output.""" diff --git a/tests/unit_tests/agents/test_agent.py b/tests/unit_tests/agents/test_agent.py index 8d1dfdb3d6..581f61ca52 100644 --- a/tests/unit_tests/agents/test_agent.py +++ b/tests/unit_tests/agents/test_agent.py @@ -4,8 +4,10 @@ from typing import Any, List, Mapping, Optional from pydantic import BaseModel -from langchain.agents import Tool, initialize_agent +from langchain.agents import AgentExecutor, Tool, initialize_agent +from langchain.callbacks.base import CallbackManager from langchain.llms.base import LLM +from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler class FakeListLLM(LLM, BaseModel): @@ -31,8 +33,8 @@ class FakeListLLM(LLM, BaseModel): return "fake_list" -def test_agent_bad_action() -> None: - """Test react chain when bad action given.""" +def _get_agent(**kwargs: Any) -> AgentExecutor: + """Get agent for testing.""" bad_action_name = "BadAction" responses = [ f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment", @@ -44,30 +46,126 @@ def test_agent_bad_action() -> None: Tool("Lookup", lambda x: x, "Useful for looking up things in a table"), ] agent = initialize_agent( - tools, fake_llm, agent="zero-shot-react-description", verbose=True + tools, fake_llm, agent="zero-shot-react-description", verbose=True, **kwargs ) + return agent + + +def test_agent_bad_action() -> None: + """Test react chain when bad action given.""" + agent = _get_agent() output = agent.run("when was langchain made") assert output == "curses foiled again" def test_agent_stopped_early() -> None: """Test react chain when bad action given.""" - bad_action_name = "BadAction" + agent = _get_agent(max_iterations=0) + output = agent.run("when was langchain made") + assert output == "Agent stopped due to max iterations." + + +def test_agent_with_callbacks_global() -> None: + """Test react chain with callbacks by setting verbose globally.""" + import langchain + + langchain.verbose = True + handler = FakeCallbackHandler() + manager = CallbackManager(handlers=[handler]) + tool = "Search" responses = [ - f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment", + f"FooBarBaz\nAction: {tool}\nAction Input: misalignment", "Oh well\nAction: Final Answer\nAction Input: curses foiled again", ] - fake_llm = FakeListLLM(responses=responses) + fake_llm = FakeListLLM(responses=responses, callback_manager=manager, verbose=True) tools = [ Tool("Search", lambda x: x, "Useful for searching"), - Tool("Lookup", lambda x: x, "Useful for looking up things in a table"), ] agent = initialize_agent( tools, fake_llm, agent="zero-shot-react-description", verbose=True, - max_iterations=0, + callback_manager=manager, ) + output = agent.run("when was langchain made") - assert output == "Agent stopped due to max iterations." + assert output == "curses foiled again" + + # 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run + assert handler.starts == 6 + # 1 extra agent end + assert handler.ends == 7 + assert handler.errors == 0 + # during LLMChain + assert handler.text == 2 + + +def test_agent_with_callbacks_local() -> None: + """Test react chain with callbacks by setting verbose locally.""" + import langchain + + langchain.verbose = False + handler = FakeCallbackHandler() + manager = CallbackManager(handlers=[handler]) + tool = "Search" + responses = [ + f"FooBarBaz\nAction: {tool}\nAction Input: misalignment", + "Oh well\nAction: Final Answer\nAction Input: curses foiled again", + ] + fake_llm = FakeListLLM(responses=responses, callback_manager=manager, verbose=True) + tools = [ + Tool("Search", lambda x: x, "Useful for searching"), + ] + agent = initialize_agent( + tools, + fake_llm, + agent="zero-shot-react-description", + verbose=True, + callback_manager=manager, + ) + + agent.agent.llm_chain.verbose = True + + output = agent.run("when was langchain made") + assert output == "curses foiled again" + + # 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run + assert handler.starts == 6 + # 1 extra agent end + assert handler.ends == 7 + assert handler.errors == 0 + # during LLMChain + assert handler.text == 2 + + +def test_agent_with_callbacks_not_verbose() -> None: + """Test react chain with callbacks but not verbose.""" + import langchain + + langchain.verbose = False + handler = FakeCallbackHandler() + manager = CallbackManager(handlers=[handler]) + tool = "Search" + responses = [ + f"FooBarBaz\nAction: {tool}\nAction Input: misalignment", + "Oh well\nAction: Final Answer\nAction Input: curses foiled again", + ] + fake_llm = FakeListLLM(responses=responses, callback_manager=manager) + tools = [ + Tool("Search", lambda x: x, "Useful for searching"), + ] + agent = initialize_agent( + tools, + fake_llm, + agent="zero-shot-react-description", + callback_manager=manager, + ) + + output = agent.run("when was langchain made") + assert output == "curses foiled again" + + # 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run + assert handler.starts == 0 + assert handler.ends == 0 + assert handler.errors == 0 diff --git a/tests/unit_tests/callbacks/__init__.py b/tests/unit_tests/callbacks/__init__.py new file mode 100644 index 0000000000..cd34752b30 --- /dev/null +++ b/tests/unit_tests/callbacks/__init__.py @@ -0,0 +1 @@ +"""Tests for correct functioning of callbacks.""" diff --git a/tests/unit_tests/callbacks/fake_callback_handler.py b/tests/unit_tests/callbacks/fake_callback_handler.py new file mode 100644 index 0000000000..07cf52f5e2 --- /dev/null +++ b/tests/unit_tests/callbacks/fake_callback_handler.py @@ -0,0 +1,67 @@ +"""A fake callback handler for testing purposes.""" +from typing import Any, Dict, List + +from langchain.callbacks.base import BaseCallbackHandler +from langchain.schema import AgentAction, AgentFinish, LLMResult + + +class FakeCallbackHandler(BaseCallbackHandler): + """Fake callback handler for testing.""" + + starts: int = 0 + ends: int = 0 + errors: int = 0 + text: int = 0 + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Run when LLM starts running.""" + self.starts += 1 + + def on_llm_end( + self, + response: LLMResult, + ) -> None: + """Run when LLM ends running.""" + self.ends += 1 + + def on_llm_error(self, error: Exception) -> None: + """Run when LLM errors.""" + self.errors += 1 + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Run when chain starts running.""" + self.starts += 1 + + def on_chain_end(self, outputs: Dict[str, Any]) -> None: + """Run when chain ends running.""" + self.ends += 1 + + def on_chain_error(self, error: Exception) -> None: + """Run when chain errors.""" + self.errors += 1 + + def on_tool_start( + self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any + ) -> None: + """Run when tool starts running.""" + self.starts += 1 + + def on_tool_end(self, output: str, **kwargs: Any) -> None: + """Run when tool ends running.""" + self.ends += 1 + + def on_tool_error(self, error: Exception) -> None: + """Run when tool errors.""" + self.errors += 1 + + def on_text(self, text: str, **kwargs: Any) -> None: + """Run when agent is ending.""" + self.text += 1 + + def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: + """Run when agent ends running.""" + self.ends += 1 diff --git a/tests/unit_tests/callbacks/test_callback_manager.py b/tests/unit_tests/callbacks/test_callback_manager.py new file mode 100644 index 0000000000..c5c493ce27 --- /dev/null +++ b/tests/unit_tests/callbacks/test_callback_manager.py @@ -0,0 +1,97 @@ +"""Test CallbackManager.""" + +from langchain.callbacks.base import BaseCallbackManager, CallbackManager +from langchain.callbacks.shared import SharedCallbackManager +from langchain.schema import AgentAction, AgentFinish, LLMResult +from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler + + +def _test_callback_manager( + manager: BaseCallbackManager, *handlers: FakeCallbackHandler +) -> None: + """Test the CallbackManager.""" + manager.on_llm_start({}, []) + manager.on_llm_end(LLMResult(generations=[])) + manager.on_llm_error(Exception()) + manager.on_chain_start({"name": "foo"}, {}) + manager.on_chain_end({}) + manager.on_chain_error(Exception()) + manager.on_tool_start({}, AgentAction("", "", "")) + manager.on_tool_end("") + manager.on_tool_error(Exception()) + manager.on_agent_finish(AgentFinish({}, "")) + for handler in handlers: + assert handler.starts == 3 + assert handler.ends == 4 + assert handler.errors == 3 + + +def test_callback_manager() -> None: + """Test the CallbackManager.""" + handler1 = FakeCallbackHandler() + handler2 = FakeCallbackHandler() + manager = CallbackManager(handlers=[handler1, handler2]) + _test_callback_manager(manager, handler1, handler2) + + +def test_ignore_llm() -> None: + """Test ignore llm param for callback handlers.""" + handler1 = FakeCallbackHandler(ignore_llm=True) + handler2 = FakeCallbackHandler() + manager = CallbackManager(handlers=[handler1, handler2]) + manager.on_llm_start({}, []) + manager.on_llm_end(LLMResult(generations=[])) + manager.on_llm_error(Exception()) + assert handler1.starts == 0 + assert handler1.ends == 0 + assert handler1.errors == 0 + assert handler2.starts == 1 + assert handler2.ends == 1 + assert handler2.errors == 1 + + +def test_ignore_chain() -> None: + """Test ignore chain param for callback handlers.""" + handler1 = FakeCallbackHandler(ignore_chain=True) + handler2 = FakeCallbackHandler() + manager = CallbackManager(handlers=[handler1, handler2]) + manager.on_chain_start({"name": "foo"}, {}) + manager.on_chain_end({}) + manager.on_chain_error(Exception()) + assert handler1.starts == 0 + assert handler1.ends == 0 + assert handler1.errors == 0 + assert handler2.starts == 1 + assert handler2.ends == 1 + assert handler2.errors == 1 + + +def test_ignore_agent() -> None: + """Test ignore agent param for callback handlers.""" + handler1 = FakeCallbackHandler(ignore_agent=True) + handler2 = FakeCallbackHandler() + manager = CallbackManager(handlers=[handler1, handler2]) + manager.on_tool_start({}, AgentAction("", "", "")) + manager.on_tool_end("") + manager.on_tool_error(Exception()) + manager.on_agent_finish(AgentFinish({}, "")) + assert handler1.starts == 0 + assert handler1.ends == 0 + assert handler1.errors == 0 + assert handler2.starts == 1 + assert handler2.ends == 2 + assert handler2.errors == 1 + + +def test_shared_callback_manager() -> None: + """Test the SharedCallbackManager.""" + manager1 = SharedCallbackManager() + manager2 = SharedCallbackManager() + + assert manager1 is manager2 + + handler1 = FakeCallbackHandler() + handler2 = FakeCallbackHandler() + manager1.add_handler(handler1) + manager2.add_handler(handler2) + _test_callback_manager(manager1, handler1, handler2) diff --git a/tests/unit_tests/chains/test_base.py b/tests/unit_tests/chains/test_base.py index 69dfd6bf7c..c07f857631 100644 --- a/tests/unit_tests/chains/test_base.py +++ b/tests/unit_tests/chains/test_base.py @@ -4,7 +4,9 @@ from typing import Any, Dict, List import pytest from pydantic import BaseModel +from langchain.callbacks.base import CallbackManager from langchain.chains.base import Chain, Memory +from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler class FakeMemory(Memory, BaseModel): @@ -133,3 +135,31 @@ def test_run_arg_with_memory() -> None: """Test run method works when arg is passed.""" chain = FakeChain(the_input_keys=["foo", "baz"], memory=FakeMemory()) chain.run("bar") + + +def test_run_with_callback() -> None: + """Test run method works when callback manager is passed.""" + handler = FakeCallbackHandler() + chain = FakeChain( + callback_manager=CallbackManager(handlers=[handler]), verbose=True + ) + output = chain.run("bar") + assert output == "baz" + assert handler.starts == 1 + assert handler.ends == 1 + assert handler.errors == 0 + + +def test_run_with_callback_not_verbose() -> None: + """Test run method works when callback manager is passed and not verbose.""" + import langchain + + langchain.verbose = False + + handler = FakeCallbackHandler() + chain = FakeChain(callback_manager=CallbackManager(handlers=[handler])) + output = chain.run("bar") + assert output == "baz" + assert handler.starts == 0 + assert handler.ends == 0 + assert handler.errors == 0 diff --git a/tests/unit_tests/llms/test_callbacks.py b/tests/unit_tests/llms/test_callbacks.py new file mode 100644 index 0000000000..d9d52630b7 --- /dev/null +++ b/tests/unit_tests/llms/test_callbacks.py @@ -0,0 +1,30 @@ +"""Test LLM callbacks.""" +from langchain.callbacks.base import CallbackManager +from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler +from tests.unit_tests.llms.fake_llm import FakeLLM + + +def test_llm_with_callbacks() -> None: + """Test LLM callbacks.""" + handler = FakeCallbackHandler() + llm = FakeLLM(callback_manager=CallbackManager(handlers=[handler]), verbose=True) + output = llm("foo") + assert output == "foo" + assert handler.starts == 1 + assert handler.ends == 1 + assert handler.errors == 0 + + +def test_llm_with_callbacks_not_verbose() -> None: + """Test LLM callbacks but not verbose.""" + import langchain + + langchain.verbose = False + + handler = FakeCallbackHandler() + llm = FakeLLM(callback_manager=CallbackManager(handlers=[handler])) + output = llm("foo") + assert output == "foo" + assert handler.starts == 0 + assert handler.ends == 0 + assert handler.errors == 0 diff --git a/tests/unit_tests/test_hyde.py b/tests/unit_tests/test_hyde.py index 91df0f3407..91b7bb550d 100644 --- a/tests/unit_tests/test_hyde.py +++ b/tests/unit_tests/test_hyde.py @@ -7,8 +7,8 @@ from pydantic import BaseModel from langchain.embeddings.base import Embeddings from langchain.embeddings.hyde.base import HypotheticalDocumentEmbedder from langchain.embeddings.hyde.prompts import PROMPT_MAP -from langchain.llms.base import BaseLLM, LLMResult -from langchain.schema import Generation +from langchain.llms.base import BaseLLM +from langchain.schema import Generation, LLMResult class FakeEmbeddings(Embeddings):