Make streamlit import optional (#6510)

master
Davis Chase 12 months ago committed by GitHub
parent cece8c8bf0
commit b3c49e94a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,8 +1,6 @@
"""Callback Handler that logs to streamlit.""" """Callback Handler that logs to streamlit."""
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import streamlit as st
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult from langchain.schema import AgentAction, AgentFinish, LLMResult
@ -11,16 +9,25 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
"""Callback Handler that logs to streamlit.""" """Callback Handler that logs to streamlit."""
def __init__(self) -> None: def __init__(self) -> None:
try:
import streamlit as st
except ImportError as e:
raise ImportError(
"Could not import streamlit Python package. "
"Please install it with `pip install streamlit`."
) from e
self.tokens_area = st.empty() self.tokens_area = st.empty()
self.tokens_stream = "" self.tokens_stream = ""
self.st = st
def on_llm_start( def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None: ) -> None:
"""Print out the prompts.""" """Print out the prompts."""
st.write("Prompts after formatting:") self.st.write("Prompts after formatting:")
for prompt in prompts: for prompt in prompts:
st.write(prompt) self.st.write(prompt)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None: def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled.""" """Run on new LLM token. Only available when streaming is enabled."""
@ -42,11 +49,11 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
) -> None: ) -> None:
"""Print out that we are entering a chain.""" """Print out that we are entering a chain."""
class_name = serialized["name"] class_name = serialized["name"]
st.write(f"Entering new {class_name} chain...") self.st.write(f"Entering new {class_name} chain...")
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain.""" """Print out that we finished a chain."""
st.write("Finished chain.") self.st.write("Finished chain.")
def on_chain_error( def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
@ -66,7 +73,8 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action.""" """Run on agent action."""
# st.write requires two spaces before a newline to render it # st.write requires two spaces before a newline to render it
st.markdown(action.log.replace("\n", " \n"))
self.st.markdown(action.log.replace("\n", " \n"))
def on_tool_end( def on_tool_end(
self, self,
@ -76,8 +84,8 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""If not the final action, print out observation.""" """If not the final action, print out observation."""
st.write(f"{observation_prefix}{output}") self.st.write(f"{observation_prefix}{output}")
st.write(llm_prefix) self.st.write(llm_prefix)
def on_tool_error( def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
@ -88,9 +96,9 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
def on_text(self, text: str, **kwargs: Any) -> None: def on_text(self, text: str, **kwargs: Any) -> None:
"""Run on text.""" """Run on text."""
# st.write requires two spaces before a newline to render it # st.write requires two spaces before a newline to render it
st.write(text.replace("\n", " \n")) self.st.write(text.replace("\n", " \n"))
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end.""" """Run on agent end."""
# st.write requires two spaces before a newline to render it # st.write requires two spaces before a newline to render it
st.write(finish.log.replace("\n", " \n")) self.st.write(finish.log.replace("\n", " \n"))

Loading…
Cancel
Save