@ -15,16 +15,32 @@ from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_
from langchain . chains . llm import LLMChain
from langchain . chains . question_answering import load_qa_chain
from langchain . prompts . base import BasePromptTemplate
from langchain . schema import BaseLanguageModel , Base Retriever, Document
from langchain . schema import BaseLanguageModel , Base Message, Base Retriever, Document
from langchain . vectorstores . base import VectorStore
# Depending on the memory type and configuration, the chat history format may differ.
# This needs to be consolidated.
CHAT_TURN_TYPE = Union [ Tuple [ str , str ] , BaseMessage ]
def _get_chat_history ( chat_history : List [ Tuple [ str , str ] ] ) - > str :
_ROLE_MAP = { " human " : " Human: " , " ai " : " Assistant: " }
def _get_chat_history ( chat_history : List [ CHAT_TURN_TYPE ] ) - > str :
buffer = " "
for human_s , ai_s in chat_history :
human = " Human: " + human_s
ai = " Assistant: " + ai_s
buffer + = " \n " + " \n " . join ( [ human , ai ] )
for dialogue_turn in chat_history :
if isinstance ( dialogue_turn , BaseMessage ) :
role_prefix = _ROLE_MAP . get ( dialogue_turn . type , f " { dialogue_turn . type } : " )
buffer + = f " \n { role_prefix } { dialogue_turn . content } "
elif isinstance ( dialogue_turn , tuple ) :
human = " Human: " + dialogue_turn [ 0 ]
ai = " Assistant: " + dialogue_turn [ 1 ]
buffer + = " \n " + " \n " . join ( [ human , ai ] )
else :
raise ValueError (
f " Unsupported chat history format: { type ( dialogue_turn ) } . "
f " Full chat history: { chat_history } "
)
return buffer
@ -35,7 +51,7 @@ class BaseConversationalRetrievalChain(Chain):
question_generator : LLMChain
output_key : str = " answer "
return_source_documents : bool = False
get_chat_history : Optional [ Callable [ [ Tuple[ str , str ] ] , str ] ] = None
get_chat_history : Optional [ Callable [ [ CHAT_TURN_TYPE ] , str ] ] = None
""" Return the source documents. """
class Config :