Harrison/chat history formatter1 (#1538)

Co-authored-by: Youssef A. Abukwaik <yousseb@users.noreply.github.com>
This commit is contained in:
Harrison Chase 2023-03-08 20:46:37 -08:00 committed by GitHub
parent 31303d0b11
commit 523ad8d2e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 89 additions and 19 deletions

View File

@ -268,48 +268,44 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "4f49beab",
"metadata": {},
"source": [ "source": [
"## Chat Vector DB with `search_distance`\n", "## Chat Vector DB with `search_distance`\n",
"If you are using a vector store that supports filtering by search distance, you can add a threshold value parameter." "If you are using a vector store that supports filtering by search distance, you can add a threshold value parameter."
], ]
"metadata": {
"collapsed": false
}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "5ed8d612",
"metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"vectordbkwargs = {\"search_distance\": 0.9}" "vectordbkwargs = {\"search_distance\": 0.9}"
], ]
"metadata": {
"collapsed": false
}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "6a7b3459",
"metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"qa = ChatVectorDBChain.from_llm(OpenAI(temperature=0), vectorstore, return_source_documents=True)\n", "qa = ChatVectorDBChain.from_llm(OpenAI(temperature=0), vectorstore, return_source_documents=True)\n",
"chat_history = []\n", "chat_history = []\n",
"query = \"What did the president say about Ketanji Brown Jackson\"\n", "query = \"What did the president say about Ketanji Brown Jackson\"\n",
"result = qa({\"question\": query, \"chat_history\": chat_history, \"vectordbkwargs\": vectordbkwargs})" "result = qa({\"question\": query, \"chat_history\": chat_history, \"vectordbkwargs\": vectordbkwargs})"
], ]
"metadata": {
"collapsed": false
}
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "99b96dae",
"metadata": {},
"source": [ "source": [
"## Chat Vector DB with `map_reduce`\n", "## Chat Vector DB with `map_reduce`\n",
"We can also use different types of combine document chains with the Chat Vector DB chain." "We can also use different types of combine document chains with the Chat Vector DB chain."
], ]
"metadata": {
"collapsed": false
}
}, },
{ {
"cell_type": "code", "cell_type": "code",
@ -524,6 +520,71 @@
"query = \"Did he mention who she suceeded\"\n", "query = \"Did he mention who she suceeded\"\n",
"result = qa({\"question\": query, \"chat_history\": chat_history})\n" "result = qa({\"question\": query, \"chat_history\": chat_history})\n"
] ]
},
{
"cell_type": "markdown",
"id": "f793d56b",
"metadata": {},
"source": [
"## get_chat_history Function\n",
"You can also specify a `get_chat_history` function, which can be used to format the chat_history string."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "a7ba9d8c",
"metadata": {},
"outputs": [],
"source": [
"def get_chat_history(inputs) -> str:\n",
" res = []\n",
" for human, ai in inputs:\n",
" res.append(f\"Human:{human}\\nAI:{ai}\")\n",
" return \"\\n\".join(res)\n",
"qa = ChatVectorDBChain.from_llm(OpenAI(temperature=0), vectorstore, get_chat_history=get_chat_history)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "a3e33c0d",
"metadata": {},
"outputs": [],
"source": [
"chat_history = []\n",
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
"result = qa({\"question\": query, \"chat_history\": chat_history})"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "936dc62f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\" The president said that Ketanji Brown Jackson is one of the nation's top legal minds, a former top litigator in private practice, a former federal public defender, and from a family of public school educators and police officers. He also said that she is a consensus builder and has received a broad range of support from the Fraternal Order of Police to former judges appointed by Democrats and Republicans.\""
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result['answer']"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b8c26901",
"metadata": {},
"outputs": [],
"source": []
} }
], ],
"metadata": { "metadata": {

View File

@ -1,7 +1,8 @@
"""Chain for chatting with a vector database.""" """Chain for chatting with a vector database."""
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel from pydantic import BaseModel
@ -33,6 +34,7 @@ class ChatVectorDBChain(Chain, BaseModel):
output_key: str = "answer" output_key: str = "answer"
return_source_documents: bool = False return_source_documents: bool = False
top_k_docs_for_context: int = 4 top_k_docs_for_context: int = 4
get_chat_history: Optional[Callable[[Tuple[str, str]], str]] = None
"""Return the source documents.""" """Return the source documents."""
@property @property
@ -81,7 +83,8 @@ class ChatVectorDBChain(Chain, BaseModel):
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]: def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
question = inputs["question"] question = inputs["question"]
chat_history_str = _get_chat_history(inputs["chat_history"]) get_chat_history = self.get_chat_history or _get_chat_history
chat_history_str = get_chat_history(inputs["chat_history"])
vectordbkwargs = inputs.get("vectordbkwargs", {}) vectordbkwargs = inputs.get("vectordbkwargs", {})
if chat_history_str: if chat_history_str:
new_question = self.question_generator.run( new_question = self.question_generator.run(
@ -103,7 +106,8 @@ class ChatVectorDBChain(Chain, BaseModel):
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, Any]: async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
question = inputs["question"] question = inputs["question"]
chat_history_str = _get_chat_history(inputs["chat_history"]) get_chat_history = self.get_chat_history or _get_chat_history
chat_history_str = get_chat_history(inputs["chat_history"])
vectordbkwargs = inputs.get("vectordbkwargs", {}) vectordbkwargs = inputs.get("vectordbkwargs", {})
if chat_history_str: if chat_history_str:
new_question = await self.question_generator.arun( new_question = await self.question_generator.arun(
@ -123,3 +127,8 @@ class ChatVectorDBChain(Chain, BaseModel):
return {self.output_key: answer, "source_documents": docs} return {self.output_key: answer, "source_documents": docs}
else: else:
return {self.output_key: answer} return {self.output_key: answer}
def save(self, file_path: Union[Path, str]) -> None:
if self.get_chat_history:
raise ValueError("Chain not savable when `get_chat_history` is not None.")
super().save(file_path)