Harrison/streamlit handler (#488)

also add a set handler method

usage is:
```
from langchain.callbacks.streamlit import StreamlitCallbackHandler
import langchain
langchain.set_handler(StreamlitCallbackHandler())
```

produces the following output


![Screen Shot 2022-12-29 at 10 50 33
PM](https://user-images.githubusercontent.com/11986836/210032762-7f53fffa-cb2f-4dac-af39-7d4cf81e55dd.png)

only works for agent stuff currently
harrison/callback-updates
Harrison Chase 1 year ago committed by GitHub
parent 45d6de177e
commit a3d2a2ec2a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,7 +4,7 @@ from typing import Optional
from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain
from langchain.cache import BaseCache
from langchain.callbacks import set_default_callback_manager
from langchain.callbacks import set_default_callback_manager, set_handler
from langchain.chains import (
ConversationChain,
LLMBashChain,
@ -63,4 +63,5 @@ __all__ = [
"VectorDBQAWithSourcesChain",
"QAWithSourcesChain",
"PALChain",
"set_handler",
]

@ -1,5 +1,5 @@
"""Callback handlers that allow listening to events in LangChain."""
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.base import BaseCallbackHandler, BaseCallbackManager
from langchain.callbacks.shared import SharedCallbackManager
from langchain.callbacks.stdout import StdOutCallbackHandler
@ -13,3 +13,9 @@ def set_default_callback_manager() -> None:
"""Set default callback manager."""
callback = get_callback_manager()
callback.add_handler(StdOutCallbackHandler())
def set_handler(handler: BaseCallbackHandler) -> None:
"""Set handler."""
callback = get_callback_manager()
callback.set_handler(handler)

@ -70,6 +70,10 @@ class BaseCallbackManager(BaseCallbackHandler, ABC):
def remove_handler(self, handler: BaseCallbackHandler) -> None:
"""Remove a handler from the callback manager."""
@abstractmethod
def set_handler(self, handler: BaseCallbackHandler) -> None:
"""Set handler as the only handler on the callback manager."""
class CallbackManager(BaseCallbackManager):
"""Callback manager that can be used to handle callbacks from langchain."""
@ -144,3 +148,7 @@ class CallbackManager(BaseCallbackManager):
def remove_handler(self, handler: BaseCallbackHandler) -> None:
"""Remove a handler from the callback manager."""
self.handlers.remove(handler)
def set_handler(self, handler: BaseCallbackHandler) -> None:
"""Set handler as the only handler on the callback manager."""
self.handlers = [handler]

@ -102,3 +102,8 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
"""Remove a callback from the callback manager."""
with self._lock:
self._callback_manager.remove_handler(callback)
def set_handler(self, handler: BaseCallbackHandler) -> None:
"""Set handler as the only handler on the callback manager."""
with self._lock:
self._callback_manager.handlers = [handler]

@ -0,0 +1,72 @@
"""Callback Handler that logs to streamlit."""
from typing import Any, Dict, List, Optional
import streamlit as st
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, LLMResult
class StreamlitCallbackHandler(BaseCallbackHandler):
"""Callback Handler that logs to streamlit."""
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Print out the prompts."""
st.write("Prompts after formatting:")
for prompt in prompts:
st.write(prompt)
def on_llm_end(self, response: LLMResult) -> None:
"""Do nothing."""
pass
def on_llm_error(self, error: Exception) -> None:
"""Do nothing."""
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
class_name = serialized["name"]
st.write(f"Entering new {class_name} chain...")
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
"""Print out that we finished a chain."""
st.write("Finished chain.")
def on_chain_error(self, error: Exception) -> None:
"""Do nothing."""
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
action: AgentAction,
**kwargs: Any,
) -> None:
"""Print out the log in specified color."""
# st.write requires two spaces before a newline to render it
st.markdown(action.log.replace("\n", " \n"))
def on_tool_end(
self,
output: str,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""If not the final action, print out observation."""
st.write(f"{observation_prefix}{output}")
st.write(llm_prefix)
def on_tool_error(self, error: Exception) -> None:
"""Do nothing."""
pass
def on_text(self, text: str, **kwargs: Any) -> None:
"""Run on text."""
# st.write requires two spaces before a newline to render it
st.write(text.replace("\n", " \n"))
Loading…
Cancel
Save