@ -4,13 +4,10 @@ import asyncio
from typing import (
from typing import (
TYPE_CHECKING ,
TYPE_CHECKING ,
Any ,
Any ,
Callable ,
Dict ,
List ,
List ,
Optional ,
Optional ,
Sequence ,
Sequence ,
Type ,
Type ,
Union ,
)
)
from langchain_core . chat_history import BaseChatMessageHistory
from langchain_core . chat_history import BaseChatMessageHistory
@ -28,6 +25,9 @@ if TYPE_CHECKING:
from langchain_core . runnables . config import RunnableConfig
from langchain_core . runnables . config import RunnableConfig
from langchain_core . tracers . schemas import Run
from langchain_core . tracers . schemas import Run
import inspect
from typing import Callable , Dict , Union
MessagesOrDictWithMessages = Union [ Sequence [ " BaseMessage " ] , Dict [ str , Any ] ]
MessagesOrDictWithMessages = Union [ Sequence [ " BaseMessage " ] , Dict [ str , Any ] ]
GetSessionHistoryCallable = Callable [ . . . , BaseChatMessageHistory ]
GetSessionHistoryCallable = Callable [ . . . , BaseChatMessageHistory ]
@ -38,8 +38,10 @@ class RunnableWithMessageHistory(RunnableBindingBase):
Base runnable must have inputs and outputs that can be converted to a list of
Base runnable must have inputs and outputs that can be converted to a list of
BaseMessages .
BaseMessages .
RunnableWithMessageHistory must always be called with a config that contains session_id , e . g . :
RunnableWithMessageHistory must always be called with a config that contains
` ` { " configurable " : { " session_id " : " <SESSION_ID> " } } ` `
session_id , e . g . :
` ` { " configurable " : { " session_id " : " <SESSION_ID> " } } `
Example ( dict input ) :
Example ( dict input ) :
. . code - block : : python
. . code - block : : python
@ -79,12 +81,66 @@ class RunnableWithMessageHistory(RunnableBindingBase):
)
)
# -> "The inverse of cosine is called arccosine ..."
# -> "The inverse of cosine is called arccosine ..."
Here ' s an example that uses an in memory chat history, and a factory that
takes in two keys ( user_id and conversation id ) to create a chat history instance .
. . code - block : : python
store = { }
def get_session_history (
user_id : str , conversation_id : str
) - > ChatMessageHistory :
if ( user_id , conversation_id ) not in store :
store [ ( user_id , conversation_id ) ] = ChatMessageHistory ( )
return store [ ( user_id , conversation_id ) ]
prompt = ChatPromptTemplate . from_messages ( [
( " system " , " You ' re an assistant who ' s good at {ability} " ) ,
MessagesPlaceholder ( variable_name = " history " ) ,
( " human " , " {question} " ) ,
] )
chain = prompt | ChatAnthropic ( model = " claude-2 " )
with_message_history = RunnableWithMessageHistory (
chain ,
get_session_history = get_session_history ,
input_messages_key = " messages " ,
history_messages_key = " history " ,
history_factory_config = [
ConfigurableFieldSpec (
id = " user_id " ,
annotation = str ,
name = " User ID " ,
description = " Unique identifier for the user. " ,
default = " " ,
is_shared = True ,
) ,
ConfigurableFieldSpec (
id = " conversation_id " ,
annotation = str ,
name = " Conversation ID " ,
description = " Unique identifier for the conversation. " ,
default = " " ,
is_shared = True ,
) ,
] ,
)
chain_with_history . invoke (
{ " ability " : " math " , " question " : " What does cosine mean? " } ,
config = { " configurable " : { " user_id " : " 123 " , " conversation_id " : " 1 " } }
)
""" # noqa: E501
""" # noqa: E501
get_session_history : GetSessionHistoryCallable
get_session_history : GetSessionHistoryCallable
input_messages_key : Optional [ str ] = None
input_messages_key : Optional [ str ] = None
output_messages_key : Optional [ str ] = None
output_messages_key : Optional [ str ] = None
history_messages_key : Optional [ str ] = None
history_messages_key : Optional [ str ] = None
history_factory_config : Sequence [ ConfigurableFieldSpec ]
@classmethod
@classmethod
def get_lc_namespace ( cls ) - > List [ str ] :
def get_lc_namespace ( cls ) - > List [ str ] :
@ -102,6 +158,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
input_messages_key : Optional [ str ] = None ,
input_messages_key : Optional [ str ] = None ,
output_messages_key : Optional [ str ] = None ,
output_messages_key : Optional [ str ] = None ,
history_messages_key : Optional [ str ] = None ,
history_messages_key : Optional [ str ] = None ,
history_factory_config : Optional [ Sequence [ ConfigurableFieldSpec ] ] = None ,
* * kwargs : Any ,
* * kwargs : Any ,
) - > None :
) - > None :
""" Initialize RunnableWithMessageHistory.
""" Initialize RunnableWithMessageHistory.
@ -121,10 +178,10 @@ class RunnableWithMessageHistory(RunnableBindingBase):
- A BaseMessage or sequence of BaseMessages
- A BaseMessage or sequence of BaseMessages
- A dict with a key for a BaseMessage or sequence of BaseMessages
- A dict with a key for a BaseMessage or sequence of BaseMessages
get_session_history : Function that returns a new BaseChatMessageHistory
get_session_history : Function that returns a new BaseChatMessageHistory .
given a session id . Should take a single
This function should either take a single positional argument
positional argument ` session_id ` which is a string and a named argument
` session_id ` of type string and return a corresponding
` user_id ` which can be a string or None . e . g . :
chat message history instance .
` ` ` python
` ` ` python
def get_session_history (
def get_session_history (
@ -135,12 +192,29 @@ class RunnableWithMessageHistory(RunnableBindingBase):
. . .
. . .
` ` `
` ` `
Or it should take keyword arguments that match the keys of
` session_history_config_specs ` and return a corresponding
chat message history instance .
` ` ` python
def get_session_history (
* ,
user_id : str ,
thread_id : str ,
) - > BaseChatMessageHistory :
. . .
` ` `
input_messages_key : Must be specified if the base runnable accepts a dict
input_messages_key : Must be specified if the base runnable accepts a dict
as input .
as input .
output_messages_key : Must be specified if the base runnable returns a dict
output_messages_key : Must be specified if the base runnable returns a dict
as output .
as output .
history_messages_key : Must be specified if the base runnable accepts a dict
history_messages_key : Must be specified if the base runnable accepts a dict
as input and expects a separate key for historical messages .
as input and expects a separate key for historical messages .
history_factory_config : Configure fields that should be passed to the
chat history factory . See ` ` ConfigurableFieldSpec ` ` for more details .
Specifying these allows you to pass multiple config keys
into the get_session_history factory .
* * kwargs : Arbitrary additional kwargs to pass to parent class
* * kwargs : Arbitrary additional kwargs to pass to parent class
` ` RunnableBindingBase ` ` init .
` ` RunnableBindingBase ` ` init .
""" # noqa: E501
""" # noqa: E501
@ -155,29 +229,36 @@ class RunnableWithMessageHistory(RunnableBindingBase):
bound = (
bound = (
history_chain | runnable . with_listeners ( on_end = self . _exit_history )
history_chain | runnable . with_listeners ( on_end = self . _exit_history )
) . with_config ( run_name = " RunnableWithMessageHistory " )
) . with_config ( run_name = " RunnableWithMessageHistory " )
if history_factory_config :
_config_specs = history_factory_config
else :
# If not provided, then we'll use the default session_id field
_config_specs = [
ConfigurableFieldSpec (
id = " session_id " ,
annotation = str ,
name = " Session ID " ,
description = " Unique identifier for a session. " ,
default = " " ,
is_shared = True ,
) ,
]
super ( ) . __init__ (
super ( ) . __init__ (
get_session_history = get_session_history ,
get_session_history = get_session_history ,
input_messages_key = input_messages_key ,
input_messages_key = input_messages_key ,
output_messages_key = output_messages_key ,
output_messages_key = output_messages_key ,
bound = bound ,
bound = bound ,
history_messages_key = history_messages_key ,
history_messages_key = history_messages_key ,
history_factory_config = _config_specs ,
* * kwargs ,
* * kwargs ,
)
)
@property
@property
def config_specs ( self ) - > List [ ConfigurableFieldSpec ] :
def config_specs ( self ) - > List [ ConfigurableFieldSpec ] :
return get_unique_config_specs (
return get_unique_config_specs (
super ( ) . config_specs
super ( ) . config_specs + list ( self . history_factory_config )
+ [
ConfigurableFieldSpec (
id = " session_id " ,
annotation = str ,
name = " Session ID " ,
description = " Unique identifier for a session. " ,
default = " " ,
is_shared = True ,
) ,
]
)
)
def get_input_schema (
def get_input_schema (
@ -278,16 +359,46 @@ class RunnableWithMessageHistory(RunnableBindingBase):
def _merge_configs ( self , * configs : Optional [ RunnableConfig ] ) - > RunnableConfig :
def _merge_configs ( self , * configs : Optional [ RunnableConfig ] ) - > RunnableConfig :
config = super ( ) . _merge_configs ( * configs )
config = super ( ) . _merge_configs ( * configs )
# extract session_id
expected_keys = [ field_spec . id for field_spec in self . history_factory_config ]
if " session_id " not in config . get ( " configurable " , { } ) :
configurable = config . get ( " configurable " , { } )
missing_keys = set ( expected_keys ) - set ( configurable . keys ( ) )
if missing_keys :
example_input = { self . input_messages_key : " foo " }
example_input = { self . input_messages_key : " foo " }
example_config = { " configurable " : { " session_id " : " 123 " } }
example_configurable = {
missing_key : " [your-value-here] " for missing_key in missing_keys
}
example_config = { " configurable " : example_configurable }
raise ValueError (
raise ValueError (
" session_id is required. "
f " Missing keys { sorted ( missing_keys ) } in config[ ' configurable ' ] "
" Pass it in as part of the config argument to .invoke() or .stream() "
f " Expected keys are { sorted ( expected_keys ) } . "
f " \n eg. chain.invoke( { example_input } , { example_config } ) "
f " When using via .invoke() or .stream(), pass in a config; "
f " e.g., chain.invoke( { example_input } , { example_config } ) "
)
)
# attach message_history
session_id = config [ " configurable " ] [ " session_id " ]
parameter_names = _get_parameter_names ( self . get_session_history )
config [ " configurable " ] [ " message_history " ] = self . get_session_history ( session_id )
if len ( expected_keys ) == 1 :
# If arity = 1, then invoke function by positional arguments
message_history = self . get_session_history ( configurable [ expected_keys [ 0 ] ] )
else :
# otherwise verify that names of keys patch and invoke by named arguments
if set ( expected_keys ) != set ( parameter_names ) :
raise ValueError (
f " Expected keys { sorted ( expected_keys ) } do not match parameter "
f " names { sorted ( parameter_names ) } of get_session_history. "
)
message_history = self . get_session_history (
* * { key : configurable [ key ] for key in expected_keys }
)
config [ " configurable " ] [ " message_history " ] = message_history
return config
return config
def _get_parameter_names ( callable_ : GetSessionHistoryCallable ) - > List [ str ] :
""" Get the parameter names of the callable. """
sig = inspect . signature ( callable_ )
return list ( sig . parameters . keys ( ) )