|
|
|
@ -10,6 +10,10 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult
|
|
|
|
|
class StreamlitCallbackHandler(BaseCallbackHandler):
|
|
|
|
|
"""Callback Handler that logs to streamlit."""
|
|
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
|
self.tokens_area = st.empty()
|
|
|
|
|
self.tokens_stream = ""
|
|
|
|
|
|
|
|
|
|
def on_llm_start(
|
|
|
|
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
|
|
|
|
) -> None:
|
|
|
|
@ -19,8 +23,9 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
|
|
|
|
|
st.write(prompt)
|
|
|
|
|
|
|
|
|
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
|
|
|
|
"""Do nothing."""
|
|
|
|
|
pass
|
|
|
|
|
"""Run on new LLM token. Only available when streaming is enabled."""
|
|
|
|
|
self.tokens_stream += token
|
|
|
|
|
self.tokens_area.write(self.tokens_stream)
|
|
|
|
|
|
|
|
|
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
|
|
|
|
"""Do nothing."""
|
|
|
|
|