From b3c49e94a0ada8a1dacc5116611abcde0560846c Mon Sep 17 00:00:00 2001 From: Davis Chase <130488702+dev2049@users.noreply.github.com> Date: Tue, 20 Jun 2023 18:41:59 -0700 Subject: [PATCH] Make streamlit import optional (#6510) --- langchain/callbacks/streamlit.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/langchain/callbacks/streamlit.py b/langchain/callbacks/streamlit.py index c543f1cd..08d68ec8 100644 --- a/langchain/callbacks/streamlit.py +++ b/langchain/callbacks/streamlit.py @@ -1,8 +1,6 @@ """Callback Handler that logs to streamlit.""" from typing import Any, Dict, List, Optional, Union -import streamlit as st - from langchain.callbacks.base import BaseCallbackHandler from langchain.schema import AgentAction, AgentFinish, LLMResult @@ -11,16 +9,25 @@ class StreamlitCallbackHandler(BaseCallbackHandler): """Callback Handler that logs to streamlit.""" def __init__(self) -> None: + try: + import streamlit as st + except ImportError as e: + raise ImportError( + "Could not import streamlit Python package. " + "Please install it with `pip install streamlit`." + ) from e + self.tokens_area = st.empty() self.tokens_stream = "" + self.st = st def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: """Print out the prompts.""" - st.write("Prompts after formatting:") + self.st.write("Prompts after formatting:") for prompt in prompts: - st.write(prompt) + self.st.write(prompt) def on_llm_new_token(self, token: str, **kwargs: Any) -> None: """Run on new LLM token. Only available when streaming is enabled.""" @@ -42,11 +49,11 @@ class StreamlitCallbackHandler(BaseCallbackHandler): ) -> None: """Print out that we are entering a chain.""" class_name = serialized["name"] - st.write(f"Entering new {class_name} chain...") + self.st.write(f"Entering new {class_name} chain...") def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: """Print out that we finished a chain.""" - st.write("Finished chain.") + self.st.write("Finished chain.") def on_chain_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any @@ -66,7 +73,8 @@ class StreamlitCallbackHandler(BaseCallbackHandler): def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: """Run on agent action.""" # st.write requires two spaces before a newline to render it - st.markdown(action.log.replace("\n", " \n")) + + self.st.markdown(action.log.replace("\n", " \n")) def on_tool_end( self, @@ -76,8 +84,8 @@ class StreamlitCallbackHandler(BaseCallbackHandler): **kwargs: Any, ) -> None: """If not the final action, print out observation.""" - st.write(f"{observation_prefix}{output}") - st.write(llm_prefix) + self.st.write(f"{observation_prefix}{output}") + self.st.write(llm_prefix) def on_tool_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any @@ -88,9 +96,9 @@ class StreamlitCallbackHandler(BaseCallbackHandler): def on_text(self, text: str, **kwargs: Any) -> None: """Run on text.""" # st.write requires two spaces before a newline to render it - st.write(text.replace("\n", " \n")) + self.st.write(text.replace("\n", " \n")) def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: """Run on agent end.""" # st.write requires two spaces before a newline to render it - st.write(finish.log.replace("\n", " \n")) + self.st.write(finish.log.replace("\n", " \n"))