add input type for convo retrieval chain (#11679)

pull/11682/head
Harrison Chase 1 year ago committed by GitHub
parent d5e762d328
commit 9f39c23a13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5,7 +5,7 @@ import inspect
import warnings
from abc import abstractmethod
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
@ -18,7 +18,7 @@ from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.pydantic_v1 import Extra, Field, root_validator
from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator
from langchain.schema import BasePromptTemplate, BaseRetriever, Document
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import BaseMessage
@ -50,6 +50,11 @@ def _get_chat_history(chat_history: List[CHAT_TURN_TYPE]) -> str:
return buffer
class InputType(BaseModel):
question: str
chat_history: List[CHAT_TURN_TYPE]
class BaseConversationalRetrievalChain(Chain):
"""Chain for chatting with an index."""
@ -87,6 +92,10 @@ class BaseConversationalRetrievalChain(Chain):
"""Input keys."""
return ["question", "chat_history"]
@property
def input_schema(self) -> Type[BaseModel]:
return InputType
@property
def output_keys(self) -> List[str]:
"""Return the output keys.

Loading…
Cancel
Save