@ -1,7 +1,8 @@
""" Chain for chatting with a vector database. """
from __future__ import annotations
from typing import Any , Dict , List , Optional , Tuple
from pathlib import Path
from typing import Any , Callable , Dict , List , Optional , Tuple , Union
from pydantic import BaseModel
@ -33,6 +34,7 @@ class ChatVectorDBChain(Chain, BaseModel):
output_key : str = " answer "
return_source_documents : bool = False
top_k_docs_for_context : int = 4
get_chat_history : Optional [ Callable [ [ Tuple [ str , str ] ] , str ] ] = None
""" Return the source documents. """
@property
@ -81,7 +83,8 @@ class ChatVectorDBChain(Chain, BaseModel):
def _call ( self , inputs : Dict [ str , Any ] ) - > Dict [ str , Any ] :
question = inputs [ " question " ]
chat_history_str = _get_chat_history ( inputs [ " chat_history " ] )
get_chat_history = self . get_chat_history or _get_chat_history
chat_history_str = get_chat_history ( inputs [ " chat_history " ] )
vectordbkwargs = inputs . get ( " vectordbkwargs " , { } )
if chat_history_str :
new_question = self . question_generator . run (
@ -103,7 +106,8 @@ class ChatVectorDBChain(Chain, BaseModel):
async def _acall ( self , inputs : Dict [ str , Any ] ) - > Dict [ str , Any ] :
question = inputs [ " question " ]
chat_history_str = _get_chat_history ( inputs [ " chat_history " ] )
get_chat_history = self . get_chat_history or _get_chat_history
chat_history_str = get_chat_history ( inputs [ " chat_history " ] )
vectordbkwargs = inputs . get ( " vectordbkwargs " , { } )
if chat_history_str :
new_question = await self . question_generator . arun (
@ -123,3 +127,8 @@ class ChatVectorDBChain(Chain, BaseModel):
return { self . output_key : answer , " source_documents " : docs }
else :
return { self . output_key : answer }
def save ( self , file_path : Union [ Path , str ] ) - > None :
if self . get_chat_history :
raise ValueError ( " Chain not savable when `get_chat_history` is not None. " )
super ( ) . save ( file_path )