mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
langchain[patch]: deprecate various chains (#25310)
- [x] NatbotChain: move to community, deprecate langchain version. Update to use `prompt | llm | output_parser` instead of LLMChain. - [x] LLMMathChain: deprecate + add langgraph replacement example to API ref - [x] HypotheticalDocumentEmbedder (retriever): update to use `prompt | llm | output_parser` instead of LLMChain - [x] FlareChain: update to use `prompt | llm | output_parser` instead of LLMChain - [x] ConstitutionalChain: deprecate + add langgraph replacement example to API ref - [x] LLMChainExtractor (document compressor): update to use `prompt | llm | output_parser` instead of LLMChain - [x] LLMChainFilter (document compressor): update to use `prompt | llm | output_parser` instead of LLMChain - [x] RePhraseQueryRetriever (retriever): update to use `prompt | llm | output_parser` instead of LLMChain
This commit is contained in:
parent
66e30efa61
commit
8afbab4cf6
332
docs/docs/versions/migrating_chains/constitutional_chain.ipynb
Normal file
332
docs/docs/versions/migrating_chains/constitutional_chain.ipynb
Normal file
@ -0,0 +1,332 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b57124cc-60a0-4c18-b7ce-3e483d1024a2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"title: Migrating from ConstitutionalChain\n",
|
||||
"---"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ce8457ed-c0b1-4a74-abbd-9d3d2211270f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"[ConstitutionalChain](https://api.python.langchain.com/en/latest/chains/langchain.chains.constitutional_ai.base.ConstitutionalChain.html) allowed for a LLM to critique and revise generations based on [principles](https://api.python.langchain.com/en/latest/chains/langchain.chains.constitutional_ai.models.ConstitutionalPrinciple.html), structured as combinations of critique and revision requests. For example, a principle might include a request to identify harmful content, and a request to rewrite the content.\n",
|
||||
"\n",
|
||||
"In `ConstitutionalChain`, this structure of critique requests and associated revisions was formatted into a LLM prompt and parsed out of string responses. This is more naturally achieved via [structured output](/docs/how_to/structured_output/) features of chat models. We can construct a simple chain in [LangGraph](https://langchain-ai.github.io/langgraph/) for this purpose. Some advantages of this approach include:\n",
|
||||
"\n",
|
||||
"- Leverage tool-calling capabilities of chat models that have been fine-tuned for this purpose;\n",
|
||||
"- Reduce parsing errors from extracting expression from a string LLM response;\n",
|
||||
"- Delegation of instructions to [message roles](/docs/concepts/#messages) (e.g., chat models can understand what a `ToolMessage` represents without the need for additional prompting);\n",
|
||||
"- Support for streaming, both of individual tokens and chain steps."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b99b47ec",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install --upgrade --quiet langchain-openai"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "717c8673",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from getpass import getpass\n",
|
||||
"\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = getpass()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e3621b62-a037-42b8-8faa-59575608bb8b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Legacy\n",
|
||||
"\n",
|
||||
"<details open>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "f91c9809-8ee7-4e38-881d-0ace4f6ea883",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chains import ConstitutionalChain, LLMChain\n",
|
||||
"from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple\n",
|
||||
"from langchain_core.prompts import PromptTemplate\n",
|
||||
"from langchain_openai import OpenAI\n",
|
||||
"\n",
|
||||
"llm = OpenAI()\n",
|
||||
"\n",
|
||||
"qa_prompt = PromptTemplate(\n",
|
||||
" template=\"Q: {question} A:\",\n",
|
||||
" input_variables=[\"question\"],\n",
|
||||
")\n",
|
||||
"qa_chain = LLMChain(llm=llm, prompt=qa_prompt)\n",
|
||||
"\n",
|
||||
"constitutional_chain = ConstitutionalChain.from_llm(\n",
|
||||
" llm=llm,\n",
|
||||
" chain=qa_chain,\n",
|
||||
" constitutional_principles=[\n",
|
||||
" ConstitutionalPrinciple(\n",
|
||||
" critique_request=\"Tell if this answer is good.\",\n",
|
||||
" revision_request=\"Give a better answer.\",\n",
|
||||
" )\n",
|
||||
" ],\n",
|
||||
" return_intermediate_steps=True,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"result = constitutional_chain.invoke(\"What is the meaning of life?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "fa3d11a1-ac1f-4a9a-9ab3-b7b244daa506",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'question': 'What is the meaning of life?',\n",
|
||||
" 'output': 'The meaning of life is a deeply personal and ever-evolving concept. It is a journey of self-discovery and growth, and can be different for each individual. Some may find meaning in relationships, others in achieving their goals, and some may never find a concrete answer. Ultimately, the meaning of life is what we make of it.',\n",
|
||||
" 'initial_output': ' The meaning of life is a subjective concept that can vary from person to person. Some may believe that the purpose of life is to find happiness and fulfillment, while others may see it as a journey of self-discovery and personal growth. Ultimately, the meaning of life is something that each individual must determine for themselves.',\n",
|
||||
" 'critiques_and_revisions': [('This answer is good in that it recognizes and acknowledges the subjective nature of the question and provides a valid and thoughtful response. However, it could have also mentioned that the meaning of life is a complex and deeply personal concept that can also change and evolve over time for each individual. Critique Needed.',\n",
|
||||
" 'The meaning of life is a deeply personal and ever-evolving concept. It is a journey of self-discovery and growth, and can be different for each individual. Some may find meaning in relationships, others in achieving their goals, and some may never find a concrete answer. Ultimately, the meaning of life is what we make of it.')]}"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"result"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "374ae108-f1a0-4723-9237-5259c8123c04",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Above, we've returned intermediate steps showing:\n",
|
||||
"\n",
|
||||
"- The original question;\n",
|
||||
"- The initial output;\n",
|
||||
"- Critiques and revisions;\n",
|
||||
"- The final output (matching a revision)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "cdc3b527-c09e-4c77-9711-c3cc4506cd95",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"</details>\n",
|
||||
"\n",
|
||||
"## LangGraph\n",
|
||||
"\n",
|
||||
"<details open>\n",
|
||||
"\n",
|
||||
"Below, we use the [.with_structured_output](/docs/how_to/structured_output/) method to simultaneously generate (1) a judgment of whether a critique is needed, and (2) the critique. We surface all prompts involved for clarity and ease of customizability.\n",
|
||||
"\n",
|
||||
"Note that we are also able to stream intermediate steps with this implementation, so we can monitor and if needed intervene during its execution."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "917fdb73-2411-4fcc-9add-c32dc5c745da",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import List, Optional, Tuple\n",
|
||||
"\n",
|
||||
"from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple\n",
|
||||
"from langchain.chains.constitutional_ai.prompts import (\n",
|
||||
" CRITIQUE_PROMPT,\n",
|
||||
" REVISION_PROMPT,\n",
|
||||
")\n",
|
||||
"from langchain_core.output_parsers import StrOutputParser\n",
|
||||
"from langchain_core.prompts import ChatPromptTemplate\n",
|
||||
"from langchain_openai import ChatOpenAI\n",
|
||||
"from langgraph.graph import END, START, StateGraph\n",
|
||||
"from typing_extensions import Annotated, TypedDict\n",
|
||||
"\n",
|
||||
"llm = ChatOpenAI(model=\"gpt-4o-mini\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class Critique(TypedDict):\n",
|
||||
" \"\"\"Generate a critique, if needed.\"\"\"\n",
|
||||
"\n",
|
||||
" critique_needed: Annotated[bool, ..., \"Whether or not a critique is needed.\"]\n",
|
||||
" critique: Annotated[str, ..., \"If needed, the critique.\"]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"critique_prompt = ChatPromptTemplate.from_template(\n",
|
||||
" \"Critique this response according to the critique request. \"\n",
|
||||
" \"If no critique is needed, specify that.\\n\\n\"\n",
|
||||
" \"Query: {query}\\n\\n\"\n",
|
||||
" \"Response: {response}\\n\\n\"\n",
|
||||
" \"Critique request: {critique_request}\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"revision_prompt = ChatPromptTemplate.from_template(\n",
|
||||
" \"Revise this response according to the critique and reivsion request.\\n\\n\"\n",
|
||||
" \"Query: {query}\\n\\n\"\n",
|
||||
" \"Response: {response}\\n\\n\"\n",
|
||||
" \"Critique request: {critique_request}\\n\\n\"\n",
|
||||
" \"Critique: {critique}\\n\\n\"\n",
|
||||
" \"If the critique does not identify anything worth changing, ignore the \"\n",
|
||||
" \"revision request and return 'No revisions needed'. If the critique \"\n",
|
||||
" \"does identify something worth changing, revise the response based on \"\n",
|
||||
" \"the revision request.\\n\\n\"\n",
|
||||
" \"Revision Request: {revision_request}\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"chain = llm | StrOutputParser()\n",
|
||||
"critique_chain = critique_prompt | llm.with_structured_output(Critique)\n",
|
||||
"revision_chain = revision_prompt | llm | StrOutputParser()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class State(TypedDict):\n",
|
||||
" query: str\n",
|
||||
" constitutional_principles: List[ConstitutionalPrinciple]\n",
|
||||
" initial_response: str\n",
|
||||
" critiques_and_revisions: List[Tuple[str, str]]\n",
|
||||
" response: str\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"async def generate_response(state: State):\n",
|
||||
" \"\"\"Generate initial response.\"\"\"\n",
|
||||
" response = await chain.ainvoke(state[\"query\"])\n",
|
||||
" return {\"response\": response, \"initial_response\": response}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"async def critique_and_revise(state: State):\n",
|
||||
" \"\"\"Critique and revise response according to principles.\"\"\"\n",
|
||||
" critiques_and_revisions = []\n",
|
||||
" response = state[\"initial_response\"]\n",
|
||||
" for principle in state[\"constitutional_principles\"]:\n",
|
||||
" critique = await critique_chain.ainvoke(\n",
|
||||
" {\n",
|
||||
" \"query\": state[\"query\"],\n",
|
||||
" \"response\": response,\n",
|
||||
" \"critique_request\": principle.critique_request,\n",
|
||||
" }\n",
|
||||
" )\n",
|
||||
" if critique[\"critique_needed\"]:\n",
|
||||
" revision = await revision_chain.ainvoke(\n",
|
||||
" {\n",
|
||||
" \"query\": state[\"query\"],\n",
|
||||
" \"response\": response,\n",
|
||||
" \"critique_request\": principle.critique_request,\n",
|
||||
" \"critique\": critique[\"critique\"],\n",
|
||||
" \"revision_request\": principle.revision_request,\n",
|
||||
" }\n",
|
||||
" )\n",
|
||||
" response = revision\n",
|
||||
" critiques_and_revisions.append((critique[\"critique\"], revision))\n",
|
||||
" else:\n",
|
||||
" critiques_and_revisions.append((critique[\"critique\"], \"\"))\n",
|
||||
" return {\n",
|
||||
" \"critiques_and_revisions\": critiques_and_revisions,\n",
|
||||
" \"response\": response,\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"graph = StateGraph(State)\n",
|
||||
"graph.add_node(\"generate_response\", generate_response)\n",
|
||||
"graph.add_node(\"critique_and_revise\", critique_and_revise)\n",
|
||||
"\n",
|
||||
"graph.add_edge(START, \"generate_response\")\n",
|
||||
"graph.add_edge(\"generate_response\", \"critique_and_revise\")\n",
|
||||
"graph.add_edge(\"critique_and_revise\", END)\n",
|
||||
"app = graph.compile()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "01aac88d-464e-431f-b92e-746dcb743e1b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{}\n",
|
||||
"{'initial_response': 'Finding purpose, connection, and joy in our experiences and relationships.', 'response': 'Finding purpose, connection, and joy in our experiences and relationships.'}\n",
|
||||
"{'initial_response': 'Finding purpose, connection, and joy in our experiences and relationships.', 'critiques_and_revisions': [(\"The response exceeds the 10-word limit, providing a more elaborate answer than requested. A concise response, such as 'To seek purpose and joy in life,' would better align with the query.\", 'To seek purpose and joy in life.')], 'response': 'To seek purpose and joy in life.'}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"constitutional_principles = [\n",
|
||||
" ConstitutionalPrinciple(\n",
|
||||
" critique_request=\"Tell if this answer is good.\",\n",
|
||||
" revision_request=\"Give a better answer.\",\n",
|
||||
" )\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"query = \"What is the meaning of life? Answer in 10 words or fewer.\"\n",
|
||||
"\n",
|
||||
"async for step in app.astream(\n",
|
||||
" {\"query\": query, \"constitutional_principles\": constitutional_principles},\n",
|
||||
" stream_mode=\"values\",\n",
|
||||
"):\n",
|
||||
" subset = [\"initial_response\", \"critiques_and_revisions\", \"response\"]\n",
|
||||
" print({k: v for k, v in step.items() if k in subset})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b2717810",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"</details>\n",
|
||||
"\n",
|
||||
"## Next steps\n",
|
||||
"\n",
|
||||
"See guides for generating structured output [here](/docs/how_to/structured_output/).\n",
|
||||
"\n",
|
||||
"Check out the [LangGraph documentation](https://langchain-ai.github.io/langgraph/) for detail on building with LangGraph."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -45,5 +45,7 @@ The below pages assist with migration from various specific chains to LCEL and L
|
||||
- [RefineDocumentsChain](/docs/versions/migrating_chains/refine_docs_chain)
|
||||
- [LLMRouterChain](/docs/versions/migrating_chains/llm_router_chain)
|
||||
- [MultiPromptChain](/docs/versions/migrating_chains/multi_prompt_chain)
|
||||
- [LLMMathChain](/docs/versions/migrating_chains/llm_math_chain)
|
||||
- [ConstitutionalChain](/docs/versions/migrating_chains/constitutional_chain)
|
||||
|
||||
Check out the [LCEL conceptual docs](/docs/concepts/#langchain-expression-language-lcel) and [LangGraph docs](https://langchain-ai.github.io/langgraph/) for more background information.
|
281
docs/docs/versions/migrating_chains/llm_math_chain.ipynb
Normal file
281
docs/docs/versions/migrating_chains/llm_math_chain.ipynb
Normal file
File diff suppressed because one or more lines are too long
@ -0,0 +1,8 @@
|
||||
"""Implement a GPT-3 driven browser.
|
||||
|
||||
Heavily influenced from https://github.com/nat/natbot
|
||||
"""
|
||||
|
||||
from langchain_community.chains.natbot.base import NatBotChain
|
||||
|
||||
__all__ = ["NatBotChain"]
|
3
libs/community/langchain_community/chains/natbot/base.py
Normal file
3
libs/community/langchain_community/chains/natbot/base.py
Normal file
@ -0,0 +1,3 @@
|
||||
from langchain.chains import NatBotChain
|
||||
|
||||
__all__ = ["NatBotChain"]
|
@ -0,0 +1,7 @@
|
||||
from langchain.chains.natbot.crawler import (
|
||||
Crawler,
|
||||
ElementInViewPort,
|
||||
black_listed_elements,
|
||||
)
|
||||
|
||||
__all__ = ["ElementInViewPort", "Crawler", "black_listed_elements"]
|
@ -0,0 +1,3 @@
|
||||
from langchain.chains.natbot.prompt import PROMPT
|
||||
|
||||
__all__ = ["PROMPT"]
|
@ -6,14 +6,6 @@ from langchain_core.documents import Document
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
|
||||
|
||||
def test_llm_construction_with_kwargs() -> None:
|
||||
llm_chain_kwargs = {"verbose": True}
|
||||
compressor = LLMChainExtractor.from_llm(
|
||||
ChatOpenAI(), llm_chain_kwargs=llm_chain_kwargs
|
||||
)
|
||||
assert compressor.llm_chain.verbose is True
|
||||
|
||||
|
||||
def test_llm_chain_extractor() -> None:
|
||||
texts = [
|
||||
"The Roman Empire followed the Roman Republic.",
|
||||
|
@ -2,11 +2,10 @@
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.chains.natbot.base import NatBotChain
|
||||
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
|
||||
from langchain.chains.natbot.base import NatBotChain
|
||||
|
||||
|
||||
class FakeLLM(LLM):
|
||||
"""Fake LLM wrapper for testing purposes."""
|
@ -2,6 +2,7 @@
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
@ -13,9 +14,151 @@ from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION
|
||||
from langchain.chains.llm import LLMChain
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.2.13",
|
||||
message=(
|
||||
"This class is deprecated and will be removed in langchain 1.0. "
|
||||
"See API reference for replacement: "
|
||||
"https://api.python.langchain.com/en/latest/chains/langchain.chains.constitutional_ai.base.ConstitutionalChain.html" # noqa: E501
|
||||
),
|
||||
removal="1.0",
|
||||
)
|
||||
class ConstitutionalChain(Chain):
|
||||
"""Chain for applying constitutional principles.
|
||||
|
||||
Note: this class is deprecated. See below for a replacement implementation
|
||||
using LangGraph. The benefits of this implementation are:
|
||||
|
||||
- Uses LLM tool calling features instead of parsing string responses;
|
||||
- Support for both token-by-token and step-by-step streaming;
|
||||
- Support for checkpointing and memory of chat history;
|
||||
- Easier to modify or extend (e.g., with additional tools, structured responses, etc.)
|
||||
|
||||
Install LangGraph with:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U langgraph
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from langchain.chains.constitutional_ai.prompts import (
|
||||
CRITIQUE_PROMPT,
|
||||
REVISION_PROMPT,
|
||||
)
|
||||
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from typing_extensions import Annotated, TypedDict
|
||||
|
||||
llm = ChatOpenAI(model="gpt-4o-mini")
|
||||
|
||||
class Critique(TypedDict):
|
||||
\"\"\"Generate a critique, if needed.\"\"\"
|
||||
critique_needed: Annotated[bool, ..., "Whether or not a critique is needed."]
|
||||
critique: Annotated[str, ..., "If needed, the critique."]
|
||||
|
||||
critique_prompt = ChatPromptTemplate.from_template(
|
||||
"Critique this response according to the critique request. "
|
||||
"If no critique is needed, specify that.\\n\\n"
|
||||
"Query: {query}\\n\\n"
|
||||
"Response: {response}\\n\\n"
|
||||
"Critique request: {critique_request}"
|
||||
)
|
||||
|
||||
revision_prompt = ChatPromptTemplate.from_template(
|
||||
"Revise this response according to the critique and reivsion request.\\n\\n"
|
||||
"Query: {query}\\n\\n"
|
||||
"Response: {response}\\n\\n"
|
||||
"Critique request: {critique_request}\\n\\n"
|
||||
"Critique: {critique}\\n\\n"
|
||||
"If the critique does not identify anything worth changing, ignore the "
|
||||
"revision request and return 'No revisions needed'. If the critique "
|
||||
"does identify something worth changing, revise the response based on "
|
||||
"the revision request.\\n\\n"
|
||||
"Revision Request: {revision_request}"
|
||||
)
|
||||
|
||||
chain = llm | StrOutputParser()
|
||||
critique_chain = critique_prompt | llm.with_structured_output(Critique)
|
||||
revision_chain = revision_prompt | llm | StrOutputParser()
|
||||
|
||||
|
||||
class State(TypedDict):
|
||||
query: str
|
||||
constitutional_principles: List[ConstitutionalPrinciple]
|
||||
initial_response: str
|
||||
critiques_and_revisions: List[Tuple[str, str]]
|
||||
response: str
|
||||
|
||||
|
||||
async def generate_response(state: State):
|
||||
\"\"\"Generate initial response.\"\"\"
|
||||
response = await chain.ainvoke(state["query"])
|
||||
return {"response": response, "initial_response": response}
|
||||
|
||||
async def critique_and_revise(state: State):
|
||||
\"\"\"Critique and revise response according to principles.\"\"\"
|
||||
critiques_and_revisions = []
|
||||
response = state["initial_response"]
|
||||
for principle in state["constitutional_principles"]:
|
||||
critique = await critique_chain.ainvoke(
|
||||
{
|
||||
"query": state["query"],
|
||||
"response": response,
|
||||
"critique_request": principle.critique_request,
|
||||
}
|
||||
)
|
||||
if critique["critique_needed"]:
|
||||
revision = await revision_chain.ainvoke(
|
||||
{
|
||||
"query": state["query"],
|
||||
"response": response,
|
||||
"critique_request": principle.critique_request,
|
||||
"critique": critique["critique"],
|
||||
"revision_request": principle.revision_request,
|
||||
}
|
||||
)
|
||||
response = revision
|
||||
critiques_and_revisions.append((critique["critique"], revision))
|
||||
else:
|
||||
critiques_and_revisions.append((critique["critique"], ""))
|
||||
return {
|
||||
"critiques_and_revisions": critiques_and_revisions,
|
||||
"response": response,
|
||||
}
|
||||
|
||||
graph = StateGraph(State)
|
||||
graph.add_node("generate_response", generate_response)
|
||||
graph.add_node("critique_and_revise", critique_and_revise)
|
||||
|
||||
graph.add_edge(START, "generate_response")
|
||||
graph.add_edge("generate_response", "critique_and_revise")
|
||||
graph.add_edge("critique_and_revise", END)
|
||||
app = graph.compile()
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
constitutional_principles=[
|
||||
ConstitutionalPrinciple(
|
||||
critique_request="Tell if this answer is good.",
|
||||
revision_request="Give a better answer.",
|
||||
)
|
||||
]
|
||||
|
||||
query = "What is the meaning of life? Answer in 10 words or fewer."
|
||||
|
||||
async for step in app.astream(
|
||||
{"query": query, "constitutional_principles": constitutional_principles},
|
||||
stream_mode="values",
|
||||
):
|
||||
subset = ["initial_response", "critiques_and_revisions", "response"]
|
||||
print({k: v for k, v in step.items() if k in subset})
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
@ -44,7 +187,7 @@ class ConstitutionalChain(Chain):
|
||||
)
|
||||
|
||||
constitutional_chain.run(question="What is the meaning of life?")
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
chain: LLMChain
|
||||
constitutional_principles: List[ConstitutionalPrinciple]
|
||||
|
@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
@ -9,10 +8,12 @@ from langchain_core.callbacks import (
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.outputs import Generation
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.runnables import Runnable
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.flare.prompts import (
|
||||
@ -23,51 +24,14 @@ from langchain.chains.flare.prompts import (
|
||||
from langchain.chains.llm import LLMChain
|
||||
|
||||
|
||||
class _ResponseChain(LLMChain):
|
||||
"""Base class for chains that generate responses."""
|
||||
|
||||
prompt: BasePromptTemplate = PROMPT
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
return self.prompt.input_variables
|
||||
|
||||
def generate_tokens_and_log_probs(
|
||||
self,
|
||||
_input: Dict[str, Any],
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Tuple[Sequence[str], Sequence[float]]:
|
||||
llm_result = self.generate([_input], run_manager=run_manager)
|
||||
return self._extract_tokens_and_log_probs(llm_result.generations[0])
|
||||
|
||||
@abstractmethod
|
||||
def _extract_tokens_and_log_probs(
|
||||
self, generations: List[Generation]
|
||||
) -> Tuple[Sequence[str], Sequence[float]]:
|
||||
"""Extract tokens and log probs from response."""
|
||||
|
||||
|
||||
class _OpenAIResponseChain(_ResponseChain):
|
||||
"""Chain that generates responses from user input and context."""
|
||||
|
||||
llm: BaseLanguageModel
|
||||
|
||||
def _extract_tokens_and_log_probs(
|
||||
self, generations: List[Generation]
|
||||
) -> Tuple[Sequence[str], Sequence[float]]:
|
||||
tokens = []
|
||||
log_probs = []
|
||||
for gen in generations:
|
||||
if gen.generation_info is None:
|
||||
raise ValueError
|
||||
tokens.extend(gen.generation_info["logprobs"]["tokens"])
|
||||
log_probs.extend(gen.generation_info["logprobs"]["token_logprobs"])
|
||||
return tokens, log_probs
|
||||
def _extract_tokens_and_log_probs(response: AIMessage) -> Tuple[List[str], List[float]]:
|
||||
"""Extract tokens and log probabilities from chat model response."""
|
||||
tokens = []
|
||||
log_probs = []
|
||||
for token in response.response_metadata["logprobs"]["content"]:
|
||||
tokens.append(token["token"])
|
||||
log_probs.append(token["logprob"])
|
||||
return tokens, log_probs
|
||||
|
||||
|
||||
class QuestionGeneratorChain(LLMChain):
|
||||
@ -111,9 +75,9 @@ class FlareChain(Chain):
|
||||
"""Chain that combines a retriever, a question generator,
|
||||
and a response generator."""
|
||||
|
||||
question_generator_chain: QuestionGeneratorChain
|
||||
question_generator_chain: Runnable
|
||||
"""Chain that generates questions from uncertain spans."""
|
||||
response_chain: _ResponseChain
|
||||
response_chain: Runnable
|
||||
"""Chain that generates responses from user input and context."""
|
||||
output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser)
|
||||
"""Parser that determines whether the chain is finished."""
|
||||
@ -152,12 +116,16 @@ class FlareChain(Chain):
|
||||
for question in questions:
|
||||
docs.extend(self.retriever.invoke(question))
|
||||
context = "\n\n".join(d.page_content for d in docs)
|
||||
result = self.response_chain.predict(
|
||||
user_input=user_input,
|
||||
context=context,
|
||||
response=response,
|
||||
callbacks=callbacks,
|
||||
result = self.response_chain.invoke(
|
||||
{
|
||||
"user_input": user_input,
|
||||
"context": context,
|
||||
"response": response,
|
||||
},
|
||||
{"callbacks": callbacks},
|
||||
)
|
||||
if isinstance(result, AIMessage):
|
||||
result = result.content
|
||||
marginal, finished = self.output_parser.parse(result)
|
||||
return marginal, finished
|
||||
|
||||
@ -178,13 +146,18 @@ class FlareChain(Chain):
|
||||
for span in low_confidence_spans
|
||||
]
|
||||
callbacks = _run_manager.get_child()
|
||||
question_gen_outputs = self.question_generator_chain.apply(
|
||||
question_gen_inputs, callbacks=callbacks
|
||||
)
|
||||
questions = [
|
||||
output[self.question_generator_chain.output_keys[0]]
|
||||
for output in question_gen_outputs
|
||||
]
|
||||
if isinstance(self.question_generator_chain, LLMChain):
|
||||
question_gen_outputs = self.question_generator_chain.apply(
|
||||
question_gen_inputs, callbacks=callbacks
|
||||
)
|
||||
questions = [
|
||||
output[self.question_generator_chain.output_keys[0]]
|
||||
for output in question_gen_outputs
|
||||
]
|
||||
else:
|
||||
questions = self.question_generator_chain.batch(
|
||||
question_gen_inputs, config={"callbacks": callbacks}
|
||||
)
|
||||
_run_manager.on_text(
|
||||
f"Generated Questions: {questions}", color="yellow", end="\n"
|
||||
)
|
||||
@ -206,8 +179,10 @@ class FlareChain(Chain):
|
||||
f"Current Response: {response}", color="blue", end="\n"
|
||||
)
|
||||
_input = {"user_input": user_input, "context": "", "response": response}
|
||||
tokens, log_probs = self.response_chain.generate_tokens_and_log_probs(
|
||||
_input, run_manager=_run_manager
|
||||
tokens, log_probs = _extract_tokens_and_log_probs(
|
||||
self.response_chain.invoke(
|
||||
_input, {"callbacks": _run_manager.get_child()}
|
||||
)
|
||||
)
|
||||
low_confidence_spans = _low_confidence_spans(
|
||||
tokens,
|
||||
@ -251,18 +226,16 @@ class FlareChain(Chain):
|
||||
FlareChain class with the given language model.
|
||||
"""
|
||||
try:
|
||||
from langchain_openai import OpenAI
|
||||
from langchain_openai import ChatOpenAI
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"OpenAI is required for FlareChain. "
|
||||
"Please install langchain-openai."
|
||||
"pip install langchain-openai"
|
||||
)
|
||||
question_gen_chain = QuestionGeneratorChain(llm=llm)
|
||||
response_llm = OpenAI(
|
||||
max_tokens=max_generation_len, model_kwargs={"logprobs": 1}, temperature=0
|
||||
)
|
||||
response_chain = _OpenAIResponseChain(llm=response_llm)
|
||||
llm = ChatOpenAI(max_tokens=max_generation_len, logprobs=True, temperature=0)
|
||||
response_chain = PROMPT | llm
|
||||
question_gen_chain = QUESTION_GENERATOR_PROMPT | llm | StrOutputParser()
|
||||
return cls(
|
||||
question_generator_chain=question_gen_chain,
|
||||
response_chain=response_chain,
|
||||
|
@ -11,7 +11,9 @@ import numpy as np
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.runnables import Runnable
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.hyde.prompts import PROMPT_MAP
|
||||
@ -25,7 +27,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
||||
"""
|
||||
|
||||
base_embeddings: Embeddings
|
||||
llm_chain: LLMChain
|
||||
llm_chain: Runnable
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
@ -34,12 +36,15 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Input keys for Hyde's LLM chain."""
|
||||
return self.llm_chain.input_keys
|
||||
return self.llm_chain.input_schema.schema()["required"]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Output keys for Hyde's LLM chain."""
|
||||
return self.llm_chain.output_keys
|
||||
if isinstance(self.llm_chain, LLMChain):
|
||||
return self.llm_chain.output_keys
|
||||
else:
|
||||
return ["text"]
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call the base embeddings."""
|
||||
@ -51,9 +56,12 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Generate a hypothetical document and embedded it."""
|
||||
var_name = self.llm_chain.input_keys[0]
|
||||
result = self.llm_chain.generate([{var_name: text}])
|
||||
documents = [generation.text for generation in result.generations[0]]
|
||||
var_name = self.input_keys[0]
|
||||
result = self.llm_chain.invoke({var_name: text})
|
||||
if isinstance(self.llm_chain, LLMChain):
|
||||
documents = [result[self.output_keys[0]]]
|
||||
else:
|
||||
documents = [result]
|
||||
embeddings = self.embed_documents(documents)
|
||||
return self.combine_embeddings(embeddings)
|
||||
|
||||
@ -64,7 +72,9 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
||||
) -> Dict[str, str]:
|
||||
"""Call the internal llm chain."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
return self.llm_chain(inputs, callbacks=_run_manager.get_child())
|
||||
return self.llm_chain.invoke(
|
||||
inputs, config={"callbacks": _run_manager.get_child()}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
@ -86,7 +96,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
||||
f"of {list(PROMPT_MAP.keys())}."
|
||||
)
|
||||
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
llm_chain = prompt | llm | StrOutputParser()
|
||||
return cls(base_embeddings=base_embeddings, llm_chain=llm_chain, **kwargs)
|
||||
|
||||
@property
|
||||
|
@ -7,6 +7,7 @@ import re
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
@ -20,16 +21,132 @@ from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_math.prompt import PROMPT
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.2.13",
|
||||
message=(
|
||||
"This class is deprecated and will be removed in langchain 1.0. "
|
||||
"See API reference for replacement: "
|
||||
"https://api.python.langchain.com/en/latest/chains/langchain.chains.llm_math.base.LLMMathChain.html" # noqa: E501
|
||||
),
|
||||
removal="1.0",
|
||||
)
|
||||
class LLMMathChain(Chain):
|
||||
"""Chain that interprets a prompt and executes python code to do math.
|
||||
|
||||
Note: this class is deprecated. See below for a replacement implementation
|
||||
using LangGraph. The benefits of this implementation are:
|
||||
|
||||
- Uses LLM tool calling features;
|
||||
- Support for both token-by-token and step-by-step streaming;
|
||||
- Support for checkpointing and memory of chat history;
|
||||
- Easier to modify or extend (e.g., with additional tools, structured responses, etc.)
|
||||
|
||||
Install LangGraph with:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U langgraph
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import math
|
||||
from typing import Annotated, Sequence
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.tools import tool
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.graph import END, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.prebuilt.tool_node import ToolNode
|
||||
import numexpr
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
@tool
|
||||
def calculator(expression: str) -> str:
|
||||
\"\"\"Calculate expression using Python's numexpr library.
|
||||
|
||||
Expression should be a single line mathematical expression
|
||||
that solves the problem.
|
||||
|
||||
Examples:
|
||||
"37593 * 67" for "37593 times 67"
|
||||
"37593**(1/5)" for "37593^(1/5)"
|
||||
\"\"\"
|
||||
local_dict = {"pi": math.pi, "e": math.e}
|
||||
return str(
|
||||
numexpr.evaluate(
|
||||
expression.strip(),
|
||||
global_dict={}, # restrict access to globals
|
||||
local_dict=local_dict, # add common mathematical functions
|
||||
)
|
||||
)
|
||||
|
||||
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
|
||||
tools = [calculator]
|
||||
llm_with_tools = llm.bind_tools(tools, tool_choice="any")
|
||||
|
||||
class ChainState(TypedDict):
|
||||
\"\"\"LangGraph state.\"\"\"
|
||||
|
||||
messages: Annotated[Sequence[BaseMessage], add_messages]
|
||||
|
||||
async def acall_chain(state: ChainState, config: RunnableConfig):
|
||||
last_message = state["messages"][-1]
|
||||
response = await llm_with_tools.ainvoke(state["messages"], config)
|
||||
return {"messages": [response]}
|
||||
|
||||
async def acall_model(state: ChainState, config: RunnableConfig):
|
||||
response = await llm.ainvoke(state["messages"], config)
|
||||
return {"messages": [response]}
|
||||
|
||||
graph_builder = StateGraph(ChainState)
|
||||
graph_builder.add_node("call_tool", acall_chain)
|
||||
graph_builder.add_node("execute_tool", ToolNode(tools))
|
||||
graph_builder.add_node("call_model", acall_model)
|
||||
graph_builder.set_entry_point("call_tool")
|
||||
graph_builder.add_edge("call_tool", "execute_tool")
|
||||
graph_builder.add_edge("execute_tool", "call_model")
|
||||
graph_builder.add_edge("call_model", END)
|
||||
chain = graph_builder.compile()
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
example_query = "What is 551368 divided by 82"
|
||||
|
||||
events = chain.astream(
|
||||
{"messages": [("user", example_query)]},
|
||||
stream_mode="values",
|
||||
)
|
||||
async for event in events:
|
||||
event["messages"][-1].pretty_print()
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
================================ Human Message =================================
|
||||
|
||||
What is 551368 divided by 82
|
||||
================================== Ai Message ==================================
|
||||
Tool Calls:
|
||||
calculator (call_MEiGXuJjJ7wGU4aOT86QuGJS)
|
||||
Call ID: call_MEiGXuJjJ7wGU4aOT86QuGJS
|
||||
Args:
|
||||
expression: 551368 / 82
|
||||
================================= Tool Message =================================
|
||||
Name: calculator
|
||||
|
||||
6724.0
|
||||
================================== Ai Message ==================================
|
||||
|
||||
551368 divided by 82 equals 6724.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chains import LLMMathChain
|
||||
from langchain_community.llms import OpenAI
|
||||
llm_math = LLMMathChain.from_llm(OpenAI())
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
llm_chain: LLMChain
|
||||
llm: Optional[BaseLanguageModel] = None
|
||||
|
@ -5,15 +5,27 @@ from __future__ import annotations
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.runnables import Runnable
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.natbot.prompt import PROMPT
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.2.13",
|
||||
message=(
|
||||
"Importing NatBotChain from langchain is deprecated and will be removed in "
|
||||
"langchain 1.0. Please import from langchain_community instead: "
|
||||
"from langchain_community.chains.natbot import NatBotChain. "
|
||||
"You may need to pip install -U langchain-community."
|
||||
),
|
||||
removal="1.0",
|
||||
)
|
||||
class NatBotChain(Chain):
|
||||
"""Implement an LLM driven browser.
|
||||
|
||||
@ -37,7 +49,7 @@ class NatBotChain(Chain):
|
||||
natbot = NatBotChain.from_default("Buy me a new hat.")
|
||||
"""
|
||||
|
||||
llm_chain: LLMChain
|
||||
llm_chain: Runnable
|
||||
objective: str
|
||||
"""Objective that NatBot is tasked with completing."""
|
||||
llm: Optional[BaseLanguageModel] = None
|
||||
@ -60,7 +72,7 @@ class NatBotChain(Chain):
|
||||
"class method."
|
||||
)
|
||||
if "llm_chain" not in values and values["llm"] is not None:
|
||||
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=PROMPT)
|
||||
values["llm_chain"] = PROMPT | values["llm"] | StrOutputParser()
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
@ -77,7 +89,7 @@ class NatBotChain(Chain):
|
||||
cls, llm: BaseLanguageModel, objective: str, **kwargs: Any
|
||||
) -> NatBotChain:
|
||||
"""Load from LLM."""
|
||||
llm_chain = LLMChain(llm=llm, prompt=PROMPT)
|
||||
llm_chain = PROMPT | llm | StrOutputParser()
|
||||
return cls(llm_chain=llm_chain, objective=objective, **kwargs)
|
||||
|
||||
@property
|
||||
@ -104,12 +116,14 @@ class NatBotChain(Chain):
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
url = inputs[self.input_url_key]
|
||||
browser_content = inputs[self.input_browser_content_key]
|
||||
llm_cmd = self.llm_chain.predict(
|
||||
objective=self.objective,
|
||||
url=url[:100],
|
||||
previous_command=self.previous_command,
|
||||
browser_content=browser_content[:4500],
|
||||
callbacks=_run_manager.get_child(),
|
||||
llm_cmd = self.llm_chain.invoke(
|
||||
{
|
||||
"objective": self.objective,
|
||||
"url": url[:100],
|
||||
"previous_command": self.previous_command,
|
||||
"browser_content": browser_content[:4500],
|
||||
},
|
||||
config={"callbacks": _run_manager.get_child()},
|
||||
)
|
||||
llm_cmd = llm_cmd.strip()
|
||||
self.previous_command = llm_cmd
|
||||
|
@ -8,8 +8,9 @@ from typing import Any, Callable, Dict, Optional, Sequence, cast
|
||||
from langchain_core.callbacks.manager import Callbacks
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.runnables import Runnable
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
||||
@ -49,12 +50,15 @@ class LLMChainExtractor(BaseDocumentCompressor):
|
||||
"""Document compressor that uses an LLM chain to extract
|
||||
the relevant parts of documents."""
|
||||
|
||||
llm_chain: LLMChain
|
||||
llm_chain: Runnable
|
||||
"""LLM wrapper to use for compressing documents."""
|
||||
|
||||
get_input: Callable[[str, Document], dict] = default_get_input
|
||||
"""Callable for constructing the chain input from the query and a Document."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def compress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
@ -65,10 +69,13 @@ class LLMChainExtractor(BaseDocumentCompressor):
|
||||
compressed_docs = []
|
||||
for doc in documents:
|
||||
_input = self.get_input(query, doc)
|
||||
output_dict = self.llm_chain.invoke(_input, config={"callbacks": callbacks})
|
||||
output = output_dict[self.llm_chain.output_key]
|
||||
if self.llm_chain.prompt.output_parser is not None:
|
||||
output = self.llm_chain.prompt.output_parser.parse(output)
|
||||
output_ = self.llm_chain.invoke(_input, config={"callbacks": callbacks})
|
||||
if isinstance(self.llm_chain, LLMChain):
|
||||
output = output_[self.llm_chain.output_key]
|
||||
if self.llm_chain.prompt.output_parser is not None:
|
||||
output = self.llm_chain.prompt.output_parser.parse(output)
|
||||
else:
|
||||
output = output_
|
||||
if len(output) == 0:
|
||||
continue
|
||||
compressed_docs.append(
|
||||
@ -85,9 +92,7 @@ class LLMChainExtractor(BaseDocumentCompressor):
|
||||
"""Compress page content of raw documents asynchronously."""
|
||||
outputs = await asyncio.gather(
|
||||
*[
|
||||
self.llm_chain.apredict_and_parse(
|
||||
**self.get_input(query, doc), callbacks=callbacks
|
||||
)
|
||||
self.llm_chain.ainvoke(self.get_input(query, doc), callbacks=callbacks)
|
||||
for doc in documents
|
||||
]
|
||||
)
|
||||
@ -111,5 +116,9 @@ class LLMChainExtractor(BaseDocumentCompressor):
|
||||
"""Initialize from LLM."""
|
||||
_prompt = prompt if prompt is not None else _get_default_chain_prompt()
|
||||
_get_input = get_input if get_input is not None else default_get_input
|
||||
llm_chain = LLMChain(llm=llm, prompt=_prompt, **(llm_chain_kwargs or {}))
|
||||
if _prompt.output_parser is not None:
|
||||
parser = _prompt.output_parser
|
||||
else:
|
||||
parser = StrOutputParser()
|
||||
llm_chain = _prompt | llm | parser
|
||||
return cls(llm_chain=llm_chain, get_input=_get_input) # type: ignore[arg-type]
|
||||
|
@ -5,7 +5,9 @@ from typing import Any, Callable, Dict, Optional, Sequence
|
||||
from langchain_core.callbacks.manager import Callbacks
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from langchain.chains import LLMChain
|
||||
@ -32,13 +34,16 @@ def default_get_input(query: str, doc: Document) -> Dict[str, Any]:
|
||||
class LLMChainFilter(BaseDocumentCompressor):
|
||||
"""Filter that drops documents that aren't relevant to the query."""
|
||||
|
||||
llm_chain: LLMChain
|
||||
llm_chain: Runnable
|
||||
"""LLM wrapper to use for filtering documents.
|
||||
The chain prompt is expected to have a BooleanOutputParser."""
|
||||
|
||||
get_input: Callable[[str, Document], dict] = default_get_input
|
||||
"""Callable for constructing the chain input from the query and a Document."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def compress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
@ -56,11 +61,15 @@ class LLMChainFilter(BaseDocumentCompressor):
|
||||
documents,
|
||||
)
|
||||
|
||||
for output_dict, doc in outputs:
|
||||
for output_, doc in outputs:
|
||||
include_doc = None
|
||||
output = output_dict[self.llm_chain.output_key]
|
||||
if self.llm_chain.prompt.output_parser is not None:
|
||||
include_doc = self.llm_chain.prompt.output_parser.parse(output)
|
||||
if isinstance(self.llm_chain, LLMChain):
|
||||
output = output_[self.llm_chain.output_key]
|
||||
if self.llm_chain.prompt.output_parser is not None:
|
||||
include_doc = self.llm_chain.prompt.output_parser.parse(output)
|
||||
else:
|
||||
if isinstance(output_, bool):
|
||||
include_doc = output_
|
||||
if include_doc:
|
||||
filtered_docs.append(doc)
|
||||
|
||||
@ -82,11 +91,15 @@ class LLMChainFilter(BaseDocumentCompressor):
|
||||
),
|
||||
documents,
|
||||
)
|
||||
for output_dict, doc in outputs:
|
||||
for output_, doc in outputs:
|
||||
include_doc = None
|
||||
output = output_dict[self.llm_chain.output_key]
|
||||
if self.llm_chain.prompt.output_parser is not None:
|
||||
include_doc = self.llm_chain.prompt.output_parser.parse(output)
|
||||
if isinstance(self.llm_chain, LLMChain):
|
||||
output = output_[self.llm_chain.output_key]
|
||||
if self.llm_chain.prompt.output_parser is not None:
|
||||
include_doc = self.llm_chain.prompt.output_parser.parse(output)
|
||||
else:
|
||||
if isinstance(output_, bool):
|
||||
include_doc = output_
|
||||
if include_doc:
|
||||
filtered_docs.append(doc)
|
||||
|
||||
@ -110,5 +123,9 @@ class LLMChainFilter(BaseDocumentCompressor):
|
||||
A LLMChainFilter that uses the given language model.
|
||||
"""
|
||||
_prompt = prompt if prompt is not None else _get_default_chain_prompt()
|
||||
llm_chain = LLMChain(llm=llm, prompt=_prompt)
|
||||
if _prompt.output_parser is not None:
|
||||
parser = _prompt.output_parser
|
||||
else:
|
||||
parser = StrOutputParser()
|
||||
llm_chain = _prompt | llm | parser
|
||||
return cls(llm_chain=llm_chain, **kwargs)
|
||||
|
@ -7,11 +7,11 @@ from langchain_core.callbacks import (
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain_core.runnables import Runnable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -30,7 +30,7 @@ class RePhraseQueryRetriever(BaseRetriever):
|
||||
Then, retrieve docs for the re-phrased query."""
|
||||
|
||||
retriever: BaseRetriever
|
||||
llm_chain: LLMChain
|
||||
llm_chain: Runnable
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
@ -51,8 +51,7 @@ class RePhraseQueryRetriever(BaseRetriever):
|
||||
Returns:
|
||||
RePhraseQueryRetriever
|
||||
"""
|
||||
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
llm_chain = prompt | llm | StrOutputParser()
|
||||
return cls(
|
||||
retriever=retriever,
|
||||
llm_chain=llm_chain,
|
||||
@ -72,8 +71,9 @@ class RePhraseQueryRetriever(BaseRetriever):
|
||||
Returns:
|
||||
Relevant documents for re-phrased question
|
||||
"""
|
||||
response = self.llm_chain(query, callbacks=run_manager.get_child())
|
||||
re_phrased_question = response["text"]
|
||||
re_phrased_question = self.llm_chain.invoke(
|
||||
query, {"callbacks": run_manager.get_child()}
|
||||
)
|
||||
logger.info(f"Re-phrased question: {re_phrased_question}")
|
||||
docs = self.retriever.invoke(
|
||||
re_phrased_question, config={"callbacks": run_manager.get_child()}
|
||||
|
@ -0,0 +1,84 @@
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import FakeListChatModel
|
||||
|
||||
from langchain.retrievers.document_compressors import LLMChainExtractor
|
||||
|
||||
|
||||
def test_llm_chain_extractor() -> None:
|
||||
documents = [
|
||||
Document(
|
||||
page_content=(
|
||||
"The sky is blue. Candlepin bowling is popular in New England."
|
||||
),
|
||||
metadata={"a": 1},
|
||||
),
|
||||
Document(
|
||||
page_content=(
|
||||
"Mercury is the closest planet to the Sun. "
|
||||
"Candlepin bowling balls are smaller."
|
||||
),
|
||||
metadata={"b": 2},
|
||||
),
|
||||
Document(page_content="The moon is round.", metadata={"c": 3}),
|
||||
]
|
||||
llm = FakeListChatModel(
|
||||
responses=[
|
||||
"Candlepin bowling is popular in New England.",
|
||||
"Candlepin bowling balls are smaller.",
|
||||
"NO_OUTPUT",
|
||||
]
|
||||
)
|
||||
doc_compressor = LLMChainExtractor.from_llm(llm)
|
||||
output = doc_compressor.compress_documents(
|
||||
documents, "Tell me about Candlepin bowling."
|
||||
)
|
||||
expected = documents = [
|
||||
Document(
|
||||
page_content="Candlepin bowling is popular in New England.",
|
||||
metadata={"a": 1},
|
||||
),
|
||||
Document(
|
||||
page_content="Candlepin bowling balls are smaller.", metadata={"b": 2}
|
||||
),
|
||||
]
|
||||
assert output == expected
|
||||
|
||||
|
||||
async def test_llm_chain_extractor_async() -> None:
|
||||
documents = [
|
||||
Document(
|
||||
page_content=(
|
||||
"The sky is blue. Candlepin bowling is popular in New England."
|
||||
),
|
||||
metadata={"a": 1},
|
||||
),
|
||||
Document(
|
||||
page_content=(
|
||||
"Mercury is the closest planet to the Sun. "
|
||||
"Candlepin bowling balls are smaller."
|
||||
),
|
||||
metadata={"b": 2},
|
||||
),
|
||||
Document(page_content="The moon is round.", metadata={"c": 3}),
|
||||
]
|
||||
llm = FakeListChatModel(
|
||||
responses=[
|
||||
"Candlepin bowling is popular in New England.",
|
||||
"Candlepin bowling balls are smaller.",
|
||||
"NO_OUTPUT",
|
||||
]
|
||||
)
|
||||
doc_compressor = LLMChainExtractor.from_llm(llm)
|
||||
output = await doc_compressor.acompress_documents(
|
||||
documents, "Tell me about Candlepin bowling."
|
||||
)
|
||||
expected = documents = [
|
||||
Document(
|
||||
page_content="Candlepin bowling is popular in New England.",
|
||||
metadata={"a": 1},
|
||||
),
|
||||
Document(
|
||||
page_content="Candlepin bowling balls are smaller.", metadata={"b": 2}
|
||||
),
|
||||
]
|
||||
assert output == expected
|
@ -0,0 +1,46 @@
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import FakeListChatModel
|
||||
|
||||
from langchain.retrievers.document_compressors import LLMChainFilter
|
||||
|
||||
|
||||
def test_llm_chain_filter() -> None:
|
||||
documents = [
|
||||
Document(
|
||||
page_content="Candlepin bowling is popular in New England.",
|
||||
metadata={"a": 1},
|
||||
),
|
||||
Document(
|
||||
page_content="Candlepin bowling balls are smaller.",
|
||||
metadata={"b": 2},
|
||||
),
|
||||
Document(page_content="The moon is round.", metadata={"c": 3}),
|
||||
]
|
||||
llm = FakeListChatModel(responses=["YES", "YES", "NO"])
|
||||
doc_compressor = LLMChainFilter.from_llm(llm)
|
||||
output = doc_compressor.compress_documents(
|
||||
documents, "Tell me about Candlepin bowling."
|
||||
)
|
||||
expected = documents[:2]
|
||||
assert output == expected
|
||||
|
||||
|
||||
async def test_llm_chain_extractor_async() -> None:
|
||||
documents = [
|
||||
Document(
|
||||
page_content="Candlepin bowling is popular in New England.",
|
||||
metadata={"a": 1},
|
||||
),
|
||||
Document(
|
||||
page_content="Candlepin bowling balls are smaller.",
|
||||
metadata={"b": 2},
|
||||
),
|
||||
Document(page_content="The moon is round.", metadata={"c": 3}),
|
||||
]
|
||||
llm = FakeListChatModel(responses=["YES", "YES", "NO"])
|
||||
doc_compressor = LLMChainFilter.from_llm(llm)
|
||||
output = await doc_compressor.acompress_documents(
|
||||
documents, "Tell me about Candlepin bowling."
|
||||
)
|
||||
expected = documents[:2]
|
||||
assert output == expected
|
Loading…
Reference in New Issue
Block a user