mirror of https://github.com/hwchase17/langchain
Tree of Thought introducing a new ToTChain. (#5167)
# [WIP] Tree of Thought introducing a new ToTChain. This PR adds a new chain called ToTChain that implements the ["Large Language Model Guided Tree-of-Though"](https://arxiv.org/pdf/2305.08291.pdf) paper. There's a notebook example `docs/modules/chains/examples/tot.ipynb` that shows how to use it. Implements #4975 ## Who can review? Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested: - @hwchase17 - @vowelparrot --------- Co-authored-by: Vadim Gubergrits <vgubergrits@outbox.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>pull/8339/head
parent
412e29d436
commit
e7e5cb9d08
@ -0,0 +1,239 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Tree of Thouht (ToT) example\n",
|
||||
"\n",
|
||||
"The Tree of Thought (ToT) is a chain that allows you to query a Large Language Model (LLM) using the Tree of Thought technique. This is based on the papaer [\"Large Language Model Guided Tree-of-Thought\"](https://arxiv.org/pdf/2305.08291.pdf)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.13) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"\n",
|
||||
"llm = OpenAI(temperature=1, max_tokens=512, model=\"text-davinci-003\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1\n",
|
||||
"\n",
|
||||
"- This is a 4x4 Sudoku puzzle.\n",
|
||||
"- The * represents a cell to be filled.\n",
|
||||
"- The | character separates rows.\n",
|
||||
"- At each step, replace one or more * with digits 1-4.\n",
|
||||
"- There must be no duplicate digits in any row, column or 2x2 subgrid.\n",
|
||||
"- Keep the known digits from previous valid thoughts in place.\n",
|
||||
"- Each thought can be a partial or the final solution.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"sudoku_puzzle = \"3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1\"\n",
|
||||
"sudoku_solution = \"3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1\"\n",
|
||||
"problem_description = f\"\"\"\n",
|
||||
"{sudoku_puzzle}\n",
|
||||
"\n",
|
||||
"- This is a 4x4 Sudoku puzzle.\n",
|
||||
"- The * represents a cell to be filled.\n",
|
||||
"- The | character separates rows.\n",
|
||||
"- At each step, replace one or more * with digits 1-4.\n",
|
||||
"- There must be no duplicate digits in any row, column or 2x2 subgrid.\n",
|
||||
"- Keep the known digits from previous valid thoughts in place.\n",
|
||||
"- Each thought can be a partial or the final solution.\n",
|
||||
"\"\"\".strip()\n",
|
||||
"print(problem_description)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Rules Based Checker\n",
|
||||
"\n",
|
||||
"Each thought is evaluated by the thought checker and is given a validity type: valid, invalid or partial. A simple checker can be rule based. For example, in the case of a sudoku puzzle, the checker can check if the puzzle is valid, invalid or partial.\n",
|
||||
"\n",
|
||||
"In the following code we implement a simple rule based checker for a specific 4x4 sudoku puzzle.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import Tuple\n",
|
||||
"from langchain_experimental.tot.checker import ToTChecker\n",
|
||||
"from langchain_experimental.tot.thought import ThoughtValidity\n",
|
||||
"import re\n",
|
||||
"\n",
|
||||
"class MyChecker(ToTChecker):\n",
|
||||
" def evaluate(self, problem_description: str, thoughts: Tuple[str, ...] = ()) -> ThoughtValidity:\n",
|
||||
" last_thought = thoughts[-1]\n",
|
||||
" clean_solution = last_thought.replace(\" \", \"\").replace('\"', \"\")\n",
|
||||
" regex_solution = clean_solution.replace(\"*\", \".\").replace(\"|\", \"\\\\|\")\n",
|
||||
" if sudoku_solution in clean_solution:\n",
|
||||
" return ThoughtValidity.VALID_FINAL\n",
|
||||
" elif re.search(regex_solution, sudoku_solution):\n",
|
||||
" return ThoughtValidity.VALID_INTERMEDIATE\n",
|
||||
" else:\n",
|
||||
" return ThoughtValidity.INVALID"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Just testing the MyChecker class above:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"checker = MyChecker()\n",
|
||||
"assert checker.evaluate(\"\", (\"3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1\",)) == ThoughtValidity.VALID_INTERMEDIATE\n",
|
||||
"assert checker.evaluate(\"\", (\"3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1\",)) == ThoughtValidity.VALID_FINAL\n",
|
||||
"assert checker.evaluate(\"\", (\"3,4,1,2|1,2,3,4|2,1,4,3|4,3,*,1\",)) == ThoughtValidity.VALID_INTERMEDIATE\n",
|
||||
"assert checker.evaluate(\"\", (\"3,4,1,2|1,2,3,4|2,1,4,3|4,*,3,1\",)) == ThoughtValidity.INVALID"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Tree of Thought Chain\n",
|
||||
"\n",
|
||||
"Initialize and run the ToT chain, with maximum number of interactions `k` set to `30` and the maximum number child thoughts `c` set to `8`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new ToTChain chain...\u001b[0m\n",
|
||||
"Starting the ToT solve procedure.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/harrisonchase/workplace/langchain/libs/langchain/langchain/chains/llm.py:275: UserWarning: The predict_and_parse method is deprecated, instead pass an output parser directly to LLMChain.\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[31;1m\u001b[1;3mThought: 3*,*,2|1*,3,*|*,1,*,3|4,*,*,1\n",
|
||||
"\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3*,1,2|1*,3,*|*,1,*,3|4,*,*,1\n",
|
||||
"\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3*,1,2|1*,3,4|*,1,*,3|4,*,*,1\n",
|
||||
"\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3*,1,2|1*,3,4|*,1,2,3|4,*,*,1\n",
|
||||
"\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3*,1,2|1*,3,4|2,1,*,3|4,*,*,1\n",
|
||||
"\u001b[0m"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Type <enum 'ThoughtValidity'> not serializable\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[31;1m\u001b[1;3mThought: 3,*,*,2|1,*,3,*|*,1,*,3|4,1,*,*\n",
|
||||
"\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3,*,*,2|*,3,2,*|*,1,*,3|4,1,*,*\n",
|
||||
"\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3,2,*,2|1,*,3,*|*,1,*,3|4,1,*,*\n",
|
||||
"\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3,2,*,2|1,*,3,*|1,1,*,3|4,1,*,*\n",
|
||||
"\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3,2,*,2|1,1,3,*|1,1,*,3|4,1,*,*\n",
|
||||
"\u001b[0m\u001b[33;1m\u001b[1;3mThought: 3,*,*,2|1,2,3,*|*,1,*,3|4,*,*,1\n",
|
||||
"\u001b[0m\u001b[31;1m\u001b[1;3m Thought: 3,1,4,2|1,2,3,4|2,1,4,3|4,3,2,1\n",
|
||||
"\u001b[0m\u001b[32;1m\u001b[1;3m Thought: 3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1\n",
|
||||
"\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1'"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain_experimental.tot.base import ToTChain\n",
|
||||
"\n",
|
||||
"tot_chain = ToTChain(llm=llm, checker=MyChecker(), k=30, c=5, verbose=True, verbose_llm=False)\n",
|
||||
"tot_chain.run(problem_description=problem_description)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
from langchain_experimental.tot.base import ToTChain
|
||||
from langchain_experimental.tot.checker import ToTChecker
|
||||
|
||||
__all__ = ["ToTChain", "ToTChecker"]
|
@ -0,0 +1,155 @@
|
||||
"""
|
||||
This a Tree of Thought (ToT) chain based on the paper "Large Language Model
|
||||
Guided Tree-of-Thought"
|
||||
|
||||
https://arxiv.org/pdf/2305.08291.pdf
|
||||
|
||||
The Tree of Thought (ToT) chain uses a tree structure to explore the space of
|
||||
possible solutions to a problem.
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from textwrap import indent
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain_experimental.tot.checker import ToTChecker
|
||||
from langchain_experimental.tot.controller import ToTController
|
||||
from langchain_experimental.tot.memory import ToTDFSMemory
|
||||
from langchain_experimental.tot.thought import Thought, ThoughtValidity
|
||||
from langchain_experimental.tot.thought_generation import (
|
||||
BaseThoughtGenerationStrategy,
|
||||
ProposePromptStrategy,
|
||||
)
|
||||
|
||||
|
||||
class ToTChain(Chain):
|
||||
"""
|
||||
A Chain implementing the Tree of Thought (ToT).
|
||||
"""
|
||||
|
||||
llm: BaseLanguageModel
|
||||
"""
|
||||
Language model to use. It must be set to produce different variations for
|
||||
the same prompt.
|
||||
"""
|
||||
checker: ToTChecker
|
||||
"""ToT Checker to use."""
|
||||
output_key: str = "response" #: :meta private:
|
||||
k: int = 10
|
||||
"""The maximmum number of conversation rounds"""
|
||||
c: int = 3
|
||||
"""The number of children to explore at each node"""
|
||||
tot_memory: ToTDFSMemory = ToTDFSMemory()
|
||||
tot_controller: ToTController = ToTController()
|
||||
tot_strategy_class: Type[BaseThoughtGenerationStrategy] = ProposePromptStrategy
|
||||
verbose_llm: bool = False
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@classmethod
|
||||
def from_llm(cls, llm: BaseLanguageModel, **kwargs: Any) -> ToTChain:
|
||||
"""
|
||||
Create a ToTChain from a language model.
|
||||
|
||||
:param llm: The language model to use.
|
||||
:param kwargs: Additional arguments to pass to the ToTChain constructor.
|
||||
"""
|
||||
return cls(llm=llm, **kwargs)
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self.tot_controller.c = self.c
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Will be whatever keys the prompt expects.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return ["problem_description"]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Will always return text key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def log_thought(
|
||||
self,
|
||||
thought: Thought,
|
||||
level: int,
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> None:
|
||||
if run_manager:
|
||||
colors = {
|
||||
ThoughtValidity.VALID_FINAL: "green",
|
||||
ThoughtValidity.VALID_INTERMEDIATE: "yellow",
|
||||
ThoughtValidity.INVALID: "red",
|
||||
}
|
||||
text = indent(f"Thought: {thought.text}\n", prefix=" " * level)
|
||||
run_manager.on_text(
|
||||
text=text, color=colors[thought.validity], verbose=self.verbose
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
if run_manager:
|
||||
run_manager.on_text(text="Starting the ToT solve procedure.\n")
|
||||
|
||||
problem_description = inputs["problem_description"]
|
||||
checker_inputs = {"problem_description": problem_description}
|
||||
thoughts_path: tuple[str, ...] = ()
|
||||
thought_generator = self.tot_strategy_class(
|
||||
llm=self.llm, c=self.c, verbose=self.verbose_llm
|
||||
)
|
||||
|
||||
level = 0
|
||||
for _ in range(self.k):
|
||||
level = self.tot_memory.level
|
||||
thought_text = thought_generator.next_thought(
|
||||
problem_description, thoughts_path, callbacks=_run_manager.get_child()
|
||||
)
|
||||
checker_inputs["thoughts"] = thoughts_path + (thought_text,)
|
||||
thought_validity = self.checker(
|
||||
checker_inputs, callbacks=_run_manager.get_child()
|
||||
)["validity"]
|
||||
thought = Thought(text=thought_text, validity=thought_validity)
|
||||
if thought.validity == ThoughtValidity.VALID_FINAL:
|
||||
self.log_thought(thought, level, run_manager)
|
||||
return {self.output_key: thought.text}
|
||||
self.tot_memory.store(thought)
|
||||
self.log_thought(thought, level, run_manager)
|
||||
thoughts_path = self.tot_controller(self.tot_memory)
|
||||
|
||||
return {self.output_key: "No solution found"}
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
raise NotImplementedError("Async not implemented yet")
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "tot"
|
@ -0,0 +1,52 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
|
||||
from langchain_experimental.tot.thought import ThoughtValidity
|
||||
|
||||
|
||||
class ToTChecker(Chain, ABC):
|
||||
"""
|
||||
Tree of Thought (ToT) checker.
|
||||
|
||||
This is an abstract ToT checker that must be implemented by the user. You
|
||||
can implement a simple rule-based checker or a more sophisticated
|
||||
neural network based classifier.
|
||||
"""
|
||||
|
||||
output_key: str = "validity" #: :meta private:
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""The checker input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return ["problem_description", "thoughts"]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""The checker output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
@abstractmethod
|
||||
def evaluate(
|
||||
self,
|
||||
problem_description: str,
|
||||
thoughts: Tuple[str, ...] = (),
|
||||
) -> ThoughtValidity:
|
||||
"""
|
||||
Evaluate the response to the problem description and return the solution type.
|
||||
"""
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, ThoughtValidity]:
|
||||
return {self.output_key: self.evaluate(**inputs)}
|
@ -0,0 +1,54 @@
|
||||
from typing import Tuple
|
||||
|
||||
from langchain_experimental.tot.memory import ToTDFSMemory
|
||||
from langchain_experimental.tot.thought import ThoughtValidity
|
||||
|
||||
|
||||
class ToTController:
|
||||
"""
|
||||
Tree of Thought (ToT) controller.
|
||||
|
||||
This is a version of a ToT controller, dubbed in the paper as a "Simple
|
||||
Controller".
|
||||
|
||||
It has one parameter `c` which is the number of children to explore for each
|
||||
thought.
|
||||
"""
|
||||
|
||||
def __init__(self, c: int = 3):
|
||||
"""
|
||||
Initialize the controller.
|
||||
|
||||
Args:
|
||||
c: The number of children to explore at each node.
|
||||
"""
|
||||
self.c = c
|
||||
|
||||
def __call__(self, memory: ToTDFSMemory) -> Tuple[str, ...]:
|
||||
next_thought = memory.top()
|
||||
parent_thought = memory.top_parent()
|
||||
validity = (
|
||||
ThoughtValidity.VALID_INTERMEDIATE
|
||||
if next_thought is None
|
||||
else next_thought.validity
|
||||
)
|
||||
|
||||
# 1 if the current partial solution is invalid, backtrack to the parent
|
||||
# thought.
|
||||
if validity == ThoughtValidity.INVALID:
|
||||
memory.pop()
|
||||
next_thought = memory.top()
|
||||
if next_thought and len(next_thought.children) >= self.c:
|
||||
memory.pop()
|
||||
|
||||
# 2 if the current partial solution is valid but C children were
|
||||
# explored and yet failed to find a final solution, backtrack to the
|
||||
# parent thought.
|
||||
elif (
|
||||
validity == ThoughtValidity.VALID_INTERMEDIATE
|
||||
and parent_thought
|
||||
and len(parent_thought.children) >= self.c
|
||||
):
|
||||
memory.pop(2)
|
||||
|
||||
return tuple(thought.text for thought in memory.current_path())
|
@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain_experimental.tot.thought import Thought
|
||||
|
||||
|
||||
class ToTDFSMemory:
|
||||
"""
|
||||
Memory for the Tree of Thought (ToT) chain. Implemented as a stack of
|
||||
thoughts. This allows for a depth first search (DFS) of the ToT.
|
||||
"""
|
||||
|
||||
def __init__(self, stack: Optional[List[Thought]] = None):
|
||||
self.stack: list[Thought] = stack or []
|
||||
|
||||
def top(self) -> Optional[Thought]:
|
||||
"Get the top of the stack without popping it."
|
||||
return self.stack[-1] if len(self.stack) > 0 else None
|
||||
|
||||
def pop(self, n: int = 1) -> Optional[Thought]:
|
||||
"Pop the top n elements of the stack and return the last one."
|
||||
if len(self.stack) < n:
|
||||
return None
|
||||
for _ in range(n):
|
||||
node = self.stack.pop()
|
||||
return node
|
||||
|
||||
def top_parent(self) -> Optional[Thought]:
|
||||
"Get the parent of the top of the stack without popping it."
|
||||
return self.stack[-2] if len(self.stack) > 1 else None
|
||||
|
||||
def store(self, node: Thought) -> None:
|
||||
"Add a node on the top of the stack."
|
||||
if len(self.stack) > 0:
|
||||
self.stack[-1].children.add(node)
|
||||
self.stack.append(node)
|
||||
|
||||
@property
|
||||
def level(self) -> int:
|
||||
"Return the current level of the stack."
|
||||
return len(self.stack)
|
||||
|
||||
def current_path(self) -> List[Thought]:
|
||||
"Return the thoughts path."
|
||||
return self.stack[:]
|
@ -0,0 +1,137 @@
|
||||
import json
|
||||
from textwrap import dedent
|
||||
from typing import List
|
||||
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema import BaseOutputParser
|
||||
|
||||
from langchain_experimental.tot.thought import ThoughtValidity
|
||||
|
||||
COT_PROMPT = PromptTemplate(
|
||||
template_format="jinja2",
|
||||
input_variables=["problem_description", "thoughts"],
|
||||
template=dedent(
|
||||
"""
|
||||
You are an intelligent agent that is generating one thought at a time in
|
||||
a tree of thoughts setting.
|
||||
|
||||
PROBLEM
|
||||
|
||||
{{problem_description}}
|
||||
|
||||
{% if thoughts %}
|
||||
THOUGHTS
|
||||
|
||||
{% for thought in thoughts %}
|
||||
{{ thought }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
Let's think step by step.
|
||||
"""
|
||||
).strip(),
|
||||
)
|
||||
|
||||
|
||||
class JSONListOutputParser(BaseOutputParser):
|
||||
"""Class to parse the output of a PROPOSE_PROMPT response."""
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "json_list"
|
||||
|
||||
def parse(self, text: str) -> List[str]:
|
||||
"""Parse the output of an LLM call."""
|
||||
|
||||
json_string = text.split("```json")[1].strip().strip("```").strip()
|
||||
try:
|
||||
return json.loads(json_string)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
|
||||
|
||||
PROPOSE_PROMPT = PromptTemplate(
|
||||
template_format="jinja2",
|
||||
input_variables=["problem_description", "thoughts", "n"],
|
||||
output_parser=JSONListOutputParser(),
|
||||
template=dedent(
|
||||
"""
|
||||
You are an intelligent agent that is generating thoughts in a tree of
|
||||
thoughts setting.
|
||||
|
||||
The output should be a markdown code snippet formatted as a JSON list of
|
||||
strings, including the leading and trailing "```json" and "```":
|
||||
|
||||
```json
|
||||
[
|
||||
"<thought-1>",
|
||||
"<thought-2>",
|
||||
"<thought-3>"
|
||||
]
|
||||
```
|
||||
|
||||
PROBLEM
|
||||
|
||||
{{ problem_description }}
|
||||
|
||||
{% if thoughts %}
|
||||
VALID THOUGHTS
|
||||
|
||||
{% for thought in thoughts %}
|
||||
{{ thought }}
|
||||
{% endfor %}
|
||||
|
||||
Possible next {{ n }} valid thoughts based on the last valid thought:
|
||||
{% else %}
|
||||
|
||||
Possible next {{ n }} valid thoughts based on the PROBLEM:
|
||||
{%- endif -%}
|
||||
"""
|
||||
).strip(),
|
||||
)
|
||||
|
||||
|
||||
class CheckerOutputParser(BaseOutputParser):
|
||||
def parse(self, text: str) -> ThoughtValidity:
|
||||
"""Parse the output of the language model."""
|
||||
text = text.upper()
|
||||
if "INVALID" in text:
|
||||
return ThoughtValidity.INVALID
|
||||
elif "INTERMEDIATE" in text:
|
||||
return ThoughtValidity.VALID_INTERMEDIATE
|
||||
elif "VALID" in text:
|
||||
return ThoughtValidity.VALID_FINAL
|
||||
else:
|
||||
return ThoughtValidity.INVALID
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "tot_llm_checker_output"
|
||||
|
||||
|
||||
CHECKER_PROMPT = PromptTemplate(
|
||||
input_variables=["problem_description", "thoughts"],
|
||||
template=dedent(
|
||||
"""
|
||||
You are an intelligent agent, validating thoughts of another intelligent agent.
|
||||
|
||||
PROBLEM
|
||||
|
||||
{problem_description}
|
||||
|
||||
THOUGHTS
|
||||
|
||||
{thoughts}
|
||||
|
||||
Evaluate the thoughts and respond with one word.
|
||||
|
||||
- Respond VALID if the last thought is a valid final solution to the
|
||||
poblem.
|
||||
- Respond INVALID if the last thought is invalid.
|
||||
- Respond INTERMEDIATE if the last thought is valid but not the final
|
||||
solution to the problem.
|
||||
|
||||
This chain of thoughts is"""
|
||||
).strip(),
|
||||
output_parser=CheckerOutputParser(),
|
||||
)
|
@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Set
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ThoughtValidity(Enum):
|
||||
VALID_INTERMEDIATE = 0
|
||||
VALID_FINAL = 1
|
||||
INVALID = 2
|
||||
|
||||
|
||||
class Thought(BaseModel):
|
||||
text: str
|
||||
validity: ThoughtValidity
|
||||
children: Set[Thought] = Field(default_factory=set)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return id(self)
|
@ -0,0 +1,94 @@
|
||||
"""
|
||||
We provide two strategies for generating thoughts in the Tree of Thoughts (ToT)
|
||||
framework to avoid repetition:
|
||||
|
||||
These strategies ensure that the language model generates diverse and
|
||||
non-repeating thoughts, which are crucial for problem-solving tasks that require
|
||||
exploration.
|
||||
"""
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_experimental.tot.prompts import COT_PROMPT, PROPOSE_PROMPT
|
||||
|
||||
|
||||
class BaseThoughtGenerationStrategy(LLMChain):
|
||||
"""
|
||||
Base class for a thought generation strategy.
|
||||
"""
|
||||
|
||||
c: int = 3
|
||||
"""The number of children thoughts to propose at each step."""
|
||||
|
||||
@abstractmethod
|
||||
def next_thought(
|
||||
self,
|
||||
problem_description: str,
|
||||
thoughts_path: Tuple[str, ...] = (),
|
||||
**kwargs: Any
|
||||
) -> str:
|
||||
"""
|
||||
Generate the next thought given the problem description and the thoughts
|
||||
generated so far.
|
||||
"""
|
||||
|
||||
|
||||
class SampleCoTStrategy(BaseThoughtGenerationStrategy):
|
||||
"""
|
||||
Sample thoughts from a Chain-of-Thought (CoT) prompt.
|
||||
|
||||
This strategy works better when the thought space is rich, such as when each
|
||||
thought is a paragraph. Independent and identically distributed samples
|
||||
lead to diversity, which helps to avoid repetition.
|
||||
"""
|
||||
|
||||
prompt: BasePromptTemplate = COT_PROMPT
|
||||
|
||||
def next_thought(
|
||||
self,
|
||||
problem_description: str,
|
||||
thoughts_path: Tuple[str, ...] = (),
|
||||
**kwargs: Any
|
||||
) -> str:
|
||||
response_text = self.predict_and_parse(
|
||||
problem_description=problem_description, thoughts=thoughts_path, **kwargs
|
||||
)
|
||||
return response_text if isinstance(response_text, str) else ""
|
||||
|
||||
|
||||
class ProposePromptStrategy(BaseThoughtGenerationStrategy):
|
||||
"""
|
||||
Propose thoughts sequentially using a "propose prompt".
|
||||
|
||||
This strategy works better when the thought space is more constrained, such
|
||||
as when each thought is just a word or a line. Proposing different thoughts
|
||||
in the same prompt completion helps to avoid duplication.
|
||||
"""
|
||||
|
||||
prompt: BasePromptTemplate = PROPOSE_PROMPT
|
||||
tot_memory: Dict[Tuple[str, ...], List[str]] = Field(default_factory=dict)
|
||||
|
||||
def next_thought(
|
||||
self,
|
||||
problem_description: str,
|
||||
thoughts_path: Tuple[str, ...] = (),
|
||||
**kwargs: Any
|
||||
) -> str:
|
||||
if thoughts_path not in self.tot_memory or not self.tot_memory[thoughts_path]:
|
||||
new_thoughts = self.predict_and_parse(
|
||||
problem_description=problem_description,
|
||||
thoughts=thoughts_path,
|
||||
n=self.c,
|
||||
**kwargs
|
||||
)
|
||||
if not new_thoughts:
|
||||
return ""
|
||||
if isinstance(new_thoughts, list):
|
||||
self.tot_memory[thoughts_path] = new_thoughts[::-1]
|
||||
else:
|
||||
return ""
|
||||
return self.tot_memory[thoughts_path].pop()
|
@ -0,0 +1,61 @@
|
||||
"""Fake LLM wrapper for testing purposes."""
|
||||
from typing import Any, Dict, List, Mapping, Optional, cast
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from pydantic import validator
|
||||
|
||||
|
||||
class FakeLLM(LLM):
|
||||
"""Fake LLM wrapper for testing purposes."""
|
||||
|
||||
queries: Optional[Mapping] = None
|
||||
sequential_responses: Optional[bool] = False
|
||||
response_index: int = 0
|
||||
|
||||
@validator("queries", always=True)
|
||||
def check_queries_required(
|
||||
cls, queries: Optional[Mapping], values: Mapping[str, Any]
|
||||
) -> Optional[Mapping]:
|
||||
if values.get("sequential_response") and not queries:
|
||||
raise ValueError(
|
||||
"queries is required when sequential_response is set to True"
|
||||
)
|
||||
return queries
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Return number of tokens."""
|
||||
return len(text.split())
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "fake"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
if self.sequential_responses:
|
||||
return self._get_next_response_in_sequence
|
||||
|
||||
if self.queries is not None:
|
||||
return self.queries[prompt]
|
||||
if stop is None:
|
||||
return "foo"
|
||||
else:
|
||||
return "bar"
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def _get_next_response_in_sequence(self) -> str:
|
||||
queries = cast(Mapping, self.queries)
|
||||
response = queries[list(queries.keys())[self.response_index]]
|
||||
self.response_index = self.response_index + 1
|
||||
return response
|
@ -0,0 +1,151 @@
|
||||
import re
|
||||
import unittest
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_experimental.tot.base import ToTChain
|
||||
from langchain_experimental.tot.checker import ToTChecker
|
||||
from langchain_experimental.tot.controller import ToTController
|
||||
from langchain_experimental.tot.memory import ToTDFSMemory
|
||||
from langchain_experimental.tot.thought import Thought, ThoughtValidity
|
||||
from langchain_experimental.tot.thought_generation import SampleCoTStrategy
|
||||
from tests.unit_tests.fake_llm import FakeLLM
|
||||
|
||||
sudoku_puzzle = "3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1"
|
||||
solutions = [
|
||||
"3,*,4,2|1,*,3,*|*,1,*,3|4,*,*,1", # VALID_INTERMEDIATE
|
||||
" 3,4,1,2|1,6,3,*|*,1,*,3|4,*,*,1", # INVALID c=1
|
||||
" 3,4,1,2|1,7,3,*|*,1,*,3|4,*,*,1", # INVALID c=2
|
||||
" 3,4,1,2|1,8,3,*|*,1,*,3|4,*,*,1", # INVALID c=3
|
||||
" 3,4,1,2|1,2,3,*|*,1,*,3|4,*,*,1", # VALID_INTERMEDIATE c=4 (rollback)
|
||||
"3,1,4,2|1,*,3,*|*,1,*,3|4,*,*,1", # INVALID (rollback)
|
||||
"3,4,1,2|1,2,3,4|*,1,*,3|4,*,*,1", # VALID_INTERMEDIATE
|
||||
" 3,4,1,2|1,2,3,4|4,1,*,3|4,*,*,1", # INVALID (rollback)
|
||||
" 3,4,1,2|1,2,3,4|2,1,4,3|4,*,*,1", # VALID_INTERMEDIATE
|
||||
" 3,4,1,2|1,2,3,4|2,1,4,3|4,3,*,1", # VALID_INTERMEDIATE
|
||||
" 3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1", # VALID_FINAL
|
||||
]
|
||||
sudoku_solution = "3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_llm_sudoku() -> FakeLLM:
|
||||
"""This is a fake LLM that responds to the sudoku problem."""
|
||||
queries = {i: next_step.strip() for i, next_step in enumerate(solutions)}
|
||||
return FakeLLM(queries=queries, sequential_responses=True)
|
||||
|
||||
|
||||
class SudokuChecker(ToTChecker):
|
||||
def evaluate(
|
||||
self, problem_description: str, thoughts: Tuple[str, ...] = ()
|
||||
) -> ThoughtValidity:
|
||||
last_thought = thoughts[-1]
|
||||
clean_solution = last_thought.replace(" ", "").replace('"', "")
|
||||
regex_solution = clean_solution.replace("*", ".").replace("|", "\\|")
|
||||
if sudoku_solution in clean_solution:
|
||||
return ThoughtValidity.VALID_FINAL
|
||||
elif re.search(regex_solution, sudoku_solution):
|
||||
return ThoughtValidity.VALID_INTERMEDIATE
|
||||
else:
|
||||
return ThoughtValidity.INVALID
|
||||
|
||||
|
||||
def test_solve_sudoku(fake_llm_sudoku: FakeLLM) -> None:
|
||||
"""Test simple question that should not need python."""
|
||||
tot_chain = ToTChain(
|
||||
llm=fake_llm_sudoku,
|
||||
checker=SudokuChecker(),
|
||||
k=len(solutions),
|
||||
c=4,
|
||||
tot_strategy_class=SampleCoTStrategy,
|
||||
)
|
||||
output = tot_chain.run({"problem_description": ""})
|
||||
assert output == sudoku_solution
|
||||
|
||||
|
||||
def test_solve_sudoku_k_too_small(fake_llm_sudoku: FakeLLM) -> None:
|
||||
"""Test simple question that should not need python."""
|
||||
tot_chain = ToTChain(
|
||||
llm=fake_llm_sudoku,
|
||||
checker=SudokuChecker(),
|
||||
k=len(solutions) - 1,
|
||||
c=4,
|
||||
tot_strategy_class=SampleCoTStrategy,
|
||||
)
|
||||
output = tot_chain.run({"problem_description": ""})
|
||||
assert output != sudoku_solution
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_llm_checker() -> FakeLLM:
|
||||
"""This is a fake LLM that responds with a thought validity."""
|
||||
responses = [
|
||||
"VALID",
|
||||
"valid",
|
||||
"INVALID",
|
||||
"invalid",
|
||||
"INTERMEDIATE",
|
||||
"intermediate",
|
||||
"SOMETHING ELSE",
|
||||
]
|
||||
queries = dict(enumerate(responses))
|
||||
return FakeLLM(queries=queries, sequential_responses=True)
|
||||
|
||||
|
||||
class ControllerTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.controller = ToTController(c=3)
|
||||
|
||||
def test_empty(self) -> None:
|
||||
memory = ToTDFSMemory([])
|
||||
self.assertEqual(self.controller(memory), ())
|
||||
|
||||
def test_one_thoghts(self) -> None:
|
||||
thoughts = [Thought(text="a", validity=ThoughtValidity.VALID_FINAL)]
|
||||
memory = ToTDFSMemory(thoughts)
|
||||
self.assertEqual(self.controller(memory), ("a",))
|
||||
|
||||
def test_two_thoghts(self) -> None:
|
||||
memory = ToTDFSMemory(
|
||||
[
|
||||
Thought(text="a", validity=ThoughtValidity.VALID_INTERMEDIATE),
|
||||
Thought(text="b", validity=ThoughtValidity.VALID_INTERMEDIATE),
|
||||
]
|
||||
)
|
||||
self.assertEqual(self.controller(memory), ("a", "b"))
|
||||
|
||||
def test_two_thoughts_invalid(self) -> None:
|
||||
memory = ToTDFSMemory(
|
||||
[
|
||||
Thought(text="a", validity=ThoughtValidity.VALID_INTERMEDIATE),
|
||||
Thought(text="b", validity=ThoughtValidity.INVALID),
|
||||
]
|
||||
)
|
||||
self.assertEqual(self.controller(memory), ("a",))
|
||||
|
||||
def test_thoughts_rollback(self) -> None:
|
||||
a = Thought(text="a", validity=ThoughtValidity.VALID_INTERMEDIATE)
|
||||
b = Thought(text="b", validity=ThoughtValidity.VALID_INTERMEDIATE)
|
||||
c_1 = Thought(text="c_1", validity=ThoughtValidity.VALID_INTERMEDIATE)
|
||||
c_2 = Thought(text="c_2", validity=ThoughtValidity.VALID_INTERMEDIATE)
|
||||
c_3 = Thought(text="c_3", validity=ThoughtValidity.VALID_INTERMEDIATE)
|
||||
|
||||
a.children = {b}
|
||||
b.children = {c_1, c_2, c_3}
|
||||
|
||||
memory = ToTDFSMemory([a, b, c_3])
|
||||
self.assertEqual(self.controller(memory), ("a",))
|
||||
|
||||
def test_thoughts_rollback_invalid(self) -> None:
|
||||
a = Thought(text="a", validity=ThoughtValidity.VALID_INTERMEDIATE)
|
||||
b = Thought(text="b", validity=ThoughtValidity.VALID_INTERMEDIATE)
|
||||
c_1 = Thought(text="c_1", validity=ThoughtValidity.VALID_INTERMEDIATE)
|
||||
c_2 = Thought(text="c_2", validity=ThoughtValidity.VALID_INTERMEDIATE)
|
||||
c_3 = Thought(text="c_3", validity=ThoughtValidity.INVALID)
|
||||
|
||||
a.children = {b}
|
||||
b.children = {c_1, c_2, c_3}
|
||||
|
||||
memory = ToTDFSMemory([a, b, c_3])
|
||||
self.assertEqual(self.controller(memory), ("a",))
|
Loading…
Reference in New Issue