|
|
|
@ -5,7 +5,7 @@ import inspect
|
|
|
|
|
import warnings
|
|
|
|
|
from abc import abstractmethod
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
|
|
|
|
|
|
|
|
|
from langchain.callbacks.manager import (
|
|
|
|
|
AsyncCallbackManagerForChainRun,
|
|
|
|
@ -18,7 +18,7 @@ from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
|
|
|
|
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
|
|
|
|
|
from langchain.chains.llm import LLMChain
|
|
|
|
|
from langchain.chains.question_answering import load_qa_chain
|
|
|
|
|
from langchain.pydantic_v1 import Extra, Field, root_validator
|
|
|
|
|
from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
|
|
|
|
from langchain.schema import BasePromptTemplate, BaseRetriever, Document
|
|
|
|
|
from langchain.schema.language_model import BaseLanguageModel
|
|
|
|
|
from langchain.schema.messages import BaseMessage
|
|
|
|
@ -50,6 +50,11 @@ def _get_chat_history(chat_history: List[CHAT_TURN_TYPE]) -> str:
|
|
|
|
|
return buffer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InputType(BaseModel):
|
|
|
|
|
question: str
|
|
|
|
|
chat_history: List[CHAT_TURN_TYPE]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseConversationalRetrievalChain(Chain):
|
|
|
|
|
"""Chain for chatting with an index."""
|
|
|
|
|
|
|
|
|
@ -87,6 +92,10 @@ class BaseConversationalRetrievalChain(Chain):
|
|
|
|
|
"""Input keys."""
|
|
|
|
|
return ["question", "chat_history"]
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def input_schema(self) -> Type[BaseModel]:
|
|
|
|
|
return InputType
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def output_keys(self) -> List[str]:
|
|
|
|
|
"""Return the output keys.
|
|
|
|
|