mirror of https://github.com/hwchase17/langchain
Add create_conv_retrieval_chain func (#15084)
``` +----------+ | MapInput | **+----------+**** **** **** **** *** ** **** +------------------------------------+ ** | Lambda(itemgetter('chat_history')) | * +------------------------------------+ * * * * * * * +---------------------------+ +--------------------------------+ | Lambda(_get_chat_history) | | Lambda(itemgetter('question')) | +---------------------------+ +--------------------------------+ * * * * * * +----------------------------+ +------------------------+ | ContextSet('chat_history') | | ContextSet('question') | +----------------------------+ +------------------------+ **** **** **** **** ** ** +-----------+ | MapOutput | +-----------+ * * * +----------------+ | PromptTemplate | +----------------+ * * * +-------------+ | FakeListLLM | +-------------+ * * * +-----------------+ | StrOutputParser | +-----------------+ * * * +----------------------------+ | ContextSet('new_question') | +----------------------------+ * * * +---------------------+ | SequentialRetriever | +---------------------+ * * * +------------------------------------+ | Lambda(_reduce_tokens_below_limit) | +------------------------------------+ * * * +-------------------------------+ | ContextSet('input_documents') | +-------------------------------+ * * * +----------+ ***| MapInput |**** ******* +----------+ ******** ******** * ******* ******* * ******** **** * **** +-------------------------------+ +----------------------------+ +----------------------------+ | ContextGet('input_documents') | | ContextGet('chat_history') | | ContextGet('new_question') | +-------------------------------+**** +----------------------------+ +----------------------------+ ********* * ******* ******** * ****** ***** * **** +-----------+ | MapOutput | +-----------+ * * * +-------------+ | FakeListLLM | +-------------+ * * * +----------+ ***| MapInput |*** ******** +----------+ ****** ******* * ***** ******** * ****** **** * *** +-------------------------------+ +----------------------------+ +-------------+ | ContextGet('input_documents') | | ContextGet('new_question') | **| Passthrough | +-------------------------------+ +----------------------------+ ******* +-------------+ ******* * ****** ****** * ******* **** * **** +-----------+ | MapOutput | +-----------+ ``` --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>pull/15196/head
parent
4ad77f777e
commit
f36ef0739d
@ -0,0 +1,67 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.language_models import LanguageModelLike
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.retrievers import RetrieverLike, RetrieverOutputLike
|
||||
from langchain_core.runnables import RunnableBranch
|
||||
|
||||
|
||||
def create_history_aware_retriever(
|
||||
llm: LanguageModelLike,
|
||||
retriever: RetrieverLike,
|
||||
prompt: BasePromptTemplate,
|
||||
) -> RetrieverOutputLike:
|
||||
"""Create a chain that takes conversation history and returns documents.
|
||||
|
||||
If there is no `chat_history`, then the `input` is just passed directly to the
|
||||
retriever. If there is `chat_history`, then the prompt and LLM will be used
|
||||
to generate a search query. That search query is then passed to the retriever.
|
||||
|
||||
Args:
|
||||
llm: Language model to use for generating a search term given chat history
|
||||
retriever: RetrieverLike object that takes a string as input and outputs
|
||||
a list of Documents.
|
||||
prompt: The prompt used to generate the search query for the retriever.
|
||||
|
||||
Returns:
|
||||
An LCEL Runnable. The runnable input must take in `input`, and if there
|
||||
is chat history should take it in the form of `chat_history`.
|
||||
The Runnable output is a list of Documents
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
# pip install -U langchain langchain-community
|
||||
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain.chains import create_chat_history_retriever
|
||||
from langchain import hub
|
||||
|
||||
rephrase_prompt = hub.pull("langchain-ai/chat-langchain-rephrase")
|
||||
llm = ChatOpenAI()
|
||||
retriever = ...
|
||||
chat_retriever_chain = create_chat_retriever_chain(
|
||||
llm, retriever, rephrase_prompt
|
||||
)
|
||||
|
||||
chain.invoke({"input": "...", "chat_history": })
|
||||
|
||||
"""
|
||||
if "input" not in prompt.input_variables:
|
||||
raise ValueError(
|
||||
"Expected `input` to be a prompt variable, "
|
||||
f"but got {prompt.input_variables}"
|
||||
)
|
||||
|
||||
retrieve_documents: RetrieverOutputLike = RunnableBranch(
|
||||
(
|
||||
# Both empty string and empty list evaluate to False
|
||||
lambda x: not x.get("chat_history", False),
|
||||
# If no chat history, then we just pass input to retriever
|
||||
(lambda x: x["input"]) | retriever,
|
||||
),
|
||||
# If chat history, then we pass inputs to LLM chain, then to retriever
|
||||
prompt | llm | StrOutputParser() | retriever,
|
||||
).with_config(run_name="chat_retriever_chain")
|
||||
return retrieve_documents
|
@ -0,0 +1,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from langchain_core.retrievers import (
|
||||
BaseRetriever,
|
||||
RetrieverOutput,
|
||||
)
|
||||
from langchain_core.runnables import Runnable, RunnablePassthrough
|
||||
|
||||
|
||||
def create_retrieval_chain(
|
||||
retriever: Union[BaseRetriever, Runnable[dict, RetrieverOutput]],
|
||||
combine_docs_chain: Runnable[Dict[str, Any], str],
|
||||
) -> Runnable:
|
||||
"""Create retrieval chain that retrieves documents and then passes them on.
|
||||
|
||||
Args:
|
||||
retriever: Retriever-like object that returns list of documents. Should
|
||||
either be a subclass of BaseRetriever or a Runnable that returns
|
||||
a list of documents. If a subclass of BaseRetriever, then it
|
||||
is expected that an `input` key be passed in - this is what
|
||||
is will be used to pass into the retriever. If this is NOT a
|
||||
subclass of BaseRetriever, then all the inputs will be passed
|
||||
into this runnable, meaning that runnable should take a dictionary
|
||||
as input.
|
||||
combine_docs_chain: Runnable that takes inputs and produces a string output.
|
||||
The inputs to this will be any original inputs to this chain, a new
|
||||
context key with the retrieved documents, and chat_history (if not present
|
||||
in the inputs) with a value of `[]` (to easily enable conversational
|
||||
retrieval.
|
||||
|
||||
Returns:
|
||||
An LCEL Runnable. The Runnable return is a dictionary containing at the very
|
||||
least a `context` and `answer` key.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
# pip install -U langchain langchain-community
|
||||
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain.chains.combine_documents import create_stuff_documents_chain
|
||||
from langchain.chains import create_retrieval_chain
|
||||
from langchain import hub
|
||||
|
||||
retrieval_qa_chat_prompt = hub.pull("langchain-ai/retrieval-qa-chat")
|
||||
llm = ChatOpenAI()
|
||||
retriever = ...
|
||||
combine_docs_chain = create_stuff_documents_chain(
|
||||
llm, retrieval_qa_chat_prompt
|
||||
)
|
||||
retrieval_chain = create_retrieval_chain(retriever, combine_docs_chain)
|
||||
|
||||
chain.invoke({"input": "..."})
|
||||
|
||||
"""
|
||||
if not isinstance(retriever, BaseRetriever):
|
||||
retrieval_docs: Runnable[dict, RetrieverOutput] = retriever
|
||||
else:
|
||||
retrieval_docs = (lambda x: x["input"]) | retriever
|
||||
|
||||
retrieval_chain = (
|
||||
RunnablePassthrough.assign(
|
||||
context=retrieval_docs.with_config(run_name="retrieve_documents"),
|
||||
chat_history=lambda x: x.get("chat_history", []),
|
||||
)
|
||||
| RunnablePassthrough.assign(answer=combine_docs_chain)
|
||||
).with_config(run_name="retrieval_chain")
|
||||
|
||||
return retrieval_chain
|
@ -0,0 +1,26 @@
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
from langchain.chains import create_history_aware_retriever
|
||||
from langchain.llms.fake import FakeListLLM
|
||||
from tests.unit_tests.retrievers.parrot_retriever import FakeParrotRetriever
|
||||
|
||||
|
||||
def test_create() -> None:
|
||||
answer = "I know the answer!"
|
||||
llm = FakeListLLM(responses=[answer])
|
||||
retriever = FakeParrotRetriever()
|
||||
question_gen_prompt = PromptTemplate.from_template("hi! {input} {chat_history}")
|
||||
chain = create_history_aware_retriever(llm, retriever, question_gen_prompt)
|
||||
expected_output = [Document(page_content="What is the answer?")]
|
||||
output = chain.invoke({"input": "What is the answer?", "chat_history": []})
|
||||
assert output == expected_output
|
||||
|
||||
output = chain.invoke({"input": "What is the answer?"})
|
||||
assert output == expected_output
|
||||
|
||||
expected_output = [Document(page_content="I know the answer!")]
|
||||
output = chain.invoke(
|
||||
{"input": "What is the answer?", "chat_history": ["hi", "hi"]}
|
||||
)
|
||||
assert output == expected_output
|
@ -0,0 +1,32 @@
|
||||
"""Test conversation chain and memory."""
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
|
||||
from langchain.chains import create_retrieval_chain
|
||||
from langchain.llms.fake import FakeListLLM
|
||||
from tests.unit_tests.retrievers.parrot_retriever import FakeParrotRetriever
|
||||
|
||||
|
||||
def test_create() -> None:
|
||||
answer = "I know the answer!"
|
||||
llm = FakeListLLM(responses=[answer])
|
||||
retriever = FakeParrotRetriever()
|
||||
question_gen_prompt = PromptTemplate.from_template("hi! {input} {chat_history}")
|
||||
chain = create_retrieval_chain(retriever, question_gen_prompt | llm)
|
||||
expected_output = {
|
||||
"answer": "I know the answer!",
|
||||
"chat_history": [],
|
||||
"context": [Document(page_content="What is the answer?")],
|
||||
"input": "What is the answer?",
|
||||
}
|
||||
output = chain.invoke({"input": "What is the answer?"})
|
||||
assert output == expected_output
|
||||
|
||||
expected_output = {
|
||||
"answer": "I know the answer!",
|
||||
"chat_history": "foo",
|
||||
"context": [Document(page_content="What is the answer?")],
|
||||
"input": "What is the answer?",
|
||||
}
|
||||
output = chain.invoke({"input": "What is the answer?", "chat_history": "foo"})
|
||||
assert output == expected_output
|
@ -0,0 +1,20 @@
|
||||
from typing import List
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
|
||||
class FakeParrotRetriever(BaseRetriever):
|
||||
"""Test util that parrots the query back as documents."""
|
||||
|
||||
def _get_relevant_documents( # type: ignore[override]
|
||||
self,
|
||||
query: str,
|
||||
) -> List[Document]:
|
||||
return [Document(page_content=query)]
|
||||
|
||||
async def _aget_relevant_documents( # type: ignore[override]
|
||||
self,
|
||||
query: str,
|
||||
) -> List[Document]:
|
||||
return [Document(page_content=query)]
|
Loading…
Reference in New Issue