diff --git a/langchain/callbacks/streamlit.py b/langchain/callbacks/streamlit.py index a603765f..c543f1cd 100644 --- a/langchain/callbacks/streamlit.py +++ b/langchain/callbacks/streamlit.py @@ -10,6 +10,10 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult class StreamlitCallbackHandler(BaseCallbackHandler): """Callback Handler that logs to streamlit.""" + def __init__(self) -> None: + self.tokens_area = st.empty() + self.tokens_stream = "" + def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: @@ -19,8 +23,9 @@ class StreamlitCallbackHandler(BaseCallbackHandler): st.write(prompt) def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Do nothing.""" - pass + """Run on new LLM token. Only available when streaming is enabled.""" + self.tokens_stream += token + self.tokens_area.write(self.tokens_stream) def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """Do nothing."""