Compare commits

...

22 Commits

Author SHA1 Message Date
Harrison Chase
d3a6387ab9 cr 2023-01-03 14:21:22 -08:00
Harrison Chase
3efee27e56 cr 2023-01-03 14:12:43 -08:00
Harrison Chase
7d0b1cafd7 cr 2023-01-03 14:00:44 -08:00
Harrison Chase
6fe6af7048 cr 2023-01-03 12:39:57 -08:00
Harrison Chase
6953c2e707 cr 2023-01-03 12:32:20 -08:00
Harrison Chase
3086b752a3 Merge branch 'ankush/callbackhandler' into harrison/callback-updates 2023-01-03 12:31:45 -08:00
Harrison Chase
03e3cd468b cr 2023-01-03 12:31:29 -08:00
Harrison Chase
7eb33690a9 callback updates 2023-01-03 12:26:52 -08:00
Harrison Chase
23b8cfc123 cr 2023-01-03 12:14:25 -08:00
Ankush Gola
db5c8e0c42
explicitly set global verbosity flag in unit tests that test callbacks (#502)
If another file anywhere in `unit_tests` sets `langchain.verbose =
True`, it messes up all of the tests that check for no callbacks because
the `None` for verbose gets overridden by the global verbosity flag. By
explicitly setting it in unit tests, we bypass that potential issue
2023-01-02 14:06:46 -08:00
Ankush Gola
aae3609aa8
fix ontext comment (#500) 2022-12-31 06:06:16 -05:00
Harrison Chase
a3d2a2ec2a
Harrison/streamlit handler (#488)
also add a set handler method

usage is:
```
from langchain.callbacks.streamlit import StreamlitCallbackHandler
import langchain
langchain.set_handler(StreamlitCallbackHandler())
```

produces the following output


![Screen Shot 2022-12-29 at 10 50 33
PM](https://user-images.githubusercontent.com/11986836/210032762-7f53fffa-cb2f-4dac-af39-7d4cf81e55dd.png)

only works for agent stuff currently
2022-12-30 14:43:28 -05:00
Harrison Chase
45d6de177e
remove logger (#491)
remove old logging class (no longer used anyways)
2022-12-30 13:57:47 -05:00
Harrison Chase
175a248506
Harrison/get rid of prints (#490)
deprecate all prints in favor of callback_manager.on_text (open to
better naming)
2022-12-30 13:55:30 -05:00
Harrison Chase
b902bddb8a
fix verbosity (#496)
1. remove verbose from someplace it didnt relaly belong
2. everywhere else, make verbose Optional[bool] with default to None
3. make base classes accept None, and then look up globabl verbosity if
thats the case
2022-12-30 13:26:41 -05:00
Ankush Gola
164806a844
quick comment fixes (#494) 2022-12-30 07:20:13 -05:00
Harrison Chase
e3edd74eab
switch up defaults (#485)
i kinda like this just because we call `self.callback_manager` so many
times, and thats nicer than `self._get_callback_manager()`?
2022-12-29 23:07:55 -05:00
Harrison Chase
52490e2dcd
add explicit agent end method (#486) 2022-12-29 22:23:15 -05:00
Harrison Chase
7e36f28e78
Harrison/not verbose (#487) 2022-12-29 22:23:02 -05:00
Harrison Chase
5d43246694
WIP: stdout callback (#479)
first pass at stdout callback

for the most part, went pretty smoothly. aside from the code here, here
are some comments/observations.

1. 
should somehow default to stdouthandler so i dont have to do 
```
from langchain.callbacks import get_callback_manager
from langchain.callbacks.stdout import StdOutCallbackHandler

get_callback_manager().add_handler(StdOutCallbackHandler())
```

2. I kept around the verbosity flag. 1) this is pretty important for
getting the stdout to look good for agents (and other things). 2) I
actually added this for LLM class since it didn't have it.

3. The only part that isn't basically perfectly moved over is the end of
the agent run. Here's a screenshot of the new stdout tracing
![Screen Shot 2022-12-29 at 4 03 50
PM](https://user-images.githubusercontent.com/11986836/210011538-6a74551a-2e61-437b-98d3-674212dede56.png)

Noticing it is missing logging of the final thought, eg before this is
what it looked like
![Screen Shot 2022-12-29 at 4 13 07
PM](https://user-images.githubusercontent.com/11986836/210011635-de68b3f5-e2b0-4cd3-9f1a-3afe970a8716.png)

The reason its missing is that this was previously logged as part of
agent end (lines 205 and 206)

this is probably only relevant for the std out logger? any thoughts for
how to get it back in?
2022-12-29 21:34:47 -05:00
Ankush Gola
36922318d3
allow for optional CallbackManager in LLM, Chain, and Agent (#482) 2022-12-29 17:30:31 -08:00
Ankush Gola
46b31626b5
Add BaseCallbackHandler and CallbackManager (#476) 2022-12-29 15:11:37 -05:00
31 changed files with 946 additions and 173 deletions

View File

@ -51,7 +51,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 2,
"id": "0728f0d9", "id": "0728f0d9",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -69,7 +69,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 3,
"id": "ba4e7618", "id": "ba4e7618",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -87,7 +87,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 4,
"id": "03208e2b", "id": "03208e2b",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -105,7 +105,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 5,
"id": "244ee75c", "id": "244ee75c",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -131,7 +131,7 @@
"\u001b[0m\n", "\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: Harry Styles is Olivia Wilde's boyfriend and his current age raised to the 0.23 power is 2.1520202182226886.\u001b[0m\n", "Final Answer: Harry Styles is Olivia Wilde's boyfriend and his current age raised to the 0.23 power is 2.1520202182226886.\u001b[0m\n",
"\u001b[1m> Finished AgentExecutor chain.\u001b[0m\n" "\u001b[1m> Finished chain.\u001b[0m\n"
] ]
}, },
{ {
@ -140,7 +140,7 @@
"\"Harry Styles is Olivia Wilde's boyfriend and his current age raised to the 0.23 power is 2.1520202182226886.\"" "\"Harry Styles is Olivia Wilde's boyfriend and his current age raised to the 0.23 power is 2.1520202182226886.\""
] ]
}, },
"execution_count": 6, "execution_count": 5,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -166,12 +166,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.0" "version": "3.10.9"
},
"vscode": {
"interpreter": {
"hash": "b1677b440931f40d89ef8be7bf03acb108ce003de0ac9b18e8d43753ea2e7103"
}
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -4,6 +4,7 @@ from typing import Optional
from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain
from langchain.cache import BaseCache from langchain.cache import BaseCache
from langchain.callbacks import set_default_callback_manager, set_handler
from langchain.chains import ( from langchain.chains import (
ConversationChain, ConversationChain,
LLMBashChain, LLMBashChain,
@ -19,7 +20,6 @@ from langchain.chains import (
from langchain.docstore import InMemoryDocstore, Wikipedia from langchain.docstore import InMemoryDocstore, Wikipedia
from langchain.llms import Cohere, HuggingFaceHub, OpenAI from langchain.llms import Cohere, HuggingFaceHub, OpenAI
from langchain.llms.huggingface_pipeline import HuggingFacePipeline from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.logger import BaseLogger, StdOutLogger
from langchain.prompts import ( from langchain.prompts import (
BasePromptTemplate, BasePromptTemplate,
FewShotPromptTemplate, FewShotPromptTemplate,
@ -31,9 +31,9 @@ from langchain.sql_database import SQLDatabase
from langchain.utilities.google_search import GoogleSearchAPIWrapper from langchain.utilities.google_search import GoogleSearchAPIWrapper
from langchain.vectorstores import FAISS, ElasticVectorSearch from langchain.vectorstores import FAISS, ElasticVectorSearch
logger: BaseLogger = StdOutLogger()
verbose: bool = False verbose: bool = False
llm_cache: Optional[BaseCache] = None llm_cache: Optional[BaseCache] = None
set_default_callback_manager()
__all__ = [ __all__ = [
"LLMChain", "LLMChain",
@ -65,4 +65,5 @@ __all__ = [
"VectorDBQAWithSourcesChain", "VectorDBQAWithSourcesChain",
"QAWithSourcesChain", "QAWithSourcesChain",
"PALChain", "PALChain",
"set_handler",
] ]

View File

@ -7,8 +7,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel, root_validator from pydantic import BaseModel, root_validator
import langchain
from langchain.agents.tools import Tool from langchain.agents.tools import Tool
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.input import get_color_mapping from langchain.input import get_color_mapping
@ -46,7 +46,7 @@ class Agent(BaseModel):
def plan( def plan(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
) -> Union[AgentFinish, AgentAction]: ) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do. """Given input, decided what to do.
Args: Args:
@ -132,10 +132,19 @@ class Agent(BaseModel):
pass pass
@classmethod @classmethod
def from_llm_and_tools(cls, llm: BaseLLM, tools: List[Tool]) -> Agent: def from_llm_and_tools(
cls,
llm: BaseLLM,
tools: List[Tool],
callback_manager: Optional[BaseCallbackManager] = None,
) -> Agent:
"""Construct an agent from an LLM and tools.""" """Construct an agent from an LLM and tools."""
cls._validate_tools(tools) cls._validate_tools(tools)
llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools)) llm_chain = LLMChain(
llm=llm,
prompt=cls.create_prompt(tools),
callback_manager=callback_manager,
)
return cls(llm_chain=llm_chain) return cls(llm_chain=llm_chain)
def return_stopped_response( def return_stopped_response(
@ -194,10 +203,16 @@ class AgentExecutor(Chain, BaseModel):
@classmethod @classmethod
def from_agent_and_tools( def from_agent_and_tools(
cls, agent: Agent, tools: List[Tool], **kwargs: Any cls,
agent: Agent,
tools: List[Tool],
callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any,
) -> AgentExecutor: ) -> AgentExecutor:
"""Create from agent and tools.""" """Create from agent and tools."""
return cls(agent=agent, tools=tools, **kwargs) return cls(
agent=agent, tools=tools, callback_manager=callback_manager, **kwargs
)
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> List[str]:
@ -244,24 +259,31 @@ class AgentExecutor(Chain, BaseModel):
# If the tool chosen is the finishing tool, then we end and return. # If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish): if isinstance(output, AgentFinish):
if self.verbose: if self.verbose:
langchain.logger.log_agent_end(output, color="green") self.callback_manager.on_agent_end(output, color="green")
final_output = output.return_values final_output = output.return_values
if self.return_intermediate_steps: if self.return_intermediate_steps:
final_output["intermediate_steps"] = intermediate_steps final_output["intermediate_steps"] = intermediate_steps
return final_output return final_output
if self.verbose:
langchain.logger.log_agent_action(output, color="green")
# And then we lookup the tool # And then we lookup the tool
if output.tool in name_to_tool_map: if output.tool in name_to_tool_map:
chain = name_to_tool_map[output.tool] chain = name_to_tool_map[output.tool]
if self.verbose:
self.callback_manager.on_tool_start(
{"name": str(chain)[:60] + "..."}, output, color="green"
)
# We then call the tool on the tool input to get an observation # We then call the tool on the tool input to get an observation
observation = chain(output.tool_input) observation = chain(output.tool_input)
color = color_mapping[output.tool] color = color_mapping[output.tool]
else: else:
if self.verbose:
self.callback_manager.on_tool_start(
{"name": "N/A"}, output, color="green"
)
observation = f"{output.tool} is not a valid tool, try another one." observation = f"{output.tool} is not a valid tool, try another one."
color = None color = None
if self.verbose: if self.verbose:
langchain.logger.log_agent_observation( self.callback_manager.on_tool_end(
observation, observation,
color=color, color=color,
observation_prefix=self.agent.observation_prefix, observation_prefix=self.agent.observation_prefix,
@ -272,6 +294,8 @@ class AgentExecutor(Chain, BaseModel):
output = self.agent.return_stopped_response( output = self.agent.return_stopped_response(
self.early_stopping_method, intermediate_steps, **inputs self.early_stopping_method, intermediate_steps, **inputs
) )
if self.verbose:
self.callback_manager.on_agent_end(output, color="green")
final_output = output.return_values final_output = output.return_values
if self.return_intermediate_steps: if self.return_intermediate_steps:
final_output["intermediate_steps"] = intermediate_steps final_output["intermediate_steps"] = intermediate_steps

View File

@ -1,11 +1,12 @@
"""Load agent.""" """Load agent."""
from typing import Any, List from typing import Any, List, Optional
from langchain.agents.agent import AgentExecutor from langchain.agents.agent import AgentExecutor
from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.react.base import ReActDocstoreAgent from langchain.agents.react.base import ReActDocstoreAgent
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
from langchain.agents.tools import Tool from langchain.agents.tools import Tool
from langchain.callbacks.base import BaseCallbackManager
from langchain.llms.base import BaseLLM from langchain.llms.base import BaseLLM
AGENT_TO_CLASS = { AGENT_TO_CLASS = {
@ -19,6 +20,7 @@ def initialize_agent(
tools: List[Tool], tools: List[Tool],
llm: BaseLLM, llm: BaseLLM,
agent: str = "zero-shot-react-description", agent: str = "zero-shot-react-description",
callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any, **kwargs: Any,
) -> AgentExecutor: ) -> AgentExecutor:
"""Load agent given tools and LLM. """Load agent given tools and LLM.
@ -28,6 +30,8 @@ def initialize_agent(
llm: Language model to use as the agent. llm: Language model to use as the agent.
agent: The agent to use. Valid options are: agent: The agent to use. Valid options are:
`zero-shot-react-description`, `react-docstore`, `self-ask-with-search`. `zero-shot-react-description`, `react-docstore`, `self-ask-with-search`.
callback_manager: CallbackManager to use. Global callback manager is used if
not provided. Defaults to None.
**kwargs: Additional key word arguments to pass to the agent. **kwargs: Additional key word arguments to pass to the agent.
Returns: Returns:
@ -39,5 +43,12 @@ def initialize_agent(
f"Valid types are: {AGENT_TO_CLASS.keys()}." f"Valid types are: {AGENT_TO_CLASS.keys()}."
) )
agent_cls = AGENT_TO_CLASS[agent] agent_cls = AGENT_TO_CLASS[agent]
agent_obj = agent_cls.from_llm_and_tools(llm, tools) agent_obj = agent_cls.from_llm_and_tools(
return AgentExecutor.from_agent_and_tools(agent=agent_obj, tools=tools, **kwargs) llm, tools, callback_manager=callback_manager
)
return AgentExecutor.from_agent_and_tools(
agent=agent_obj,
tools=tools,
callback_manager=callback_manager,
**kwargs,
)

View File

@ -0,0 +1,20 @@
"""Callback handlers that allow listening to events in LangChain."""
from langchain.callbacks.base import BaseCallbackHandler, BaseCallbackManager
from langchain.callbacks.shared import SharedCallbackManager
from langchain.callbacks.stdout import StdOutCallbackHandler
def get_callback_manager() -> BaseCallbackManager:
"""Return the shared callback manager."""
return SharedCallbackManager()
def set_handler(handler: BaseCallbackHandler) -> None:
"""Set handler."""
callback = get_callback_manager()
callback.set_handler(handler)
def set_default_callback_manager() -> None:
"""Set default callback manager."""
set_handler(StdOutCallbackHandler())

177
langchain/callbacks/base.py Normal file
View File

@ -0,0 +1,177 @@
"""Base callback handler that can be used to handle callbacks from langchain."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from pydantic import BaseModel
from langchain.schema import AgentAction, AgentFinish, LLMResult
class BaseCallbackHandler(BaseModel, ABC):
"""Base callback handler that can be used to handle callbacks from langchain."""
ignore_llm: bool = False
ignore_chain: bool = False
ignore_agent: bool = False
@abstractmethod
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
@abstractmethod
def on_llm_end(
self,
response: LLMResult,
) -> None:
"""Run when LLM ends running."""
@abstractmethod
def on_llm_error(self, error: Exception) -> None:
"""Run when LLM errors."""
@abstractmethod
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
@abstractmethod
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
"""Run when chain ends running."""
@abstractmethod
def on_chain_error(self, error: Exception) -> None:
"""Run when chain errors."""
@abstractmethod
def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
) -> None:
"""Run when tool starts running."""
@abstractmethod
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
@abstractmethod
def on_tool_error(self, error: Exception) -> None:
"""Run when tool errors."""
@abstractmethod
def on_text(self, text: str, **kwargs: Any) -> None:
"""Run on arbitrary text."""
@abstractmethod
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end."""
class BaseCallbackManager(BaseCallbackHandler, ABC):
"""Base callback manager that can be used to handle callbacks from LangChain."""
@abstractmethod
def add_handler(self, callback: BaseCallbackHandler) -> None:
"""Add a handler to the callback manager."""
@abstractmethod
def remove_handler(self, handler: BaseCallbackHandler) -> None:
"""Remove a handler from the callback manager."""
@abstractmethod
def set_handler(self, handler: BaseCallbackHandler) -> None:
"""Set handler as the only handler on the callback manager."""
class CallbackManager(BaseCallbackManager):
"""Callback manager that can be used to handle callbacks from langchain."""
handlers: List[BaseCallbackHandler]
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
for handler in self.handlers:
if not handler.ignore_llm:
handler.on_llm_start(serialized, prompts, **kwargs)
def on_llm_end(
self,
response: LLMResult,
) -> None:
"""Run when LLM ends running."""
for handler in self.handlers:
if not handler.ignore_llm:
handler.on_llm_end(response)
def on_llm_error(self, error: Exception) -> None:
"""Run when LLM errors."""
for handler in self.handlers:
if not handler.ignore_llm:
handler.on_llm_error(error)
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
for handler in self.handlers:
if not handler.ignore_chain:
handler.on_chain_start(serialized, inputs, **kwargs)
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
"""Run when chain ends running."""
for handler in self.handlers:
if not handler.ignore_chain:
handler.on_chain_end(outputs)
def on_chain_error(self, error: Exception) -> None:
"""Run when chain errors."""
for handler in self.handlers:
if not handler.ignore_chain:
handler.on_chain_error(error)
def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
) -> None:
"""Run when tool starts running."""
for handler in self.handlers:
if not handler.ignore_agent:
handler.on_tool_start(serialized, action, **kwargs)
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
for handler in self.handlers:
if not handler.ignore_agent:
handler.on_tool_end(output, **kwargs)
def on_tool_error(self, error: Exception) -> None:
"""Run when tool errors."""
for handler in self.handlers:
if not handler.ignore_agent:
handler.on_tool_error(error)
def on_text(self, text: str, **kwargs: Any) -> None:
"""Run on additional input from chains and agents."""
for handler in self.handlers:
handler.on_text(text, **kwargs)
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end."""
for handler in self.handlers:
if not handler.ignore_agent:
handler.on_agent_end(finish, **kwargs)
def add_handler(self, handler: BaseCallbackHandler) -> None:
"""Add a handler to the callback manager."""
self.handlers.append(handler)
def remove_handler(self, handler: BaseCallbackHandler) -> None:
"""Remove a handler from the callback manager."""
self.handlers.remove(handler)
def set_handler(self, handler: BaseCallbackHandler) -> None:
"""Set handler as the only handler on the callback manager."""
self.handlers = [handler]

View File

@ -0,0 +1,114 @@
"""A shared CallbackManager."""
import threading
from typing import Any, Dict, List
from langchain.callbacks.base import (
BaseCallbackHandler,
BaseCallbackManager,
CallbackManager,
)
from langchain.schema import AgentAction, AgentFinish, LLMResult
class Singleton:
"""A thread-safe singleton class that can be inherited from."""
_instance = None
_lock = threading.Lock()
def __new__(cls) -> Any:
"""Create a new shared instance of the class."""
if cls._instance is None:
with cls._lock:
# Another thread could have created the instance
# before we acquired the lock. So check that the
# instance is still nonexistent.
if not cls._instance:
cls._instance = super().__new__(cls)
return cls._instance
class SharedCallbackManager(Singleton, BaseCallbackManager):
"""A thread-safe singleton CallbackManager."""
_callback_manager: CallbackManager = CallbackManager(handlers=[])
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
with self._lock:
self._callback_manager.on_llm_start(serialized, prompts, **kwargs)
def on_llm_end(
self,
response: LLMResult,
) -> None:
"""Run when LLM ends running."""
with self._lock:
self._callback_manager.on_llm_end(response)
def on_llm_error(self, error: Exception) -> None:
"""Run when LLM errors."""
with self._lock:
self._callback_manager.on_llm_error(error)
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
with self._lock:
self._callback_manager.on_chain_start(serialized, inputs, **kwargs)
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
"""Run when chain ends running."""
with self._lock:
self._callback_manager.on_chain_end(outputs)
def on_chain_error(self, error: Exception) -> None:
"""Run when chain errors."""
with self._lock:
self._callback_manager.on_chain_error(error)
def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
) -> None:
"""Run when tool starts running."""
with self._lock:
self._callback_manager.on_tool_start(serialized, action, **kwargs)
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
with self._lock:
self._callback_manager.on_tool_end(output, **kwargs)
def on_tool_error(self, error: Exception) -> None:
"""Run when tool errors."""
with self._lock:
self._callback_manager.on_tool_error(error)
def on_text(self, text: str, **kwargs: Any) -> None:
"""Run on arbitrary text."""
with self._lock:
self._callback_manager.on_text(text, **kwargs)
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end."""
with self._lock:
self._callback_manager.on_agent_end(finish, **kwargs)
def add_handler(self, callback: BaseCallbackHandler) -> None:
"""Add a callback to the callback manager."""
with self._lock:
self._callback_manager.add_handler(callback)
def remove_handler(self, callback: BaseCallbackHandler) -> None:
"""Remove a callback from the callback manager."""
with self._lock:
self._callback_manager.remove_handler(callback)
def set_handler(self, handler: BaseCallbackHandler) -> None:
"""Set handler as the only handler on the callback manager."""
with self._lock:
self._callback_manager.handlers = [handler]

View File

@ -0,0 +1,84 @@
"""Callback Handler that prints to std out."""
from typing import Any, Dict, List, Optional
from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text
from langchain.schema import AgentAction, AgentFinish, LLMResult
class StdOutCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Print out the prompts."""
print("Prompts after formatting:")
for prompt in prompts:
print_text(prompt, color="green", end="\n")
def on_llm_end(self, response: LLMResult) -> None:
"""Do nothing."""
pass
def on_llm_error(self, error: Exception) -> None:
"""Do nothing."""
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
class_name = serialized["name"]
print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m")
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
"""Print out that we finished a chain."""
print("\n\033[1m> Finished chain.\033[0m")
def on_chain_error(self, error: Exception) -> None:
"""Do nothing."""
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
action: AgentAction,
color: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Print out the log in specified color."""
print_text(action.log, color=color)
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""If not the final action, print out observation."""
print_text(f"\n{observation_prefix}")
print_text(output, color=color)
print_text(f"\n{llm_prefix}")
def on_tool_error(self, error: Exception) -> None:
"""Do nothing."""
pass
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Optional[str],
) -> None:
"""Run when agent ends."""
print_text(text, color=color, end=end)
def on_agent_end(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
"""Run on agent end."""
print_text(finish.log, color=color, end="\n")

View File

@ -0,0 +1,77 @@
"""Callback Handler that logs to streamlit."""
from typing import Any, Dict, List, Optional
import streamlit as st
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
class StreamlitCallbackHandler(BaseCallbackHandler):
"""Callback Handler that logs to streamlit."""
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Print out the prompts."""
st.write("Prompts after formatting:")
for prompt in prompts:
st.write(prompt)
def on_llm_end(self, response: LLMResult) -> None:
"""Do nothing."""
pass
def on_llm_error(self, error: Exception) -> None:
"""Do nothing."""
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
class_name = serialized["name"]
st.write(f"Entering new {class_name} chain...")
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
"""Print out that we finished a chain."""
st.write("Finished chain.")
def on_chain_error(self, error: Exception) -> None:
"""Do nothing."""
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
action: AgentAction,
**kwargs: Any,
) -> None:
"""Print out the log in specified color."""
# st.write requires two spaces before a newline to render it
st.markdown(action.log.replace("\n", " \n"))
def on_tool_end(
self,
output: str,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""If not the final action, print out observation."""
st.write(f"{observation_prefix}{output}")
st.write(llm_prefix)
def on_tool_error(self, error: Exception) -> None:
"""Do nothing."""
pass
def on_text(self, text: str, **kwargs: Any) -> None:
"""Run on text."""
# st.write requires two spaces before a newline to render it
st.write(text.replace("\n", " \n"))
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end."""
# st.write requires two spaces before a newline to render it
st.write(finish.log.replace("\n", " \n"))

View File

@ -8,7 +8,6 @@ from pydantic import BaseModel, root_validator
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.input import print_text
from langchain.llms.base import BaseLLM from langchain.llms.base import BaseLLM
from langchain.requests import RequestsWrapper from langchain.requests import RequestsWrapper
@ -67,10 +66,10 @@ class APIChain(Chain, BaseModel):
question=question, api_docs=self.api_docs question=question, api_docs=self.api_docs
) )
if self.verbose: if self.verbose:
print_text(api_url, color="green", end="\n") self.callback_manager.on_text(api_url, color="green", end="\n")
api_response = self.requests_wrapper.run(api_url) api_response = self.requests_wrapper.run(api_url)
if self.verbose: if self.verbose:
print_text(api_response, color="yellow", end="\n") self.callback_manager.on_text(api_response, color="yellow", end="\n")
answer = self.api_answer_chain.predict( answer = self.api_answer_chain.predict(
question=question, question=question,
api_docs=self.api_docs, api_docs=self.api_docs,

View File

@ -2,9 +2,11 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Extra, Field from pydantic import BaseModel, Extra, Field, validator
import langchain import langchain
from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
class Memory(BaseModel, ABC): class Memory(BaseModel, ABC):
@ -42,9 +44,36 @@ class Chain(BaseModel, ABC):
"""Base interface that all chains should implement.""" """Base interface that all chains should implement."""
memory: Optional[Memory] = None memory: Optional[Memory] = None
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
verbose: bool = Field(
default_factory=_get_verbosity
) # Whether to print the response text
verbose: bool = Field(default_factory=_get_verbosity) class Config:
"""Whether to print out response text.""" """Configuration for this pydantic object."""
arbitrary_types_allowed = True
@validator("callback_manager", pre=True, always=True)
def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager]
) -> BaseCallbackManager:
"""If callback manager is None, set it.
This allows users to pass in None as callback manager, which is a nice UX.
"""
return callback_manager or get_callback_manager()
@validator("verbose", pre=True, always=True)
def set_verbose(cls, verbose: Optional[bool]) -> bool:
"""If verbose is None, set it.
This allows users to pass in None as verbose to access the global setting.
"""
if verbose is None:
return _get_verbosity()
else:
return verbose
@property @property
@abstractmethod @abstractmethod
@ -106,12 +135,12 @@ class Chain(BaseModel, ABC):
inputs = dict(inputs, **external_context) inputs = dict(inputs, **external_context)
self._validate_inputs(inputs) self._validate_inputs(inputs)
if self.verbose: if self.verbose:
print( self.callback_manager.on_chain_start(
f"\n\n\033[1m> Entering new {self.__class__.__name__} chain...\033[0m" {"name": self.__class__.__name__}, inputs
) )
outputs = self._call(inputs) outputs = self._call(inputs)
if self.verbose: if self.verbose:
print(f"\n\033[1m> Finished {self.__class__.__name__} chain.\033[0m") self.callback_manager.on_chain_end(outputs)
self._validate_outputs(outputs) self._validate_outputs(outputs)
if self.memory is not None: if self.memory is not None:
self.memory.save_context(inputs, outputs) self.memory.save_context(inputs, outputs)

View File

@ -3,10 +3,10 @@ from typing import Any, Dict, List, Sequence, Union
from pydantic import BaseModel, Extra from pydantic import BaseModel, Extra
import langchain
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.llms.base import BaseLLM, LLMResult from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.schema import LLMResult
class LLMChain(Chain, BaseModel): class LLMChain(Chain, BaseModel):
@ -60,8 +60,6 @@ class LLMChain(Chain, BaseModel):
for inputs in input_list: for inputs in input_list:
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables} selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
prompt = self.prompt.format(**selected_inputs) prompt = self.prompt.format(**selected_inputs)
if self.verbose:
langchain.logger.log_llm_inputs(selected_inputs, prompt)
if "stop" in inputs and inputs["stop"] != stop: if "stop" in inputs and inputs["stop"] != stop:
raise ValueError( raise ValueError(
"If `stop` is present in any inputs, should be present in all." "If `stop` is present in any inputs, should be present in all."
@ -77,8 +75,6 @@ class LLMChain(Chain, BaseModel):
for generation in response.generations: for generation in response.generations:
# Get the text of the top generated string. # Get the text of the top generated string.
response_str = generation[0].text response_str = generation[0].text
if self.verbose:
langchain.logger.log_llm_response(response_str)
outputs.append({self.output_key: response_str}) outputs.append({self.output_key: response_str})
return outputs return outputs

View File

@ -6,7 +6,6 @@ from pydantic import BaseModel, Extra
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.llm_bash.prompt import PROMPT from langchain.chains.llm_bash.prompt import PROMPT
from langchain.input import print_text
from langchain.llms.base import BaseLLM from langchain.llms.base import BaseLLM
from langchain.utilities.bash import BashProcess from langchain.utilities.bash import BashProcess
@ -52,11 +51,11 @@ class LLMBashChain(Chain, BaseModel):
llm_executor = LLMChain(prompt=PROMPT, llm=self.llm) llm_executor = LLMChain(prompt=PROMPT, llm=self.llm)
bash_executor = BashProcess() bash_executor = BashProcess()
if self.verbose: if self.verbose:
print_text(inputs[self.input_key]) self.callback_manager.on_text(inputs[self.input_key])
t = llm_executor.predict(question=inputs[self.input_key]) t = llm_executor.predict(question=inputs[self.input_key])
if self.verbose: if self.verbose:
print_text(t, color="green") self.callback_manager.on_text(t, color="green")
t = t.strip() t = t.strip()
if t.startswith("```bash"): if t.startswith("```bash"):
@ -69,8 +68,8 @@ class LLMBashChain(Chain, BaseModel):
output = bash_executor.run(command_list) output = bash_executor.run(command_list)
if self.verbose: if self.verbose:
print_text("\nAnswer: ") self.callback_manager.on_text("\nAnswer: ")
print_text(output, color="yellow") self.callback_manager.on_text(output, color="yellow")
else: else:
raise ValueError(f"unknown format from LLM: {t}") raise ValueError(f"unknown format from LLM: {t}")

View File

@ -6,7 +6,6 @@ from pydantic import BaseModel, Extra
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.llm_math.prompt import PROMPT from langchain.chains.llm_math.prompt import PROMPT
from langchain.input import print_text
from langchain.llms.base import BaseLLM from langchain.llms.base import BaseLLM
from langchain.python import PythonREPL from langchain.python import PythonREPL
@ -52,17 +51,17 @@ class LLMMathChain(Chain, BaseModel):
llm_executor = LLMChain(prompt=PROMPT, llm=self.llm) llm_executor = LLMChain(prompt=PROMPT, llm=self.llm)
python_executor = PythonREPL() python_executor = PythonREPL()
if self.verbose: if self.verbose:
print_text(inputs[self.input_key]) self.callback_manager.on_text(inputs[self.input_key])
t = llm_executor.predict(question=inputs[self.input_key], stop=["```output"]) t = llm_executor.predict(question=inputs[self.input_key], stop=["```output"])
if self.verbose: if self.verbose:
print_text(t, color="green") self.callback_manager.on_text(t, color="green")
t = t.strip() t = t.strip()
if t.startswith("```python"): if t.startswith("```python"):
code = t[9:-4] code = t[9:-4]
output = python_executor.run(code) output = python_executor.run(code)
if self.verbose: if self.verbose:
print_text("\nAnswer: ") self.callback_manager.on_text("\nAnswer: ")
print_text(output, color="yellow") self.callback_manager.on_text(output, color="yellow")
answer = "Answer: " + output answer = "Answer: " + output
elif t.startswith("Answer:"): elif t.startswith("Answer:"):
answer = t answer = t

View File

@ -12,7 +12,6 @@ from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT
from langchain.chains.pal.math_prompt import MATH_PROMPT from langchain.chains.pal.math_prompt import MATH_PROMPT
from langchain.input import print_text
from langchain.llms.base import BaseLLM from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.python import PythonREPL from langchain.python import PythonREPL
@ -53,7 +52,7 @@ class PALChain(Chain, BaseModel):
llm_chain = LLMChain(llm=self.llm, prompt=self.prompt) llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)
code = llm_chain.predict(stop=[self.stop], **inputs) code = llm_chain.predict(stop=[self.stop], **inputs)
if self.verbose: if self.verbose:
print_text(code, color="green", end="\n") self.callback_manager.on_text(code, color="green", end="\n")
repl = PythonREPL() repl = PythonREPL()
res = repl.run(code + f"\n{self.get_answer_expr}") res = repl.run(code + f"\n{self.get_answer_expr}")
return {self.output_key: res.strip()} return {self.output_key: res.strip()}

View File

@ -26,7 +26,7 @@ def _load_stuff_chain(
llm: BaseLLM, llm: BaseLLM,
prompt: BasePromptTemplate = stuff_prompt.PROMPT, prompt: BasePromptTemplate = stuff_prompt.PROMPT,
document_variable_name: str = "summaries", document_variable_name: str = "summaries",
verbose: bool = False, verbose: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> StuffDocumentsChain: ) -> StuffDocumentsChain:
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
@ -49,7 +49,7 @@ def _load_map_reduce_chain(
collapse_prompt: Optional[BasePromptTemplate] = None, collapse_prompt: Optional[BasePromptTemplate] = None,
reduce_llm: Optional[BaseLLM] = None, reduce_llm: Optional[BaseLLM] = None,
collapse_llm: Optional[BaseLLM] = None, collapse_llm: Optional[BaseLLM] = None,
verbose: bool = False, verbose: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> MapReduceDocumentsChain: ) -> MapReduceDocumentsChain:
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
@ -97,7 +97,7 @@ def _load_refine_chain(
document_variable_name: str = "context_str", document_variable_name: str = "context_str",
initial_response_name: str = "existing_answer", initial_response_name: str = "existing_answer",
refine_llm: Optional[BaseLLM] = None, refine_llm: Optional[BaseLLM] = None,
verbose: bool = False, verbose: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> RefineDocumentsChain: ) -> RefineDocumentsChain:
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
@ -115,7 +115,10 @@ def _load_refine_chain(
def load_qa_with_sources_chain( def load_qa_with_sources_chain(
llm: BaseLLM, chain_type: str = "stuff", verbose: bool = False, **kwargs: Any llm: BaseLLM,
chain_type: str = "stuff",
verbose: Optional[bool] = None,
**kwargs: Any,
) -> BaseCombineDocumentsChain: ) -> BaseCombineDocumentsChain:
"""Load question answering with sources chain. """Load question answering with sources chain.

View File

@ -26,7 +26,7 @@ def _load_stuff_chain(
llm: BaseLLM, llm: BaseLLM,
prompt: BasePromptTemplate = stuff_prompt.PROMPT, prompt: BasePromptTemplate = stuff_prompt.PROMPT,
document_variable_name: str = "context", document_variable_name: str = "context",
verbose: bool = False, verbose: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> StuffDocumentsChain: ) -> StuffDocumentsChain:
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
@ -48,7 +48,7 @@ def _load_map_reduce_chain(
collapse_prompt: Optional[BasePromptTemplate] = None, collapse_prompt: Optional[BasePromptTemplate] = None,
reduce_llm: Optional[BaseLLM] = None, reduce_llm: Optional[BaseLLM] = None,
collapse_llm: Optional[BaseLLM] = None, collapse_llm: Optional[BaseLLM] = None,
verbose: bool = False, verbose: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> MapReduceDocumentsChain: ) -> MapReduceDocumentsChain:
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
@ -94,7 +94,7 @@ def _load_refine_chain(
document_variable_name: str = "context_str", document_variable_name: str = "context_str",
initial_response_name: str = "existing_answer", initial_response_name: str = "existing_answer",
refine_llm: Optional[BaseLLM] = None, refine_llm: Optional[BaseLLM] = None,
verbose: bool = False, verbose: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> RefineDocumentsChain: ) -> RefineDocumentsChain:
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
@ -111,7 +111,10 @@ def _load_refine_chain(
def load_qa_chain( def load_qa_chain(
llm: BaseLLM, chain_type: str = "stuff", verbose: bool = False, **kwargs: Any llm: BaseLLM,
chain_type: str = "stuff",
verbose: Optional[bool] = None,
**kwargs: Any,
) -> BaseCombineDocumentsChain: ) -> BaseCombineDocumentsChain:
"""Load question answering chain. """Load question answering chain.

View File

@ -5,7 +5,7 @@ from typing import Dict, List
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.input import get_color_mapping, print_text from langchain.input import get_color_mapping
class SequentialChain(Chain, BaseModel): class SequentialChain(Chain, BaseModel):
@ -133,5 +133,7 @@ class SimpleSequentialChain(Chain, BaseModel):
if self.strip_outputs: if self.strip_outputs:
_input = _input.strip() _input = _input.strip()
if self.verbose: if self.verbose:
print_text(_input, color=color_mapping[str(i)], end="\n") self.callback_manager.on_text(
_input, color=color_mapping[str(i)], end="\n"
)
return {self.output_key: _input} return {self.output_key: _input}

View File

@ -6,7 +6,6 @@ from pydantic import BaseModel, Extra
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.sql_database.prompt import PROMPT from langchain.chains.sql_database.prompt import PROMPT
from langchain.input import print_text
from langchain.llms.base import BaseLLM from langchain.llms.base import BaseLLM
from langchain.sql_database import SQLDatabase from langchain.sql_database import SQLDatabase
@ -55,7 +54,7 @@ class SQLDatabaseChain(Chain, BaseModel):
llm_chain = LLMChain(llm=self.llm, prompt=PROMPT) llm_chain = LLMChain(llm=self.llm, prompt=PROMPT)
input_text = f"{inputs[self.input_key]} \nSQLQuery:" input_text = f"{inputs[self.input_key]} \nSQLQuery:"
if self.verbose: if self.verbose:
print_text(input_text) self.callback_manager.on_text(input_text)
llm_inputs = { llm_inputs = {
"input": input_text, "input": input_text,
"dialect": self.database.dialect, "dialect": self.database.dialect,
@ -64,15 +63,15 @@ class SQLDatabaseChain(Chain, BaseModel):
} }
sql_cmd = llm_chain.predict(**llm_inputs) sql_cmd = llm_chain.predict(**llm_inputs)
if self.verbose: if self.verbose:
print_text(sql_cmd, color="green") self.callback_manager.on_text(sql_cmd, color="green")
result = self.database.run(sql_cmd) result = self.database.run(sql_cmd)
if self.verbose: if self.verbose:
print_text("\nSQLResult: ") self.callback_manager.on_text("\nSQLResult: ")
print_text(result, color="yellow") self.callback_manager.on_text(result, color="yellow")
print_text("\nAnswer:") self.callback_manager.on_text("\nAnswer:")
input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:" input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:"
llm_inputs["input"] = input_text llm_inputs["input"] = input_text
final_result = llm_chain.predict(**llm_inputs) final_result = llm_chain.predict(**llm_inputs)
if self.verbose: if self.verbose:
print_text(final_result, color="green") self.callback_manager.on_text(final_result, color="green")
return {self.output_key: final_result} return {self.output_key: final_result}

View File

@ -22,7 +22,7 @@ def _load_stuff_chain(
llm: BaseLLM, llm: BaseLLM,
prompt: BasePromptTemplate = stuff_prompt.PROMPT, prompt: BasePromptTemplate = stuff_prompt.PROMPT,
document_variable_name: str = "text", document_variable_name: str = "text",
verbose: bool = False, verbose: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> StuffDocumentsChain: ) -> StuffDocumentsChain:
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
@ -44,7 +44,7 @@ def _load_map_reduce_chain(
collapse_prompt: Optional[BasePromptTemplate] = None, collapse_prompt: Optional[BasePromptTemplate] = None,
reduce_llm: Optional[BaseLLM] = None, reduce_llm: Optional[BaseLLM] = None,
collapse_llm: Optional[BaseLLM] = None, collapse_llm: Optional[BaseLLM] = None,
verbose: bool = False, verbose: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> MapReduceDocumentsChain: ) -> MapReduceDocumentsChain:
map_chain = LLMChain(llm=llm, prompt=map_prompt, verbose=verbose) map_chain = LLMChain(llm=llm, prompt=map_prompt, verbose=verbose)
@ -90,7 +90,7 @@ def _load_refine_chain(
document_variable_name: str = "text", document_variable_name: str = "text",
initial_response_name: str = "existing_answer", initial_response_name: str = "existing_answer",
refine_llm: Optional[BaseLLM] = None, refine_llm: Optional[BaseLLM] = None,
verbose: bool = False, verbose: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> RefineDocumentsChain: ) -> RefineDocumentsChain:
@ -108,7 +108,10 @@ def _load_refine_chain(
def load_summarize_chain( def load_summarize_chain(
llm: BaseLLM, chain_type: str = "stuff", verbose: bool = False, **kwargs: Any llm: BaseLLM,
chain_type: str = "stuff",
verbose: Optional[bool] = None,
**kwargs: Any,
) -> BaseCombineDocumentsChain: ) -> BaseCombineDocumentsChain:
"""Load summarizing chain. """Load summarizing chain.

View File

@ -2,34 +2,55 @@
import json import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Union from typing import Any, Dict, List, Mapping, Optional, Union
import yaml import yaml
from pydantic import BaseModel, Extra from pydantic import BaseModel, Extra, Field, validator
import langchain import langchain
from langchain.schema import Generation from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
from langchain.schema import Generation, LLMResult
class LLMResult(NamedTuple): def _get_verbosity() -> bool:
"""Class that contains all relevant information for an LLM Result.""" return langchain.verbose
generations: List[List[Generation]]
"""List of the things generated. This is List[List[]] because
each input could have multiple generations."""
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output."""
class BaseLLM(BaseModel, ABC): class BaseLLM(BaseModel, ABC):
"""LLM wrapper should take in a prompt and return a string.""" """LLM wrapper should take in a prompt and return a string."""
cache: Optional[bool] = None cache: Optional[bool] = None
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text."""
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
extra = Extra.forbid extra = Extra.forbid
arbitrary_types_allowed = True
@validator("callback_manager", pre=True, always=True)
def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager]
) -> BaseCallbackManager:
"""If callback manager is None, set it.
This allows users to pass in None as callback manager, which is a nice UX.
"""
return callback_manager or get_callback_manager()
@validator("verbose", pre=True, always=True)
def set_verbose(cls, verbose: Optional[bool]) -> bool:
"""If verbose is None, set it.
This allows users to pass in None as verbose to access the global setting.
"""
if verbose is None:
return _get_verbosity()
else:
return verbose
@abstractmethod @abstractmethod
def _generate( def _generate(
@ -48,7 +69,14 @@ class BaseLLM(BaseModel, ABC):
raise ValueError( raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`." "Asked to cache, but no cache found at `langchain.cache`."
) )
return self._generate(prompts, stop=stop) if self.verbose:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, prompts
)
output = self._generate(prompts, stop=stop)
if self.verbose:
self.callback_manager.on_llm_end(output)
return output
params = self._llm_dict() params = self._llm_dict()
params["stop"] = stop params["stop"] = stop
llm_string = str(sorted([(k, v) for k, v in params.items()])) llm_string = str(sorted([(k, v) for k, v in params.items()]))
@ -62,7 +90,11 @@ class BaseLLM(BaseModel, ABC):
else: else:
missing_prompts.append(prompt) missing_prompts.append(prompt)
missing_prompt_idxs.append(i) missing_prompt_idxs.append(i)
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, missing_prompts
)
new_results = self._generate(missing_prompts, stop=stop) new_results = self._generate(missing_prompts, stop=stop)
self.callback_manager.on_llm_end(new_results)
for i, result in enumerate(new_results.generations): for i, result in enumerate(new_results.generations):
existing_prompts[i] = result existing_prompts[i] = result
prompt = prompts[i] prompt = prompts[i]

View File

@ -4,8 +4,8 @@ from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import BaseModel, Extra, Field, root_validator
from langchain.llms.base import BaseLLM, LLMResult from langchain.llms.base import BaseLLM
from langchain.schema import Generation from langchain.schema import Generation, LLMResult
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env

View File

@ -1,71 +0,0 @@
"""BETA: everything in here is highly experimental, do not rely on."""
from typing import Any, Optional
from langchain.input import print_text
from langchain.schema import AgentAction, AgentFinish
class BaseLogger:
"""Base logging interface."""
def log_agent_start(self, text: str, **kwargs: Any) -> None:
"""Log the start of an agent interaction."""
pass
def log_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Log the end of an agent interaction."""
pass
def log_agent_action(self, action: AgentAction, **kwargs: Any) -> None:
"""Log agent action decision."""
pass
def log_agent_observation(self, observation: str, **kwargs: Any) -> None:
"""Log agent observation."""
pass
def log_llm_inputs(self, inputs: dict, prompt: str, **kwargs: Any) -> None:
"""Log LLM inputs."""
pass
def log_llm_response(self, output: str, **kwargs: Any) -> None:
"""Log LLM response."""
pass
class StdOutLogger(BaseLogger):
"""Interface for printing things to stdout."""
def log_agent_start(self, text: str, **kwargs: Any) -> None:
"""Print the text to start the agent."""
print_text(text)
def log_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> None:
"""Print the log of the action in a certain color."""
print_text(action.log, color=color)
def log_agent_observation(
self,
observation: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Print the observation in a special color."""
print_text(f"\n{observation_prefix}")
print_text(observation, color=color)
print_text(f"\n{llm_prefix}")
def log_llm_inputs(self, inputs: dict, prompt: str, **kwargs: Any) -> None:
"""Print the prompt in green."""
print("Prompt after formatting:")
print_text(prompt, color="green", end="\n")
def log_agent_end(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
"""Log the end of an agent interaction."""
print_text(finish.log, color=color)

View File

@ -1,6 +1,6 @@
"""Common schema objects.""" """Common schema objects."""
from typing import NamedTuple from typing import List, NamedTuple, Optional
class AgentAction(NamedTuple): class AgentAction(NamedTuple):
@ -24,3 +24,13 @@ class Generation(NamedTuple):
text: str text: str
"""Generated text output.""" """Generated text output."""
# TODO: add log probs # TODO: add log probs
class LLMResult(NamedTuple):
"""Class that contains all relevant information for an LLM Result."""
generations: List[List[Generation]]
"""List of the things generated. This is List[List[]] because
each input could have multiple generations."""
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output."""

View File

@ -4,8 +4,10 @@ from typing import Any, List, Mapping, Optional
from pydantic import BaseModel from pydantic import BaseModel
from langchain.agents import Tool, initialize_agent from langchain.agents import AgentExecutor, Tool, initialize_agent
from langchain.callbacks.base import CallbackManager
from langchain.llms.base import LLM from langchain.llms.base import LLM
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
class FakeListLLM(LLM, BaseModel): class FakeListLLM(LLM, BaseModel):
@ -31,8 +33,8 @@ class FakeListLLM(LLM, BaseModel):
return "fake_list" return "fake_list"
def test_agent_bad_action() -> None: def _get_agent(**kwargs: Any) -> AgentExecutor:
"""Test react chain when bad action given.""" """Get agent for testing."""
bad_action_name = "BadAction" bad_action_name = "BadAction"
responses = [ responses = [
f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment", f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment",
@ -44,30 +46,122 @@ def test_agent_bad_action() -> None:
Tool("Lookup", lambda x: x, "Useful for looking up things in a table"), Tool("Lookup", lambda x: x, "Useful for looking up things in a table"),
] ]
agent = initialize_agent( agent = initialize_agent(
tools, fake_llm, agent="zero-shot-react-description", verbose=True tools, fake_llm, agent="zero-shot-react-description", verbose=True, **kwargs
) )
return agent
def test_agent_bad_action() -> None:
"""Test react chain when bad action given."""
agent = _get_agent()
output = agent.run("when was langchain made") output = agent.run("when was langchain made")
assert output == "curses foiled again" assert output == "curses foiled again"
def test_agent_stopped_early() -> None: def test_agent_stopped_early() -> None:
"""Test react chain when bad action given.""" """Test react chain when bad action given."""
bad_action_name = "BadAction" agent = _get_agent(max_iterations=0)
output = agent.run("when was langchain made")
assert output == "Agent stopped due to max iterations."
def test_agent_with_callbacks_global() -> None:
"""Test react chain with callbacks by setting verbose globally."""
import langchain
langchain.verbose = True
handler = FakeCallbackHandler()
manager = CallbackManager(handlers=[handler])
tool = "Search"
responses = [ responses = [
f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment", f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
"Oh well\nAction: Final Answer\nAction Input: curses foiled again", "Oh well\nAction: Final Answer\nAction Input: curses foiled again",
] ]
fake_llm = FakeListLLM(responses=responses) fake_llm = FakeListLLM(responses=responses, callback_manager=manager, verbose=True)
tools = [ tools = [
Tool("Search", lambda x: x, "Useful for searching"), Tool("Search", lambda x: x, "Useful for searching"),
Tool("Lookup", lambda x: x, "Useful for looking up things in a table"),
] ]
agent = initialize_agent( agent = initialize_agent(
tools, tools,
fake_llm, fake_llm,
agent="zero-shot-react-description", agent="zero-shot-react-description",
verbose=True, verbose=True,
max_iterations=0, callback_manager=manager,
) )
output = agent.run("when was langchain made") output = agent.run("when was langchain made")
assert output == "Agent stopped due to max iterations." assert output == "curses foiled again"
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run
assert handler.starts == 6
# 1 extra agent end
assert handler.ends == 7
assert handler.errors == 0
def test_agent_with_callbacks_local() -> None:
"""Test react chain with callbacks by setting verbose locally."""
import langchain
langchain.verbose = False
handler = FakeCallbackHandler()
manager = CallbackManager(handlers=[handler])
tool = "Search"
responses = [
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
"Oh well\nAction: Final Answer\nAction Input: curses foiled again",
]
fake_llm = FakeListLLM(responses=responses, callback_manager=manager, verbose=True)
tools = [
Tool("Search", lambda x: x, "Useful for searching"),
]
agent = initialize_agent(
tools,
fake_llm,
agent="zero-shot-react-description",
verbose=True,
callback_manager=manager,
)
agent.agent.llm_chain.verbose = True
output = agent.run("when was langchain made")
assert output == "curses foiled again"
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run
assert handler.starts == 6
# 1 extra agent end
assert handler.ends == 7
assert handler.errors == 0
def test_agent_with_callbacks_not_verbose() -> None:
"""Test react chain with callbacks but not verbose."""
import langchain
langchain.verbose = False
handler = FakeCallbackHandler()
manager = CallbackManager(handlers=[handler])
tool = "Search"
responses = [
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
"Oh well\nAction: Final Answer\nAction Input: curses foiled again",
]
fake_llm = FakeListLLM(responses=responses, callback_manager=manager)
tools = [
Tool("Search", lambda x: x, "Useful for searching"),
]
agent = initialize_agent(
tools,
fake_llm,
agent="zero-shot-react-description",
callback_manager=manager,
)
output = agent.run("when was langchain made")
assert output == "curses foiled again"
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run
assert handler.starts == 0
assert handler.ends == 0
assert handler.errors == 0

View File

@ -0,0 +1 @@
"""Tests for correct functioning of callbacks."""

View File

@ -0,0 +1,66 @@
"""A fake callback handler for testing purposes."""
from typing import Any, Dict, List
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
class FakeCallbackHandler(BaseCallbackHandler):
"""Fake callback handler for testing."""
starts: int = 0
ends: int = 0
errors: int = 0
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
self.starts += 1
def on_llm_end(
self,
response: LLMResult,
) -> None:
"""Run when LLM ends running."""
self.ends += 1
def on_llm_error(self, error: Exception) -> None:
"""Run when LLM errors."""
self.errors += 1
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
self.starts += 1
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
"""Run when chain ends running."""
self.ends += 1
def on_chain_error(self, error: Exception) -> None:
"""Run when chain errors."""
self.errors += 1
def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
) -> None:
"""Run when tool starts running."""
self.starts += 1
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
self.ends += 1
def on_tool_error(self, error: Exception) -> None:
"""Run when tool errors."""
self.errors += 1
def on_text(self, text: str, **kwargs: Any) -> None:
"""Run when agent is ending."""
self.ends += 1
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run when agent ends running."""
self.ends += 1

View File

@ -0,0 +1,47 @@
"""Test CallbackManager."""
from langchain.callbacks.base import BaseCallbackManager, CallbackManager
from langchain.callbacks.shared import SharedCallbackManager
from langchain.schema import AgentAction, LLMResult
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
def _test_callback_manager(
manager: BaseCallbackManager, *handlers: FakeCallbackHandler
) -> None:
"""Test the CallbackManager."""
manager.on_llm_start({}, [])
manager.on_llm_end(LLMResult(generations=[]))
manager.on_llm_error(Exception())
manager.on_chain_start({"name": "foo"}, {})
manager.on_chain_end({})
manager.on_chain_error(Exception())
manager.on_tool_start({}, AgentAction("", "", ""))
manager.on_tool_end("")
manager.on_tool_error(Exception())
for handler in handlers:
assert handler.starts == 3
assert handler.ends == 3
assert handler.errors == 3
def test_callback_manager() -> None:
"""Test the CallbackManager."""
handler1 = FakeCallbackHandler()
handler2 = FakeCallbackHandler()
manager = CallbackManager(handlers=[handler1, handler2])
_test_callback_manager(manager, handler1, handler2)
def test_shared_callback_manager() -> None:
"""Test the SharedCallbackManager."""
manager1 = SharedCallbackManager()
manager2 = SharedCallbackManager()
assert manager1 is manager2
handler1 = FakeCallbackHandler()
handler2 = FakeCallbackHandler()
manager1.add_handler(handler1)
manager2.add_handler(handler2)
_test_callback_manager(manager1, handler1, handler2)

View File

@ -4,7 +4,9 @@ from typing import Any, Dict, List
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
from langchain.callbacks.base import CallbackManager
from langchain.chains.base import Chain, Memory from langchain.chains.base import Chain, Memory
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
class FakeMemory(Memory, BaseModel): class FakeMemory(Memory, BaseModel):
@ -133,3 +135,31 @@ def test_run_arg_with_memory() -> None:
"""Test run method works when arg is passed.""" """Test run method works when arg is passed."""
chain = FakeChain(the_input_keys=["foo", "baz"], memory=FakeMemory()) chain = FakeChain(the_input_keys=["foo", "baz"], memory=FakeMemory())
chain.run("bar") chain.run("bar")
def test_run_with_callback() -> None:
"""Test run method works when callback manager is passed."""
handler = FakeCallbackHandler()
chain = FakeChain(
callback_manager=CallbackManager(handlers=[handler]), verbose=True
)
output = chain.run("bar")
assert output == "baz"
assert handler.starts == 1
assert handler.ends == 1
assert handler.errors == 0
def test_run_with_callback_not_verbose() -> None:
"""Test run method works when callback manager is passed and not verbose."""
import langchain
langchain.verbose = False
handler = FakeCallbackHandler()
chain = FakeChain(callback_manager=CallbackManager(handlers=[handler]))
output = chain.run("bar")
assert output == "baz"
assert handler.starts == 0
assert handler.ends == 0
assert handler.errors == 0

View File

@ -0,0 +1,30 @@
"""Test LLM callbacks."""
from langchain.callbacks.base import CallbackManager
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
from tests.unit_tests.llms.fake_llm import FakeLLM
def test_llm_with_callbacks() -> None:
"""Test LLM callbacks."""
handler = FakeCallbackHandler()
llm = FakeLLM(callback_manager=CallbackManager(handlers=[handler]), verbose=True)
output = llm("foo")
assert output == "foo"
assert handler.starts == 1
assert handler.ends == 1
assert handler.errors == 0
def test_llm_with_callbacks_not_verbose() -> None:
"""Test LLM callbacks but not verbose."""
import langchain
langchain.verbose = False
handler = FakeCallbackHandler()
llm = FakeLLM(callback_manager=CallbackManager(handlers=[handler]))
output = llm("foo")
assert output == "foo"
assert handler.starts == 0
assert handler.ends == 0
assert handler.errors == 0

View File

@ -7,8 +7,8 @@ from pydantic import BaseModel
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.embeddings.hyde.base import HypotheticalDocumentEmbedder from langchain.embeddings.hyde.base import HypotheticalDocumentEmbedder
from langchain.embeddings.hyde.prompts import PROMPT_MAP from langchain.embeddings.hyde.prompts import PROMPT_MAP
from langchain.llms.base import BaseLLM, LLMResult from langchain.llms.base import BaseLLM
from langchain.schema import Generation from langchain.schema import Generation, LLMResult
class FakeEmbeddings(Embeddings): class FakeEmbeddings(Embeddings):