From 9f39c23a13ebc031d852b69e1a694b446c8997e6 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Wed, 11 Oct 2023 14:13:48 -0700 Subject: [PATCH] add input type for convo retrieval chain (#11679) --- .../chains/conversational_retrieval/base.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/chains/conversational_retrieval/base.py b/libs/langchain/langchain/chains/conversational_retrieval/base.py index a8beab3e9e..aee5ffa2ae 100644 --- a/libs/langchain/langchain/chains/conversational_retrieval/base.py +++ b/libs/langchain/langchain/chains/conversational_retrieval/base.py @@ -5,7 +5,7 @@ import inspect import warnings from abc import abstractmethod from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, @@ -18,7 +18,7 @@ from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT from langchain.chains.llm import LLMChain from langchain.chains.question_answering import load_qa_chain -from langchain.pydantic_v1 import Extra, Field, root_validator +from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator from langchain.schema import BasePromptTemplate, BaseRetriever, Document from langchain.schema.language_model import BaseLanguageModel from langchain.schema.messages import BaseMessage @@ -50,6 +50,11 @@ def _get_chat_history(chat_history: List[CHAT_TURN_TYPE]) -> str: return buffer +class InputType(BaseModel): + question: str + chat_history: List[CHAT_TURN_TYPE] + + class BaseConversationalRetrievalChain(Chain): """Chain for chatting with an index.""" @@ -87,6 +92,10 @@ class BaseConversationalRetrievalChain(Chain): """Input keys.""" return ["question", "chat_history"] + @property + def input_schema(self) -> Type[BaseModel]: + return InputType + @property def output_keys(self) -> List[str]: """Return the output keys.