mirror of https://github.com/hwchase17/langchain
Tree of Thought introducing a new ToTChain. (#5167)
# [WIP] Tree of Thought introducing a new ToTChain. This PR adds a new chain called ToTChain that implements the ["Large Language Model Guided Tree-of-Though"](https://arxiv.org/pdf/2305.08291.pdf) paper. There's a notebook example `docs/modules/chains/examples/tot.ipynb` that shows how to use it. Implements #4975 ## Who can review? Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested: - @hwchase17 - @vowelparrot --------- Co-authored-by: Vadim Gubergrits <vgubergrits@outbox.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>pull/8339/head
parent
412e29d436
commit
e7e5cb9d08
@ -0,0 +1,239 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Tree of Thouht (ToT) example\n",
|
||||||
|
"\n",
|
||||||
|
"The Tree of Thought (ToT) is a chain that allows you to query a Large Language Model (LLM) using the Tree of Thought technique. This is based on the papaer [\"Large Language Model Guided Tree-of-Thought\"](https://arxiv.org/pdf/2305.08291.pdf)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.13) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n",
|
||||||
|
" warnings.warn(\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from langchain.llms import OpenAI\n",
|
||||||
|
"\n",
|
||||||
|
"llm = OpenAI(temperature=1, max_tokens=512, model=\"text-davinci-003\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1\n",
|
||||||
|
"\n",
|
||||||
|
"- This is a 4x4 Sudoku puzzle.\n",
|
||||||
|
"- The * represents a cell to be filled.\n",
|
||||||
|
"- The | character separates rows.\n",
|
||||||
|
"- At each step, replace one or more * with digits 1-4.\n",
|
||||||
|
"- There must be no duplicate digits in any row, column or 2x2 subgrid.\n",
|
||||||
|
"- Keep the known digits from previous valid thoughts in place.\n",
|
||||||
|
"- Each thought can be a partial or the final solution.\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"sudoku_puzzle = \"3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1\"\n",
|
||||||
|
"sudoku_solution = \"3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1\"\n",
|
||||||
|
"problem_description = f\"\"\"\n",
|
||||||
|
"{sudoku_puzzle}\n",
|
||||||
|
"\n",
|
||||||
|
"- This is a 4x4 Sudoku puzzle.\n",
|
||||||
|
"- The * represents a cell to be filled.\n",
|
||||||
|
"- The | character separates rows.\n",
|
||||||
|
"- At each step, replace one or more * with digits 1-4.\n",
|
||||||
|
"- There must be no duplicate digits in any row, column or 2x2 subgrid.\n",
|
||||||
|
"- Keep the known digits from previous valid thoughts in place.\n",
|
||||||
|
"- Each thought can be a partial or the final solution.\n",
|
||||||
|
"\"\"\".strip()\n",
|
||||||
|
"print(problem_description)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Rules Based Checker\n",
|
||||||
|
"\n",
|
||||||
|
"Each thought is evaluated by the thought checker and is given a validity type: valid, invalid or partial. A simple checker can be rule based. For example, in the case of a sudoku puzzle, the checker can check if the puzzle is valid, invalid or partial.\n",
|
||||||
|
"\n",
|
||||||
|
"In the following code we implement a simple rule based checker for a specific 4x4 sudoku puzzle.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from typing import Tuple\n",
|
||||||
|
"from langchain_experimental.tot.checker import ToTChecker\n",
|
||||||
|
"from langchain_experimental.tot.thought import ThoughtValidity\n",
|
||||||
|
"import re\n",
|
||||||
|
"\n",
|
||||||
|
"class MyChecker(ToTChecker):\n",
|
||||||
|
" def evaluate(self, problem_description: str, thoughts: Tuple[str, ...] = ()) -> ThoughtValidity:\n",
|
||||||
|
" last_thought = thoughts[-1]\n",
|
||||||
|
" clean_solution = last_thought.replace(\" \", \"\").replace('\"', \"\")\n",
|
||||||
|
" regex_solution = clean_solution.replace(\"*\", \".\").replace(\"|\", \"\\\\|\")\n",
|
||||||
|
" if sudoku_solution in clean_solution:\n",
|
||||||
|
" return ThoughtValidity.VALID_FINAL\n",
|
||||||
|
" elif re.search(regex_solution, sudoku_solution):\n",
|
||||||
|
" return ThoughtValidity.VALID_INTERMEDIATE\n",
|
||||||
|
" else:\n",
|
||||||
|
" return ThoughtValidity.INVALID"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Just testing the MyChecker class above:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"checker = MyChecker()\n",
|
||||||
|
"assert checker.evaluate(\"\", (\"3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1\",)) == ThoughtValidity.VALID_INTERMEDIATE\n",
|
||||||
|
"assert checker.evaluate(\"\", (\"3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1\",)) == ThoughtValidity.VALID_FINAL\n",
|
||||||
|
"assert checker.evaluate(\"\", (\"3,4,1,2|1,2,3,4|2,1,4,3|4,3,*,1\",)) == ThoughtValidity.VALID_INTERMEDIATE\n",
|
||||||
|
"assert checker.evaluate(\"\", (\"3,4,1,2|1,2,3,4|2,1,4,3|4,*,3,1\",)) == ThoughtValidity.INVALID"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Tree of Thought Chain\n",
|
||||||
|
"\n",
|
||||||
|
"Initialize and run the ToT chain, with maximum number of interactions `k` set to `30` and the maximum number child thoughts `c` set to `8`."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new ToTChain chain...\u001b[0m\n",
|
||||||
|
"Starting the ToT solve procedure.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/Users/harrisonchase/workplace/langchain/libs/langchain/langchain/chains/llm.py:275: UserWarning: The predict_and_parse method is deprecated, instead pass an output parser directly to LLMChain.\n",
|
||||||
|
" warnings.warn(\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\u001b[31;1m\u001b[1;3mThought: 3*,*,2|1*,3,*|*,1,*,3|4,*,*,1\n",
|
||||||
|
"\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3*,1,2|1*,3,*|*,1,*,3|4,*,*,1\n",
|
||||||
|
"\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3*,1,2|1*,3,4|*,1,*,3|4,*,*,1\n",
|
||||||
|
"\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3*,1,2|1*,3,4|*,1,2,3|4,*,*,1\n",
|
||||||
|
"\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3*,1,2|1*,3,4|2,1,*,3|4,*,*,1\n",
|
||||||
|
"\u001b[0m"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Type <enum 'ThoughtValidity'> not serializable\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\u001b[31;1m\u001b[1;3mThought: 3,*,*,2|1,*,3,*|*,1,*,3|4,1,*,*\n",
|
||||||
|
"\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3,*,*,2|*,3,2,*|*,1,*,3|4,1,*,*\n",
|
||||||
|
"\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3,2,*,2|1,*,3,*|*,1,*,3|4,1,*,*\n",
|
||||||
|
"\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3,2,*,2|1,*,3,*|1,1,*,3|4,1,*,*\n",
|
||||||
|
"\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3,2,*,2|1,1,3,*|1,1,*,3|4,1,*,*\n",
|
||||||
|
"\u001b[0m\u001b[33;1m\u001b[1;3mThought: 3,*,*,2|1,2,3,*|*,1,*,3|4,*,*,1\n",
|
||||||
|
"\u001b[0m\u001b[31;1m\u001b[1;3m Thought: 3,1,4,2|1,2,3,4|2,1,4,3|4,3,2,1\n",
|
||||||
|
"\u001b[0m\u001b[32;1m\u001b[1;3m Thought: 3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1\n",
|
||||||
|
"\u001b[0m\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from langchain_experimental.tot.base import ToTChain\n",
|
||||||
|
"\n",
|
||||||
|
"tot_chain = ToTChain(llm=llm, checker=MyChecker(), k=30, c=5, verbose=True, verbose_llm=False)\n",
|
||||||
|
"tot_chain.run(problem_description=problem_description)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
@ -0,0 +1,4 @@
|
|||||||
|
from langchain_experimental.tot.base import ToTChain
|
||||||
|
from langchain_experimental.tot.checker import ToTChecker
|
||||||
|
|
||||||
|
__all__ = ["ToTChain", "ToTChecker"]
|
@ -0,0 +1,155 @@
|
|||||||
|
"""
|
||||||
|
This a Tree of Thought (ToT) chain based on the paper "Large Language Model
|
||||||
|
Guided Tree-of-Thought"
|
||||||
|
|
||||||
|
https://arxiv.org/pdf/2305.08291.pdf
|
||||||
|
|
||||||
|
The Tree of Thought (ToT) chain uses a tree structure to explore the space of
|
||||||
|
possible solutions to a problem.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from textwrap import indent
|
||||||
|
from typing import Any, Dict, List, Optional, Type
|
||||||
|
|
||||||
|
from langchain.base_language import BaseLanguageModel
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForChainRun,
|
||||||
|
CallbackManagerForChainRun,
|
||||||
|
)
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from pydantic import Extra
|
||||||
|
|
||||||
|
from langchain_experimental.tot.checker import ToTChecker
|
||||||
|
from langchain_experimental.tot.controller import ToTController
|
||||||
|
from langchain_experimental.tot.memory import ToTDFSMemory
|
||||||
|
from langchain_experimental.tot.thought import Thought, ThoughtValidity
|
||||||
|
from langchain_experimental.tot.thought_generation import (
|
||||||
|
BaseThoughtGenerationStrategy,
|
||||||
|
ProposePromptStrategy,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ToTChain(Chain):
|
||||||
|
"""
|
||||||
|
A Chain implementing the Tree of Thought (ToT).
|
||||||
|
"""
|
||||||
|
|
||||||
|
llm: BaseLanguageModel
|
||||||
|
"""
|
||||||
|
Language model to use. It must be set to produce different variations for
|
||||||
|
the same prompt.
|
||||||
|
"""
|
||||||
|
checker: ToTChecker
|
||||||
|
"""ToT Checker to use."""
|
||||||
|
output_key: str = "response" #: :meta private:
|
||||||
|
k: int = 10
|
||||||
|
"""The maximmum number of conversation rounds"""
|
||||||
|
c: int = 3
|
||||||
|
"""The number of children to explore at each node"""
|
||||||
|
tot_memory: ToTDFSMemory = ToTDFSMemory()
|
||||||
|
tot_controller: ToTController = ToTController()
|
||||||
|
tot_strategy_class: Type[BaseThoughtGenerationStrategy] = ProposePromptStrategy
|
||||||
|
verbose_llm: bool = False
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_llm(cls, llm: BaseLanguageModel, **kwargs: Any) -> ToTChain:
|
||||||
|
"""
|
||||||
|
Create a ToTChain from a language model.
|
||||||
|
|
||||||
|
:param llm: The language model to use.
|
||||||
|
:param kwargs: Additional arguments to pass to the ToTChain constructor.
|
||||||
|
"""
|
||||||
|
return cls(llm=llm, **kwargs)
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.tot_controller.c = self.c
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Will be whatever keys the prompt expects.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return ["problem_description"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Will always return text key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.output_key]
|
||||||
|
|
||||||
|
def log_thought(
|
||||||
|
self,
|
||||||
|
thought: Thought,
|
||||||
|
level: int,
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
) -> None:
|
||||||
|
if run_manager:
|
||||||
|
colors = {
|
||||||
|
ThoughtValidity.VALID_FINAL: "green",
|
||||||
|
ThoughtValidity.VALID_INTERMEDIATE: "yellow",
|
||||||
|
ThoughtValidity.INVALID: "red",
|
||||||
|
}
|
||||||
|
text = indent(f"Thought: {thought.text}\n", prefix=" " * level)
|
||||||
|
run_manager.on_text(
|
||||||
|
text=text, color=colors[thought.validity], verbose=self.verbose
|
||||||
|
)
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_text(text="Starting the ToT solve procedure.\n")
|
||||||
|
|
||||||
|
problem_description = inputs["problem_description"]
|
||||||
|
checker_inputs = {"problem_description": problem_description}
|
||||||
|
thoughts_path: tuple[str, ...] = ()
|
||||||
|
thought_generator = self.tot_strategy_class(
|
||||||
|
llm=self.llm, c=self.c, verbose=self.verbose_llm
|
||||||
|
)
|
||||||
|
|
||||||
|
level = 0
|
||||||
|
for _ in range(self.k):
|
||||||
|
level = self.tot_memory.level
|
||||||
|
thought_text = thought_generator.next_thought(
|
||||||
|
problem_description, thoughts_path, callbacks=_run_manager.get_child()
|
||||||
|
)
|
||||||
|
checker_inputs["thoughts"] = thoughts_path + (thought_text,)
|
||||||
|
thought_validity = self.checker(
|
||||||
|
checker_inputs, callbacks=_run_manager.get_child()
|
||||||
|
)["validity"]
|
||||||
|
thought = Thought(text=thought_text, validity=thought_validity)
|
||||||
|
if thought.validity == ThoughtValidity.VALID_FINAL:
|
||||||
|
self.log_thought(thought, level, run_manager)
|
||||||
|
return {self.output_key: thought.text}
|
||||||
|
self.tot_memory.store(thought)
|
||||||
|
self.log_thought(thought, level, run_manager)
|
||||||
|
thoughts_path = self.tot_controller(self.tot_memory)
|
||||||
|
|
||||||
|
return {self.output_key: "No solution found"}
|
||||||
|
|
||||||
|
async def _acall(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
raise NotImplementedError("Async not implemented yet")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _chain_type(self) -> str:
|
||||||
|
return "tot"
|
@ -0,0 +1,52 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
|
||||||
|
from langchain_experimental.tot.thought import ThoughtValidity
|
||||||
|
|
||||||
|
|
||||||
|
class ToTChecker(Chain, ABC):
|
||||||
|
"""
|
||||||
|
Tree of Thought (ToT) checker.
|
||||||
|
|
||||||
|
This is an abstract ToT checker that must be implemented by the user. You
|
||||||
|
can implement a simple rule-based checker or a more sophisticated
|
||||||
|
neural network based classifier.
|
||||||
|
"""
|
||||||
|
|
||||||
|
output_key: str = "validity" #: :meta private:
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""The checker input keys.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return ["problem_description", "thoughts"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""The checker output keys.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.output_key]
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def evaluate(
|
||||||
|
self,
|
||||||
|
problem_description: str,
|
||||||
|
thoughts: Tuple[str, ...] = (),
|
||||||
|
) -> ThoughtValidity:
|
||||||
|
"""
|
||||||
|
Evaluate the response to the problem description and return the solution type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
) -> Dict[str, ThoughtValidity]:
|
||||||
|
return {self.output_key: self.evaluate(**inputs)}
|
@ -0,0 +1,54 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from langchain_experimental.tot.memory import ToTDFSMemory
|
||||||
|
from langchain_experimental.tot.thought import ThoughtValidity
|
||||||
|
|
||||||
|
|
||||||
|
class ToTController:
|
||||||
|
"""
|
||||||
|
Tree of Thought (ToT) controller.
|
||||||
|
|
||||||
|
This is a version of a ToT controller, dubbed in the paper as a "Simple
|
||||||
|
Controller".
|
||||||
|
|
||||||
|
It has one parameter `c` which is the number of children to explore for each
|
||||||
|
thought.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, c: int = 3):
|
||||||
|
"""
|
||||||
|
Initialize the controller.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
c: The number of children to explore at each node.
|
||||||
|
"""
|
||||||
|
self.c = c
|
||||||
|
|
||||||
|
def __call__(self, memory: ToTDFSMemory) -> Tuple[str, ...]:
|
||||||
|
next_thought = memory.top()
|
||||||
|
parent_thought = memory.top_parent()
|
||||||
|
validity = (
|
||||||
|
ThoughtValidity.VALID_INTERMEDIATE
|
||||||
|
if next_thought is None
|
||||||
|
else next_thought.validity
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1 if the current partial solution is invalid, backtrack to the parent
|
||||||
|
# thought.
|
||||||
|
if validity == ThoughtValidity.INVALID:
|
||||||
|
memory.pop()
|
||||||
|
next_thought = memory.top()
|
||||||
|
if next_thought and len(next_thought.children) >= self.c:
|
||||||
|
memory.pop()
|
||||||
|
|
||||||
|
# 2 if the current partial solution is valid but C children were
|
||||||
|
# explored and yet failed to find a final solution, backtrack to the
|
||||||
|
# parent thought.
|
||||||
|
elif (
|
||||||
|
validity == ThoughtValidity.VALID_INTERMEDIATE
|
||||||
|
and parent_thought
|
||||||
|
and len(parent_thought.children) >= self.c
|
||||||
|
):
|
||||||
|
memory.pop(2)
|
||||||
|
|
||||||
|
return tuple(thought.text for thought in memory.current_path())
|
@ -0,0 +1,46 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from langchain_experimental.tot.thought import Thought
|
||||||
|
|
||||||
|
|
||||||
|
class ToTDFSMemory:
|
||||||
|
"""
|
||||||
|
Memory for the Tree of Thought (ToT) chain. Implemented as a stack of
|
||||||
|
thoughts. This allows for a depth first search (DFS) of the ToT.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, stack: Optional[List[Thought]] = None):
|
||||||
|
self.stack: list[Thought] = stack or []
|
||||||
|
|
||||||
|
def top(self) -> Optional[Thought]:
|
||||||
|
"Get the top of the stack without popping it."
|
||||||
|
return self.stack[-1] if len(self.stack) > 0 else None
|
||||||
|
|
||||||
|
def pop(self, n: int = 1) -> Optional[Thought]:
|
||||||
|
"Pop the top n elements of the stack and return the last one."
|
||||||
|
if len(self.stack) < n:
|
||||||
|
return None
|
||||||
|
for _ in range(n):
|
||||||
|
node = self.stack.pop()
|
||||||
|
return node
|
||||||
|
|
||||||
|
def top_parent(self) -> Optional[Thought]:
|
||||||
|
"Get the parent of the top of the stack without popping it."
|
||||||
|
return self.stack[-2] if len(self.stack) > 1 else None
|
||||||
|
|
||||||
|
def store(self, node: Thought) -> None:
|
||||||
|
"Add a node on the top of the stack."
|
||||||
|
if len(self.stack) > 0:
|
||||||
|
self.stack[-1].children.add(node)
|
||||||
|
self.stack.append(node)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def level(self) -> int:
|
||||||
|
"Return the current level of the stack."
|
||||||
|
return len(self.stack)
|
||||||
|
|
||||||
|
def current_path(self) -> List[Thought]:
|
||||||
|
"Return the thoughts path."
|
||||||
|
return self.stack[:]
|
@ -0,0 +1,137 @@
|
|||||||
|
import json
|
||||||
|
from textwrap import dedent
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
from langchain.schema import BaseOutputParser
|
||||||
|
|
||||||
|
from langchain_experimental.tot.thought import ThoughtValidity
|
||||||
|
|
||||||
|
COT_PROMPT = PromptTemplate(
|
||||||
|
template_format="jinja2",
|
||||||
|
input_variables=["problem_description", "thoughts"],
|
||||||
|
template=dedent(
|
||||||
|
"""
|
||||||
|
You are an intelligent agent that is generating one thought at a time in
|
||||||
|
a tree of thoughts setting.
|
||||||
|
|
||||||
|
PROBLEM
|
||||||
|
|
||||||
|
{{problem_description}}
|
||||||
|
|
||||||
|
{% if thoughts %}
|
||||||
|
THOUGHTS
|
||||||
|
|
||||||
|
{% for thought in thoughts %}
|
||||||
|
{{ thought }}
|
||||||
|
{% endfor %}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
Let's think step by step.
|
||||||
|
"""
|
||||||
|
).strip(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class JSONListOutputParser(BaseOutputParser):
|
||||||
|
"""Class to parse the output of a PROPOSE_PROMPT response."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type(self) -> str:
|
||||||
|
return "json_list"
|
||||||
|
|
||||||
|
def parse(self, text: str) -> List[str]:
|
||||||
|
"""Parse the output of an LLM call."""
|
||||||
|
|
||||||
|
json_string = text.split("```json")[1].strip().strip("```").strip()
|
||||||
|
try:
|
||||||
|
return json.loads(json_string)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
PROPOSE_PROMPT = PromptTemplate(
|
||||||
|
template_format="jinja2",
|
||||||
|
input_variables=["problem_description", "thoughts", "n"],
|
||||||
|
output_parser=JSONListOutputParser(),
|
||||||
|
template=dedent(
|
||||||
|
"""
|
||||||
|
You are an intelligent agent that is generating thoughts in a tree of
|
||||||
|
thoughts setting.
|
||||||
|
|
||||||
|
The output should be a markdown code snippet formatted as a JSON list of
|
||||||
|
strings, including the leading and trailing "```json" and "```":
|
||||||
|
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
"<thought-1>",
|
||||||
|
"<thought-2>",
|
||||||
|
"<thought-3>"
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
PROBLEM
|
||||||
|
|
||||||
|
{{ problem_description }}
|
||||||
|
|
||||||
|
{% if thoughts %}
|
||||||
|
VALID THOUGHTS
|
||||||
|
|
||||||
|
{% for thought in thoughts %}
|
||||||
|
{{ thought }}
|
||||||
|
{% endfor %}
|
||||||
|
|
||||||
|
Possible next {{ n }} valid thoughts based on the last valid thought:
|
||||||
|
{% else %}
|
||||||
|
|
||||||
|
Possible next {{ n }} valid thoughts based on the PROBLEM:
|
||||||
|
{%- endif -%}
|
||||||
|
"""
|
||||||
|
).strip(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CheckerOutputParser(BaseOutputParser):
|
||||||
|
def parse(self, text: str) -> ThoughtValidity:
|
||||||
|
"""Parse the output of the language model."""
|
||||||
|
text = text.upper()
|
||||||
|
if "INVALID" in text:
|
||||||
|
return ThoughtValidity.INVALID
|
||||||
|
elif "INTERMEDIATE" in text:
|
||||||
|
return ThoughtValidity.VALID_INTERMEDIATE
|
||||||
|
elif "VALID" in text:
|
||||||
|
return ThoughtValidity.VALID_FINAL
|
||||||
|
else:
|
||||||
|
return ThoughtValidity.INVALID
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type(self) -> str:
|
||||||
|
return "tot_llm_checker_output"
|
||||||
|
|
||||||
|
|
||||||
|
CHECKER_PROMPT = PromptTemplate(
|
||||||
|
input_variables=["problem_description", "thoughts"],
|
||||||
|
template=dedent(
|
||||||
|
"""
|
||||||
|
You are an intelligent agent, validating thoughts of another intelligent agent.
|
||||||
|
|
||||||
|
PROBLEM
|
||||||
|
|
||||||
|
{problem_description}
|
||||||
|
|
||||||
|
THOUGHTS
|
||||||
|
|
||||||
|
{thoughts}
|
||||||
|
|
||||||
|
Evaluate the thoughts and respond with one word.
|
||||||
|
|
||||||
|
- Respond VALID if the last thought is a valid final solution to the
|
||||||
|
poblem.
|
||||||
|
- Respond INVALID if the last thought is invalid.
|
||||||
|
- Respond INTERMEDIATE if the last thought is valid but not the final
|
||||||
|
solution to the problem.
|
||||||
|
|
||||||
|
This chain of thoughts is"""
|
||||||
|
).strip(),
|
||||||
|
output_parser=CheckerOutputParser(),
|
||||||
|
)
|
@ -0,0 +1,21 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Set
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class ThoughtValidity(Enum):
|
||||||
|
VALID_INTERMEDIATE = 0
|
||||||
|
VALID_FINAL = 1
|
||||||
|
INVALID = 2
|
||||||
|
|
||||||
|
|
||||||
|
class Thought(BaseModel):
|
||||||
|
text: str
|
||||||
|
validity: ThoughtValidity
|
||||||
|
children: Set[Thought] = Field(default_factory=set)
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return id(self)
|
@ -0,0 +1,94 @@
|
|||||||
|
"""
|
||||||
|
We provide two strategies for generating thoughts in the Tree of Thoughts (ToT)
|
||||||
|
framework to avoid repetition:
|
||||||
|
|
||||||
|
These strategies ensure that the language model generates diverse and
|
||||||
|
non-repeating thoughts, which are crucial for problem-solving tasks that require
|
||||||
|
exploration.
|
||||||
|
"""
|
||||||
|
from abc import abstractmethod
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from langchain_experimental.tot.prompts import COT_PROMPT, PROPOSE_PROMPT
|
||||||
|
|
||||||
|
|
||||||
|
class BaseThoughtGenerationStrategy(LLMChain):
|
||||||
|
"""
|
||||||
|
Base class for a thought generation strategy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
c: int = 3
|
||||||
|
"""The number of children thoughts to propose at each step."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def next_thought(
|
||||||
|
self,
|
||||||
|
problem_description: str,
|
||||||
|
thoughts_path: Tuple[str, ...] = (),
|
||||||
|
**kwargs: Any
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Generate the next thought given the problem description and the thoughts
|
||||||
|
generated so far.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class SampleCoTStrategy(BaseThoughtGenerationStrategy):
|
||||||
|
"""
|
||||||
|
Sample thoughts from a Chain-of-Thought (CoT) prompt.
|
||||||
|
|
||||||
|
This strategy works better when the thought space is rich, such as when each
|
||||||
|
thought is a paragraph. Independent and identically distributed samples
|
||||||
|
lead to diversity, which helps to avoid repetition.
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt: BasePromptTemplate = COT_PROMPT
|
||||||
|
|
||||||
|
def next_thought(
|
||||||
|
self,
|
||||||
|
problem_description: str,
|
||||||
|
thoughts_path: Tuple[str, ...] = (),
|
||||||
|
**kwargs: Any
|
||||||
|
) -> str:
|
||||||
|
response_text = self.predict_and_parse(
|
||||||
|
problem_description=problem_description, thoughts=thoughts_path, **kwargs
|
||||||
|
)
|
||||||
|
return response_text if isinstance(response_text, str) else ""
|
||||||
|
|
||||||
|
|
||||||
|
class ProposePromptStrategy(BaseThoughtGenerationStrategy):
|
||||||
|
"""
|
||||||
|
Propose thoughts sequentially using a "propose prompt".
|
||||||
|
|
||||||
|
This strategy works better when the thought space is more constrained, such
|
||||||
|
as when each thought is just a word or a line. Proposing different thoughts
|
||||||
|
in the same prompt completion helps to avoid duplication.
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt: BasePromptTemplate = PROPOSE_PROMPT
|
||||||
|
tot_memory: Dict[Tuple[str, ...], List[str]] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
def next_thought(
|
||||||
|
self,
|
||||||
|
problem_description: str,
|
||||||
|
thoughts_path: Tuple[str, ...] = (),
|
||||||
|
**kwargs: Any
|
||||||
|
) -> str:
|
||||||
|
if thoughts_path not in self.tot_memory or not self.tot_memory[thoughts_path]:
|
||||||
|
new_thoughts = self.predict_and_parse(
|
||||||
|
problem_description=problem_description,
|
||||||
|
thoughts=thoughts_path,
|
||||||
|
n=self.c,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
if not new_thoughts:
|
||||||
|
return ""
|
||||||
|
if isinstance(new_thoughts, list):
|
||||||
|
self.tot_memory[thoughts_path] = new_thoughts[::-1]
|
||||||
|
else:
|
||||||
|
return ""
|
||||||
|
return self.tot_memory[thoughts_path].pop()
|
@ -0,0 +1,61 @@
|
|||||||
|
"""Fake LLM wrapper for testing purposes."""
|
||||||
|
from typing import Any, Dict, List, Mapping, Optional, cast
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
from pydantic import validator
|
||||||
|
|
||||||
|
|
||||||
|
class FakeLLM(LLM):
|
||||||
|
"""Fake LLM wrapper for testing purposes."""
|
||||||
|
|
||||||
|
queries: Optional[Mapping] = None
|
||||||
|
sequential_responses: Optional[bool] = False
|
||||||
|
response_index: int = 0
|
||||||
|
|
||||||
|
@validator("queries", always=True)
|
||||||
|
def check_queries_required(
|
||||||
|
cls, queries: Optional[Mapping], values: Mapping[str, Any]
|
||||||
|
) -> Optional[Mapping]:
|
||||||
|
if values.get("sequential_response") and not queries:
|
||||||
|
raise ValueError(
|
||||||
|
"queries is required when sequential_response is set to True"
|
||||||
|
)
|
||||||
|
return queries
|
||||||
|
|
||||||
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
"""Return number of tokens."""
|
||||||
|
return len(text.split())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of llm."""
|
||||||
|
return "fake"
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
if self.sequential_responses:
|
||||||
|
return self._get_next_response_in_sequence
|
||||||
|
|
||||||
|
if self.queries is not None:
|
||||||
|
return self.queries[prompt]
|
||||||
|
if stop is None:
|
||||||
|
return "foo"
|
||||||
|
else:
|
||||||
|
return "bar"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Dict[str, Any]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _get_next_response_in_sequence(self) -> str:
|
||||||
|
queries = cast(Mapping, self.queries)
|
||||||
|
response = queries[list(queries.keys())[self.response_index]]
|
||||||
|
self.response_index = self.response_index + 1
|
||||||
|
return response
|
@ -0,0 +1,151 @@
|
|||||||
|
import re
|
||||||
|
import unittest
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain_experimental.tot.base import ToTChain
|
||||||
|
from langchain_experimental.tot.checker import ToTChecker
|
||||||
|
from langchain_experimental.tot.controller import ToTController
|
||||||
|
from langchain_experimental.tot.memory import ToTDFSMemory
|
||||||
|
from langchain_experimental.tot.thought import Thought, ThoughtValidity
|
||||||
|
from langchain_experimental.tot.thought_generation import SampleCoTStrategy
|
||||||
|
from tests.unit_tests.fake_llm import FakeLLM
|
||||||
|
|
||||||
|
sudoku_puzzle = "3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1"
|
||||||
|
solutions = [
|
||||||
|
"3,*,4,2|1,*,3,*|*,1,*,3|4,*,*,1", # VALID_INTERMEDIATE
|
||||||
|
" 3,4,1,2|1,6,3,*|*,1,*,3|4,*,*,1", # INVALID c=1
|
||||||
|
" 3,4,1,2|1,7,3,*|*,1,*,3|4,*,*,1", # INVALID c=2
|
||||||
|
" 3,4,1,2|1,8,3,*|*,1,*,3|4,*,*,1", # INVALID c=3
|
||||||
|
" 3,4,1,2|1,2,3,*|*,1,*,3|4,*,*,1", # VALID_INTERMEDIATE c=4 (rollback)
|
||||||
|
"3,1,4,2|1,*,3,*|*,1,*,3|4,*,*,1", # INVALID (rollback)
|
||||||
|
"3,4,1,2|1,2,3,4|*,1,*,3|4,*,*,1", # VALID_INTERMEDIATE
|
||||||
|
" 3,4,1,2|1,2,3,4|4,1,*,3|4,*,*,1", # INVALID (rollback)
|
||||||
|
" 3,4,1,2|1,2,3,4|2,1,4,3|4,*,*,1", # VALID_INTERMEDIATE
|
||||||
|
" 3,4,1,2|1,2,3,4|2,1,4,3|4,3,*,1", # VALID_INTERMEDIATE
|
||||||
|
" 3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1", # VALID_FINAL
|
||||||
|
]
|
||||||
|
sudoku_solution = "3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fake_llm_sudoku() -> FakeLLM:
|
||||||
|
"""This is a fake LLM that responds to the sudoku problem."""
|
||||||
|
queries = {i: next_step.strip() for i, next_step in enumerate(solutions)}
|
||||||
|
return FakeLLM(queries=queries, sequential_responses=True)
|
||||||
|
|
||||||
|
|
||||||
|
class SudokuChecker(ToTChecker):
|
||||||
|
def evaluate(
|
||||||
|
self, problem_description: str, thoughts: Tuple[str, ...] = ()
|
||||||
|
) -> ThoughtValidity:
|
||||||
|
last_thought = thoughts[-1]
|
||||||
|
clean_solution = last_thought.replace(" ", "").replace('"', "")
|
||||||
|
regex_solution = clean_solution.replace("*", ".").replace("|", "\\|")
|
||||||
|
if sudoku_solution in clean_solution:
|
||||||
|
return ThoughtValidity.VALID_FINAL
|
||||||
|
elif re.search(regex_solution, sudoku_solution):
|
||||||
|
return ThoughtValidity.VALID_INTERMEDIATE
|
||||||
|
else:
|
||||||
|
return ThoughtValidity.INVALID
|
||||||
|
|
||||||
|
|
||||||
|
def test_solve_sudoku(fake_llm_sudoku: FakeLLM) -> None:
|
||||||
|
"""Test simple question that should not need python."""
|
||||||
|
tot_chain = ToTChain(
|
||||||
|
llm=fake_llm_sudoku,
|
||||||
|
checker=SudokuChecker(),
|
||||||
|
k=len(solutions),
|
||||||
|
c=4,
|
||||||
|
tot_strategy_class=SampleCoTStrategy,
|
||||||
|
)
|
||||||
|
output = tot_chain.run({"problem_description": ""})
|
||||||
|
assert output == sudoku_solution
|
||||||
|
|
||||||
|
|
||||||
|
def test_solve_sudoku_k_too_small(fake_llm_sudoku: FakeLLM) -> None:
|
||||||
|
"""Test simple question that should not need python."""
|
||||||
|
tot_chain = ToTChain(
|
||||||
|
llm=fake_llm_sudoku,
|
||||||
|
checker=SudokuChecker(),
|
||||||
|
k=len(solutions) - 1,
|
||||||
|
c=4,
|
||||||
|
tot_strategy_class=SampleCoTStrategy,
|
||||||
|
)
|
||||||
|
output = tot_chain.run({"problem_description": ""})
|
||||||
|
assert output != sudoku_solution
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fake_llm_checker() -> FakeLLM:
|
||||||
|
"""This is a fake LLM that responds with a thought validity."""
|
||||||
|
responses = [
|
||||||
|
"VALID",
|
||||||
|
"valid",
|
||||||
|
"INVALID",
|
||||||
|
"invalid",
|
||||||
|
"INTERMEDIATE",
|
||||||
|
"intermediate",
|
||||||
|
"SOMETHING ELSE",
|
||||||
|
]
|
||||||
|
queries = dict(enumerate(responses))
|
||||||
|
return FakeLLM(queries=queries, sequential_responses=True)
|
||||||
|
|
||||||
|
|
||||||
|
class ControllerTestCase(unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.controller = ToTController(c=3)
|
||||||
|
|
||||||
|
def test_empty(self) -> None:
|
||||||
|
memory = ToTDFSMemory([])
|
||||||
|
self.assertEqual(self.controller(memory), ())
|
||||||
|
|
||||||
|
def test_one_thoghts(self) -> None:
|
||||||
|
thoughts = [Thought(text="a", validity=ThoughtValidity.VALID_FINAL)]
|
||||||
|
memory = ToTDFSMemory(thoughts)
|
||||||
|
self.assertEqual(self.controller(memory), ("a",))
|
||||||
|
|
||||||
|
def test_two_thoghts(self) -> None:
|
||||||
|
memory = ToTDFSMemory(
|
||||||
|
[
|
||||||
|
Thought(text="a", validity=ThoughtValidity.VALID_INTERMEDIATE),
|
||||||
|
Thought(text="b", validity=ThoughtValidity.VALID_INTERMEDIATE),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.assertEqual(self.controller(memory), ("a", "b"))
|
||||||
|
|
||||||
|
def test_two_thoughts_invalid(self) -> None:
|
||||||
|
memory = ToTDFSMemory(
|
||||||
|
[
|
||||||
|
Thought(text="a", validity=ThoughtValidity.VALID_INTERMEDIATE),
|
||||||
|
Thought(text="b", validity=ThoughtValidity.INVALID),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.assertEqual(self.controller(memory), ("a",))
|
||||||
|
|
||||||
|
def test_thoughts_rollback(self) -> None:
|
||||||
|
a = Thought(text="a", validity=ThoughtValidity.VALID_INTERMEDIATE)
|
||||||
|
b = Thought(text="b", validity=ThoughtValidity.VALID_INTERMEDIATE)
|
||||||
|
c_1 = Thought(text="c_1", validity=ThoughtValidity.VALID_INTERMEDIATE)
|
||||||
|
c_2 = Thought(text="c_2", validity=ThoughtValidity.VALID_INTERMEDIATE)
|
||||||
|
c_3 = Thought(text="c_3", validity=ThoughtValidity.VALID_INTERMEDIATE)
|
||||||
|
|
||||||
|
a.children = {b}
|
||||||
|
b.children = {c_1, c_2, c_3}
|
||||||
|
|
||||||
|
memory = ToTDFSMemory([a, b, c_3])
|
||||||
|
self.assertEqual(self.controller(memory), ("a",))
|
||||||
|
|
||||||
|
def test_thoughts_rollback_invalid(self) -> None:
|
||||||
|
a = Thought(text="a", validity=ThoughtValidity.VALID_INTERMEDIATE)
|
||||||
|
b = Thought(text="b", validity=ThoughtValidity.VALID_INTERMEDIATE)
|
||||||
|
c_1 = Thought(text="c_1", validity=ThoughtValidity.VALID_INTERMEDIATE)
|
||||||
|
c_2 = Thought(text="c_2", validity=ThoughtValidity.VALID_INTERMEDIATE)
|
||||||
|
c_3 = Thought(text="c_3", validity=ThoughtValidity.INVALID)
|
||||||
|
|
||||||
|
a.children = {b}
|
||||||
|
b.children = {c_1, c_2, c_3}
|
||||||
|
|
||||||
|
memory = ToTDFSMemory([a, b, c_3])
|
||||||
|
self.assertEqual(self.controller(memory), ("a",))
|
Loading…
Reference in New Issue