diff --git a/docs/modules/indexes/chain_examples/chat_vector_db.ipynb b/docs/modules/indexes/chain_examples/chat_vector_db.ipynb index 62a007b9..d5b7148d 100644 --- a/docs/modules/indexes/chain_examples/chat_vector_db.ipynb +++ b/docs/modules/indexes/chain_examples/chat_vector_db.ipynb @@ -268,48 +268,44 @@ }, { "cell_type": "markdown", + "id": "4f49beab", + "metadata": {}, "source": [ "## Chat Vector DB with `search_distance`\n", "If you are using a vector store that supports filtering by search distance, you can add a threshold value parameter." - ], - "metadata": { - "collapsed": false - } + ] }, { "cell_type": "code", "execution_count": null, + "id": "5ed8d612", + "metadata": {}, "outputs": [], "source": [ "vectordbkwargs = {\"search_distance\": 0.9}" - ], - "metadata": { - "collapsed": false - } + ] }, { "cell_type": "code", "execution_count": null, + "id": "6a7b3459", + "metadata": {}, "outputs": [], "source": [ "qa = ChatVectorDBChain.from_llm(OpenAI(temperature=0), vectorstore, return_source_documents=True)\n", "chat_history = []\n", "query = \"What did the president say about Ketanji Brown Jackson\"\n", "result = qa({\"question\": query, \"chat_history\": chat_history, \"vectordbkwargs\": vectordbkwargs})" - ], - "metadata": { - "collapsed": false - } + ] }, { "cell_type": "markdown", + "id": "99b96dae", + "metadata": {}, "source": [ "## Chat Vector DB with `map_reduce`\n", "We can also use different types of combine document chains with the Chat Vector DB chain." - ], - "metadata": { - "collapsed": false - } + ] }, { "cell_type": "code", @@ -524,6 +520,71 @@ "query = \"Did he mention who she suceeded\"\n", "result = qa({\"question\": query, \"chat_history\": chat_history})\n" ] + }, + { + "cell_type": "markdown", + "id": "f793d56b", + "metadata": {}, + "source": [ + "## get_chat_history Function\n", + "You can also specify a `get_chat_history` function, which can be used to format the chat_history string." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "a7ba9d8c", + "metadata": {}, + "outputs": [], + "source": [ + "def get_chat_history(inputs) -> str:\n", + " res = []\n", + " for human, ai in inputs:\n", + " res.append(f\"Human:{human}\\nAI:{ai}\")\n", + " return \"\\n\".join(res)\n", + "qa = ChatVectorDBChain.from_llm(OpenAI(temperature=0), vectorstore, get_chat_history=get_chat_history)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "a3e33c0d", + "metadata": {}, + "outputs": [], + "source": [ + "chat_history = []\n", + "query = \"What did the president say about Ketanji Brown Jackson\"\n", + "result = qa({\"question\": query, \"chat_history\": chat_history})" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "936dc62f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\" The president said that Ketanji Brown Jackson is one of the nation's top legal minds, a former top litigator in private practice, a former federal public defender, and from a family of public school educators and police officers. He also said that she is a consensus builder and has received a broad range of support from the Fraternal Order of Police to former judges appointed by Democrats and Republicans.\"" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result['answer']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b8c26901", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/langchain/chains/chat_vector_db/base.py b/langchain/chains/chat_vector_db/base.py index 22cbc5dd..4dd18b68 100644 --- a/langchain/chains/chat_vector_db/base.py +++ b/langchain/chains/chat_vector_db/base.py @@ -1,7 +1,8 @@ """Chain for chatting with a vector database.""" from __future__ import annotations -from typing import Any, Dict, List, Optional, Tuple +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from pydantic import BaseModel @@ -33,6 +34,7 @@ class ChatVectorDBChain(Chain, BaseModel): output_key: str = "answer" return_source_documents: bool = False top_k_docs_for_context: int = 4 + get_chat_history: Optional[Callable[[Tuple[str, str]], str]] = None """Return the source documents.""" @property @@ -81,7 +83,8 @@ class ChatVectorDBChain(Chain, BaseModel): def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]: question = inputs["question"] - chat_history_str = _get_chat_history(inputs["chat_history"]) + get_chat_history = self.get_chat_history or _get_chat_history + chat_history_str = get_chat_history(inputs["chat_history"]) vectordbkwargs = inputs.get("vectordbkwargs", {}) if chat_history_str: new_question = self.question_generator.run( @@ -103,7 +106,8 @@ class ChatVectorDBChain(Chain, BaseModel): async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, Any]: question = inputs["question"] - chat_history_str = _get_chat_history(inputs["chat_history"]) + get_chat_history = self.get_chat_history or _get_chat_history + chat_history_str = get_chat_history(inputs["chat_history"]) vectordbkwargs = inputs.get("vectordbkwargs", {}) if chat_history_str: new_question = await self.question_generator.arun( @@ -123,3 +127,8 @@ class ChatVectorDBChain(Chain, BaseModel): return {self.output_key: answer, "source_documents": docs} else: return {self.output_key: answer} + + def save(self, file_path: Union[Path, str]) -> None: + if self.get_chat_history: + raise ValueError("Chain not savable when `get_chat_history` is not None.") + super().save(file_path)