Harrison/chat history formatter1 (#1538)

Co-authored-by: Youssef A. Abukwaik <yousseb@users.noreply.github.com>
This commit is contained in:
Harrison Chase 2023-03-08 20:46:37 -08:00 committed by GitHub
parent 31303d0b11
commit 523ad8d2e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 89 additions and 19 deletions

View File

@ -268,48 +268,44 @@
},
{
"cell_type": "markdown",
"id": "4f49beab",
"metadata": {},
"source": [
"## Chat Vector DB with `search_distance`\n",
"If you are using a vector store that supports filtering by search distance, you can add a threshold value parameter."
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5ed8d612",
"metadata": {},
"outputs": [],
"source": [
"vectordbkwargs = {\"search_distance\": 0.9}"
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6a7b3459",
"metadata": {},
"outputs": [],
"source": [
"qa = ChatVectorDBChain.from_llm(OpenAI(temperature=0), vectorstore, return_source_documents=True)\n",
"chat_history = []\n",
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
"result = qa({\"question\": query, \"chat_history\": chat_history, \"vectordbkwargs\": vectordbkwargs})"
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "markdown",
"id": "99b96dae",
"metadata": {},
"source": [
"## Chat Vector DB with `map_reduce`\n",
"We can also use different types of combine document chains with the Chat Vector DB chain."
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "code",
@ -524,6 +520,71 @@
"query = \"Did he mention who she suceeded\"\n",
"result = qa({\"question\": query, \"chat_history\": chat_history})\n"
]
},
{
"cell_type": "markdown",
"id": "f793d56b",
"metadata": {},
"source": [
"## get_chat_history Function\n",
"You can also specify a `get_chat_history` function, which can be used to format the chat_history string."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "a7ba9d8c",
"metadata": {},
"outputs": [],
"source": [
"def get_chat_history(inputs) -> str:\n",
" res = []\n",
" for human, ai in inputs:\n",
" res.append(f\"Human:{human}\\nAI:{ai}\")\n",
" return \"\\n\".join(res)\n",
"qa = ChatVectorDBChain.from_llm(OpenAI(temperature=0), vectorstore, get_chat_history=get_chat_history)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "a3e33c0d",
"metadata": {},
"outputs": [],
"source": [
"chat_history = []\n",
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
"result = qa({\"question\": query, \"chat_history\": chat_history})"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "936dc62f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\" The president said that Ketanji Brown Jackson is one of the nation's top legal minds, a former top litigator in private practice, a former federal public defender, and from a family of public school educators and police officers. He also said that she is a consensus builder and has received a broad range of support from the Fraternal Order of Police to former judges appointed by Democrats and Republicans.\""
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result['answer']"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b8c26901",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {

View File

@ -1,7 +1,8 @@
"""Chain for chatting with a vector database."""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel
@ -33,6 +34,7 @@ class ChatVectorDBChain(Chain, BaseModel):
output_key: str = "answer"
return_source_documents: bool = False
top_k_docs_for_context: int = 4
get_chat_history: Optional[Callable[[Tuple[str, str]], str]] = None
"""Return the source documents."""
@property
@ -81,7 +83,8 @@ class ChatVectorDBChain(Chain, BaseModel):
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
question = inputs["question"]
chat_history_str = _get_chat_history(inputs["chat_history"])
get_chat_history = self.get_chat_history or _get_chat_history
chat_history_str = get_chat_history(inputs["chat_history"])
vectordbkwargs = inputs.get("vectordbkwargs", {})
if chat_history_str:
new_question = self.question_generator.run(
@ -103,7 +106,8 @@ class ChatVectorDBChain(Chain, BaseModel):
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
question = inputs["question"]
chat_history_str = _get_chat_history(inputs["chat_history"])
get_chat_history = self.get_chat_history or _get_chat_history
chat_history_str = get_chat_history(inputs["chat_history"])
vectordbkwargs = inputs.get("vectordbkwargs", {})
if chat_history_str:
new_question = await self.question_generator.arun(
@ -123,3 +127,8 @@ class ChatVectorDBChain(Chain, BaseModel):
return {self.output_key: answer, "source_documents": docs}
else:
return {self.output_key: answer}
def save(self, file_path: Union[Path, str]) -> None:
if self.get_chat_history:
raise ValueError("Chain not savable when `get_chat_history` is not None.")
super().save(file_path)