Compare commits

...

22 Commits

Author SHA1 Message Date
Harrison Chase d3a6387ab9 cr 1 year ago
Harrison Chase 3efee27e56 cr 1 year ago
Harrison Chase 7d0b1cafd7 cr 1 year ago
Harrison Chase 6fe6af7048 cr 1 year ago
Harrison Chase 6953c2e707 cr 1 year ago
Harrison Chase 3086b752a3 Merge branch 'ankush/callbackhandler' into harrison/callback-updates 1 year ago
Harrison Chase 03e3cd468b cr 1 year ago
Harrison Chase 7eb33690a9 callback updates 1 year ago
Harrison Chase 23b8cfc123 cr 1 year ago
Ankush Gola db5c8e0c42
explicitly set global verbosity flag in unit tests that test callbacks (#502)
If another file anywhere in `unit_tests` sets `langchain.verbose =
True`, it messes up all of the tests that check for no callbacks because
the `None` for verbose gets overridden by the global verbosity flag. By
explicitly setting it in unit tests, we bypass that potential issue
1 year ago
Ankush Gola aae3609aa8
fix ontext comment (#500) 1 year ago
Harrison Chase a3d2a2ec2a
Harrison/streamlit handler (#488)
also add a set handler method

usage is:
```
from langchain.callbacks.streamlit import StreamlitCallbackHandler
import langchain
langchain.set_handler(StreamlitCallbackHandler())
```

produces the following output


![Screen Shot 2022-12-29 at 10 50 33
PM](https://user-images.githubusercontent.com/11986836/210032762-7f53fffa-cb2f-4dac-af39-7d4cf81e55dd.png)

only works for agent stuff currently
1 year ago
Harrison Chase 45d6de177e
remove logger (#491)
remove old logging class (no longer used anyways)
1 year ago
Harrison Chase 175a248506
Harrison/get rid of prints (#490)
deprecate all prints in favor of callback_manager.on_text (open to
better naming)
1 year ago
Harrison Chase b902bddb8a
fix verbosity (#496)
1. remove verbose from someplace it didnt relaly belong
2. everywhere else, make verbose Optional[bool] with default to None
3. make base classes accept None, and then look up globabl verbosity if
thats the case
1 year ago
Ankush Gola 164806a844
quick comment fixes (#494) 1 year ago
Harrison Chase e3edd74eab
switch up defaults (#485)
i kinda like this just because we call `self.callback_manager` so many
times, and thats nicer than `self._get_callback_manager()`?
1 year ago
Harrison Chase 52490e2dcd
add explicit agent end method (#486) 1 year ago
Harrison Chase 7e36f28e78
Harrison/not verbose (#487) 1 year ago
Harrison Chase 5d43246694
WIP: stdout callback (#479)
first pass at stdout callback

for the most part, went pretty smoothly. aside from the code here, here
are some comments/observations.

1. 
should somehow default to stdouthandler so i dont have to do 
```
from langchain.callbacks import get_callback_manager
from langchain.callbacks.stdout import StdOutCallbackHandler

get_callback_manager().add_handler(StdOutCallbackHandler())
```

2. I kept around the verbosity flag. 1) this is pretty important for
getting the stdout to look good for agents (and other things). 2) I
actually added this for LLM class since it didn't have it.

3. The only part that isn't basically perfectly moved over is the end of
the agent run. Here's a screenshot of the new stdout tracing
![Screen Shot 2022-12-29 at 4 03 50
PM](https://user-images.githubusercontent.com/11986836/210011538-6a74551a-2e61-437b-98d3-674212dede56.png)

Noticing it is missing logging of the final thought, eg before this is
what it looked like
![Screen Shot 2022-12-29 at 4 13 07
PM](https://user-images.githubusercontent.com/11986836/210011635-de68b3f5-e2b0-4cd3-9f1a-3afe970a8716.png)

The reason its missing is that this was previously logged as part of
agent end (lines 205 and 206)

this is probably only relevant for the std out logger? any thoughts for
how to get it back in?
1 year ago
Ankush Gola 36922318d3
allow for optional CallbackManager in LLM, Chain, and Agent (#482) 1 year ago
Ankush Gola 46b31626b5
Add BaseCallbackHandler and CallbackManager (#476) 1 year ago

@ -51,7 +51,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"id": "0728f0d9",
"metadata": {},
"outputs": [],
@ -69,7 +69,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"id": "ba4e7618",
"metadata": {},
"outputs": [],
@ -87,7 +87,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"id": "03208e2b",
"metadata": {},
"outputs": [],
@ -105,7 +105,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"id": "244ee75c",
"metadata": {},
"outputs": [
@ -131,7 +131,7 @@
"\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: Harry Styles is Olivia Wilde's boyfriend and his current age raised to the 0.23 power is 2.1520202182226886.\u001b[0m\n",
"\u001b[1m> Finished AgentExecutor chain.\u001b[0m\n"
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
@ -140,7 +140,7 @@
"\"Harry Styles is Olivia Wilde's boyfriend and his current age raised to the 0.23 power is 2.1520202182226886.\""
]
},
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@ -166,12 +166,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.0"
},
"vscode": {
"interpreter": {
"hash": "b1677b440931f40d89ef8be7bf03acb108ce003de0ac9b18e8d43753ea2e7103"
}
"version": "3.10.9"
}
},
"nbformat": 4,

@ -4,6 +4,7 @@ from typing import Optional
from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain
from langchain.cache import BaseCache
from langchain.callbacks import set_default_callback_manager, set_handler
from langchain.chains import (
ConversationChain,
LLMBashChain,
@ -19,7 +20,6 @@ from langchain.chains import (
from langchain.docstore import InMemoryDocstore, Wikipedia
from langchain.llms import Cohere, HuggingFaceHub, OpenAI
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.logger import BaseLogger, StdOutLogger
from langchain.prompts import (
BasePromptTemplate,
FewShotPromptTemplate,
@ -31,9 +31,9 @@ from langchain.sql_database import SQLDatabase
from langchain.utilities.google_search import GoogleSearchAPIWrapper
from langchain.vectorstores import FAISS, ElasticVectorSearch
logger: BaseLogger = StdOutLogger()
verbose: bool = False
llm_cache: Optional[BaseCache] = None
set_default_callback_manager()
__all__ = [
"LLMChain",
@ -65,4 +65,5 @@ __all__ = [
"VectorDBQAWithSourcesChain",
"QAWithSourcesChain",
"PALChain",
"set_handler",
]

@ -7,8 +7,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel, root_validator
import langchain
from langchain.agents.tools import Tool
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.input import get_color_mapping
@ -46,7 +46,7 @@ class Agent(BaseModel):
def plan(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
) -> Union[AgentFinish, AgentAction]:
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
@ -132,10 +132,19 @@ class Agent(BaseModel):
pass
@classmethod
def from_llm_and_tools(cls, llm: BaseLLM, tools: List[Tool]) -> Agent:
def from_llm_and_tools(
cls,
llm: BaseLLM,
tools: List[Tool],
callback_manager: Optional[BaseCallbackManager] = None,
) -> Agent:
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools))
llm_chain = LLMChain(
llm=llm,
prompt=cls.create_prompt(tools),
callback_manager=callback_manager,
)
return cls(llm_chain=llm_chain)
def return_stopped_response(
@ -194,10 +203,16 @@ class AgentExecutor(Chain, BaseModel):
@classmethod
def from_agent_and_tools(
cls, agent: Agent, tools: List[Tool], **kwargs: Any
cls,
agent: Agent,
tools: List[Tool],
callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any,
) -> AgentExecutor:
"""Create from agent and tools."""
return cls(agent=agent, tools=tools, **kwargs)
return cls(
agent=agent, tools=tools, callback_manager=callback_manager, **kwargs
)
@property
def input_keys(self) -> List[str]:
@ -244,24 +259,31 @@ class AgentExecutor(Chain, BaseModel):
# If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish):
if self.verbose:
langchain.logger.log_agent_end(output, color="green")
self.callback_manager.on_agent_end(output, color="green")
final_output = output.return_values
if self.return_intermediate_steps:
final_output["intermediate_steps"] = intermediate_steps
return final_output
if self.verbose:
langchain.logger.log_agent_action(output, color="green")
# And then we lookup the tool
if output.tool in name_to_tool_map:
chain = name_to_tool_map[output.tool]
if self.verbose:
self.callback_manager.on_tool_start(
{"name": str(chain)[:60] + "..."}, output, color="green"
)
# We then call the tool on the tool input to get an observation
observation = chain(output.tool_input)
color = color_mapping[output.tool]
else:
if self.verbose:
self.callback_manager.on_tool_start(
{"name": "N/A"}, output, color="green"
)
observation = f"{output.tool} is not a valid tool, try another one."
color = None
if self.verbose:
langchain.logger.log_agent_observation(
self.callback_manager.on_tool_end(
observation,
color=color,
observation_prefix=self.agent.observation_prefix,
@ -272,6 +294,8 @@ class AgentExecutor(Chain, BaseModel):
output = self.agent.return_stopped_response(
self.early_stopping_method, intermediate_steps, **inputs
)
if self.verbose:
self.callback_manager.on_agent_end(output, color="green")
final_output = output.return_values
if self.return_intermediate_steps:
final_output["intermediate_steps"] = intermediate_steps

@ -1,11 +1,12 @@
"""Load agent."""
from typing import Any, List
from typing import Any, List, Optional
from langchain.agents.agent import AgentExecutor
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.react.base import ReActDocstoreAgent
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
from langchain.agents.tools import Tool
from langchain.callbacks.base import BaseCallbackManager
from langchain.llms.base import BaseLLM
AGENT_TO_CLASS = {
@ -19,6 +20,7 @@ def initialize_agent(
tools: List[Tool],
llm: BaseLLM,
agent: str = "zero-shot-react-description",
callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any,
) -> AgentExecutor:
"""Load agent given tools and LLM.
@ -28,6 +30,8 @@ def initialize_agent(
llm: Language model to use as the agent.
agent: The agent to use. Valid options are:
`zero-shot-react-description`, `react-docstore`, `self-ask-with-search`.
callback_manager: CallbackManager to use. Global callback manager is used if
not provided. Defaults to None.
**kwargs: Additional key word arguments to pass to the agent.
Returns:
@ -39,5 +43,12 @@ def initialize_agent(
f"Valid types are: {AGENT_TO_CLASS.keys()}."
)
agent_cls = AGENT_TO_CLASS[agent]
agent_obj = agent_cls.from_llm_and_tools(llm, tools)
return AgentExecutor.from_agent_and_tools(agent=agent_obj, tools=tools, **kwargs)
agent_obj = agent_cls.from_llm_and_tools(
llm, tools, callback_manager=callback_manager
)
return AgentExecutor.from_agent_and_tools(
agent=agent_obj,
tools=tools,
callback_manager=callback_manager,
**kwargs,
)

@ -0,0 +1,20 @@
"""Callback handlers that allow listening to events in LangChain."""
from langchain.callbacks.base import BaseCallbackHandler, BaseCallbackManager
from langchain.callbacks.shared import SharedCallbackManager
from langchain.callbacks.stdout import StdOutCallbackHandler
def get_callback_manager() -> BaseCallbackManager:
"""Return the shared callback manager."""
return SharedCallbackManager()
def set_handler(handler: BaseCallbackHandler) -> None:
"""Set handler."""
callback = get_callback_manager()
callback.set_handler(handler)
def set_default_callback_manager() -> None:
"""Set default callback manager."""
set_handler(StdOutCallbackHandler())

@ -0,0 +1,177 @@
"""Base callback handler that can be used to handle callbacks from langchain."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from pydantic import BaseModel
from langchain.schema import AgentAction, AgentFinish, LLMResult
class BaseCallbackHandler(BaseModel, ABC):
"""Base callback handler that can be used to handle callbacks from langchain."""
ignore_llm: bool = False
ignore_chain: bool = False
ignore_agent: bool = False
@abstractmethod
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
@abstractmethod
def on_llm_end(
self,
response: LLMResult,
) -> None:
"""Run when LLM ends running."""
@abstractmethod
def on_llm_error(self, error: Exception) -> None:
"""Run when LLM errors."""
@abstractmethod
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
@abstractmethod
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
"""Run when chain ends running."""
@abstractmethod
def on_chain_error(self, error: Exception) -> None:
"""Run when chain errors."""
@abstractmethod
def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
) -> None:
"""Run when tool starts running."""
@abstractmethod
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
@abstractmethod
def on_tool_error(self, error: Exception) -> None:
"""Run when tool errors."""
@abstractmethod
def on_text(self, text: str, **kwargs: Any) -> None:
"""Run on arbitrary text."""
@abstractmethod
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end."""
class BaseCallbackManager(BaseCallbackHandler, ABC):
"""Base callback manager that can be used to handle callbacks from LangChain."""
@abstractmethod
def add_handler(self, callback: BaseCallbackHandler) -> None:
"""Add a handler to the callback manager."""
@abstractmethod
def remove_handler(self, handler: BaseCallbackHandler) -> None:
"""Remove a handler from the callback manager."""
@abstractmethod
def set_handler(self, handler: BaseCallbackHandler) -> None:
"""Set handler as the only handler on the callback manager."""
class CallbackManager(BaseCallbackManager):
"""Callback manager that can be used to handle callbacks from langchain."""
handlers: List[BaseCallbackHandler]
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
for handler in self.handlers:
if not handler.ignore_llm:
handler.on_llm_start(serialized, prompts, **kwargs)
def on_llm_end(
self,
response: LLMResult,
) -> None:
"""Run when LLM ends running."""
for handler in self.handlers:
if not handler.ignore_llm:
handler.on_llm_end(response)
def on_llm_error(self, error: Exception) -> None:
"""Run when LLM errors."""
for handler in self.handlers:
if not handler.ignore_llm:
handler.on_llm_error(error)
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
for handler in self.handlers:
if not handler.ignore_chain:
handler.on_chain_start(serialized, inputs, **kwargs)
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
"""Run when chain ends running."""
for handler in self.handlers:
if not handler.ignore_chain:
handler.on_chain_end(outputs)
def on_chain_error(self, error: Exception) -> None:
"""Run when chain errors."""
for handler in self.handlers:
if not handler.ignore_chain:
handler.on_chain_error(error)
def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
) -> None:
"""Run when tool starts running."""
for handler in self.handlers:
if not handler.ignore_agent:
handler.on_tool_start(serialized, action, **kwargs)
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
for handler in self.handlers:
if not handler.ignore_agent:
handler.on_tool_end(output, **kwargs)
def on_tool_error(self, error: Exception) -> None:
"""Run when tool errors."""
for handler in self.handlers:
if not handler.ignore_agent:
handler.on_tool_error(error)
def on_text(self, text: str, **kwargs: Any) -> None:
"""Run on additional input from chains and agents."""
for handler in self.handlers:
handler.on_text(text, **kwargs)
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end."""
for handler in self.handlers:
if not handler.ignore_agent:
handler.on_agent_end(finish, **kwargs)
def add_handler(self, handler: BaseCallbackHandler) -> None:
"""Add a handler to the callback manager."""
self.handlers.append(handler)
def remove_handler(self, handler: BaseCallbackHandler) -> None:
"""Remove a handler from the callback manager."""
self.handlers.remove(handler)
def set_handler(self, handler: BaseCallbackHandler) -> None:
"""Set handler as the only handler on the callback manager."""
self.handlers = [handler]

@ -0,0 +1,114 @@
"""A shared CallbackManager."""
import threading
from typing import Any, Dict, List
from langchain.callbacks.base import (
BaseCallbackHandler,
BaseCallbackManager,
CallbackManager,
)
from langchain.schema import AgentAction, AgentFinish, LLMResult
class Singleton:
"""A thread-safe singleton class that can be inherited from."""
_instance = None
_lock = threading.Lock()
def __new__(cls) -> Any:
"""Create a new shared instance of the class."""
if cls._instance is None:
with cls._lock:
# Another thread could have created the instance
# before we acquired the lock. So check that the
# instance is still nonexistent.
if not cls._instance:
cls._instance = super().__new__(cls)
return cls._instance
class SharedCallbackManager(Singleton, BaseCallbackManager):
"""A thread-safe singleton CallbackManager."""
_callback_manager: CallbackManager = CallbackManager(handlers=[])
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
with self._lock:
self._callback_manager.on_llm_start(serialized, prompts, **kwargs)
def on_llm_end(
self,
response: LLMResult,
) -> None:
"""Run when LLM ends running."""
with self._lock:
self._callback_manager.on_llm_end(response)
def on_llm_error(self, error: Exception) -> None:
"""Run when LLM errors."""
with self._lock:
self._callback_manager.on_llm_error(error)
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
with self._lock:
self._callback_manager.on_chain_start(serialized, inputs, **kwargs)
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
"""Run when chain ends running."""
with self._lock:
self._callback_manager.on_chain_end(outputs)
def on_chain_error(self, error: Exception) -> None:
"""Run when chain errors."""
with self._lock:
self._callback_manager.on_chain_error(error)
def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
) -> None:
"""Run when tool starts running."""
with self._lock:
self._callback_manager.on_tool_start(serialized, action, **kwargs)
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
with self._lock:
self._callback_manager.on_tool_end(output, **kwargs)
def on_tool_error(self, error: Exception) -> None:
"""Run when tool errors."""
with self._lock:
self._callback_manager.on_tool_error(error)
def on_text(self, text: str, **kwargs: Any) -> None:
"""Run on arbitrary text."""
with self._lock:
self._callback_manager.on_text(text, **kwargs)
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end."""
with self._lock:
self._callback_manager.on_agent_end(finish, **kwargs)
def add_handler(self, callback: BaseCallbackHandler) -> None:
"""Add a callback to the callback manager."""
with self._lock:
self._callback_manager.add_handler(callback)
def remove_handler(self, callback: BaseCallbackHandler) -> None:
"""Remove a callback from the callback manager."""
with self._lock:
self._callback_manager.remove_handler(callback)
def set_handler(self, handler: BaseCallbackHandler) -> None:
"""Set handler as the only handler on the callback manager."""
with self._lock:
self._callback_manager.handlers = [handler]

@ -0,0 +1,84 @@
"""Callback Handler that prints to std out."""
from typing import Any, Dict, List, Optional
from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text
from langchain.schema import AgentAction, AgentFinish, LLMResult
class StdOutCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Print out the prompts."""
print("Prompts after formatting:")
for prompt in prompts:
print_text(prompt, color="green", end="\n")
def on_llm_end(self, response: LLMResult) -> None:
"""Do nothing."""
pass
def on_llm_error(self, error: Exception) -> None:
"""Do nothing."""
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
class_name = serialized["name"]
print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m")
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
"""Print out that we finished a chain."""
print("\n\033[1m> Finished chain.\033[0m")
def on_chain_error(self, error: Exception) -> None:
"""Do nothing."""
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
action: AgentAction,
color: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Print out the log in specified color."""
print_text(action.log, color=color)
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""If not the final action, print out observation."""
print_text(f"\n{observation_prefix}")
print_text(output, color=color)
print_text(f"\n{llm_prefix}")
def on_tool_error(self, error: Exception) -> None:
"""Do nothing."""
pass
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Optional[str],
) -> None:
"""Run when agent ends."""
print_text(text, color=color, end=end)
def on_agent_end(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
"""Run on agent end."""
print_text(finish.log, color=color, end="\n")

@ -0,0 +1,77 @@
"""Callback Handler that logs to streamlit."""
from typing import Any, Dict, List, Optional
import streamlit as st
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
class StreamlitCallbackHandler(BaseCallbackHandler):
"""Callback Handler that logs to streamlit."""
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Print out the prompts."""
st.write("Prompts after formatting:")
for prompt in prompts:
st.write(prompt)
def on_llm_end(self, response: LLMResult) -> None:
"""Do nothing."""
pass
def on_llm_error(self, error: Exception) -> None:
"""Do nothing."""
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
class_name = serialized["name"]
st.write(f"Entering new {class_name} chain...")
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
"""Print out that we finished a chain."""
st.write("Finished chain.")
def on_chain_error(self, error: Exception) -> None:
"""Do nothing."""
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
action: AgentAction,
**kwargs: Any,
) -> None:
"""Print out the log in specified color."""
# st.write requires two spaces before a newline to render it
st.markdown(action.log.replace("\n", " \n"))
def on_tool_end(
self,
output: str,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""If not the final action, print out observation."""
st.write(f"{observation_prefix}{output}")
st.write(llm_prefix)
def on_tool_error(self, error: Exception) -> None:
"""Do nothing."""
pass
def on_text(self, text: str, **kwargs: Any) -> None:
"""Run on text."""
# st.write requires two spaces before a newline to render it
st.write(text.replace("\n", " \n"))
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end."""
# st.write requires two spaces before a newline to render it
st.write(finish.log.replace("\n", " \n"))

@ -8,7 +8,6 @@ from pydantic import BaseModel, root_validator
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.input import print_text
from langchain.llms.base import BaseLLM
from langchain.requests import RequestsWrapper
@ -67,10 +66,10 @@ class APIChain(Chain, BaseModel):
question=question, api_docs=self.api_docs
)
if self.verbose:
print_text(api_url, color="green", end="\n")
self.callback_manager.on_text(api_url, color="green", end="\n")
api_response = self.requests_wrapper.run(api_url)
if self.verbose:
print_text(api_response, color="yellow", end="\n")
self.callback_manager.on_text(api_response, color="yellow", end="\n")
answer = self.api_answer_chain.predict(
question=question,
api_docs=self.api_docs,

@ -2,9 +2,11 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Extra, Field
from pydantic import BaseModel, Extra, Field, validator
import langchain
from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
class Memory(BaseModel, ABC):
@ -42,9 +44,36 @@ class Chain(BaseModel, ABC):
"""Base interface that all chains should implement."""
memory: Optional[Memory] = None
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
verbose: bool = Field(
default_factory=_get_verbosity
) # Whether to print the response text
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text."""
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@validator("callback_manager", pre=True, always=True)
def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager]
) -> BaseCallbackManager:
"""If callback manager is None, set it.
This allows users to pass in None as callback manager, which is a nice UX.
"""
return callback_manager or get_callback_manager()
@validator("verbose", pre=True, always=True)
def set_verbose(cls, verbose: Optional[bool]) -> bool:
"""If verbose is None, set it.
This allows users to pass in None as verbose to access the global setting.
"""
if verbose is None:
return _get_verbosity()
else:
return verbose
@property
@abstractmethod
@ -106,12 +135,12 @@ class Chain(BaseModel, ABC):
inputs = dict(inputs, **external_context)
self._validate_inputs(inputs)
if self.verbose:
print(
f"\n\n\033[1m> Entering new {self.__class__.__name__} chain...\033[0m"
self.callback_manager.on_chain_start(
{"name": self.__class__.__name__}, inputs
)
outputs = self._call(inputs)
if self.verbose:
print(f"\n\033[1m> Finished {self.__class__.__name__} chain.\033[0m")
self.callback_manager.on_chain_end(outputs)
self._validate_outputs(outputs)
if self.memory is not None:
self.memory.save_context(inputs, outputs)

@ -3,10 +3,10 @@ from typing import Any, Dict, List, Sequence, Union
from pydantic import BaseModel, Extra
import langchain
from langchain.chains.base import Chain
from langchain.llms.base import BaseLLM, LLMResult
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import LLMResult
class LLMChain(Chain, BaseModel):
@ -60,8 +60,6 @@ class LLMChain(Chain, BaseModel):
for inputs in input_list:
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
prompt = self.prompt.format(**selected_inputs)
if self.verbose:
langchain.logger.log_llm_inputs(selected_inputs, prompt)
if "stop" in inputs and inputs["stop"] != stop:
raise ValueError(
"If `stop` is present in any inputs, should be present in all."
@ -77,8 +75,6 @@ class LLMChain(Chain, BaseModel):
for generation in response.generations:
# Get the text of the top generated string.
response_str = generation[0].text
if self.verbose:
langchain.logger.log_llm_response(response_str)
outputs.append({self.output_key: response_str})
return outputs

@ -6,7 +6,6 @@ from pydantic import BaseModel, Extra
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.llm_bash.prompt import PROMPT
from langchain.input import print_text
from langchain.llms.base import BaseLLM
from langchain.utilities.bash import BashProcess
@ -52,11 +51,11 @@ class LLMBashChain(Chain, BaseModel):
llm_executor = LLMChain(prompt=PROMPT, llm=self.llm)
bash_executor = BashProcess()
if self.verbose:
print_text(inputs[self.input_key])
self.callback_manager.on_text(inputs[self.input_key])
t = llm_executor.predict(question=inputs[self.input_key])
if self.verbose:
print_text(t, color="green")
self.callback_manager.on_text(t, color="green")
t = t.strip()
if t.startswith("```bash"):
@ -69,8 +68,8 @@ class LLMBashChain(Chain, BaseModel):
output = bash_executor.run(command_list)
if self.verbose:
print_text("\nAnswer: ")
print_text(output, color="yellow")
self.callback_manager.on_text("\nAnswer: ")
self.callback_manager.on_text(output, color="yellow")
else:
raise ValueError(f"unknown format from LLM: {t}")

@ -6,7 +6,6 @@ from pydantic import BaseModel, Extra
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.llm_math.prompt import PROMPT
from langchain.input import print_text
from langchain.llms.base import BaseLLM
from langchain.python import PythonREPL
@ -52,17 +51,17 @@ class LLMMathChain(Chain, BaseModel):
llm_executor = LLMChain(prompt=PROMPT, llm=self.llm)
python_executor = PythonREPL()
if self.verbose:
print_text(inputs[self.input_key])
self.callback_manager.on_text(inputs[self.input_key])
t = llm_executor.predict(question=inputs[self.input_key], stop=["```output"])
if self.verbose:
print_text(t, color="green")
self.callback_manager.on_text(t, color="green")
t = t.strip()
if t.startswith("```python"):
code = t[9:-4]
output = python_executor.run(code)
if self.verbose:
print_text("\nAnswer: ")
print_text(output, color="yellow")
self.callback_manager.on_text("\nAnswer: ")
self.callback_manager.on_text(output, color="yellow")
answer = "Answer: " + output
elif t.startswith("Answer:"):
answer = t

@ -12,7 +12,6 @@ from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT
from langchain.chains.pal.math_prompt import MATH_PROMPT
from langchain.input import print_text
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.python import PythonREPL
@ -53,7 +52,7 @@ class PALChain(Chain, BaseModel):
llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)
code = llm_chain.predict(stop=[self.stop], **inputs)
if self.verbose:
print_text(code, color="green", end="\n")
self.callback_manager.on_text(code, color="green", end="\n")
repl = PythonREPL()
res = repl.run(code + f"\n{self.get_answer_expr}")
return {self.output_key: res.strip()}

@ -26,7 +26,7 @@ def _load_stuff_chain(
llm: BaseLLM,
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
document_variable_name: str = "summaries",
verbose: bool = False,
verbose: Optional[bool] = None,
**kwargs: Any,
) -> StuffDocumentsChain:
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
@ -49,7 +49,7 @@ def _load_map_reduce_chain(
collapse_prompt: Optional[BasePromptTemplate] = None,
reduce_llm: Optional[BaseLLM] = None,
collapse_llm: Optional[BaseLLM] = None,
verbose: bool = False,
verbose: Optional[bool] = None,
**kwargs: Any,
) -> MapReduceDocumentsChain:
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
@ -97,7 +97,7 @@ def _load_refine_chain(
document_variable_name: str = "context_str",
initial_response_name: str = "existing_answer",
refine_llm: Optional[BaseLLM] = None,
verbose: bool = False,
verbose: Optional[bool] = None,
**kwargs: Any,
) -> RefineDocumentsChain:
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
@ -115,7 +115,10 @@ def _load_refine_chain(
def load_qa_with_sources_chain(
llm: BaseLLM, chain_type: str = "stuff", verbose: bool = False, **kwargs: Any
llm: BaseLLM,
chain_type: str = "stuff",
verbose: Optional[bool] = None,
**kwargs: Any,
) -> BaseCombineDocumentsChain:
"""Load question answering with sources chain.

@ -26,7 +26,7 @@ def _load_stuff_chain(
llm: BaseLLM,
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
document_variable_name: str = "context",
verbose: bool = False,
verbose: Optional[bool] = None,
**kwargs: Any,
) -> StuffDocumentsChain:
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
@ -48,7 +48,7 @@ def _load_map_reduce_chain(
collapse_prompt: Optional[BasePromptTemplate] = None,
reduce_llm: Optional[BaseLLM] = None,
collapse_llm: Optional[BaseLLM] = None,
verbose: bool = False,
verbose: Optional[bool] = None,
**kwargs: Any,
) -> MapReduceDocumentsChain:
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
@ -94,7 +94,7 @@ def _load_refine_chain(
document_variable_name: str = "context_str",
initial_response_name: str = "existing_answer",
refine_llm: Optional[BaseLLM] = None,
verbose: bool = False,
verbose: Optional[bool] = None,
**kwargs: Any,
) -> RefineDocumentsChain:
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
@ -111,7 +111,10 @@ def _load_refine_chain(
def load_qa_chain(
llm: BaseLLM, chain_type: str = "stuff", verbose: bool = False, **kwargs: Any
llm: BaseLLM,
chain_type: str = "stuff",
verbose: Optional[bool] = None,
**kwargs: Any,
) -> BaseCombineDocumentsChain:
"""Load question answering chain.

@ -5,7 +5,7 @@ from typing import Dict, List
from pydantic import BaseModel, Extra, root_validator
from langchain.chains.base import Chain
from langchain.input import get_color_mapping, print_text
from langchain.input import get_color_mapping
class SequentialChain(Chain, BaseModel):
@ -133,5 +133,7 @@ class SimpleSequentialChain(Chain, BaseModel):
if self.strip_outputs:
_input = _input.strip()
if self.verbose:
print_text(_input, color=color_mapping[str(i)], end="\n")
self.callback_manager.on_text(
_input, color=color_mapping[str(i)], end="\n"
)
return {self.output_key: _input}

@ -6,7 +6,6 @@ from pydantic import BaseModel, Extra
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.sql_database.prompt import PROMPT
from langchain.input import print_text
from langchain.llms.base import BaseLLM
from langchain.sql_database import SQLDatabase
@ -55,7 +54,7 @@ class SQLDatabaseChain(Chain, BaseModel):
llm_chain = LLMChain(llm=self.llm, prompt=PROMPT)
input_text = f"{inputs[self.input_key]} \nSQLQuery:"
if self.verbose:
print_text(input_text)
self.callback_manager.on_text(input_text)
llm_inputs = {
"input": input_text,
"dialect": self.database.dialect,
@ -64,15 +63,15 @@ class SQLDatabaseChain(Chain, BaseModel):
}
sql_cmd = llm_chain.predict(**llm_inputs)
if self.verbose:
print_text(sql_cmd, color="green")
self.callback_manager.on_text(sql_cmd, color="green")
result = self.database.run(sql_cmd)
if self.verbose:
print_text("\nSQLResult: ")
print_text(result, color="yellow")
print_text("\nAnswer:")
self.callback_manager.on_text("\nSQLResult: ")
self.callback_manager.on_text(result, color="yellow")
self.callback_manager.on_text("\nAnswer:")
input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:"
llm_inputs["input"] = input_text
final_result = llm_chain.predict(**llm_inputs)
if self.verbose:
print_text(final_result, color="green")
self.callback_manager.on_text(final_result, color="green")
return {self.output_key: final_result}

@ -22,7 +22,7 @@ def _load_stuff_chain(
llm: BaseLLM,
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
document_variable_name: str = "text",
verbose: bool = False,
verbose: Optional[bool] = None,
**kwargs: Any,
) -> StuffDocumentsChain:
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
@ -44,7 +44,7 @@ def _load_map_reduce_chain(
collapse_prompt: Optional[BasePromptTemplate] = None,
reduce_llm: Optional[BaseLLM] = None,
collapse_llm: Optional[BaseLLM] = None,
verbose: bool = False,
verbose: Optional[bool] = None,
**kwargs: Any,
) -> MapReduceDocumentsChain:
map_chain = LLMChain(llm=llm, prompt=map_prompt, verbose=verbose)
@ -90,7 +90,7 @@ def _load_refine_chain(
document_variable_name: str = "text",
initial_response_name: str = "existing_answer",
refine_llm: Optional[BaseLLM] = None,
verbose: bool = False,
verbose: Optional[bool] = None,
**kwargs: Any,
) -> RefineDocumentsChain:
@ -108,7 +108,10 @@ def _load_refine_chain(
def load_summarize_chain(
llm: BaseLLM, chain_type: str = "stuff", verbose: bool = False, **kwargs: Any
llm: BaseLLM,
chain_type: str = "stuff",
verbose: Optional[bool] = None,
**kwargs: Any,
) -> BaseCombineDocumentsChain:
"""Load summarizing chain.

@ -2,34 +2,55 @@
import json
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Union
from typing import Any, Dict, List, Mapping, Optional, Union
import yaml
from pydantic import BaseModel, Extra
from pydantic import BaseModel, Extra, Field, validator
import langchain
from langchain.schema import Generation
from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
from langchain.schema import Generation, LLMResult
class LLMResult(NamedTuple):
"""Class that contains all relevant information for an LLM Result."""
generations: List[List[Generation]]
"""List of the things generated. This is List[List[]] because
each input could have multiple generations."""
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output."""
def _get_verbosity() -> bool:
return langchain.verbose
class BaseLLM(BaseModel, ABC):
"""LLM wrapper should take in a prompt and return a string."""
cache: Optional[bool] = None
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text."""
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@validator("callback_manager", pre=True, always=True)
def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager]
) -> BaseCallbackManager:
"""If callback manager is None, set it.
This allows users to pass in None as callback manager, which is a nice UX.
"""
return callback_manager or get_callback_manager()
@validator("verbose", pre=True, always=True)
def set_verbose(cls, verbose: Optional[bool]) -> bool:
"""If verbose is None, set it.
This allows users to pass in None as verbose to access the global setting.
"""
if verbose is None:
return _get_verbosity()
else:
return verbose
@abstractmethod
def _generate(
@ -48,7 +69,14 @@ class BaseLLM(BaseModel, ABC):
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
return self._generate(prompts, stop=stop)
if self.verbose:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, prompts
)
output = self._generate(prompts, stop=stop)
if self.verbose:
self.callback_manager.on_llm_end(output)
return output
params = self._llm_dict()
params["stop"] = stop
llm_string = str(sorted([(k, v) for k, v in params.items()]))
@ -62,7 +90,11 @@ class BaseLLM(BaseModel, ABC):
else:
missing_prompts.append(prompt)
missing_prompt_idxs.append(i)
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, missing_prompts
)
new_results = self._generate(missing_prompts, stop=stop)
self.callback_manager.on_llm_end(new_results)
for i, result in enumerate(new_results.generations):
existing_prompts[i] = result
prompt = prompts[i]

@ -4,8 +4,8 @@ from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.llms.base import BaseLLM, LLMResult
from langchain.schema import Generation
from langchain.llms.base import BaseLLM
from langchain.schema import Generation, LLMResult
from langchain.utils import get_from_dict_or_env

@ -1,71 +0,0 @@
"""BETA: everything in here is highly experimental, do not rely on."""
from typing import Any, Optional
from langchain.input import print_text
from langchain.schema import AgentAction, AgentFinish
class BaseLogger:
"""Base logging interface."""
def log_agent_start(self, text: str, **kwargs: Any) -> None:
"""Log the start of an agent interaction."""
pass
def log_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Log the end of an agent interaction."""
pass
def log_agent_action(self, action: AgentAction, **kwargs: Any) -> None:
"""Log agent action decision."""
pass
def log_agent_observation(self, observation: str, **kwargs: Any) -> None:
"""Log agent observation."""
pass
def log_llm_inputs(self, inputs: dict, prompt: str, **kwargs: Any) -> None:
"""Log LLM inputs."""
pass
def log_llm_response(self, output: str, **kwargs: Any) -> None:
"""Log LLM response."""
pass
class StdOutLogger(BaseLogger):
"""Interface for printing things to stdout."""
def log_agent_start(self, text: str, **kwargs: Any) -> None:
"""Print the text to start the agent."""
print_text(text)
def log_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> None:
"""Print the log of the action in a certain color."""
print_text(action.log, color=color)
def log_agent_observation(
self,
observation: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Print the observation in a special color."""
print_text(f"\n{observation_prefix}")
print_text(observation, color=color)
print_text(f"\n{llm_prefix}")
def log_llm_inputs(self, inputs: dict, prompt: str, **kwargs: Any) -> None:
"""Print the prompt in green."""
print("Prompt after formatting:")
print_text(prompt, color="green", end="\n")
def log_agent_end(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
"""Log the end of an agent interaction."""
print_text(finish.log, color=color)

@ -1,6 +1,6 @@
"""Common schema objects."""
from typing import NamedTuple
from typing import List, NamedTuple, Optional
class AgentAction(NamedTuple):
@ -24,3 +24,13 @@ class Generation(NamedTuple):
text: str
"""Generated text output."""
# TODO: add log probs
class LLMResult(NamedTuple):
"""Class that contains all relevant information for an LLM Result."""
generations: List[List[Generation]]
"""List of the things generated. This is List[List[]] because
each input could have multiple generations."""
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output."""

@ -4,8 +4,10 @@ from typing import Any, List, Mapping, Optional
from pydantic import BaseModel
from langchain.agents import Tool, initialize_agent
from langchain.agents import AgentExecutor, Tool, initialize_agent
from langchain.callbacks.base import CallbackManager
from langchain.llms.base import LLM
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
class FakeListLLM(LLM, BaseModel):
@ -31,8 +33,8 @@ class FakeListLLM(LLM, BaseModel):
return "fake_list"
def test_agent_bad_action() -> None:
"""Test react chain when bad action given."""
def _get_agent(**kwargs: Any) -> AgentExecutor:
"""Get agent for testing."""
bad_action_name = "BadAction"
responses = [
f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment",
@ -44,30 +46,122 @@ def test_agent_bad_action() -> None:
Tool("Lookup", lambda x: x, "Useful for looking up things in a table"),
]
agent = initialize_agent(
tools, fake_llm, agent="zero-shot-react-description", verbose=True
tools, fake_llm, agent="zero-shot-react-description", verbose=True, **kwargs
)
return agent
def test_agent_bad_action() -> None:
"""Test react chain when bad action given."""
agent = _get_agent()
output = agent.run("when was langchain made")
assert output == "curses foiled again"
def test_agent_stopped_early() -> None:
"""Test react chain when bad action given."""
bad_action_name = "BadAction"
agent = _get_agent(max_iterations=0)
output = agent.run("when was langchain made")
assert output == "Agent stopped due to max iterations."
def test_agent_with_callbacks_global() -> None:
"""Test react chain with callbacks by setting verbose globally."""
import langchain
langchain.verbose = True
handler = FakeCallbackHandler()
manager = CallbackManager(handlers=[handler])
tool = "Search"
responses = [
f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment",
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
"Oh well\nAction: Final Answer\nAction Input: curses foiled again",
]
fake_llm = FakeListLLM(responses=responses)
fake_llm = FakeListLLM(responses=responses, callback_manager=manager, verbose=True)
tools = [
Tool("Search", lambda x: x, "Useful for searching"),
Tool("Lookup", lambda x: x, "Useful for looking up things in a table"),
]
agent = initialize_agent(
tools,
fake_llm,
agent="zero-shot-react-description",
verbose=True,
max_iterations=0,
callback_manager=manager,
)
output = agent.run("when was langchain made")
assert output == "Agent stopped due to max iterations."
assert output == "curses foiled again"
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run
assert handler.starts == 6
# 1 extra agent end
assert handler.ends == 7
assert handler.errors == 0
def test_agent_with_callbacks_local() -> None:
"""Test react chain with callbacks by setting verbose locally."""
import langchain
langchain.verbose = False
handler = FakeCallbackHandler()
manager = CallbackManager(handlers=[handler])
tool = "Search"
responses = [
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
"Oh well\nAction: Final Answer\nAction Input: curses foiled again",
]
fake_llm = FakeListLLM(responses=responses, callback_manager=manager, verbose=True)
tools = [
Tool("Search", lambda x: x, "Useful for searching"),
]
agent = initialize_agent(
tools,
fake_llm,
agent="zero-shot-react-description",
verbose=True,
callback_manager=manager,
)
agent.agent.llm_chain.verbose = True
output = agent.run("when was langchain made")
assert output == "curses foiled again"
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run
assert handler.starts == 6
# 1 extra agent end
assert handler.ends == 7
assert handler.errors == 0
def test_agent_with_callbacks_not_verbose() -> None:
"""Test react chain with callbacks but not verbose."""
import langchain
langchain.verbose = False
handler = FakeCallbackHandler()
manager = CallbackManager(handlers=[handler])
tool = "Search"
responses = [
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
"Oh well\nAction: Final Answer\nAction Input: curses foiled again",
]
fake_llm = FakeListLLM(responses=responses, callback_manager=manager)
tools = [
Tool("Search", lambda x: x, "Useful for searching"),
]
agent = initialize_agent(
tools,
fake_llm,
agent="zero-shot-react-description",
callback_manager=manager,
)
output = agent.run("when was langchain made")
assert output == "curses foiled again"
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run
assert handler.starts == 0
assert handler.ends == 0
assert handler.errors == 0

@ -0,0 +1 @@
"""Tests for correct functioning of callbacks."""

@ -0,0 +1,66 @@
"""A fake callback handler for testing purposes."""
from typing import Any, Dict, List
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
class FakeCallbackHandler(BaseCallbackHandler):
"""Fake callback handler for testing."""
starts: int = 0
ends: int = 0
errors: int = 0
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
self.starts += 1
def on_llm_end(
self,
response: LLMResult,
) -> None:
"""Run when LLM ends running."""
self.ends += 1
def on_llm_error(self, error: Exception) -> None:
"""Run when LLM errors."""
self.errors += 1
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
self.starts += 1
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
"""Run when chain ends running."""
self.ends += 1
def on_chain_error(self, error: Exception) -> None:
"""Run when chain errors."""
self.errors += 1
def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
) -> None:
"""Run when tool starts running."""
self.starts += 1
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
self.ends += 1
def on_tool_error(self, error: Exception) -> None:
"""Run when tool errors."""
self.errors += 1
def on_text(self, text: str, **kwargs: Any) -> None:
"""Run when agent is ending."""
self.ends += 1
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run when agent ends running."""
self.ends += 1

@ -0,0 +1,47 @@
"""Test CallbackManager."""
from langchain.callbacks.base import BaseCallbackManager, CallbackManager
from langchain.callbacks.shared import SharedCallbackManager
from langchain.schema import AgentAction, LLMResult
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
def _test_callback_manager(
manager: BaseCallbackManager, *handlers: FakeCallbackHandler
) -> None:
"""Test the CallbackManager."""
manager.on_llm_start({}, [])
manager.on_llm_end(LLMResult(generations=[]))
manager.on_llm_error(Exception())
manager.on_chain_start({"name": "foo"}, {})
manager.on_chain_end({})
manager.on_chain_error(Exception())
manager.on_tool_start({}, AgentAction("", "", ""))
manager.on_tool_end("")
manager.on_tool_error(Exception())
for handler in handlers:
assert handler.starts == 3
assert handler.ends == 3
assert handler.errors == 3
def test_callback_manager() -> None:
"""Test the CallbackManager."""
handler1 = FakeCallbackHandler()
handler2 = FakeCallbackHandler()
manager = CallbackManager(handlers=[handler1, handler2])
_test_callback_manager(manager, handler1, handler2)
def test_shared_callback_manager() -> None:
"""Test the SharedCallbackManager."""
manager1 = SharedCallbackManager()
manager2 = SharedCallbackManager()
assert manager1 is manager2
handler1 = FakeCallbackHandler()
handler2 = FakeCallbackHandler()
manager1.add_handler(handler1)
manager2.add_handler(handler2)
_test_callback_manager(manager1, handler1, handler2)

@ -4,7 +4,9 @@ from typing import Any, Dict, List
import pytest
from pydantic import BaseModel
from langchain.callbacks.base import CallbackManager
from langchain.chains.base import Chain, Memory
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
class FakeMemory(Memory, BaseModel):
@ -133,3 +135,31 @@ def test_run_arg_with_memory() -> None:
"""Test run method works when arg is passed."""
chain = FakeChain(the_input_keys=["foo", "baz"], memory=FakeMemory())
chain.run("bar")
def test_run_with_callback() -> None:
"""Test run method works when callback manager is passed."""
handler = FakeCallbackHandler()
chain = FakeChain(
callback_manager=CallbackManager(handlers=[handler]), verbose=True
)
output = chain.run("bar")
assert output == "baz"
assert handler.starts == 1
assert handler.ends == 1
assert handler.errors == 0
def test_run_with_callback_not_verbose() -> None:
"""Test run method works when callback manager is passed and not verbose."""
import langchain
langchain.verbose = False
handler = FakeCallbackHandler()
chain = FakeChain(callback_manager=CallbackManager(handlers=[handler]))
output = chain.run("bar")
assert output == "baz"
assert handler.starts == 0
assert handler.ends == 0
assert handler.errors == 0

@ -0,0 +1,30 @@
"""Test LLM callbacks."""
from langchain.callbacks.base import CallbackManager
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
from tests.unit_tests.llms.fake_llm import FakeLLM
def test_llm_with_callbacks() -> None:
"""Test LLM callbacks."""
handler = FakeCallbackHandler()
llm = FakeLLM(callback_manager=CallbackManager(handlers=[handler]), verbose=True)
output = llm("foo")
assert output == "foo"
assert handler.starts == 1
assert handler.ends == 1
assert handler.errors == 0
def test_llm_with_callbacks_not_verbose() -> None:
"""Test LLM callbacks but not verbose."""
import langchain
langchain.verbose = False
handler = FakeCallbackHandler()
llm = FakeLLM(callback_manager=CallbackManager(handlers=[handler]))
output = llm("foo")
assert output == "foo"
assert handler.starts == 0
assert handler.ends == 0
assert handler.errors == 0

@ -7,8 +7,8 @@ from pydantic import BaseModel
from langchain.embeddings.base import Embeddings
from langchain.embeddings.hyde.base import HypotheticalDocumentEmbedder
from langchain.embeddings.hyde.prompts import PROMPT_MAP
from langchain.llms.base import BaseLLM, LLMResult
from langchain.schema import Generation
from langchain.llms.base import BaseLLM
from langchain.schema import Generation, LLMResult
class FakeEmbeddings(Embeddings):

Loading…
Cancel
Save