langchain[patch]: deprecate various chains (#25310)

- [x] NatbotChain: move to community, deprecate langchain version.
Update to use `prompt | llm | output_parser` instead of LLMChain.
- [x] LLMMathChain: deprecate + add langgraph replacement example to API
ref
- [x] HypotheticalDocumentEmbedder (retriever): update to use `prompt |
llm | output_parser` instead of LLMChain
- [x] FlareChain: update to use `prompt | llm | output_parser` instead
of LLMChain
- [x] ConstitutionalChain: deprecate + add langgraph replacement example
to API ref
- [x] LLMChainExtractor (document compressor): update to use `prompt |
llm | output_parser` instead of LLMChain
- [x] LLMChainFilter (document compressor): update to use `prompt | llm
| output_parser` instead of LLMChain
- [x] RePhraseQueryRetriever (retriever): update to use `prompt | llm |
output_parser` instead of LLMChain
This commit is contained in:
ccurme 2024-08-15 10:49:26 -04:00 committed by GitHub
parent 66e30efa61
commit 8afbab4cf6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 1166 additions and 126 deletions

View File

@ -0,0 +1,332 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "b57124cc-60a0-4c18-b7ce-3e483d1024a2",
"metadata": {},
"source": [
"---\n",
"title: Migrating from ConstitutionalChain\n",
"---"
]
},
{
"cell_type": "markdown",
"id": "ce8457ed-c0b1-4a74-abbd-9d3d2211270f",
"metadata": {},
"source": [
"[ConstitutionalChain](https://api.python.langchain.com/en/latest/chains/langchain.chains.constitutional_ai.base.ConstitutionalChain.html) allowed for a LLM to critique and revise generations based on [principles](https://api.python.langchain.com/en/latest/chains/langchain.chains.constitutional_ai.models.ConstitutionalPrinciple.html), structured as combinations of critique and revision requests. For example, a principle might include a request to identify harmful content, and a request to rewrite the content.\n",
"\n",
"In `ConstitutionalChain`, this structure of critique requests and associated revisions was formatted into a LLM prompt and parsed out of string responses. This is more naturally achieved via [structured output](/docs/how_to/structured_output/) features of chat models. We can construct a simple chain in [LangGraph](https://langchain-ai.github.io/langgraph/) for this purpose. Some advantages of this approach include:\n",
"\n",
"- Leverage tool-calling capabilities of chat models that have been fine-tuned for this purpose;\n",
"- Reduce parsing errors from extracting expression from a string LLM response;\n",
"- Delegation of instructions to [message roles](/docs/concepts/#messages) (e.g., chat models can understand what a `ToolMessage` represents without the need for additional prompting);\n",
"- Support for streaming, both of individual tokens and chain steps."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b99b47ec",
"metadata": {},
"outputs": [],
"source": [
"%pip install --upgrade --quiet langchain-openai"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "717c8673",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from getpass import getpass\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = getpass()"
]
},
{
"cell_type": "markdown",
"id": "e3621b62-a037-42b8-8faa-59575608bb8b",
"metadata": {},
"source": [
"## Legacy\n",
"\n",
"<details open>"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "f91c9809-8ee7-4e38-881d-0ace4f6ea883",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains import ConstitutionalChain, LLMChain\n",
"from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple\n",
"from langchain_core.prompts import PromptTemplate\n",
"from langchain_openai import OpenAI\n",
"\n",
"llm = OpenAI()\n",
"\n",
"qa_prompt = PromptTemplate(\n",
" template=\"Q: {question} A:\",\n",
" input_variables=[\"question\"],\n",
")\n",
"qa_chain = LLMChain(llm=llm, prompt=qa_prompt)\n",
"\n",
"constitutional_chain = ConstitutionalChain.from_llm(\n",
" llm=llm,\n",
" chain=qa_chain,\n",
" constitutional_principles=[\n",
" ConstitutionalPrinciple(\n",
" critique_request=\"Tell if this answer is good.\",\n",
" revision_request=\"Give a better answer.\",\n",
" )\n",
" ],\n",
" return_intermediate_steps=True,\n",
")\n",
"\n",
"result = constitutional_chain.invoke(\"What is the meaning of life?\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "fa3d11a1-ac1f-4a9a-9ab3-b7b244daa506",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'question': 'What is the meaning of life?',\n",
" 'output': 'The meaning of life is a deeply personal and ever-evolving concept. It is a journey of self-discovery and growth, and can be different for each individual. Some may find meaning in relationships, others in achieving their goals, and some may never find a concrete answer. Ultimately, the meaning of life is what we make of it.',\n",
" 'initial_output': ' The meaning of life is a subjective concept that can vary from person to person. Some may believe that the purpose of life is to find happiness and fulfillment, while others may see it as a journey of self-discovery and personal growth. Ultimately, the meaning of life is something that each individual must determine for themselves.',\n",
" 'critiques_and_revisions': [('This answer is good in that it recognizes and acknowledges the subjective nature of the question and provides a valid and thoughtful response. However, it could have also mentioned that the meaning of life is a complex and deeply personal concept that can also change and evolve over time for each individual. Critique Needed.',\n",
" 'The meaning of life is a deeply personal and ever-evolving concept. It is a journey of self-discovery and growth, and can be different for each individual. Some may find meaning in relationships, others in achieving their goals, and some may never find a concrete answer. Ultimately, the meaning of life is what we make of it.')]}"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result"
]
},
{
"cell_type": "markdown",
"id": "374ae108-f1a0-4723-9237-5259c8123c04",
"metadata": {},
"source": [
"Above, we've returned intermediate steps showing:\n",
"\n",
"- The original question;\n",
"- The initial output;\n",
"- Critiques and revisions;\n",
"- The final output (matching a revision)."
]
},
{
"cell_type": "markdown",
"id": "cdc3b527-c09e-4c77-9711-c3cc4506cd95",
"metadata": {},
"source": [
"</details>\n",
"\n",
"## LangGraph\n",
"\n",
"<details open>\n",
"\n",
"Below, we use the [.with_structured_output](/docs/how_to/structured_output/) method to simultaneously generate (1) a judgment of whether a critique is needed, and (2) the critique. We surface all prompts involved for clarity and ease of customizability.\n",
"\n",
"Note that we are also able to stream intermediate steps with this implementation, so we can monitor and if needed intervene during its execution."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "917fdb73-2411-4fcc-9add-c32dc5c745da",
"metadata": {},
"outputs": [],
"source": [
"from typing import List, Optional, Tuple\n",
"\n",
"from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple\n",
"from langchain.chains.constitutional_ai.prompts import (\n",
" CRITIQUE_PROMPT,\n",
" REVISION_PROMPT,\n",
")\n",
"from langchain_core.output_parsers import StrOutputParser\n",
"from langchain_core.prompts import ChatPromptTemplate\n",
"from langchain_openai import ChatOpenAI\n",
"from langgraph.graph import END, START, StateGraph\n",
"from typing_extensions import Annotated, TypedDict\n",
"\n",
"llm = ChatOpenAI(model=\"gpt-4o-mini\")\n",
"\n",
"\n",
"class Critique(TypedDict):\n",
" \"\"\"Generate a critique, if needed.\"\"\"\n",
"\n",
" critique_needed: Annotated[bool, ..., \"Whether or not a critique is needed.\"]\n",
" critique: Annotated[str, ..., \"If needed, the critique.\"]\n",
"\n",
"\n",
"critique_prompt = ChatPromptTemplate.from_template(\n",
" \"Critique this response according to the critique request. \"\n",
" \"If no critique is needed, specify that.\\n\\n\"\n",
" \"Query: {query}\\n\\n\"\n",
" \"Response: {response}\\n\\n\"\n",
" \"Critique request: {critique_request}\"\n",
")\n",
"\n",
"revision_prompt = ChatPromptTemplate.from_template(\n",
" \"Revise this response according to the critique and reivsion request.\\n\\n\"\n",
" \"Query: {query}\\n\\n\"\n",
" \"Response: {response}\\n\\n\"\n",
" \"Critique request: {critique_request}\\n\\n\"\n",
" \"Critique: {critique}\\n\\n\"\n",
" \"If the critique does not identify anything worth changing, ignore the \"\n",
" \"revision request and return 'No revisions needed'. If the critique \"\n",
" \"does identify something worth changing, revise the response based on \"\n",
" \"the revision request.\\n\\n\"\n",
" \"Revision Request: {revision_request}\"\n",
")\n",
"\n",
"chain = llm | StrOutputParser()\n",
"critique_chain = critique_prompt | llm.with_structured_output(Critique)\n",
"revision_chain = revision_prompt | llm | StrOutputParser()\n",
"\n",
"\n",
"class State(TypedDict):\n",
" query: str\n",
" constitutional_principles: List[ConstitutionalPrinciple]\n",
" initial_response: str\n",
" critiques_and_revisions: List[Tuple[str, str]]\n",
" response: str\n",
"\n",
"\n",
"async def generate_response(state: State):\n",
" \"\"\"Generate initial response.\"\"\"\n",
" response = await chain.ainvoke(state[\"query\"])\n",
" return {\"response\": response, \"initial_response\": response}\n",
"\n",
"\n",
"async def critique_and_revise(state: State):\n",
" \"\"\"Critique and revise response according to principles.\"\"\"\n",
" critiques_and_revisions = []\n",
" response = state[\"initial_response\"]\n",
" for principle in state[\"constitutional_principles\"]:\n",
" critique = await critique_chain.ainvoke(\n",
" {\n",
" \"query\": state[\"query\"],\n",
" \"response\": response,\n",
" \"critique_request\": principle.critique_request,\n",
" }\n",
" )\n",
" if critique[\"critique_needed\"]:\n",
" revision = await revision_chain.ainvoke(\n",
" {\n",
" \"query\": state[\"query\"],\n",
" \"response\": response,\n",
" \"critique_request\": principle.critique_request,\n",
" \"critique\": critique[\"critique\"],\n",
" \"revision_request\": principle.revision_request,\n",
" }\n",
" )\n",
" response = revision\n",
" critiques_and_revisions.append((critique[\"critique\"], revision))\n",
" else:\n",
" critiques_and_revisions.append((critique[\"critique\"], \"\"))\n",
" return {\n",
" \"critiques_and_revisions\": critiques_and_revisions,\n",
" \"response\": response,\n",
" }\n",
"\n",
"\n",
"graph = StateGraph(State)\n",
"graph.add_node(\"generate_response\", generate_response)\n",
"graph.add_node(\"critique_and_revise\", critique_and_revise)\n",
"\n",
"graph.add_edge(START, \"generate_response\")\n",
"graph.add_edge(\"generate_response\", \"critique_and_revise\")\n",
"graph.add_edge(\"critique_and_revise\", END)\n",
"app = graph.compile()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "01aac88d-464e-431f-b92e-746dcb743e1b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{}\n",
"{'initial_response': 'Finding purpose, connection, and joy in our experiences and relationships.', 'response': 'Finding purpose, connection, and joy in our experiences and relationships.'}\n",
"{'initial_response': 'Finding purpose, connection, and joy in our experiences and relationships.', 'critiques_and_revisions': [(\"The response exceeds the 10-word limit, providing a more elaborate answer than requested. A concise response, such as 'To seek purpose and joy in life,' would better align with the query.\", 'To seek purpose and joy in life.')], 'response': 'To seek purpose and joy in life.'}\n"
]
}
],
"source": [
"constitutional_principles = [\n",
" ConstitutionalPrinciple(\n",
" critique_request=\"Tell if this answer is good.\",\n",
" revision_request=\"Give a better answer.\",\n",
" )\n",
"]\n",
"\n",
"query = \"What is the meaning of life? Answer in 10 words or fewer.\"\n",
"\n",
"async for step in app.astream(\n",
" {\"query\": query, \"constitutional_principles\": constitutional_principles},\n",
" stream_mode=\"values\",\n",
"):\n",
" subset = [\"initial_response\", \"critiques_and_revisions\", \"response\"]\n",
" print({k: v for k, v in step.items() if k in subset})"
]
},
{
"cell_type": "markdown",
"id": "b2717810",
"metadata": {},
"source": [
"</details>\n",
"\n",
"## Next steps\n",
"\n",
"See guides for generating structured output [here](/docs/how_to/structured_output/).\n",
"\n",
"Check out the [LangGraph documentation](https://langchain-ai.github.io/langgraph/) for detail on building with LangGraph."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -45,5 +45,7 @@ The below pages assist with migration from various specific chains to LCEL and L
- [RefineDocumentsChain](/docs/versions/migrating_chains/refine_docs_chain)
- [LLMRouterChain](/docs/versions/migrating_chains/llm_router_chain)
- [MultiPromptChain](/docs/versions/migrating_chains/multi_prompt_chain)
- [LLMMathChain](/docs/versions/migrating_chains/llm_math_chain)
- [ConstitutionalChain](/docs/versions/migrating_chains/constitutional_chain)
Check out the [LCEL conceptual docs](/docs/concepts/#langchain-expression-language-lcel) and [LangGraph docs](https://langchain-ai.github.io/langgraph/) for more background information.

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,8 @@
"""Implement a GPT-3 driven browser.
Heavily influenced from https://github.com/nat/natbot
"""
from langchain_community.chains.natbot.base import NatBotChain
__all__ = ["NatBotChain"]

View File

@ -0,0 +1,3 @@
from langchain.chains import NatBotChain
__all__ = ["NatBotChain"]

View File

@ -0,0 +1,7 @@
from langchain.chains.natbot.crawler import (
Crawler,
ElementInViewPort,
black_listed_elements,
)
__all__ = ["ElementInViewPort", "Crawler", "black_listed_elements"]

View File

@ -0,0 +1,3 @@
from langchain.chains.natbot.prompt import PROMPT
__all__ = ["PROMPT"]

View File

@ -6,14 +6,6 @@ from langchain_core.documents import Document
from langchain_community.chat_models import ChatOpenAI
def test_llm_construction_with_kwargs() -> None:
llm_chain_kwargs = {"verbose": True}
compressor = LLMChainExtractor.from_llm(
ChatOpenAI(), llm_chain_kwargs=llm_chain_kwargs
)
assert compressor.llm_chain.verbose is True
def test_llm_chain_extractor() -> None:
texts = [
"The Roman Empire followed the Roman Republic.",

View File

@ -2,11 +2,10 @@
from typing import Any, Dict, List, Optional
from langchain.chains.natbot.base import NatBotChain
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain.chains.natbot.base import NatBotChain
class FakeLLM(LLM):
"""Fake LLM wrapper for testing purposes."""

View File

@ -2,6 +2,7 @@
from typing import Any, Dict, List, Optional
from langchain_core._api import deprecated
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
@ -13,9 +14,151 @@ from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION
from langchain.chains.llm import LLMChain
@deprecated(
since="0.2.13",
message=(
"This class is deprecated and will be removed in langchain 1.0. "
"See API reference for replacement: "
"https://api.python.langchain.com/en/latest/chains/langchain.chains.constitutional_ai.base.ConstitutionalChain.html" # noqa: E501
),
removal="1.0",
)
class ConstitutionalChain(Chain):
"""Chain for applying constitutional principles.
Note: this class is deprecated. See below for a replacement implementation
using LangGraph. The benefits of this implementation are:
- Uses LLM tool calling features instead of parsing string responses;
- Support for both token-by-token and step-by-step streaming;
- Support for checkpointing and memory of chat history;
- Easier to modify or extend (e.g., with additional tools, structured responses, etc.)
Install LangGraph with:
.. code-block:: bash
pip install -U langgraph
.. code-block:: python
from typing import List, Optional, Tuple
from langchain.chains.constitutional_ai.prompts import (
CRITIQUE_PROMPT,
REVISION_PROMPT,
)
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langgraph.graph import END, START, StateGraph
from typing_extensions import Annotated, TypedDict
llm = ChatOpenAI(model="gpt-4o-mini")
class Critique(TypedDict):
\"\"\"Generate a critique, if needed.\"\"\"
critique_needed: Annotated[bool, ..., "Whether or not a critique is needed."]
critique: Annotated[str, ..., "If needed, the critique."]
critique_prompt = ChatPromptTemplate.from_template(
"Critique this response according to the critique request. "
"If no critique is needed, specify that.\\n\\n"
"Query: {query}\\n\\n"
"Response: {response}\\n\\n"
"Critique request: {critique_request}"
)
revision_prompt = ChatPromptTemplate.from_template(
"Revise this response according to the critique and reivsion request.\\n\\n"
"Query: {query}\\n\\n"
"Response: {response}\\n\\n"
"Critique request: {critique_request}\\n\\n"
"Critique: {critique}\\n\\n"
"If the critique does not identify anything worth changing, ignore the "
"revision request and return 'No revisions needed'. If the critique "
"does identify something worth changing, revise the response based on "
"the revision request.\\n\\n"
"Revision Request: {revision_request}"
)
chain = llm | StrOutputParser()
critique_chain = critique_prompt | llm.with_structured_output(Critique)
revision_chain = revision_prompt | llm | StrOutputParser()
class State(TypedDict):
query: str
constitutional_principles: List[ConstitutionalPrinciple]
initial_response: str
critiques_and_revisions: List[Tuple[str, str]]
response: str
async def generate_response(state: State):
\"\"\"Generate initial response.\"\"\"
response = await chain.ainvoke(state["query"])
return {"response": response, "initial_response": response}
async def critique_and_revise(state: State):
\"\"\"Critique and revise response according to principles.\"\"\"
critiques_and_revisions = []
response = state["initial_response"]
for principle in state["constitutional_principles"]:
critique = await critique_chain.ainvoke(
{
"query": state["query"],
"response": response,
"critique_request": principle.critique_request,
}
)
if critique["critique_needed"]:
revision = await revision_chain.ainvoke(
{
"query": state["query"],
"response": response,
"critique_request": principle.critique_request,
"critique": critique["critique"],
"revision_request": principle.revision_request,
}
)
response = revision
critiques_and_revisions.append((critique["critique"], revision))
else:
critiques_and_revisions.append((critique["critique"], ""))
return {
"critiques_and_revisions": critiques_and_revisions,
"response": response,
}
graph = StateGraph(State)
graph.add_node("generate_response", generate_response)
graph.add_node("critique_and_revise", critique_and_revise)
graph.add_edge(START, "generate_response")
graph.add_edge("generate_response", "critique_and_revise")
graph.add_edge("critique_and_revise", END)
app = graph.compile()
.. code-block:: python
constitutional_principles=[
ConstitutionalPrinciple(
critique_request="Tell if this answer is good.",
revision_request="Give a better answer.",
)
]
query = "What is the meaning of life? Answer in 10 words or fewer."
async for step in app.astream(
{"query": query, "constitutional_principles": constitutional_principles},
stream_mode="values",
):
subset = ["initial_response", "critiques_and_revisions", "response"]
print({k: v for k, v in step.items() if k in subset})
Example:
.. code-block:: python
@ -44,7 +187,7 @@ class ConstitutionalChain(Chain):
)
constitutional_chain.run(question="What is the meaning of life?")
"""
""" # noqa: E501
chain: LLMChain
constitutional_principles: List[ConstitutionalPrinciple]

View File

@ -1,7 +1,6 @@
from __future__ import annotations
import re
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Sequence, Tuple
import numpy as np
@ -9,10 +8,12 @@ from langchain_core.callbacks import (
CallbackManagerForChainRun,
)
from langchain_core.language_models import BaseLanguageModel
from langchain_core.outputs import Generation
from langchain_core.messages import AIMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import Runnable
from langchain.chains.base import Chain
from langchain.chains.flare.prompts import (
@ -23,51 +24,14 @@ from langchain.chains.flare.prompts import (
from langchain.chains.llm import LLMChain
class _ResponseChain(LLMChain):
"""Base class for chains that generate responses."""
prompt: BasePromptTemplate = PROMPT
@classmethod
def is_lc_serializable(cls) -> bool:
return False
@property
def input_keys(self) -> List[str]:
return self.prompt.input_variables
def generate_tokens_and_log_probs(
self,
_input: Dict[str, Any],
*,
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Tuple[Sequence[str], Sequence[float]]:
llm_result = self.generate([_input], run_manager=run_manager)
return self._extract_tokens_and_log_probs(llm_result.generations[0])
@abstractmethod
def _extract_tokens_and_log_probs(
self, generations: List[Generation]
) -> Tuple[Sequence[str], Sequence[float]]:
"""Extract tokens and log probs from response."""
class _OpenAIResponseChain(_ResponseChain):
"""Chain that generates responses from user input and context."""
llm: BaseLanguageModel
def _extract_tokens_and_log_probs(
self, generations: List[Generation]
) -> Tuple[Sequence[str], Sequence[float]]:
tokens = []
log_probs = []
for gen in generations:
if gen.generation_info is None:
raise ValueError
tokens.extend(gen.generation_info["logprobs"]["tokens"])
log_probs.extend(gen.generation_info["logprobs"]["token_logprobs"])
return tokens, log_probs
def _extract_tokens_and_log_probs(response: AIMessage) -> Tuple[List[str], List[float]]:
"""Extract tokens and log probabilities from chat model response."""
tokens = []
log_probs = []
for token in response.response_metadata["logprobs"]["content"]:
tokens.append(token["token"])
log_probs.append(token["logprob"])
return tokens, log_probs
class QuestionGeneratorChain(LLMChain):
@ -111,9 +75,9 @@ class FlareChain(Chain):
"""Chain that combines a retriever, a question generator,
and a response generator."""
question_generator_chain: QuestionGeneratorChain
question_generator_chain: Runnable
"""Chain that generates questions from uncertain spans."""
response_chain: _ResponseChain
response_chain: Runnable
"""Chain that generates responses from user input and context."""
output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser)
"""Parser that determines whether the chain is finished."""
@ -152,12 +116,16 @@ class FlareChain(Chain):
for question in questions:
docs.extend(self.retriever.invoke(question))
context = "\n\n".join(d.page_content for d in docs)
result = self.response_chain.predict(
user_input=user_input,
context=context,
response=response,
callbacks=callbacks,
result = self.response_chain.invoke(
{
"user_input": user_input,
"context": context,
"response": response,
},
{"callbacks": callbacks},
)
if isinstance(result, AIMessage):
result = result.content
marginal, finished = self.output_parser.parse(result)
return marginal, finished
@ -178,13 +146,18 @@ class FlareChain(Chain):
for span in low_confidence_spans
]
callbacks = _run_manager.get_child()
question_gen_outputs = self.question_generator_chain.apply(
question_gen_inputs, callbacks=callbacks
)
questions = [
output[self.question_generator_chain.output_keys[0]]
for output in question_gen_outputs
]
if isinstance(self.question_generator_chain, LLMChain):
question_gen_outputs = self.question_generator_chain.apply(
question_gen_inputs, callbacks=callbacks
)
questions = [
output[self.question_generator_chain.output_keys[0]]
for output in question_gen_outputs
]
else:
questions = self.question_generator_chain.batch(
question_gen_inputs, config={"callbacks": callbacks}
)
_run_manager.on_text(
f"Generated Questions: {questions}", color="yellow", end="\n"
)
@ -206,8 +179,10 @@ class FlareChain(Chain):
f"Current Response: {response}", color="blue", end="\n"
)
_input = {"user_input": user_input, "context": "", "response": response}
tokens, log_probs = self.response_chain.generate_tokens_and_log_probs(
_input, run_manager=_run_manager
tokens, log_probs = _extract_tokens_and_log_probs(
self.response_chain.invoke(
_input, {"callbacks": _run_manager.get_child()}
)
)
low_confidence_spans = _low_confidence_spans(
tokens,
@ -251,18 +226,16 @@ class FlareChain(Chain):
FlareChain class with the given language model.
"""
try:
from langchain_openai import OpenAI
from langchain_openai import ChatOpenAI
except ImportError:
raise ImportError(
"OpenAI is required for FlareChain. "
"Please install langchain-openai."
"pip install langchain-openai"
)
question_gen_chain = QuestionGeneratorChain(llm=llm)
response_llm = OpenAI(
max_tokens=max_generation_len, model_kwargs={"logprobs": 1}, temperature=0
)
response_chain = _OpenAIResponseChain(llm=response_llm)
llm = ChatOpenAI(max_tokens=max_generation_len, logprobs=True, temperature=0)
response_chain = PROMPT | llm
question_gen_chain = QUESTION_GENERATOR_PROMPT | llm | StrOutputParser()
return cls(
question_generator_chain=question_gen_chain,
response_chain=response_chain,

View File

@ -11,7 +11,9 @@ import numpy as np
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.runnables import Runnable
from langchain.chains.base import Chain
from langchain.chains.hyde.prompts import PROMPT_MAP
@ -25,7 +27,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
"""
base_embeddings: Embeddings
llm_chain: LLMChain
llm_chain: Runnable
class Config:
arbitrary_types_allowed = True
@ -34,12 +36,15 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
@property
def input_keys(self) -> List[str]:
"""Input keys for Hyde's LLM chain."""
return self.llm_chain.input_keys
return self.llm_chain.input_schema.schema()["required"]
@property
def output_keys(self) -> List[str]:
"""Output keys for Hyde's LLM chain."""
return self.llm_chain.output_keys
if isinstance(self.llm_chain, LLMChain):
return self.llm_chain.output_keys
else:
return ["text"]
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call the base embeddings."""
@ -51,9 +56,12 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
def embed_query(self, text: str) -> List[float]:
"""Generate a hypothetical document and embedded it."""
var_name = self.llm_chain.input_keys[0]
result = self.llm_chain.generate([{var_name: text}])
documents = [generation.text for generation in result.generations[0]]
var_name = self.input_keys[0]
result = self.llm_chain.invoke({var_name: text})
if isinstance(self.llm_chain, LLMChain):
documents = [result[self.output_keys[0]]]
else:
documents = [result]
embeddings = self.embed_documents(documents)
return self.combine_embeddings(embeddings)
@ -64,7 +72,9 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
) -> Dict[str, str]:
"""Call the internal llm chain."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
return self.llm_chain(inputs, callbacks=_run_manager.get_child())
return self.llm_chain.invoke(
inputs, config={"callbacks": _run_manager.get_child()}
)
@classmethod
def from_llm(
@ -86,7 +96,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
f"of {list(PROMPT_MAP.keys())}."
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
llm_chain = prompt | llm | StrOutputParser()
return cls(base_embeddings=base_embeddings, llm_chain=llm_chain, **kwargs)
@property

View File

@ -7,6 +7,7 @@ import re
import warnings
from typing import Any, Dict, List, Optional
from langchain_core._api import deprecated
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
@ -20,16 +21,132 @@ from langchain.chains.llm import LLMChain
from langchain.chains.llm_math.prompt import PROMPT
@deprecated(
since="0.2.13",
message=(
"This class is deprecated and will be removed in langchain 1.0. "
"See API reference for replacement: "
"https://api.python.langchain.com/en/latest/chains/langchain.chains.llm_math.base.LLMMathChain.html" # noqa: E501
),
removal="1.0",
)
class LLMMathChain(Chain):
"""Chain that interprets a prompt and executes python code to do math.
Note: this class is deprecated. See below for a replacement implementation
using LangGraph. The benefits of this implementation are:
- Uses LLM tool calling features;
- Support for both token-by-token and step-by-step streaming;
- Support for checkpointing and memory of chat history;
- Easier to modify or extend (e.g., with additional tools, structured responses, etc.)
Install LangGraph with:
.. code-block:: bash
pip install -U langgraph
.. code-block:: python
import math
from typing import Annotated, Sequence
from langchain_core.messages import BaseMessage
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt.tool_node import ToolNode
import numexpr
from typing_extensions import TypedDict
@tool
def calculator(expression: str) -> str:
\"\"\"Calculate expression using Python's numexpr library.
Expression should be a single line mathematical expression
that solves the problem.
Examples:
"37593 * 67" for "37593 times 67"
"37593**(1/5)" for "37593^(1/5)"
\"\"\"
local_dict = {"pi": math.pi, "e": math.e}
return str(
numexpr.evaluate(
expression.strip(),
global_dict={}, # restrict access to globals
local_dict=local_dict, # add common mathematical functions
)
)
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
tools = [calculator]
llm_with_tools = llm.bind_tools(tools, tool_choice="any")
class ChainState(TypedDict):
\"\"\"LangGraph state.\"\"\"
messages: Annotated[Sequence[BaseMessage], add_messages]
async def acall_chain(state: ChainState, config: RunnableConfig):
last_message = state["messages"][-1]
response = await llm_with_tools.ainvoke(state["messages"], config)
return {"messages": [response]}
async def acall_model(state: ChainState, config: RunnableConfig):
response = await llm.ainvoke(state["messages"], config)
return {"messages": [response]}
graph_builder = StateGraph(ChainState)
graph_builder.add_node("call_tool", acall_chain)
graph_builder.add_node("execute_tool", ToolNode(tools))
graph_builder.add_node("call_model", acall_model)
graph_builder.set_entry_point("call_tool")
graph_builder.add_edge("call_tool", "execute_tool")
graph_builder.add_edge("execute_tool", "call_model")
graph_builder.add_edge("call_model", END)
chain = graph_builder.compile()
.. code-block:: python
example_query = "What is 551368 divided by 82"
events = chain.astream(
{"messages": [("user", example_query)]},
stream_mode="values",
)
async for event in events:
event["messages"][-1].pretty_print()
.. code-block:: none
================================ Human Message =================================
What is 551368 divided by 82
================================== Ai Message ==================================
Tool Calls:
calculator (call_MEiGXuJjJ7wGU4aOT86QuGJS)
Call ID: call_MEiGXuJjJ7wGU4aOT86QuGJS
Args:
expression: 551368 / 82
================================= Tool Message =================================
Name: calculator
6724.0
================================== Ai Message ==================================
551368 divided by 82 equals 6724.
Example:
.. code-block:: python
from langchain.chains import LLMMathChain
from langchain_community.llms import OpenAI
llm_math = LLMMathChain.from_llm(OpenAI())
"""
""" # noqa: E501
llm_chain: LLMChain
llm: Optional[BaseLanguageModel] = None

View File

@ -5,15 +5,27 @@ from __future__ import annotations
import warnings
from typing import Any, Dict, List, Optional
from langchain_core._api import deprecated
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.pydantic_v1 import root_validator
from langchain_core.runnables import Runnable
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.natbot.prompt import PROMPT
@deprecated(
since="0.2.13",
message=(
"Importing NatBotChain from langchain is deprecated and will be removed in "
"langchain 1.0. Please import from langchain_community instead: "
"from langchain_community.chains.natbot import NatBotChain. "
"You may need to pip install -U langchain-community."
),
removal="1.0",
)
class NatBotChain(Chain):
"""Implement an LLM driven browser.
@ -37,7 +49,7 @@ class NatBotChain(Chain):
natbot = NatBotChain.from_default("Buy me a new hat.")
"""
llm_chain: LLMChain
llm_chain: Runnable
objective: str
"""Objective that NatBot is tasked with completing."""
llm: Optional[BaseLanguageModel] = None
@ -60,7 +72,7 @@ class NatBotChain(Chain):
"class method."
)
if "llm_chain" not in values and values["llm"] is not None:
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=PROMPT)
values["llm_chain"] = PROMPT | values["llm"] | StrOutputParser()
return values
@classmethod
@ -77,7 +89,7 @@ class NatBotChain(Chain):
cls, llm: BaseLanguageModel, objective: str, **kwargs: Any
) -> NatBotChain:
"""Load from LLM."""
llm_chain = LLMChain(llm=llm, prompt=PROMPT)
llm_chain = PROMPT | llm | StrOutputParser()
return cls(llm_chain=llm_chain, objective=objective, **kwargs)
@property
@ -104,12 +116,14 @@ class NatBotChain(Chain):
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
url = inputs[self.input_url_key]
browser_content = inputs[self.input_browser_content_key]
llm_cmd = self.llm_chain.predict(
objective=self.objective,
url=url[:100],
previous_command=self.previous_command,
browser_content=browser_content[:4500],
callbacks=_run_manager.get_child(),
llm_cmd = self.llm_chain.invoke(
{
"objective": self.objective,
"url": url[:100],
"previous_command": self.previous_command,
"browser_content": browser_content[:4500],
},
config={"callbacks": _run_manager.get_child()},
)
llm_cmd = llm_cmd.strip()
self.previous_command = llm_cmd

View File

@ -8,8 +8,9 @@ from typing import Any, Callable, Dict, Optional, Sequence, cast
from langchain_core.callbacks.manager import Callbacks
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import Runnable
from langchain.chains.llm import LLMChain
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
@ -49,12 +50,15 @@ class LLMChainExtractor(BaseDocumentCompressor):
"""Document compressor that uses an LLM chain to extract
the relevant parts of documents."""
llm_chain: LLMChain
llm_chain: Runnable
"""LLM wrapper to use for compressing documents."""
get_input: Callable[[str, Document], dict] = default_get_input
"""Callable for constructing the chain input from the query and a Document."""
class Config:
arbitrary_types_allowed = True
def compress_documents(
self,
documents: Sequence[Document],
@ -65,10 +69,13 @@ class LLMChainExtractor(BaseDocumentCompressor):
compressed_docs = []
for doc in documents:
_input = self.get_input(query, doc)
output_dict = self.llm_chain.invoke(_input, config={"callbacks": callbacks})
output = output_dict[self.llm_chain.output_key]
if self.llm_chain.prompt.output_parser is not None:
output = self.llm_chain.prompt.output_parser.parse(output)
output_ = self.llm_chain.invoke(_input, config={"callbacks": callbacks})
if isinstance(self.llm_chain, LLMChain):
output = output_[self.llm_chain.output_key]
if self.llm_chain.prompt.output_parser is not None:
output = self.llm_chain.prompt.output_parser.parse(output)
else:
output = output_
if len(output) == 0:
continue
compressed_docs.append(
@ -85,9 +92,7 @@ class LLMChainExtractor(BaseDocumentCompressor):
"""Compress page content of raw documents asynchronously."""
outputs = await asyncio.gather(
*[
self.llm_chain.apredict_and_parse(
**self.get_input(query, doc), callbacks=callbacks
)
self.llm_chain.ainvoke(self.get_input(query, doc), callbacks=callbacks)
for doc in documents
]
)
@ -111,5 +116,9 @@ class LLMChainExtractor(BaseDocumentCompressor):
"""Initialize from LLM."""
_prompt = prompt if prompt is not None else _get_default_chain_prompt()
_get_input = get_input if get_input is not None else default_get_input
llm_chain = LLMChain(llm=llm, prompt=_prompt, **(llm_chain_kwargs or {}))
if _prompt.output_parser is not None:
parser = _prompt.output_parser
else:
parser = StrOutputParser()
llm_chain = _prompt | llm | parser
return cls(llm_chain=llm_chain, get_input=_get_input) # type: ignore[arg-type]

View File

@ -5,7 +5,9 @@ from typing import Any, Callable, Dict, Optional, Sequence
from langchain_core.callbacks.manager import Callbacks
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.runnables import Runnable
from langchain_core.runnables.config import RunnableConfig
from langchain.chains import LLMChain
@ -32,13 +34,16 @@ def default_get_input(query: str, doc: Document) -> Dict[str, Any]:
class LLMChainFilter(BaseDocumentCompressor):
"""Filter that drops documents that aren't relevant to the query."""
llm_chain: LLMChain
llm_chain: Runnable
"""LLM wrapper to use for filtering documents.
The chain prompt is expected to have a BooleanOutputParser."""
get_input: Callable[[str, Document], dict] = default_get_input
"""Callable for constructing the chain input from the query and a Document."""
class Config:
arbitrary_types_allowed = True
def compress_documents(
self,
documents: Sequence[Document],
@ -56,11 +61,15 @@ class LLMChainFilter(BaseDocumentCompressor):
documents,
)
for output_dict, doc in outputs:
for output_, doc in outputs:
include_doc = None
output = output_dict[self.llm_chain.output_key]
if self.llm_chain.prompt.output_parser is not None:
include_doc = self.llm_chain.prompt.output_parser.parse(output)
if isinstance(self.llm_chain, LLMChain):
output = output_[self.llm_chain.output_key]
if self.llm_chain.prompt.output_parser is not None:
include_doc = self.llm_chain.prompt.output_parser.parse(output)
else:
if isinstance(output_, bool):
include_doc = output_
if include_doc:
filtered_docs.append(doc)
@ -82,11 +91,15 @@ class LLMChainFilter(BaseDocumentCompressor):
),
documents,
)
for output_dict, doc in outputs:
for output_, doc in outputs:
include_doc = None
output = output_dict[self.llm_chain.output_key]
if self.llm_chain.prompt.output_parser is not None:
include_doc = self.llm_chain.prompt.output_parser.parse(output)
if isinstance(self.llm_chain, LLMChain):
output = output_[self.llm_chain.output_key]
if self.llm_chain.prompt.output_parser is not None:
include_doc = self.llm_chain.prompt.output_parser.parse(output)
else:
if isinstance(output_, bool):
include_doc = output_
if include_doc:
filtered_docs.append(doc)
@ -110,5 +123,9 @@ class LLMChainFilter(BaseDocumentCompressor):
A LLMChainFilter that uses the given language model.
"""
_prompt = prompt if prompt is not None else _get_default_chain_prompt()
llm_chain = LLMChain(llm=llm, prompt=_prompt)
if _prompt.output_parser is not None:
parser = _prompt.output_parser
else:
parser = StrOutputParser()
llm_chain = _prompt | llm | parser
return cls(llm_chain=llm_chain, **kwargs)

View File

@ -7,11 +7,11 @@ from langchain_core.callbacks import (
)
from langchain_core.documents import Document
from langchain_core.language_models import BaseLLM
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain.chains.llm import LLMChain
from langchain_core.runnables import Runnable
logger = logging.getLogger(__name__)
@ -30,7 +30,7 @@ class RePhraseQueryRetriever(BaseRetriever):
Then, retrieve docs for the re-phrased query."""
retriever: BaseRetriever
llm_chain: LLMChain
llm_chain: Runnable
@classmethod
def from_llm(
@ -51,8 +51,7 @@ class RePhraseQueryRetriever(BaseRetriever):
Returns:
RePhraseQueryRetriever
"""
llm_chain = LLMChain(llm=llm, prompt=prompt)
llm_chain = prompt | llm | StrOutputParser()
return cls(
retriever=retriever,
llm_chain=llm_chain,
@ -72,8 +71,9 @@ class RePhraseQueryRetriever(BaseRetriever):
Returns:
Relevant documents for re-phrased question
"""
response = self.llm_chain(query, callbacks=run_manager.get_child())
re_phrased_question = response["text"]
re_phrased_question = self.llm_chain.invoke(
query, {"callbacks": run_manager.get_child()}
)
logger.info(f"Re-phrased question: {re_phrased_question}")
docs = self.retriever.invoke(
re_phrased_question, config={"callbacks": run_manager.get_child()}

View File

@ -0,0 +1,84 @@
from langchain_core.documents import Document
from langchain_core.language_models import FakeListChatModel
from langchain.retrievers.document_compressors import LLMChainExtractor
def test_llm_chain_extractor() -> None:
documents = [
Document(
page_content=(
"The sky is blue. Candlepin bowling is popular in New England."
),
metadata={"a": 1},
),
Document(
page_content=(
"Mercury is the closest planet to the Sun. "
"Candlepin bowling balls are smaller."
),
metadata={"b": 2},
),
Document(page_content="The moon is round.", metadata={"c": 3}),
]
llm = FakeListChatModel(
responses=[
"Candlepin bowling is popular in New England.",
"Candlepin bowling balls are smaller.",
"NO_OUTPUT",
]
)
doc_compressor = LLMChainExtractor.from_llm(llm)
output = doc_compressor.compress_documents(
documents, "Tell me about Candlepin bowling."
)
expected = documents = [
Document(
page_content="Candlepin bowling is popular in New England.",
metadata={"a": 1},
),
Document(
page_content="Candlepin bowling balls are smaller.", metadata={"b": 2}
),
]
assert output == expected
async def test_llm_chain_extractor_async() -> None:
documents = [
Document(
page_content=(
"The sky is blue. Candlepin bowling is popular in New England."
),
metadata={"a": 1},
),
Document(
page_content=(
"Mercury is the closest planet to the Sun. "
"Candlepin bowling balls are smaller."
),
metadata={"b": 2},
),
Document(page_content="The moon is round.", metadata={"c": 3}),
]
llm = FakeListChatModel(
responses=[
"Candlepin bowling is popular in New England.",
"Candlepin bowling balls are smaller.",
"NO_OUTPUT",
]
)
doc_compressor = LLMChainExtractor.from_llm(llm)
output = await doc_compressor.acompress_documents(
documents, "Tell me about Candlepin bowling."
)
expected = documents = [
Document(
page_content="Candlepin bowling is popular in New England.",
metadata={"a": 1},
),
Document(
page_content="Candlepin bowling balls are smaller.", metadata={"b": 2}
),
]
assert output == expected

View File

@ -0,0 +1,46 @@
from langchain_core.documents import Document
from langchain_core.language_models import FakeListChatModel
from langchain.retrievers.document_compressors import LLMChainFilter
def test_llm_chain_filter() -> None:
documents = [
Document(
page_content="Candlepin bowling is popular in New England.",
metadata={"a": 1},
),
Document(
page_content="Candlepin bowling balls are smaller.",
metadata={"b": 2},
),
Document(page_content="The moon is round.", metadata={"c": 3}),
]
llm = FakeListChatModel(responses=["YES", "YES", "NO"])
doc_compressor = LLMChainFilter.from_llm(llm)
output = doc_compressor.compress_documents(
documents, "Tell me about Candlepin bowling."
)
expected = documents[:2]
assert output == expected
async def test_llm_chain_extractor_async() -> None:
documents = [
Document(
page_content="Candlepin bowling is popular in New England.",
metadata={"a": 1},
),
Document(
page_content="Candlepin bowling balls are smaller.",
metadata={"b": 2},
),
Document(page_content="The moon is round.", metadata={"c": 3}),
]
llm = FakeListChatModel(responses=["YES", "YES", "NO"])
doc_compressor = LLMChainFilter.from_llm(llm)
output = await doc_compressor.acompress_documents(
documents, "Tell me about Candlepin bowling."
)
expected = documents[:2]
assert output == expected