"""Callback Handler that prints to streamlit.""" from __future__ import annotations from enum import Enum from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional from langchain_core.agents import AgentAction, AgentFinish from langchain_core.callbacks import BaseCallbackHandler from langchain_core.outputs import LLMResult from langchain_community.callbacks.streamlit.mutable_expander import MutableExpander if TYPE_CHECKING: from streamlit.delta_generator import DeltaGenerator def _convert_newlines(text: str) -> str: """Convert newline characters to markdown newline sequences (space, space, newline). """ return text.replace("\n", " \n") CHECKMARK_EMOJI = "✅" THINKING_EMOJI = ":thinking_face:" HISTORY_EMOJI = ":books:" EXCEPTION_EMOJI = "⚠️" class LLMThoughtState(Enum): """Enumerator of the LLMThought state.""" # The LLM is thinking about what to do next. We don't know which tool we'll run. THINKING = "THINKING" # The LLM has decided to run a tool. We don't have results from the tool yet. RUNNING_TOOL = "RUNNING_TOOL" # We have results from the tool. COMPLETE = "COMPLETE" class ToolRecord(NamedTuple): """The tool record as a NamedTuple.""" name: str input_str: str class LLMThoughtLabeler: """ Generates markdown labels for LLMThought containers. Pass a custom subclass of this to StreamlitCallbackHandler to override its default labeling logic. """ def get_initial_label(self) -> str: """Return the markdown label for a new LLMThought that doesn't have an associated tool yet. """ return f"{THINKING_EMOJI} **Thinking...**" def get_tool_label(self, tool: ToolRecord, is_complete: bool) -> str: """Return the label for an LLMThought that has an associated tool. Parameters ---------- tool The tool's ToolRecord is_complete True if the thought is complete; False if the thought is still receiving input. Returns ------- The markdown label for the thought's container. """ input = tool.input_str name = tool.name emoji = CHECKMARK_EMOJI if is_complete else THINKING_EMOJI if name == "_Exception": emoji = EXCEPTION_EMOJI name = "Parsing error" idx = min([60, len(input)]) input = input[0:idx] if len(tool.input_str) > idx: input = input + "..." input = input.replace("\n", " ") label = f"{emoji} **{name}:** {input}" return label def get_history_label(self) -> str: """Return a markdown label for the special 'history' container that contains overflow thoughts. """ return f"{HISTORY_EMOJI} **History**" def get_final_agent_thought_label(self) -> str: """Return the markdown label for the agent's final thought - the "Now I have the answer" thought, that doesn't involve a tool. """ return f"{CHECKMARK_EMOJI} **Complete!**" class LLMThought: """A thought in the LLM's thought stream.""" def __init__( self, parent_container: DeltaGenerator, labeler: LLMThoughtLabeler, expanded: bool, collapse_on_complete: bool, ): """Initialize the LLMThought. Args: parent_container: The container we're writing into. labeler: The labeler to use for this thought. expanded: Whether the thought should be expanded by default. collapse_on_complete: Whether the thought should be collapsed. """ self._container = MutableExpander( parent_container=parent_container, label=labeler.get_initial_label(), expanded=expanded, ) self._state = LLMThoughtState.THINKING self._llm_token_stream = "" self._llm_token_writer_idx: Optional[int] = None self._last_tool: Optional[ToolRecord] = None self._collapse_on_complete = collapse_on_complete self._labeler = labeler @property def container(self) -> MutableExpander: """The container we're writing into.""" return self._container @property def last_tool(self) -> Optional[ToolRecord]: """The last tool executed by this thought""" return self._last_tool def _reset_llm_token_stream(self) -> None: self._llm_token_stream = "" self._llm_token_writer_idx = None def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str]) -> None: self._reset_llm_token_stream() def on_llm_new_token(self, token: str, **kwargs: Any) -> None: # This is only called when the LLM is initialized with `streaming=True` self._llm_token_stream += _convert_newlines(token) self._llm_token_writer_idx = self._container.markdown( self._llm_token_stream, index=self._llm_token_writer_idx ) def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: # `response` is the concatenation of all the tokens received by the LLM. # If we're receiving streaming tokens from `on_llm_new_token`, this response # data is redundant self._reset_llm_token_stream() def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: self._container.markdown("**LLM encountered an error...**") self._container.exception(error) def on_tool_start( self, serialized: Dict[str, Any], input_str: str, **kwargs: Any ) -> None: # Called with the name of the tool we're about to run (in `serialized[name]`), # and its input. We change our container's label to be the tool name. self._state = LLMThoughtState.RUNNING_TOOL tool_name = serialized["name"] self._last_tool = ToolRecord(name=tool_name, input_str=input_str) self._container.update( new_label=self._labeler.get_tool_label(self._last_tool, is_complete=False) ) def on_tool_end( self, output: str, color: Optional[str] = None, observation_prefix: Optional[str] = None, llm_prefix: Optional[str] = None, **kwargs: Any, ) -> None: self._container.markdown(f"**{output}**") def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: self._container.markdown("**Tool encountered an error...**") self._container.exception(error) def on_agent_action( self, action: AgentAction, color: Optional[str] = None, **kwargs: Any ) -> Any: # Called when we're about to kick off a new tool. The `action` data # tells us the tool we're about to use, and the input we'll give it. # We don't output anything here, because we'll receive this same data # when `on_tool_start` is called immediately after. pass def complete(self, final_label: Optional[str] = None) -> None: """Finish the thought.""" if final_label is None and self._state == LLMThoughtState.RUNNING_TOOL: assert ( self._last_tool is not None ), "_last_tool should never be null when _state == RUNNING_TOOL" final_label = self._labeler.get_tool_label( self._last_tool, is_complete=True ) self._state = LLMThoughtState.COMPLETE if self._collapse_on_complete: self._container.update(new_label=final_label, new_expanded=False) else: self._container.update(new_label=final_label) def clear(self) -> None: """Remove the thought from the screen. A cleared thought can't be reused.""" self._container.clear() class StreamlitCallbackHandler(BaseCallbackHandler): """Callback handler that writes to a Streamlit app.""" def __init__( self, parent_container: DeltaGenerator, *, max_thought_containers: int = 4, expand_new_thoughts: bool = True, collapse_completed_thoughts: bool = True, thought_labeler: Optional[LLMThoughtLabeler] = None, ): """Create a StreamlitCallbackHandler instance. Parameters ---------- parent_container The `st.container` that will contain all the Streamlit elements that the Handler creates. max_thought_containers The max number of completed LLM thought containers to show at once. When this threshold is reached, a new thought will cause the oldest thoughts to be collapsed into a "History" expander. Defaults to 4. expand_new_thoughts Each LLM "thought" gets its own `st.expander`. This param controls whether that expander is expanded by default. Defaults to True. collapse_completed_thoughts If True, LLM thought expanders will be collapsed when completed. Defaults to True. thought_labeler An optional custom LLMThoughtLabeler instance. If unspecified, the handler will use the default thought labeling logic. Defaults to None. """ self._parent_container = parent_container self._history_parent = parent_container.container() self._history_container: Optional[MutableExpander] = None self._current_thought: Optional[LLMThought] = None self._completed_thoughts: List[LLMThought] = [] self._max_thought_containers = max(max_thought_containers, 1) self._expand_new_thoughts = expand_new_thoughts self._collapse_completed_thoughts = collapse_completed_thoughts self._thought_labeler = thought_labeler or LLMThoughtLabeler() def _require_current_thought(self) -> LLMThought: """Return our current LLMThought. Raise an error if we have no current thought. """ if self._current_thought is None: raise RuntimeError("Current LLMThought is unexpectedly None!") return self._current_thought def _get_last_completed_thought(self) -> Optional[LLMThought]: """Return our most recent completed LLMThought, or None if we don't have one.""" if len(self._completed_thoughts) > 0: return self._completed_thoughts[len(self._completed_thoughts) - 1] return None @property def _num_thought_containers(self) -> int: """The number of 'thought containers' we're currently showing: the number of completed thought containers, the history container (if it exists), and the current thought container (if it exists). """ count = len(self._completed_thoughts) if self._history_container is not None: count += 1 if self._current_thought is not None: count += 1 return count def _complete_current_thought(self, final_label: Optional[str] = None) -> None: """Complete the current thought, optionally assigning it a new label. Add it to our _completed_thoughts list. """ thought = self._require_current_thought() thought.complete(final_label) self._completed_thoughts.append(thought) self._current_thought = None def _prune_old_thought_containers(self) -> None: """If we have too many thoughts onscreen, move older thoughts to the 'history container.' """ while ( self._num_thought_containers > self._max_thought_containers and len(self._completed_thoughts) > 0 ): # Create our history container if it doesn't exist, and if # max_thought_containers is > 1. (if max_thought_containers is 1, we don't # have room to show history.) if self._history_container is None and self._max_thought_containers > 1: self._history_container = MutableExpander( self._history_parent, label=self._thought_labeler.get_history_label(), expanded=False, ) oldest_thought = self._completed_thoughts.pop(0) if self._history_container is not None: self._history_container.markdown(oldest_thought.container.label) self._history_container.append_copy(oldest_thought.container) oldest_thought.clear() def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: if self._current_thought is None: self._current_thought = LLMThought( parent_container=self._parent_container, expanded=self._expand_new_thoughts, collapse_on_complete=self._collapse_completed_thoughts, labeler=self._thought_labeler, ) self._current_thought.on_llm_start(serialized, prompts) # We don't prune_old_thought_containers here, because our container won't # be visible until it has a child. def on_llm_new_token(self, token: str, **kwargs: Any) -> None: self._require_current_thought().on_llm_new_token(token, **kwargs) self._prune_old_thought_containers() def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: self._require_current_thought().on_llm_end(response, **kwargs) self._prune_old_thought_containers() def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: self._require_current_thought().on_llm_error(error, **kwargs) self._prune_old_thought_containers() def on_tool_start( self, serialized: Dict[str, Any], input_str: str, **kwargs: Any ) -> None: self._require_current_thought().on_tool_start(serialized, input_str, **kwargs) self._prune_old_thought_containers() def on_tool_end( self, output: str, color: Optional[str] = None, observation_prefix: Optional[str] = None, llm_prefix: Optional[str] = None, **kwargs: Any, ) -> None: self._require_current_thought().on_tool_end( output, color, observation_prefix, llm_prefix, **kwargs ) self._complete_current_thought() def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: self._require_current_thought().on_tool_error(error, **kwargs) self._prune_old_thought_containers() def on_text( self, text: str, color: Optional[str] = None, end: str = "", **kwargs: Any, ) -> None: pass def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any ) -> None: pass def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: pass def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: pass def on_agent_action( self, action: AgentAction, color: Optional[str] = None, **kwargs: Any ) -> Any: self._require_current_thought().on_agent_action(action, color, **kwargs) self._prune_old_thought_containers() def on_agent_finish( self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any ) -> None: if self._current_thought is not None: self._current_thought.complete( self._thought_labeler.get_final_agent_thought_label() ) self._current_thought = None