mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
7cf2d2759d
Added missed docstrings. Format docstings to the consistent form.
416 lines
15 KiB
Python
416 lines
15 KiB
Python
"""Callback Handler that prints to streamlit."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from enum import Enum
|
|
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional
|
|
|
|
from langchain_core.agents import AgentAction, AgentFinish
|
|
from langchain_core.callbacks import BaseCallbackHandler
|
|
from langchain_core.outputs import LLMResult
|
|
|
|
from langchain_community.callbacks.streamlit.mutable_expander import MutableExpander
|
|
|
|
if TYPE_CHECKING:
|
|
from streamlit.delta_generator import DeltaGenerator
|
|
|
|
|
|
def _convert_newlines(text: str) -> str:
|
|
"""Convert newline characters to markdown newline sequences
|
|
(space, space, newline).
|
|
"""
|
|
return text.replace("\n", " \n")
|
|
|
|
|
|
CHECKMARK_EMOJI = "✅"
|
|
THINKING_EMOJI = ":thinking_face:"
|
|
HISTORY_EMOJI = ":books:"
|
|
EXCEPTION_EMOJI = "⚠️"
|
|
|
|
|
|
class LLMThoughtState(Enum):
|
|
"""Enumerator of the LLMThought state."""
|
|
|
|
# The LLM is thinking about what to do next. We don't know which tool we'll run.
|
|
THINKING = "THINKING"
|
|
# The LLM has decided to run a tool. We don't have results from the tool yet.
|
|
RUNNING_TOOL = "RUNNING_TOOL"
|
|
# We have results from the tool.
|
|
COMPLETE = "COMPLETE"
|
|
|
|
|
|
class ToolRecord(NamedTuple):
|
|
"""Tool record as a NamedTuple."""
|
|
|
|
name: str
|
|
input_str: str
|
|
|
|
|
|
class LLMThoughtLabeler:
|
|
"""
|
|
Generates markdown labels for LLMThought containers. Pass a custom
|
|
subclass of this to StreamlitCallbackHandler to override its default
|
|
labeling logic.
|
|
"""
|
|
|
|
def get_initial_label(self) -> str:
|
|
"""Return the markdown label for a new LLMThought that doesn't have
|
|
an associated tool yet.
|
|
"""
|
|
return f"{THINKING_EMOJI} **Thinking...**"
|
|
|
|
def get_tool_label(self, tool: ToolRecord, is_complete: bool) -> str:
|
|
"""Return the label for an LLMThought that has an associated
|
|
tool.
|
|
|
|
Parameters
|
|
----------
|
|
tool
|
|
The tool's ToolRecord
|
|
|
|
is_complete
|
|
True if the thought is complete; False if the thought
|
|
is still receiving input.
|
|
|
|
Returns
|
|
-------
|
|
The markdown label for the thought's container.
|
|
|
|
"""
|
|
input = tool.input_str
|
|
name = tool.name
|
|
emoji = CHECKMARK_EMOJI if is_complete else THINKING_EMOJI
|
|
if name == "_Exception":
|
|
emoji = EXCEPTION_EMOJI
|
|
name = "Parsing error"
|
|
idx = min([60, len(input)])
|
|
input = input[0:idx]
|
|
if len(tool.input_str) > idx:
|
|
input = input + "..."
|
|
input = input.replace("\n", " ")
|
|
label = f"{emoji} **{name}:** {input}"
|
|
return label
|
|
|
|
def get_history_label(self) -> str:
|
|
"""Return a markdown label for the special 'history' container
|
|
that contains overflow thoughts.
|
|
"""
|
|
return f"{HISTORY_EMOJI} **History**"
|
|
|
|
def get_final_agent_thought_label(self) -> str:
|
|
"""Return the markdown label for the agent's final thought -
|
|
the "Now I have the answer" thought, that doesn't involve
|
|
a tool.
|
|
"""
|
|
return f"{CHECKMARK_EMOJI} **Complete!**"
|
|
|
|
|
|
class LLMThought:
|
|
"""A thought in the LLM's thought stream."""
|
|
|
|
def __init__(
|
|
self,
|
|
parent_container: DeltaGenerator,
|
|
labeler: LLMThoughtLabeler,
|
|
expanded: bool,
|
|
collapse_on_complete: bool,
|
|
):
|
|
"""Initialize the LLMThought.
|
|
|
|
Args:
|
|
parent_container: The container we're writing into.
|
|
labeler: The labeler to use for this thought.
|
|
expanded: Whether the thought should be expanded by default.
|
|
collapse_on_complete: Whether the thought should be collapsed.
|
|
"""
|
|
self._container = MutableExpander(
|
|
parent_container=parent_container,
|
|
label=labeler.get_initial_label(),
|
|
expanded=expanded,
|
|
)
|
|
self._state = LLMThoughtState.THINKING
|
|
self._llm_token_stream = ""
|
|
self._llm_token_writer_idx: Optional[int] = None
|
|
self._last_tool: Optional[ToolRecord] = None
|
|
self._collapse_on_complete = collapse_on_complete
|
|
self._labeler = labeler
|
|
|
|
@property
|
|
def container(self) -> MutableExpander:
|
|
"""The container we're writing into."""
|
|
return self._container
|
|
|
|
@property
|
|
def last_tool(self) -> Optional[ToolRecord]:
|
|
"""The last tool executed by this thought"""
|
|
return self._last_tool
|
|
|
|
def _reset_llm_token_stream(self) -> None:
|
|
self._llm_token_stream = ""
|
|
self._llm_token_writer_idx = None
|
|
|
|
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str]) -> None:
|
|
self._reset_llm_token_stream()
|
|
|
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
|
# This is only called when the LLM is initialized with `streaming=True`
|
|
self._llm_token_stream += _convert_newlines(token)
|
|
self._llm_token_writer_idx = self._container.markdown(
|
|
self._llm_token_stream, index=self._llm_token_writer_idx
|
|
)
|
|
|
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
|
# `response` is the concatenation of all the tokens received by the LLM.
|
|
# If we're receiving streaming tokens from `on_llm_new_token`, this response
|
|
# data is redundant
|
|
self._reset_llm_token_stream()
|
|
|
|
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
|
self._container.markdown("**LLM encountered an error...**")
|
|
self._container.exception(error)
|
|
|
|
def on_tool_start(
|
|
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
|
) -> None:
|
|
# Called with the name of the tool we're about to run (in `serialized[name]`),
|
|
# and its input. We change our container's label to be the tool name.
|
|
self._state = LLMThoughtState.RUNNING_TOOL
|
|
tool_name = serialized["name"]
|
|
self._last_tool = ToolRecord(name=tool_name, input_str=input_str)
|
|
self._container.update(
|
|
new_label=self._labeler.get_tool_label(self._last_tool, is_complete=False)
|
|
)
|
|
|
|
def on_tool_end(
|
|
self,
|
|
output: Any,
|
|
color: Optional[str] = None,
|
|
observation_prefix: Optional[str] = None,
|
|
llm_prefix: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
self._container.markdown(f"**{str(output)}**")
|
|
|
|
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
|
self._container.markdown("**Tool encountered an error...**")
|
|
self._container.exception(error)
|
|
|
|
def on_agent_action(
|
|
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
|
) -> Any:
|
|
# Called when we're about to kick off a new tool. The `action` data
|
|
# tells us the tool we're about to use, and the input we'll give it.
|
|
# We don't output anything here, because we'll receive this same data
|
|
# when `on_tool_start` is called immediately after.
|
|
pass
|
|
|
|
def complete(self, final_label: Optional[str] = None) -> None:
|
|
"""Finish the thought."""
|
|
if final_label is None and self._state == LLMThoughtState.RUNNING_TOOL:
|
|
assert (
|
|
self._last_tool is not None
|
|
), "_last_tool should never be null when _state == RUNNING_TOOL"
|
|
final_label = self._labeler.get_tool_label(
|
|
self._last_tool, is_complete=True
|
|
)
|
|
self._state = LLMThoughtState.COMPLETE
|
|
if self._collapse_on_complete:
|
|
self._container.update(new_label=final_label, new_expanded=False)
|
|
else:
|
|
self._container.update(new_label=final_label)
|
|
|
|
def clear(self) -> None:
|
|
"""Remove the thought from the screen. A cleared thought can't be reused."""
|
|
self._container.clear()
|
|
|
|
|
|
class StreamlitCallbackHandler(BaseCallbackHandler):
|
|
"""Callback handler that writes to a Streamlit app."""
|
|
|
|
def __init__(
|
|
self,
|
|
parent_container: DeltaGenerator,
|
|
*,
|
|
max_thought_containers: int = 4,
|
|
expand_new_thoughts: bool = True,
|
|
collapse_completed_thoughts: bool = True,
|
|
thought_labeler: Optional[LLMThoughtLabeler] = None,
|
|
):
|
|
"""Create a StreamlitCallbackHandler instance.
|
|
|
|
Parameters
|
|
----------
|
|
parent_container
|
|
The `st.container` that will contain all the Streamlit elements that the
|
|
Handler creates.
|
|
max_thought_containers
|
|
The max number of completed LLM thought containers to show at once. When
|
|
this threshold is reached, a new thought will cause the oldest thoughts to
|
|
be collapsed into a "History" expander. Defaults to 4.
|
|
expand_new_thoughts
|
|
Each LLM "thought" gets its own `st.expander`. This param controls whether
|
|
that expander is expanded by default. Defaults to True.
|
|
collapse_completed_thoughts
|
|
If True, LLM thought expanders will be collapsed when completed.
|
|
Defaults to True.
|
|
thought_labeler
|
|
An optional custom LLMThoughtLabeler instance. If unspecified, the handler
|
|
will use the default thought labeling logic. Defaults to None.
|
|
"""
|
|
self._parent_container = parent_container
|
|
self._history_parent = parent_container.container()
|
|
self._history_container: Optional[MutableExpander] = None
|
|
self._current_thought: Optional[LLMThought] = None
|
|
self._completed_thoughts: List[LLMThought] = []
|
|
self._max_thought_containers = max(max_thought_containers, 1)
|
|
self._expand_new_thoughts = expand_new_thoughts
|
|
self._collapse_completed_thoughts = collapse_completed_thoughts
|
|
self._thought_labeler = thought_labeler or LLMThoughtLabeler()
|
|
|
|
def _require_current_thought(self) -> LLMThought:
|
|
"""Return our current LLMThought. Raise an error if we have no current
|
|
thought.
|
|
"""
|
|
if self._current_thought is None:
|
|
raise RuntimeError("Current LLMThought is unexpectedly None!")
|
|
return self._current_thought
|
|
|
|
def _get_last_completed_thought(self) -> Optional[LLMThought]:
|
|
"""Return our most recent completed LLMThought, or None if we don't have one."""
|
|
if len(self._completed_thoughts) > 0:
|
|
return self._completed_thoughts[len(self._completed_thoughts) - 1]
|
|
return None
|
|
|
|
@property
|
|
def _num_thought_containers(self) -> int:
|
|
"""The number of 'thought containers' we're currently showing: the
|
|
number of completed thought containers, the history container (if it exists),
|
|
and the current thought container (if it exists).
|
|
"""
|
|
count = len(self._completed_thoughts)
|
|
if self._history_container is not None:
|
|
count += 1
|
|
if self._current_thought is not None:
|
|
count += 1
|
|
return count
|
|
|
|
def _complete_current_thought(self, final_label: Optional[str] = None) -> None:
|
|
"""Complete the current thought, optionally assigning it a new label.
|
|
Add it to our _completed_thoughts list.
|
|
"""
|
|
thought = self._require_current_thought()
|
|
thought.complete(final_label)
|
|
self._completed_thoughts.append(thought)
|
|
self._current_thought = None
|
|
|
|
def _prune_old_thought_containers(self) -> None:
|
|
"""If we have too many thoughts onscreen, move older thoughts to the
|
|
'history container.'
|
|
"""
|
|
while (
|
|
self._num_thought_containers > self._max_thought_containers
|
|
and len(self._completed_thoughts) > 0
|
|
):
|
|
# Create our history container if it doesn't exist, and if
|
|
# max_thought_containers is > 1. (if max_thought_containers is 1, we don't
|
|
# have room to show history.)
|
|
if self._history_container is None and self._max_thought_containers > 1:
|
|
self._history_container = MutableExpander(
|
|
self._history_parent,
|
|
label=self._thought_labeler.get_history_label(),
|
|
expanded=False,
|
|
)
|
|
|
|
oldest_thought = self._completed_thoughts.pop(0)
|
|
if self._history_container is not None:
|
|
self._history_container.markdown(oldest_thought.container.label)
|
|
self._history_container.append_copy(oldest_thought.container)
|
|
oldest_thought.clear()
|
|
|
|
def on_llm_start(
|
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
|
) -> None:
|
|
if self._current_thought is None:
|
|
self._current_thought = LLMThought(
|
|
parent_container=self._parent_container,
|
|
expanded=self._expand_new_thoughts,
|
|
collapse_on_complete=self._collapse_completed_thoughts,
|
|
labeler=self._thought_labeler,
|
|
)
|
|
|
|
self._current_thought.on_llm_start(serialized, prompts)
|
|
|
|
# We don't prune_old_thought_containers here, because our container won't
|
|
# be visible until it has a child.
|
|
|
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
|
self._require_current_thought().on_llm_new_token(token, **kwargs)
|
|
self._prune_old_thought_containers()
|
|
|
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
|
self._require_current_thought().on_llm_end(response, **kwargs)
|
|
self._prune_old_thought_containers()
|
|
|
|
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
|
self._require_current_thought().on_llm_error(error, **kwargs)
|
|
self._prune_old_thought_containers()
|
|
|
|
def on_tool_start(
|
|
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
|
) -> None:
|
|
self._require_current_thought().on_tool_start(serialized, input_str, **kwargs)
|
|
self._prune_old_thought_containers()
|
|
|
|
def on_tool_end(
|
|
self,
|
|
output: Any,
|
|
color: Optional[str] = None,
|
|
observation_prefix: Optional[str] = None,
|
|
llm_prefix: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
output = str(output)
|
|
self._require_current_thought().on_tool_end(
|
|
output, color, observation_prefix, llm_prefix, **kwargs
|
|
)
|
|
self._complete_current_thought()
|
|
|
|
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
|
self._require_current_thought().on_tool_error(error, **kwargs)
|
|
self._prune_old_thought_containers()
|
|
|
|
def on_text(
|
|
self,
|
|
text: str,
|
|
color: Optional[str] = None,
|
|
end: str = "",
|
|
**kwargs: Any,
|
|
) -> None:
|
|
pass
|
|
|
|
def on_chain_start(
|
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
|
) -> None:
|
|
pass
|
|
|
|
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
|
pass
|
|
|
|
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
|
pass
|
|
|
|
def on_agent_action(
|
|
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
|
) -> Any:
|
|
self._require_current_thought().on_agent_action(action, color, **kwargs)
|
|
self._prune_old_thought_containers()
|
|
|
|
def on_agent_finish(
|
|
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
|
) -> None:
|
|
if self._current_thought is not None:
|
|
self._current_thought.complete(
|
|
self._thought_labeler.get_final_agent_thought_label()
|
|
)
|
|
self._current_thought = None
|