diff --git a/langchain/chains/conversational_retrieval/base.py b/langchain/chains/conversational_retrieval/base.py index 97424ecb..b7fb299e 100644 --- a/langchain/chains/conversational_retrieval/base.py +++ b/langchain/chains/conversational_retrieval/base.py @@ -15,16 +15,32 @@ from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_ from langchain.chains.llm import LLMChain from langchain.chains.question_answering import load_qa_chain from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel, BaseRetriever, Document +from langchain.schema import BaseLanguageModel, BaseMessage, BaseRetriever, Document from langchain.vectorstores.base import VectorStore +# Depending on the memory type and configuration, the chat history format may differ. +# This needs to be consolidated. +CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage] -def _get_chat_history(chat_history: List[Tuple[str, str]]) -> str: + +_ROLE_MAP = {"human": "Human: ", "ai": "Assistant: "} + + +def _get_chat_history(chat_history: List[CHAT_TURN_TYPE]) -> str: buffer = "" - for human_s, ai_s in chat_history: - human = "Human: " + human_s - ai = "Assistant: " + ai_s - buffer += "\n" + "\n".join([human, ai]) + for dialogue_turn in chat_history: + if isinstance(dialogue_turn, BaseMessage): + role_prefix = _ROLE_MAP.get(dialogue_turn.type, f"{dialogue_turn.type}: ") + buffer += f"\n{role_prefix}{dialogue_turn.content}" + elif isinstance(dialogue_turn, tuple): + human = "Human: " + dialogue_turn[0] + ai = "Assistant: " + dialogue_turn[1] + buffer += "\n" + "\n".join([human, ai]) + else: + raise ValueError( + f"Unsupported chat history format: {type(dialogue_turn)}." + f" Full chat history: {chat_history} " + ) return buffer @@ -35,7 +51,7 @@ class BaseConversationalRetrievalChain(Chain): question_generator: LLMChain output_key: str = "answer" return_source_documents: bool = False - get_chat_history: Optional[Callable[[Tuple[str, str]], str]] = None + get_chat_history: Optional[Callable[[CHAT_TURN_TYPE], str]] = None """Return the source documents.""" class Config: