callback updates

harrison/callback-updates
Harrison Chase 1 year ago
parent 23b8cfc123
commit 7eb33690a9

@ -259,7 +259,7 @@ class AgentExecutor(Chain, BaseModel):
# If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish):
if self.verbose:
self.callback_manager.on_text(output.log, color="green")
self.callback_manager.on_agent_end(output, color="green")
final_output = output.return_values
if self.return_intermediate_steps:
final_output["intermediate_steps"] = intermediate_steps

@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from langchain.schema import AgentAction, LLMResult
from langchain.schema import AgentAction, AgentFinish, LLMResult
class BaseCallbackHandler(ABC):
@ -56,7 +56,11 @@ class BaseCallbackHandler(ABC):
@abstractmethod
def on_text(self, text: str, **kwargs: Any) -> None:
"""Run when agent ends."""
"""Run on arbitrary text."""
@abstractmethod
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end."""
class BaseCallbackManager(BaseCallbackHandler, ABC):
@ -141,6 +145,11 @@ class CallbackManager(BaseCallbackManager):
for handler in self.handlers:
handler.on_text(text, **kwargs)
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end."""
for handler in self.handlers:
handler.on_agent_end(finish, **kwargs)
def add_handler(self, handler: BaseCallbackHandler) -> None:
"""Add a handler to the callback manager."""
self.handlers.append(handler)

@ -8,7 +8,7 @@ from langchain.callbacks.base import (
BaseCallbackManager,
CallbackManager,
)
from langchain.schema import AgentAction, LLMResult
from langchain.schema import AgentAction, AgentFinish, LLMResult
class Singleton:
@ -89,10 +89,15 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
self._callback_manager.on_tool_error(error)
def on_text(self, text: str, **kwargs: Any) -> None:
"""Run when agent ends."""
"""Run on arbitrary text."""
with self._lock:
self._callback_manager.on_text(text, **kwargs)
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end."""
with self._lock:
self._callback_manager.on_agent_end(finish, **kwargs)
def add_handler(self, callback: BaseCallbackHandler) -> None:
"""Add a callback to the callback manager."""
with self._lock:

@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text
from langchain.schema import AgentAction, LLMResult
from langchain.schema import AgentAction, AgentFinish, LLMResult
class StdOutCallbackHandler(BaseCallbackHandler):
@ -76,3 +76,9 @@ class StdOutCallbackHandler(BaseCallbackHandler):
) -> None:
"""Run when agent ends."""
print_text(text, color=color, end=end)
def on_agent_end(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
"""Run on agent end."""
print_text(finish.log, color=color, end="\n")

@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
import streamlit as st
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, LLMResult
from langchain.schema import AgentAction, AgentFinish, LLMResult
class StreamlitCallbackHandler(BaseCallbackHandler):
@ -70,3 +70,8 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
"""Run on text."""
# st.write requires two spaces before a newline to render it
st.write(text.replace("\n", " \n"))
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end."""
# st.write requires two spaces before a newline to render it
st.write(finish.log.replace("\n", " \n"))

@ -2,7 +2,7 @@
from typing import Any, Dict, List
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, LLMResult
from langchain.schema import AgentAction, LLMResult, AgentFinish
class FakeCallbackHandler(BaseCallbackHandler):
@ -62,3 +62,8 @@ class FakeCallbackHandler(BaseCallbackHandler):
def on_text(self, text: str, **kwargs: Any) -> None:
"""Run when agent is ending."""
self.ends += 1
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run when agent ends running."""
self.ends += 1

Loading…
Cancel
Save