Patch Chat History Formatting (#3236)

While we work on solidifying the memory interfaces, handle common chat
history formats.

This may break linting on anyone who has been passing in
`get_chat_history` .

Somewhat handles #3077

Alternative to #3078 that updates the typing
This commit is contained in:
Zander Chase 2023-04-20 13:31:30 -07:00 committed by GitHub
parent 8f22949dc4
commit daee0b2b97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -15,16 +15,32 @@ from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel, BaseRetriever, Document
from langchain.schema import BaseLanguageModel, BaseMessage, BaseRetriever, Document
from langchain.vectorstores.base import VectorStore
# Depending on the memory type and configuration, the chat history format may differ.
# This needs to be consolidated.
CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage]
def _get_chat_history(chat_history: List[Tuple[str, str]]) -> str:
_ROLE_MAP = {"human": "Human: ", "ai": "Assistant: "}
def _get_chat_history(chat_history: List[CHAT_TURN_TYPE]) -> str:
buffer = ""
for human_s, ai_s in chat_history:
human = "Human: " + human_s
ai = "Assistant: " + ai_s
buffer += "\n" + "\n".join([human, ai])
for dialogue_turn in chat_history:
if isinstance(dialogue_turn, BaseMessage):
role_prefix = _ROLE_MAP.get(dialogue_turn.type, f"{dialogue_turn.type}: ")
buffer += f"\n{role_prefix}{dialogue_turn.content}"
elif isinstance(dialogue_turn, tuple):
human = "Human: " + dialogue_turn[0]
ai = "Assistant: " + dialogue_turn[1]
buffer += "\n" + "\n".join([human, ai])
else:
raise ValueError(
f"Unsupported chat history format: {type(dialogue_turn)}."
f" Full chat history: {chat_history} "
)
return buffer
@ -35,7 +51,7 @@ class BaseConversationalRetrievalChain(Chain):
question_generator: LLMChain
output_key: str = "answer"
return_source_documents: bool = False
get_chat_history: Optional[Callable[[Tuple[str, str]], str]] = None
get_chat_history: Optional[Callable[[CHAT_TURN_TYPE], str]] = None
"""Return the source documents."""
class Config: